File size: 1,832 Bytes
3de264f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

import matplotlib.pyplot as plt
from PIL import Image

plt.rcParams["figure.figsize"] = (10, 5)
plt.rcParams['figure.facecolor'] = 'white'


def render_figure(model_name, fn):
    image_types = ['bird', 'human', 'room', 'vermeer']

    def plot_row(axs, control_fn_prefix, output_fn_prefix, name, show_control=False):
        for i, ax in enumerate(axs):
            if i == 0:
                if show_control:
                    ax.set_title(f'Control')
                    ax.imshow(Image.open(f'{control_fn_prefix}.png'))
            else:
                ax.set_title(f'Seed={i-1} ({name})')
                ax.imshow(Image.open(f'{output_fn_prefix}_{i-1}.png'))

    fig, axs = plt.subplots(
        2 * len(image_types), 5, layout="constrained", figsize=(10, 5 * len(image_types)))
    for ax in axs.flatten():
        ax.set_aspect('equal', 'box')
        ax.axis('off')

    pair_axs = [list(pair) for pair in zip(axs[::2], axs[1::2])]
    for image_type, pair_ax in zip(image_types, pair_axs):
        plot_row(pair_ax[0],
                 f'./control_images/converted/control_{image_type}_{model_name}',
                 f'./output_images/diffusers/output_{image_type}_{model_name}',
                 'Diffusers', show_control=True)
        plot_row(pair_ax[1],
                 f'./control_images/converted/control_{image_type}_{model_name}',
                 f'./output_images/ref/output_{image_type}_{model_name}',
                 'ref impl.')

    fig.suptitle(f'Model: {model_name}', fontsize=16)
    # fig.tight_layout()
    fig.savefig(fn, dpi=144)


if __name__ == '__main__':
    model_names = ["canny", "normal", "depth",
                   "openpose", "hed", "scribble", "mlsd", "seg"]
    for model in model_names:
        fn = f"plots/figure_{model}.png"
        render_figure(model, fn)
        print(fn)