|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
from typing import List |
|
import io |
|
|
|
def fig2img(fig: plt.Figure): |
|
"""Convert a Matplotlib figure to a PIL Image and return it""" |
|
plt.close() |
|
buf = io.BytesIO() |
|
fig.savefig(buf) |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
return img |
|
|
|
def show_tile_images( |
|
images: List[np.ndarray | Image.Image], |
|
width_parts: int, |
|
figsize = None, |
|
space = 0.0, |
|
pad = False, |
|
figcolor = 'white', |
|
titles: List[str] = None, |
|
title_color: str = None, |
|
title_background_color: str = None, |
|
title_font_size: int = None): |
|
''' |
|
Show images in a tile format |
|
Args: |
|
images: List of images to show |
|
width_parts: Number of images to show in a row |
|
figsize: Size of the figure |
|
space: Space between images |
|
pad: Whether to pad the images or not |
|
figcolor: Background color of the figure |
|
titles: Titles of the images |
|
title_color: Color of the title |
|
title_background_color: Background color of the title |
|
title_font_size: Font size of the title |
|
Returns: |
|
Image: Image of the figure |
|
''' |
|
height = int(np.ceil(len(images) / width_parts)) |
|
fig, axs = plt.subplots(height, width_parts, figsize=figsize if figsize != None else (8 * 2, 12 * height)) |
|
fig.patch.set_facecolor(figcolor) |
|
axes = axs.flatten() if isinstance(axs, np.ndarray) else [axs] |
|
titles = (titles or []) + np.full(len(images) - len(titles or []), None).tolist() |
|
for img, ax, title in zip(images, axes, titles): |
|
if title: |
|
params = {k: v for k, v in { 'color': title_color, 'backgroundcolor': title_background_color, 'fontsize': title_font_size }.items() if v is not None} |
|
ax.set_title(title, **params) |
|
ax.imshow(img.convert("RGB") if not isinstance(img, np.ndarray) else img) |
|
ax.axis('off') |
|
if pad: |
|
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=space, hspace=space) |
|
fig.tight_layout(h_pad=space, w_pad = space, pad = space) |
|
else: |
|
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=space, hspace=space) |
|
fig.tight_layout(h_pad=space, w_pad = space, pad = 0) |
|
plt.margins() |
|
plt.close() |
|
return fig2img(fig) |