File size: 3,208 Bytes
1e5aadc
 
640f9c9
 
 
 
 
 
 
 
b5b3814
640f9c9
 
ace1938
 
640f9c9
 
301a4a3
 
 
 
 
 
df470dc
 
 
301a4a3
 
 
 
640f9c9
df470dc
 
640f9c9
 
 
5b567cd
 
 
640f9c9
af5ee77
301a4a3
af5ee77
640f9c9
 
 
 
 
 
 
 
 
 
 
 
 
d8a4058
640f9c9
 
 
 
 
 
 
df470dc
 
 
 
 
 
 
 
 
 
640f9c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb17c34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
os.environ["USE_NATIVE"] = "1"
import math
import torch
import torchvision
import gradio as gr
from PIL import Image
import torchvision
from test_ddgan import load_model, sample
from model_configs import get_model_config
from subprocess import call

def download(filename):
    return "models/" + filename


device = 'cuda' if torch.cuda.is_available() else 'cpu'
cache = {}

def load(name):
    if name in cache:
        return cache[name]
    else:
        cfg_name = models[name]
        model_config = get_model_config(cfg_name)
        model_path = download(name + ".th")
        model = load_model(model_config, model_path, device=device)
        cache[name] = model
        return model

models = {
    "diffusion_db_128ch_1timesteps_openclip_vith14": "ddgan_ddb_v2",
    "diffusion_db_192ch_2timesteps_openclip_vith14": 'ddgan_ddb_v3',
}
default = "diffusion_db_128ch_1timesteps_openclip_vith14"


load(default)

def gen(md, model_name, md2, text, seed, nb_samples, width, height):
    print("load ", model_name)
    model = load(model_name)
    print(model)
    torch.manual_seed(int(seed))
    nb_samples = int(nb_samples)
    height = int(height)
    width =  int(width)
    with torch.no_grad():
        cond = model.text_encoder([text]*nb_samples)
        if text == "":
            cond[0].normal_()
            cond[1].normal_()
            cond[0][1:] = cond[0][0:1]
            cond[1][1:] = cond[1][0:1]
            
        x_init = torch.randn(nb_samples, 3, height, width).to(device)
        print(x_init.shape)
        fake_sample = sample(model, x_init=x_init, cond=cond)
        fake_sample = (fake_sample + 1) / 2
    grid = torchvision.utils.make_grid(fake_sample, nrow=4)
    grid = grid.permute(1, 2, 0).cpu().numpy()
    grid = (grid*255).astype("uint8")
    return Image.fromarray(grid)
text = """
Text-to-Image Denoising Diffusion GANs is a text-to-image model
based on Denoising Diffusion GANs <https://arxiv.org/abs/2112.07804>.
The code is based on their official code <<https://nvlabs.github.io/denoising-diffusion-gan/>,
which is updated to support text conditioning. Many thanks to the authors of DDGAN for releasing
the code.

The provided models are trained on DiffusionDB <https://arxiv.org/abs/2210.14896>, which is a dataset that was synthetically
generated with Stable Diffusion, many thanks to the authors for releasing the dataset.

Models were trained on JURECA-DC supercomputer at Jülich Supercomputing Centre (JSC), many thanks for the compute provided to train the models.
"""
iface = gr.Interface(
    fn=gen,
    inputs=[
        gr.Markdown(text),
        # text caption
        gr.Dropdown(list(models.keys()), value=default), 
        gr.Markdown("If text caption is empty, random CLIP embeddings will be used as input"),
        gr.Textbox(
            lines=1, 
            placeholder="Enter text caption here, or leave empty", 
            value="Painting of a hamster king  with a crown and a cape  in a magical forest."
        ),
        gr.Number(value=0), # seed
        gr.Number(value=4), # nb_samples
        gr.Number(value=256), # width
        gr.Number(value=256),# height
    ],
    outputs="image"
)
iface.launch(debug=True)