File size: 1,166 Bytes
fa128ec
 
 
 
 
 
50eace8
 
fa128ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torchvision
import gradio as gr
from PIL import Image
from cli import iterative_refinement
from viz import grid_of_images_default
# from subprocess
# subprocess.call("download_models.sh", shell=True)
models = {
    "convae": torch.load("convae.th", map_location="cpu"),
    "deep_convae": torch.load("deep_convae.th", map_location="cpu"),
}

def gen(model, seed, nb_iter, nb_samples, width, height):
    torch.manual_seed(int(seed))
    bs = 64
    model = models[model]
    samples = iterative_refinement(
        model, 
        nb_iter=int(nb_iter),
        nb_examples=int(nb_samples), 
        w=int(width), h=int(height), c=1, 
        batch_size=bs,
    )
    grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1])) 
    grid = (grid*255).astype("uint8")
    return Image.fromarray(grid)

iface = gr.Interface(
    fn=gen,
    inputs=[gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)],
    outputs="image"
)
iface.launch()