|
"""2D visualization primitives based on Matplotlib. |
|
|
|
1) Plot images with `plot_images`. |
|
2) Call functions to plot heatmaps, vector fields, and horizon lines. |
|
3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. |
|
""" |
|
|
|
import matplotlib.patheffects as path_effects |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
|
|
from geocalib.perspective_fields import get_perspective_field |
|
from geocalib.utils import rad2deg |
|
|
|
|
|
|
|
|
|
def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True): |
|
"""Plot a list of images. |
|
|
|
Args: |
|
imgs (List[np.ndarray]): List of images to plot. |
|
titles (List[str], optional): Titles. Defaults to None. |
|
cmaps (str, optional): Colormaps. Defaults to "gray". |
|
dpi (int, optional): Dots per inch. Defaults to 200. |
|
pad (float, optional): Padding. Defaults to 0.5. |
|
adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True. |
|
|
|
Returns: |
|
plt.Figure: Figure of the images. |
|
""" |
|
n = len(imgs) |
|
if not isinstance(cmaps, (list, tuple)): |
|
cmaps = [cmaps] * n |
|
|
|
ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n |
|
figsize = [sum(ratios) * 4.5, 4.5] |
|
fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}) |
|
if n == 1: |
|
axs = [axs] |
|
for i, (img, ax) in enumerate(zip(imgs, axs)): |
|
ax.imshow(img, cmap=plt.get_cmap(cmaps[i])) |
|
ax.set_axis_off() |
|
if titles: |
|
ax.set_title(titles[i]) |
|
fig.tight_layout(pad=pad) |
|
|
|
return fig |
|
|
|
|
|
def plot_image_grid( |
|
imgs, |
|
titles=None, |
|
cmaps="gray", |
|
dpi=100, |
|
pad=0.5, |
|
fig=None, |
|
adaptive=True, |
|
figs=3.0, |
|
return_fig=False, |
|
set_lim=False, |
|
) -> plt.Figure: |
|
"""Plot a grid of images. |
|
|
|
Args: |
|
imgs (List[np.ndarray]): List of images to plot. |
|
titles (List[str], optional): Titles. Defaults to None. |
|
cmaps (str, optional): Colormaps. Defaults to "gray". |
|
dpi (int, optional): Dots per inch. Defaults to 100. |
|
pad (float, optional): Padding. Defaults to 0.5. |
|
fig (_type_, optional): Figure to plot on. Defaults to None. |
|
adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True. |
|
figs (float, optional): Figure size. Defaults to 3.0. |
|
return_fig (bool, optional): Whether to return the figure. Defaults to False. |
|
set_lim (bool, optional): Whether to set the limits. Defaults to False. |
|
|
|
Returns: |
|
plt.Figure: Figure and axes or just axes. |
|
""" |
|
nr, n = len(imgs), len(imgs[0]) |
|
if not isinstance(cmaps, (list, tuple)): |
|
cmaps = [cmaps] * n |
|
|
|
if adaptive: |
|
ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] |
|
else: |
|
ratios = [4 / 3] * n |
|
|
|
figsize = [sum(ratios) * figs, nr * figs] |
|
if fig is None: |
|
fig, axs = plt.subplots( |
|
nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} |
|
) |
|
else: |
|
axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios}) |
|
fig.figure.set_size_inches(figsize) |
|
|
|
if nr == 1 and n == 1: |
|
axs = [[axs]] |
|
elif n == 1: |
|
axs = axs[:, None] |
|
elif nr == 1: |
|
axs = [axs] |
|
|
|
for j in range(nr): |
|
for i in range(n): |
|
ax = axs[j][i] |
|
ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i])) |
|
ax.set_axis_off() |
|
if set_lim: |
|
ax.set_xlim([0, imgs[j][i].shape[1]]) |
|
ax.set_ylim([imgs[j][i].shape[0], 0]) |
|
if titles: |
|
ax.set_title(titles[j][i]) |
|
if isinstance(fig, plt.Figure): |
|
fig.tight_layout(pad=pad) |
|
return (fig, axs) if return_fig else axs |
|
|
|
|
|
def add_text( |
|
idx, |
|
text, |
|
pos=(0.01, 0.99), |
|
fs=15, |
|
color="w", |
|
lcolor="k", |
|
lwidth=4, |
|
ha="left", |
|
va="top", |
|
axes=None, |
|
**kwargs, |
|
): |
|
"""Add text to a plot. |
|
|
|
Args: |
|
idx (int): Index of the axes. |
|
text (str): Text to add. |
|
pos (tuple, optional): Text position. Defaults to (0.01, 0.99). |
|
fs (int, optional): Font size. Defaults to 15. |
|
color (str, optional): Text color. Defaults to "w". |
|
lcolor (str, optional): Line color. Defaults to "k". |
|
lwidth (int, optional): Line width. Defaults to 4. |
|
ha (str, optional): Horizontal alignment. Defaults to "left". |
|
va (str, optional): Vertical alignment. Defaults to "top". |
|
axes (List[plt.Axes], optional): Axes to put text on. Defaults to None. |
|
|
|
Returns: |
|
plt.Text: Text object. |
|
""" |
|
if axes is None: |
|
axes = plt.gcf().axes |
|
|
|
ax = axes[idx] |
|
|
|
t = ax.text( |
|
*pos, |
|
text, |
|
fontsize=fs, |
|
ha=ha, |
|
va=va, |
|
color=color, |
|
transform=ax.transAxes, |
|
zorder=5, |
|
**kwargs, |
|
) |
|
if lcolor is not None: |
|
t.set_path_effects( |
|
[ |
|
path_effects.Stroke(linewidth=lwidth, foreground=lcolor), |
|
path_effects.Normal(), |
|
] |
|
) |
|
return t |
|
|
|
|
|
def plot_heatmaps( |
|
heatmaps, |
|
vmin=-1e-6, |
|
vmax=None, |
|
cmap="Spectral", |
|
a=0.5, |
|
axes=None, |
|
contours_every=None, |
|
contour_style="solid", |
|
colorbar=False, |
|
): |
|
"""Plot heatmaps with optional contours. |
|
|
|
To plot latitude field, set vmin=-90, vmax=90 and contours_every=15. |
|
|
|
Args: |
|
heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps. |
|
vmin (float, optional): Min Value. Defaults to -1e-6. |
|
vmax (float, optional): Max Value. Defaults to None. |
|
cmap (str, optional): Colormap. Defaults to "Spectral". |
|
a (float, optional): Alpha value. Defaults to 0.5. |
|
axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. |
|
contours_every (int, optional): If not none, will draw contours. Defaults to None. |
|
contour_style (str, optional): Style of the contours. Defaults to "solid". |
|
colorbar (bool, optional): Whether to show colorbar. Defaults to False. |
|
|
|
Returns: |
|
List[plt.Artist]: List of artists. |
|
""" |
|
if axes is None: |
|
axes = plt.gcf().axes |
|
artists = [] |
|
|
|
for i in range(len(axes)): |
|
a_ = a if isinstance(a, float) else a[i] |
|
|
|
if isinstance(heatmaps[i], torch.Tensor): |
|
heatmaps[i] = heatmaps[i].cpu().numpy() |
|
|
|
alpha = a_ |
|
|
|
art = axes[i].imshow( |
|
heatmaps[i], |
|
alpha=alpha, |
|
vmin=vmin, |
|
vmax=vmax, |
|
cmap=cmap, |
|
) |
|
if colorbar: |
|
cmax = vmax or np.percentile(heatmaps[i], 99) |
|
art.set_clim(vmin, cmax) |
|
cbar = plt.colorbar(art, ax=axes[i]) |
|
artists.append(cbar) |
|
|
|
artists.append(art) |
|
|
|
if contours_every is not None: |
|
|
|
contour_data = np.arange(vmin, vmax + contours_every, contours_every) |
|
|
|
|
|
contour_colors = [ |
|
plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level)) |
|
for level in contour_data |
|
] |
|
contours = axes[i].contour( |
|
heatmaps[i], |
|
levels=contour_data, |
|
linewidths=2, |
|
colors=contour_colors, |
|
linestyles=contour_style, |
|
) |
|
|
|
contours.set_clim(vmin, vmax) |
|
|
|
fmt = { |
|
level: f"{label}°" |
|
for level, label in zip(contour_data, contour_data.astype(int).astype(str)) |
|
} |
|
t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white") |
|
|
|
for label in t: |
|
label.set_path_effects( |
|
[ |
|
path_effects.Stroke(linewidth=1, foreground="k"), |
|
path_effects.Normal(), |
|
] |
|
) |
|
artists.append(contours) |
|
|
|
return artists |
|
|
|
|
|
def plot_horizon_lines( |
|
cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None |
|
): |
|
"""Plot horizon lines on the perspective field. |
|
|
|
Args: |
|
cameras (List[Camera]): List of cameras. |
|
gravities (List[Gravity]): Gravities. |
|
line_colors (str, optional): Line Colors. Defaults to "orange". |
|
lw (int, optional): Line width. Defaults to 2. |
|
styles (str, optional): Line styles. Defaults to "solid". |
|
alpha (float, optional): Alphas. Defaults to 1.0. |
|
ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None. |
|
""" |
|
if not isinstance(line_colors, list): |
|
line_colors = [line_colors] * len(cameras) |
|
|
|
if not isinstance(styles, list): |
|
styles = [styles] * len(cameras) |
|
|
|
fig = plt.gcf() |
|
ax = fig.gca() if ax is None else ax |
|
|
|
if isinstance(ax, plt.Axes): |
|
ax = [ax] * len(cameras) |
|
|
|
assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}" |
|
|
|
for i in range(len(cameras)): |
|
_, lat = get_perspective_field(cameras[i], gravities[i]) |
|
|
|
lat = lat[0, 0].cpu().numpy() |
|
contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i]) |
|
for contour_line in contours.collections: |
|
contour_line.set_linestyle(styles[i]) |
|
|
|
|
|
def plot_vector_fields( |
|
vector_fields, |
|
cmap="lime", |
|
subsample=15, |
|
scale=None, |
|
lw=None, |
|
alphas=0.8, |
|
axes=None, |
|
): |
|
"""Plot vector fields. |
|
|
|
Args: |
|
vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W). |
|
cmap (str, optional): Color of the vectors. Defaults to "lime". |
|
subsample (int, optional): Subsample the vector field. Defaults to 15. |
|
scale (float, optional): Scale of the vectors. Defaults to None. |
|
lw (float, optional): Line width of the vectors. Defaults to None. |
|
alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8. |
|
axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None. |
|
|
|
Returns: |
|
List[plt.Artist]: List of artists. |
|
""" |
|
if axes is None: |
|
axes = plt.gcf().axes |
|
|
|
vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields] |
|
|
|
artists = [] |
|
|
|
H, W = vector_fields[0].shape[-2:] |
|
if scale is None: |
|
scale = subsample / min(H, W) |
|
|
|
if lw is None: |
|
lw = 0.1 / subsample |
|
|
|
if alphas is None: |
|
alphas = np.ones_like(vector_fields[0][0]) |
|
alphas = np.stack([alphas] * len(vector_fields), 0) |
|
elif isinstance(alphas, float): |
|
alphas = np.ones_like(vector_fields[0][0]) * alphas |
|
alphas = np.stack([alphas] * len(vector_fields), 0) |
|
else: |
|
alphas = np.array(alphas) |
|
|
|
subsample = min(W, H) // subsample |
|
offset_x = ((W % subsample) + subsample) // 2 |
|
|
|
samples_x = np.arange(offset_x, W, subsample) |
|
samples_y = np.arange(int(subsample * 0.9), H, subsample) |
|
|
|
x_grid, y_grid = np.meshgrid(samples_x, samples_y) |
|
|
|
for i in range(len(axes)): |
|
|
|
vector_field = vector_fields[i] |
|
|
|
a = alphas[i][samples_y][:, samples_x] |
|
x, y = vector_field[:, samples_y][:, :, samples_x] |
|
|
|
c = cmap |
|
if not isinstance(cmap, str): |
|
c = cmap[i][samples_y][:, samples_x].reshape(-1, 3) |
|
|
|
s = scale * min(H, W) |
|
arrows = axes[i].quiver( |
|
x_grid, |
|
y_grid, |
|
x, |
|
y, |
|
scale=s, |
|
scale_units="width" if H > W else "height", |
|
units="width" if H > W else "height", |
|
alpha=a, |
|
color=c, |
|
angles="xy", |
|
antialiased=True, |
|
width=lw, |
|
headaxislength=3.5, |
|
zorder=5, |
|
) |
|
|
|
artists.append(arrows) |
|
|
|
return artists |
|
|
|
|
|
def plot_latitudes( |
|
latitude, |
|
is_radians=True, |
|
vmin=-90, |
|
vmax=90, |
|
cmap="seismic", |
|
contours_every=15, |
|
alpha=0.4, |
|
axes=None, |
|
**kwargs, |
|
): |
|
"""Plot latitudes. |
|
|
|
Args: |
|
latitude (List[torch.Tensor]): List of latitudes. |
|
is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True. |
|
vmin (int, optional): Min value to clip to. Defaults to -90. |
|
vmax (int, optional): Max value to clip to. Defaults to 90. |
|
cmap (str, optional): Colormap. Defaults to "seismic". |
|
contours_every (int, optional): Contours every. Defaults to 15. |
|
alpha (float, optional): Alpha value. Defaults to 0.4. |
|
axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. |
|
|
|
Returns: |
|
List[plt.Artist]: List of artists. |
|
""" |
|
if axes is None: |
|
axes = plt.gcf().axes |
|
|
|
assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}" |
|
lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude |
|
return plot_heatmaps( |
|
lat, |
|
vmin=vmin, |
|
vmax=vmax, |
|
cmap=cmap, |
|
a=alpha, |
|
axes=axes, |
|
contours_every=contours_every, |
|
**kwargs, |
|
) |
|
|
|
|
|
def plot_perspective_fields(cameras, gravities, axes=None, **kwargs): |
|
"""Plot perspective fields. |
|
|
|
Args: |
|
cameras (List[Camera]): List of cameras. |
|
gravities (List[Gravity]): List of gravities. |
|
axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. |
|
|
|
Returns: |
|
List[plt.Artist]: List of artists. |
|
""" |
|
if axes is None: |
|
axes = plt.gcf().axes |
|
|
|
assert len(axes) == len(cameras), f"{len(axes)}, {len(cameras)}" |
|
|
|
artists = [] |
|
for i in range(len(axes)): |
|
up, lat = get_perspective_field(cameras[i], gravities[i]) |
|
artists += plot_vector_fields([up[0]], axes=[axes[i]], **kwargs) |
|
artists += plot_latitudes([lat[0, 0]], axes=[axes[i]], **kwargs) |
|
|
|
return artists |
|
|
|
|
|
def plot_confidences( |
|
confidence, |
|
as_log=True, |
|
vmin=-4, |
|
vmax=0, |
|
cmap="turbo", |
|
alpha=0.4, |
|
axes=None, |
|
**kwargs, |
|
): |
|
"""Plot confidences. |
|
|
|
Args: |
|
confidence (List[torch.Tensor]): Confidence maps. |
|
as_log (bool, optional): Whether to plot in log scale. Defaults to True. |
|
vmin (int, optional): Min value to clip to. Defaults to -4. |
|
vmax (int, optional): Max value to clip to. Defaults to 0. |
|
cmap (str, optional): Colormap. Defaults to "turbo". |
|
alpha (float, optional): Alpha value. Defaults to 0.4. |
|
axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. |
|
|
|
Returns: |
|
List[plt.Artist]: List of artists. |
|
""" |
|
if axes is None: |
|
axes = plt.gcf().axes |
|
|
|
assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}" |
|
|
|
if as_log: |
|
confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence] |
|
|
|
|
|
confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence] |
|
return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs) |
|
|
|
|
|
def save_plot(path, **kw): |
|
"""Save the current figure without any white margin.""" |
|
plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) |
|
|