import torch import torchvision import gradio as gr from PIL import Image from cli import iterative_refinement from viz import grid_of_images_default # from subprocess # subprocess.call("download_models.sh", shell=True) models = { "convae": torch.load("convae.th", map_location="cpu"), "deep_convae": torch.load("deep_convae.th", map_location="cpu"), } def gen(model, seed, nb_iter, nb_samples, width, height): torch.manual_seed(int(seed)) bs = 64 model = models[model] samples = iterative_refinement( model, nb_iter=int(nb_iter), nb_examples=int(nb_samples), w=int(width), h=int(height), c=1, batch_size=bs, ) grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1])) grid = (grid*255).astype("uint8") return Image.fromarray(grid) iface = gr.Interface( fn=gen, inputs=[gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)], outputs="image" ) iface.launch()