Spaces:
Runtime error
Runtime error
"""Helpers for visualization""" | |
import numpy as np | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import cv2 | |
from PIL import Image | |
# define predominanat colors | |
COLORS = { | |
"pink": (242, 116, 223), | |
"cyan": (46, 242, 203), | |
"red": (255, 0, 0), | |
"green": (0, 255, 0), | |
"blue": (0, 0, 255), | |
"yellow": (255, 255, 0), | |
} | |
def show_single_image(image: np.ndarray, figsize: tuple = (8, 8), title: str = None, titlesize=18, cmap: str = None, ticks=False, save=False, save_path=None): | |
"""Show a single image.""" | |
fig, ax = plt.subplots(1, 1, figsize=figsize) | |
if isinstance(image, Image.Image): | |
image = np.asarray(image) | |
ax.set_title(title, fontsize=titlesize) | |
ax.imshow(image, cmap=cmap) | |
if not ticks: | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
if save: | |
plt.savefig(save_path, bbox_inches='tight') | |
plt.show() | |
def show_grid_of_images( | |
images: np.ndarray, n_cols: int = 4, figsize: tuple = (8, 8), | |
cmap=None, subtitles=None, title=None, subtitlesize=18, | |
save=False, save_path=None, titlesize=20, | |
): | |
"""Show a grid of images.""" | |
n_cols = min(n_cols, len(images)) | |
copy_of_images = images.copy() | |
for i, image in enumerate(copy_of_images): | |
if isinstance(image, Image.Image): | |
image = np.asarray(image) | |
images[i] = image | |
if subtitles is None: | |
subtitles = [None] * len(images) | |
n_rows = int(np.ceil(len(images) / n_cols)) | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) | |
for i, ax in enumerate(axes.flat): | |
if i < len(images): | |
if len(images[i].shape) == 2 and cmap is None: | |
cmap="gray" | |
ax.imshow(images[i], cmap=cmap) | |
ax.set_title(subtitles[i], fontsize=subtitlesize) | |
ax.axis('off') | |
fig.set_tight_layout(True) | |
plt.suptitle(title, y=0.8, fontsize=titlesize) | |
if save: | |
plt.savefig(save_path, bbox_inches='tight') | |
plt.close() | |
else: | |
plt.show() | |
def show_keypoint_matches( | |
img1, kp1, img2, kp2, matches, | |
K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0)), | |
choose_matches="random", | |
): | |
"""Displays matches found in the pair of images""" | |
if choose_matches == "random": | |
selected_matches = np.random.choice(matches, K) | |
elif choose_matches == "all": | |
K = len(matches) | |
selected_matches = matches | |
elif choose_matches == "topk": | |
selected_matches = matches[:K] | |
else: | |
raise ValueError(f"Unknown value for choose_matches: {choose_matches}") | |
# color each match with a different color | |
cmap = matplotlib.cm.get_cmap('gist_rainbow', K) | |
colors = [[int(x*255) for x in cmap(i)[:3]] for i in np.arange(0,K)] | |
drawMatches_args.update({"matchColor": -1, "singlePointColor": (100, 100, 100)}) | |
img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args) | |
show_single_image( | |
img3, | |
figsize=figsize, | |
title=f"[{choose_matches.upper()}] Selected K = {K} matches between the pair of images.", | |
) | |
return img3 | |
def draw_kps_on_image(image: np.ndarray, kps: np.ndarray, color=COLORS["red"], radius=3, thickness=-1, return_as="numpy"): | |
""" | |
Draw keypoints on image. | |
Args: | |
image: Image to draw keypoints on. | |
kps: Keypoints to draw. Note these should be in (x, y) format. | |
""" | |
if isinstance(image, Image.Image): | |
image = np.asarray(image) | |
for kp in kps: | |
image = cv2.circle( | |
image, (int(kp[0]), int(kp[1])), radius=radius, color=color, thickness=thickness) | |
if return_as == "PIL": | |
return Image.fromarray(image) | |
return image | |
def get_concat_h(im1, im2): | |
"""Concatenate two images horizontally""" | |
dst = Image.new('RGB', (im1.width + im2.width, im1.height)) | |
dst.paste(im1, (0, 0)) | |
dst.paste(im2, (im1.width, 0)) | |
return dst | |
def get_concat_v(im1, im2): | |
"""Concatenate two images vertically""" | |
dst = Image.new('RGB', (im1.width, im1.height + im2.height)) | |
dst.paste(im1, (0, 0)) | |
dst.paste(im2, (0, im1.height)) | |
return dst | |
def show_images_with_keypoints(images: list, kps: list, radius=15, color=(0, 220, 220), figsize=(10, 8), return_images=False, save=False, save_path="sample.png"): | |
assert len(images) == len(kps) | |
# generate | |
images_with_kps = [] | |
for i in range(len(images)): | |
img_with_kps = draw_kps_on_image(images[i], kps[i], radius=radius, color=color, return_as="PIL") | |
images_with_kps.append(img_with_kps) | |
# show | |
show_grid_of_images(images_with_kps, n_cols=len(images), figsize=figsize, save=save, save_path=save_path) | |
if return_images: | |
return images_with_kps | |
def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs): | |
try: | |
plt.rcParams.update({ | |
"text.usetex": usetex, | |
"font.family": "serif", | |
"font.serif": ["Computer Modern Roman"], | |
"font.size": fontsize, | |
**kwargs, | |
}) | |
if show_sample: | |
plt.figure() | |
plt.title("Sample $y = x^2$") | |
plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o") | |
plt.grid() | |
plt.show() | |
except: | |
print("Failed to setup LaTeX fonts. Proceeding without.") | |
pass | |
def get_colors(num_colors, palette="jet"): | |
cmap = plt.get_cmap(palette) | |
colors = [cmap(i) for i in np.linspace(0, 1, num_colors)] | |
return colors | |