from PIL import Image from matplotlib import pyplot as plt import textwrap def to_gif(images, path): images[0].save(path, save_all=True, append_images=images[1:], loop=0, duration=len(images) * 20) def figure_to_image(figure): figure.set_dpi(300) figure.canvas.draw() return Image.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb()) def image_grid(images, outpath=None, column_titles=None, row_titles=None): n_rows = len(images) n_cols = len(images[0]) fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols, n_rows), squeeze=False) for row, _images in enumerate(images): for column, image in enumerate(_images): ax = axs[row][column] ax.imshow(image) if column_titles and row == 0: ax.set_title(textwrap.fill( column_titles[column], width=12), fontsize='x-small') if row_titles and column == 0: ax.set_ylabel(row_titles[row], rotation=0, fontsize='x-small', labelpad=1.6 * len(row_titles[row])) ax.set_xticks([]) ax.set_yticks([]) plt.subplots_adjust(wspace=0, hspace=0) if outpath is not None: plt.savefig(outpath, bbox_inches='tight', dpi=300) plt.close() else: plt.tight_layout(pad=0) image = figure_to_image(plt.gcf()) plt.close() return image def get_module(module, module_name): if isinstance(module_name, str): module_name = module_name.split('.') if len(module_name) == 0: return module else: module = getattr(module, module_name[0]) return get_module(module, module_name[1:]) def set_module(module, module_name, new_module): if isinstance(module_name, str): module_name = module_name.split('.') if len(module_name) == 1: return setattr(module, module_name[0], new_module) else: module = getattr(module, module_name[0]) return set_module(module, module_name[1:], new_module) def freeze(module): for parameter in module.parameters(): parameter.requires_grad = False def unfreeze(module): for parameter in module.parameters(): parameter.requires_grad = True def get_concat_h(im1, im2): dst = Image.new('RGB', (im1.width + im2.width, im1.height)) dst.paste(im1, (0, 0)) dst.paste(im2, (im1.width, 0)) return dst def get_concat_v(im1, im2): dst = Image.new('RGB', (im1.width, im1.height + im2.height)) dst.paste(im1, (0, 0)) dst.paste(im2, (0, im1.height)) return dst