from __future__ import division, annotations
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import FancyArrowPatch
from matplotlib.path import get_path_collection_extents
import scipy.spatial.distance
import logging
from timeit import default_timer as timer
import io
try:
from matplotlib.backend_bases import _get_renderer as matplot_get_renderer
except ImportError:
matplot_get_renderer = None
from .arrops import overlap_intervals
from ._version import __version__
def get_renderer(fig):
# If the backend support get_renderer() or renderer, use that.
if hasattr(fig.canvas, "get_renderer"):
return fig.canvas.get_renderer()
if hasattr(fig.canvas, "renderer"):
return fig.canvas.renderer
# Otherwise, if we have the matplotlib function available, use that.
if matplot_get_renderer:
return matplot_get_renderer(fig)
# No dice, try and guess.
# Write the figure to a temp location, and then retrieve whichever
# render was used (doesn't work in all matplotlib versions).
fig.canvas.print_figure(io.BytesIO())
try:
return fig._cachedRenderer
except AttributeError:
# No luck.
# We're out of options.
raise ValueError("Unable to determine renderer") from None
def get_bboxes_pathcollection(sc, ax):
"""Function to return a list of bounding boxes in display coordinates
for a scatter plot
Thank you to ImportanceOfBeingErnest
https://stackoverflow.com/a/55007838/1304161"""
# ax.figure.canvas.draw() # need to draw before the transforms are set.
transform = sc.get_transform()
transOffset = sc.get_offset_transform()
offsets = sc._offsets
paths = sc.get_paths()
transforms = sc.get_transforms()
if not transform.is_affine:
paths = [transform.transform_path_non_affine(p) for p in paths]
transform = transform.get_affine()
if not transOffset.is_affine:
offsets = transOffset.transform_non_affine(offsets)
transOffset = transOffset.get_affine()
if isinstance(offsets, np.ma.MaskedArray):
offsets = offsets.filled(np.nan)
bboxes = []
if len(paths) and len(offsets):
if len(paths) < len(offsets):
# for usual scatters you have one path, but several offsets
paths = [paths[0]] * len(offsets)
if len(transforms) < len(offsets):
# often you may have a single scatter size, but several offsets
transforms = [transforms[0]] * len(offsets)
for p, o, t in zip(paths, offsets, transforms):
result = get_path_collection_extents(
transform.frozen(), [p], [t], [o], transOffset.frozen()
)
bboxes.append(result.transformed(ax.transData.inverted()))
return bboxes
def get_bboxes(objs, r=None, expand=(1, 1), ax=None):
"""
Parameters
----------
objs : list, or PathCollection
List of objects to get bboxes from. Also works with mpl PathCollection.
r : renderer
Renderer. The default is None, then automatically deduced from ax.
expand : (float, float), optional
How much to expand bboxes in (x, y), in fractions. The default is (1, 1).
ax : Axes, optional
The default is None, then uses current axes.
Returns
-------
list
List of bboxes.
"""
ax = ax or plt.gca()
r = r or get_renderer(ax.get_figure())
try:
objs = [i.get_bbox() for i in objs]
except (AttributeError, TypeError):
pass
try:
return [i.get_window_extent(r).expanded(*expand) for i in objs]
except (AttributeError, TypeError):
try:
if all([isinstance(obj, matplotlib.transforms.BboxBase) for obj in objs]):
return objs
else:
raise ValueError("Something is wrong")
except TypeError:
return get_bboxes_pathcollection(objs, ax)
def get_2d_coordinates(objs, ax):
bboxes = get_bboxes(objs, get_renderer(ax.get_figure()), (1.0, 1.0), ax)
xs = [
(ax.convert_xunits(bbox.xmin), ax.convert_yunits(bbox.xmax)) for bbox in bboxes
]
ys = [
(ax.convert_xunits(bbox.ymin), ax.convert_yunits(bbox.ymax)) for bbox in bboxes
]
coords = np.hstack([np.array(xs), np.array(ys)])
return coords
def get_shifts_texts(coords):
N = coords.shape[0]
xoverlaps = overlap_intervals(
coords[:, 0], coords[:, 1], coords[:, 0], coords[:, 1]
)
xoverlaps = xoverlaps[xoverlaps[:, 0] != xoverlaps[:, 1]]
yoverlaps = overlap_intervals(
coords[:, 2], coords[:, 3], coords[:, 2], coords[:, 3]
)
yoverlaps = yoverlaps[yoverlaps[:, 0] != yoverlaps[:, 1]]
overlaps = yoverlaps[(yoverlaps[:, None] == xoverlaps).all(-1).any(-1)]
if len(overlaps) == 0:
return np.zeros((coords.shape[0])), np.zeros((coords.shape[0]))
diff = coords[overlaps[:, 1]] - coords[overlaps[:, 0]]
xshifts = np.where(np.abs(diff[:, 0]) < np.abs(diff[:, 1]), diff[:, 0], diff[:, 1])
yshifts = np.where(np.abs(diff[:, 2]) < np.abs(diff[:, 3]), diff[:, 2], diff[:, 3])
xshifts = np.bincount(overlaps[:, 0], xshifts, minlength=N)
yshifts = np.bincount(overlaps[:, 0], yshifts, minlength=N)
return xshifts, yshifts
def get_shifts_extra(coords, extra_coords):
N = coords.shape[0]
xoverlaps = overlap_intervals(
coords[:, 0], coords[:, 1], extra_coords[:, 0], extra_coords[:, 1]
)
yoverlaps = overlap_intervals(
coords[:, 2], coords[:, 3], extra_coords[:, 2], extra_coords[:, 3]
)
overlaps = yoverlaps[(yoverlaps[:, None] == xoverlaps).all(-1).any(-1)]
if len(overlaps) == 0:
return np.zeros((coords.shape[0])), np.zeros((coords.shape[0]))
diff_x = coords[overlaps[:, 0], :2] - extra_coords[overlaps[:, 1], -3::-1]
diff_y = coords[overlaps[:, 0], 2:] - extra_coords[overlaps[:, 1], -1:-3:-1]
xshifts = np.where(
np.abs(diff_x[:, 0]) < np.abs(diff_x[:, 1]), diff_x[:, 0], diff_x[:, 1]
)
yshifts = np.where(
np.abs(diff_y[:, 0]) < np.abs(diff_y[:, 1]), diff_y[:, 0], diff_y[:, 1]
)
xshifts = np.bincount(overlaps[:, 0], xshifts, minlength=N)
yshifts = np.bincount(overlaps[:, 0], yshifts, minlength=N)
return xshifts, yshifts
def expand_coords(coords, x_frac, y_frac):
mid_x = np.mean(coords[:, :2], axis=1)
mid_y = np.mean(coords[:, 2:], axis=1)
x = np.subtract(coords[:, :2], mid_x[:, np.newaxis]) * x_frac + mid_x[:, np.newaxis]
y = np.subtract(coords[:, 2:], mid_y[:, np.newaxis]) * y_frac + mid_y[:, np.newaxis]
return np.hstack([x, y])
def expand_axes_to_fit(coords, ax, transform):
max_x, max_y = np.max(transform.inverted().transform(coords[:, [1, 3]]), axis=0)
min_x, min_y = np.min(transform.inverted().transform(coords[:, [0, 2]]), axis=0)
if min_x < ax.get_xlim()[0]:
ax.set_xlim(xmin=min_x)
if min_y < ax.get_ylim()[0]:
ax.set_ylim(ymin=min_y)
if max_x > ax.get_xlim()[1]:
ax.set_xlim(xmax=max_x)
if max_y > ax.get_ylim()[1]:
ax.set_ylim(ymax=max_y)
def apply_shifts(coords, shifts_x, shifts_y):
coords[:, :2] = np.subtract(coords[:, :2], shifts_x[:, np.newaxis])
coords[:, 2:] = np.subtract(coords[:, 2:], shifts_y[:, np.newaxis])
return coords
def force_into_bbox(coords, bbox):
xmin, xmax, ymin, ymax = bbox
dx, dy = np.zeros((coords.shape[0])), np.zeros((coords.shape[0]))
if np.any((coords[:, 0] < xmin) & (coords[:, 1] > xmax)):
logging.warn("Some labels are too long, can't fit inside the X axis")
if np.any((coords[:, 2] < ymin) & (coords[:, 3] > ymax)):
logging.warn("Some labels are too tall, can't fit inside the Y axis")
dx[coords[:, 0] < xmin] = (xmin - coords[:, 0])[coords[:, 0] < xmin]
dx[coords[:, 1] > xmax] = (xmax - coords[:, 1])[coords[:, 1] > xmax]
dy[coords[:, 2] < ymin] = (ymin - coords[:, 2])[coords[:, 2] < ymin]
dy[coords[:, 3] > ymax] = (ymax - coords[:, 3])[coords[:, 3] > ymax]
return apply_shifts(coords, -dx, -dy)
def pull_back(coords, targets):
dx = np.max(np.subtract(targets[:, 0][:, np.newaxis], coords[:, :2]), axis=1)
dy = np.max(np.subtract(targets[:, 1][:, np.newaxis], coords[:, 2:]), axis=1)
return dx, dy
def explode(coords, static_coords, r=None):
N = coords.shape[0]
x = coords[:, [0, 1]].mean(axis=1)
y = coords[:, [2, 3]].mean(axis=1)
points = np.vstack([x, y]).T
if static_coords.shape[0] > 0:
static_x = np.mean(static_coords[:, [0, 1]], axis=1)
static_y = np.mean(static_coords[:, [2, 3]], axis=1)
static_centers = np.vstack([static_x, static_y]).T
points = np.vstack([points, static_centers])
tree = scipy.spatial.KDTree(points)
pairs = tree.query_pairs(r, output_type="ndarray")
pairs = pairs[pairs[:, 0] < N]
pairs = pairs[pairs[:, 0] != pairs[:, 1]]
diff = points[pairs[:, 0]] - points[pairs[:, 1]]
xshifts = np.bincount(pairs[:, 0], diff[:, 0], minlength=N)
yshifts = np.bincount(pairs[:, 0], diff[:, 1], minlength=N)
return xshifts, yshifts
def iterate(
coords,
target_coords,
static_coords=None,
force_text: tuple[float, float] = (0.1, 0.2),
force_static: tuple[float, float] = (0.05, 0.1),
force_pull: tuple[float, float] = (0.05, 0.1),
pull_threshold: float = 10,
expand: tuple[float, float] = (1.05, 1.1),
bbox_to_contain=False,
only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
):
text_shifts_x, text_shifts_y = get_shifts_texts(
expand_coords(coords, expand[0], expand[1])
)
if static_coords.shape[0] > 0:
static_shifts_x, static_shifts_y = get_shifts_extra(
expand_coords(coords, expand[0], expand[1]), static_coords
)
else:
static_shifts_x, static_shifts_y = np.zeros((1)), np.zeros((1))
error_x = np.abs(text_shifts_x) + np.abs(static_shifts_x)
error_y = np.abs(text_shifts_y) + np.abs(static_shifts_y)
error = np.sum(np.append(error_x, error_y))
pull_x, pull_y = pull_back(coords, target_coords)
pull_x[np.abs(pull_x) < pull_threshold] = 0
pull_y[np.abs(pull_y) < pull_threshold] = 0
text_shifts_x *= force_text[0]
text_shifts_y *= force_text[1]
static_shifts_x *= force_static[0]
static_shifts_y *= force_static[1]
# Pull is in the opposite direction, so need to negate it
pull_x *= -force_pull[0]
pull_y *= -force_pull[1]
pull_x[error_x != 0] = 0
pull_y[error_y != 0] = 0
if only_move:
if "x" not in only_move.get("text", "xy"):
text_shifts_x = np.zeros_like(text_shifts_x)
elif "x+" in only_move.get("text", "xy"):
text_shifts_x[text_shifts_x > 0] = 0
elif "x-" in only_move.get("text", "xy"):
text_shifts_x[text_shifts_x < 0] = 0
if "y" not in only_move.get("text", "xy"):
text_shifts_y = np.zeros_like(text_shifts_y)
elif "y+" in only_move.get("text", "xy"):
text_shifts_y[text_shifts_y > 0] = 0
elif "y-" in only_move.get("text", "xy"):
text_shifts_y[text_shifts_y < 0] = 0
if "x" not in only_move.get("static", "xy"):
static_shifts_x = np.zeros_like(static_shifts_x)
elif "x+" in only_move.get("static", "xy"):
static_shifts_x[static_shifts_x > 0] = 0
elif "x-" in only_move.get("static", "xy"):
static_shifts_x[static_shifts_x < 0] = 0
if "y" not in only_move.get("static", "xy"):
static_shifts_y = np.zeros_like(static_shifts_y)
elif "y+" in only_move.get("static", "xy"):
static_shifts_y[static_shifts_y > 0] = 0
elif "y-" in only_move.get("static", "xy"):
static_shifts_y[static_shifts_y < 0] = 0
if "x" not in only_move.get("pull", "xy"):
pull_x = np.zeros_like(pull_x)
elif "x+" in only_move.get("pull", "xy"):
pull_x[pull_x > 0] = 0
elif "x-" in only_move.get("pull", "xy"):
pull_x[pull_x < 0] = 0
if "y" not in only_move.get("pull", "xy"):
pull_y = np.zeros_like(pull_y)
elif "y+" in only_move.get("pull", "xy"):
pull_y[pull_y > 0] = 0
elif "y-" in only_move.get("pull", "xy"):
pull_y[pull_y < 0] = 0
shifts_x = text_shifts_x + static_shifts_x + pull_x
shifts_y = text_shifts_y + static_shifts_y + pull_y
# shifts_x = np.ceil(shifts_x)
# shifts_y = np.ceil(shifts_y)
shifts_x = np.sign(shifts_x) * np.ceil(np.abs(shifts_x))
shifts_y = np.sign(shifts_y) * np.ceil(np.abs(shifts_y))
coords = apply_shifts(coords, shifts_x, shifts_y)
if bbox_to_contain:
coords = force_into_bbox(coords, bbox_to_contain)
return coords, error
def force_draw(ax):
try:
ax.figure.draw_without_rendering()
except AttributeError:
logging.warn(
"""Looks like you are using an old matplotlib version.
In some cases adjust_text might fail, if possible update
matplotlib to version >=3.5.0"""
)
ax.figure.canvas.draw()
[docs]
def adjust_text(
texts,
x=None,
y=None,
objects=None,
target_x=None,
target_y=None,
avoid_self=True,
force_text: tuple[float, float] = (0.1, 0.2),
force_static: tuple[float, float] = (0.1, 0.2),
force_pull: tuple[float, float] = (0.01, 0.01),
force_explode: tuple[float, float] = (0.05, 0.05),
pull_threshold: float = 10,
expand: tuple[float, float] = (1.05, 1.2),
explode_radius: str | float = "auto",
ensure_inside_axes: bool = True,
expand_axes: bool = False,
only_move: dict = {"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
ax: matplotlib.axes.Axes | None = None,
min_arrow_len: float = 5,
time_lim: float | None = None,
iter_lim: int | None = None,
*args,
**kwargs,
):
"""Iteratively adjusts the locations of texts.
Call adjust_text the very last, after all plotting (especially
anything that can change the axes limits) has been done. This is
because to move texts the function needs to use the dimensions of
the axes, and without knowing the final size of the plots the
results will be completely nonsensical, or suboptimal.
First "explodes" all texts to move them apart.
Then in each iteration pushes all texts away from each other, and any specified
points or objects. At the same time slowly tries to pull the texts closer to their
origianal locations that they label (this reduces chances that a text ends up super
far away). In the end adds arrows connecting the texts to the respective points.
Parameters
----------
texts : list
A list of :obj:`matplotlib.text.Text` objects to adjust.
Other Parameters
----------------
x : array_like
x-coordinates of points to repel from; with avoid_self=True, the original
text coordinates will be added to this array
y : array_like
y-coordinates of points to repel from; with avoid_self=True, the original
text coordinates will be added to this array
objects : list or PathCollection
a list of additional matplotlib objects to avoid; they must have a
`.get_window_extent()` method; alternatively, a PathCollection or a
list of Bbox objects.
target_x : array_like
if provided, x-coordinates of points to connect adjusted texts to; if not
provided, uses the original text coordinates.
Provide together with target_y.
Should be the same length as texts and in the same order, or None.
target_y : array_like
if provided, y-coordinates of points to connect adjusted texts to; if not
provided, uses the original text coordinates.
Provide together with target_x.
Should be the same length as texts and in the same order, or None.
avoid_self : bool, default True
whether to repel texts from its original positions.
force_text : tuple, default (0.1, 0.2)
the repel force from texts is multiplied by this value
force_static : tuple, default (0.1, 0.2)
the repel force from points and objects is multiplied by this value
force_pull : tuple, default (0.1, 0.1)
same as other forces, but for pulling texts back to original positions
force_explode : float, default (0.1, 0.2)
same as other forces, but for the forced move of texts away from nearby texts
and static positions before iterative adjustment
pull_threshold : float, default 10
how close to the original position the text should be pulled (if it's closer
along one of the axes, don't pull along it) - in display coordinates
expand : array_like, default (1.05, 1.2)
a tuple/list/... with 2 multipliers (x, y) by which to expand the
bounding box of texts when repelling them from each other.
explode_radius : float or "auto", default "auto"
how far to check for nearest objects to move the texts away in the beginning
in display units, so on the order of 100 is the typical value.
"auto" uses the mean size of the texts
ensure_inside_axes : bool, default True
Whether to force texts to stay inside the axes
expand_axes : bool, default False
Whether to expand the axes to fit all texts before adjusting there positions
only_move : dict, default {"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"}
a dict to restrict movement of texts to only certain axes for certain
types of overlaps.
Valid keys are 'text', 'static', 'explode' and 'pull'.
Can contain 'x', 'y', 'x+', 'x-', 'y+', 'y-', or combinations of one 'x?' and
one 'y?'. 'x' and 'y' mean that the text can move in that direction, 'x+' and
'x-' mean that the text can move in the positive or negative direction along
the x axis, and similarly for 'y+' and 'y-'.
ax : matplotlib axes, default is current axes (plt.gca())
ax object with the plot
min_arrow_len : float, default 5
If the text is closer than this to the target point, don't add an arrow
(in display units)
time_lim : float, default None
How much time to allow for the adjustments, in seconds.
If both `time_lim` and iter_lim are set, faster will be used.
If both are None, `time_lim` is set to 1 seconds.
iter_lim : int, default None
How many iterations to allow for the adjustments.
If both `time_lim` and iter_lim are set, faster will be used.
If both are None, `time_lim` is set to 1 seconds.
args and kwargs :
any arguments will be fed into obj:`FancyArrowPatch` after all the
optimization is done just for plotting the connecting arrows if
required.
"""
if not texts:
return
if ax is None:
ax = plt.gca()
force_draw(ax)
try:
transform = texts[0].get_transform()
except IndexError:
logging.warn(
"Something wrong with the texts. Did you pass a list of matplotlib text objects?"
)
return
if time_lim is None and iter_lim is None:
time_lim = 1
elif time_lim is not None and iter_lim is not None:
logging.warn("Both time_lim and iter_lim are set, faster will be used")
start_time = timer()
coords = get_2d_coordinates(texts, ax)
if expand_axes:
expand_axes_to_fit(coords, ax, transform)
force_draw(ax)
transform = texts[0].get_transform()
coords = get_2d_coordinates(texts, ax)
original_coords = [text.get_unitless_position() for text in texts]
original_coords_disp_coord = transform.transform(original_coords)
target_xy = (
list(zip(target_x, target_y))
if (target_x is not None and target_y is not None)
else original_coords
)
target_xy_disp_coord = transform.transform(target_xy)
if isinstance(only_move, str):
only_move = {
"text": only_move,
"static": only_move,
"explode": only_move,
"pull": only_move,
}
# coords += np.random.rand(*coords.shape)*1e-6
if x is not None and y is not None:
point_coords = transform.transform(np.vstack([x, y]).T)
else:
point_coords = np.empty((0, 2))
if avoid_self:
point_coords = np.vstack([point_coords, original_coords_disp_coord])
if objects is None:
obj_coords = np.empty((0, 4))
else:
obj_coords = get_2d_coordinates(objects, ax)
obj_coords[:, [0, 2]] = transform.transform(obj_coords[:, [0, 2]])
obj_coords[:, [1, 3]] = transform.transform(obj_coords[:, [1, 3]])
static_coords = np.vstack([point_coords[:, [0, 0, 1, 1]], obj_coords])
if explode_radius == "auto":
explode_radius = max(
(coords[:, 1] - coords[:, 0]).mean(), (coords[:, 3] - coords[:, 2]).mean()
)
logging.debug(f"Auto-explode radius: {explode_radius}")
if explode_radius > 0 and np.all(np.asarray(force_explode) > 0):
explode_x, explode_y = explode(coords, static_coords, explode_radius)
if "x" not in only_move.get("explode", "xy"):
explode_x = np.zeros_like(explode_x)
if "y" not in only_move.get("explode", "xy"):
explode_y = np.zeros_like(explode_y)
coords = apply_shifts(
coords, -explode_x * force_explode[0], -explode_y * force_explode[1]
)
error = np.Inf
# i_0 = 100
# i = i_0
# expand_start = 1.05, 1.5
# expand_end = 1.05, 1.5
# expand_steps = 100
# expands = list(zip(np.linspace(expand_start[0], expand_end[0], expand_steps),
# np.linspace(expand_start[1], expand_end[1], expand_steps)))
if ensure_inside_axes:
ax_bbox = ax.patch.get_extents()
ax_bbox = ax_bbox.xmin, ax_bbox.xmax, ax_bbox.ymin, ax_bbox.ymax
else:
ax_bbox = False
i = 0
while error > 0:
# expand = expands[min(i, expand_steps-1)]
coords, error = iterate(
coords,
target_xy_disp_coord,
static_coords,
force_text=force_text,
force_static=force_static,
force_pull=force_pull,
pull_threshold=pull_threshold,
expand=expand,
bbox_to_contain=ax_bbox,
only_move=only_move,
)
i += 1
if time_lim is not None and timer() - start_time > time_lim:
break
if iter_lim is not None and i == iter_lim:
break
logging.debug(f"Adjustment took {i} iterations")
logging.debug(f"Time: {timer() - start_time}")
logging.debug(f"Error: {error}")
xdists = np.min(
np.abs(np.subtract(coords[:, :2], target_xy_disp_coord[:, 0][:, np.newaxis])),
axis=1,
)
ydists = np.min(
np.abs(np.subtract(coords[:, 2:], target_xy_disp_coord[:, 1][:, np.newaxis])),
axis=1,
)
display_dists = np.max(np.vstack([xdists, ydists]), axis=0)
connections = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
target_xy_disp_coord,
]
) # For the future to move into the loop and resolve crossing connections
transformed_connections = np.empty_like(connections)
transformed_connections[:, :2] = transform.inverted().transform(connections[:, :2])
transformed_connections[:, 2:] = transform.inverted().transform(connections[:, 2:])
if "arrowprops" in kwargs:
ap = kwargs.pop("arrowprops")
else:
ap = {}
for i, text in enumerate(texts):
text_mid = transformed_connections[i, :2]
target = transformed_connections[i, 2:]
text.set_verticalalignment("center")
text.set_horizontalalignment("center")
text.set_position(text_mid)
if ap and display_dists[i] >= min_arrow_len:
arrowpatch = FancyArrowPatch(
posA=text_mid, posB=target, patchA=text, *args, **kwargs, **ap
)
ax.add_patch(arrowpatch)