import matplotlib.pyplot as plt from PIL import Image plt.rcParams["figure.figsize"] = (10, 5) plt.rcParams['figure.facecolor'] = 'white' def render_figure(model_name, fn): image_types = ['bird', 'human', 'room', 'vermeer'] def plot_row(axs, control_fn_prefix, output_fn_prefix, name, show_control=False): for i, ax in enumerate(axs): if i == 0: if show_control: ax.set_title(f'Control') ax.imshow(Image.open(f'{control_fn_prefix}.png')) else: ax.set_title(f'Seed={i-1} ({name})') ax.imshow(Image.open(f'{output_fn_prefix}_{i-1}.png')) fig, axs = plt.subplots( 2 * len(image_types), 5, layout="constrained", figsize=(10, 5 * len(image_types))) for ax in axs.flatten(): ax.set_aspect('equal', 'box') ax.axis('off') pair_axs = [list(pair) for pair in zip(axs[::2], axs[1::2])] for image_type, pair_ax in zip(image_types, pair_axs): plot_row(pair_ax[0], f'./control_images/converted/control_{image_type}_{model_name}', f'./output_images/diffusers/output_{image_type}_{model_name}', 'Diffusers', show_control=True) plot_row(pair_ax[1], f'./control_images/converted/control_{image_type}_{model_name}', f'./output_images/ref/output_{image_type}_{model_name}', 'ref impl.') fig.suptitle(f'Model: {model_name}', fontsize=16) # fig.tight_layout() fig.savefig(fn, dpi=144) if __name__ == '__main__': model_names = ["canny", "normal", "depth", "openpose", "hed", "scribble", "mlsd", "seg"] for model in model_names: fn = f"plots/figure_{model}.png" render_figure(model, fn) print(fn)