JadenFK's picture
Inference
843b14b
raw
history blame
No virus
2.67 kB
from PIL import Image
from matplotlib import pyplot as plt
import textwrap
def to_gif(images, path):
images[0].save(path, save_all=True,
append_images=images[1:], loop=0, duration=len(images) * 20)
def figure_to_image(figure):
figure.set_dpi(300)
figure.canvas.draw()
return Image.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
def image_grid(images, outpath=None, column_titles=None, row_titles=None):
n_rows = len(images)
n_cols = len(images[0])
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols,
figsize=(n_cols, n_rows), squeeze=False)
for row, _images in enumerate(images):
for column, image in enumerate(_images):
ax = axs[row][column]
ax.imshow(image)
if column_titles and row == 0:
ax.set_title(textwrap.fill(
column_titles[column], width=12), fontsize='x-small')
if row_titles and column == 0:
ax.set_ylabel(row_titles[row], rotation=0, fontsize='x-small', labelpad=1.6 * len(row_titles[row]))
ax.set_xticks([])
ax.set_yticks([])
plt.subplots_adjust(wspace=0, hspace=0)
if outpath is not None:
plt.savefig(outpath, bbox_inches='tight', dpi=300)
plt.close()
else:
plt.tight_layout(pad=0)
image = figure_to_image(plt.gcf())
plt.close()
return image
def get_module(module, module_name):
if isinstance(module_name, str):
module_name = module_name.split('.')
if len(module_name) == 0:
return module
else:
module = getattr(module, module_name[0])
return get_module(module, module_name[1:])
def set_module(module, module_name, new_module):
if isinstance(module_name, str):
module_name = module_name.split('.')
if len(module_name) == 1:
return setattr(module, module_name[0], new_module)
else:
module = getattr(module, module_name[0])
return set_module(module, module_name[1:], new_module)
def freeze(module):
for parameter in module.parameters():
parameter.requires_grad = False
def unfreeze(module):
for parameter in module.parameters():
parameter.requires_grad = True
def get_concat_h(im1, im2):
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def get_concat_v(im1, im2):
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (0, im1.height))
return dst