cadige commited on
Commit
91bbc7e
1 Parent(s): d6695bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py CHANGED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Docs
8
+ Solutions
9
+ Pricing
10
+
11
+
12
+
13
+ Spaces:
14
+
15
+ AIZ2H
16
+ /
17
+ 02-Gradio-Art-From-Text-And-Images Copied
18
+ like
19
+ 1
20
+ See logs
21
+ App
22
+ Files and versions
23
+ Community
24
+ Settings
25
+ 02-Gradio-Art-From-Text-And-Images
26
+ /
27
+ app.py
28
+ awacke1's picture
29
+ awacke1
30
+ Create new file
31
+ abb46ec
32
+ about 19 hours ago
33
+ raw
34
+ history
35
+ blame
36
+ edit
37
+ delete
38
+ Safe
39
+ 8.72 kB
40
+ import os
41
+
42
+ os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
43
+ os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
44
+
45
+ import argparse
46
+ from functools import partial
47
+ from pathlib import Path
48
+ import sys
49
+ sys.path.append('./cloob-latent-diffusion')
50
+ sys.path.append('./cloob-latent-diffusion/cloob-training')
51
+ sys.path.append('./cloob-latent-diffusion/latent-diffusion')
52
+ sys.path.append('./cloob-latent-diffusion/taming-transformers')
53
+ sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
54
+ from omegaconf import OmegaConf
55
+ from PIL import Image
56
+ import torch
57
+ from torch import nn
58
+ from torch.nn import functional as F
59
+ from torchvision import transforms
60
+ from torchvision.transforms import functional as TF
61
+ from tqdm import trange
62
+ from CLIP import clip
63
+ from cloob_training import model_pt, pretrained
64
+ import ldm.models.autoencoder
65
+ from diffusion import sampling, utils
66
+ import train_latent_diffusion as train
67
+ from huggingface_hub import hf_hub_url, cached_download
68
+ import random
69
+
70
+ # Download the model files
71
+ checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
72
+ ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
73
+ ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
74
+
75
+ # Define a few utility functions
76
+
77
+
78
+ def parse_prompt(prompt, default_weight=3.):
79
+ if prompt.startswith('http://') or prompt.startswith('https://'):
80
+ vals = prompt.rsplit(':', 2)
81
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
82
+ else:
83
+ vals = prompt.rsplit(':', 1)
84
+ vals = vals + ['', default_weight][len(vals):]
85
+ return vals[0], float(vals[1])
86
+
87
+
88
+ def resize_and_center_crop(image, size):
89
+ fac = max(size[0] / image.size[0], size[1] / image.size[1])
90
+ image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
91
+ return TF.center_crop(image, size[::-1])
92
+
93
+
94
+ # Load the models
95
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
96
+ print('Using device:', device)
97
+ print('loading models')
98
+
99
+ # autoencoder
100
+ ae_config = OmegaConf.load(ae_config_path)
101
+ ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
102
+ ae_model.eval().requires_grad_(False).to(device)
103
+ ae_model.load_state_dict(torch.load(ae_model_path))
104
+ n_ch, side_y, side_x = 4, 32, 32
105
+
106
+ # diffusion model
107
+ model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
108
+ model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
109
+ model = model.to(device).eval().requires_grad_(False)
110
+
111
+ # CLOOB
112
+ cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
113
+ cloob = model_pt.get_pt_model(cloob_config)
114
+ checkpoint = pretrained.download_checkpoint(cloob_config)
115
+ cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
116
+ cloob.eval().requires_grad_(False).to(device)
117
+
118
+
119
+ # The key function: returns a list of n PIL images
120
+ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
121
+ method='plms', eta=None):
122
+ zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
123
+ target_embeds, weights = [zero_embed], []
124
+
125
+ for prompt in prompts:
126
+ txt, weight = parse_prompt(prompt)
127
+ target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
128
+ weights.append(weight)
129
+
130
+ for prompt in images:
131
+ path, weight = parse_prompt(prompt)
132
+ img = Image.open(utils.fetch(path)).convert('RGB')
133
+ clip_size = cloob.config['image_encoder']['image_size']
134
+ img = resize_and_center_crop(img, (clip_size, clip_size))
135
+ batch = TF.to_tensor(img)[None].to(device)
136
+ embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
137
+ target_embeds.append(embed)
138
+ weights.append(weight)
139
+
140
+ weights = torch.tensor([1 - sum(weights), *weights], device=device)
141
+
142
+ torch.manual_seed(seed)
143
+
144
+ def cfg_model_fn(x, t):
145
+ n = x.shape[0]
146
+ n_conds = len(target_embeds)
147
+ x_in = x.repeat([n_conds, 1, 1, 1])
148
+ t_in = t.repeat([n_conds])
149
+ clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
150
+ vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
151
+ v = vs.mul(weights[:, None, None, None, None]).sum(0)
152
+ return v
153
+
154
+ def run(x, steps):
155
+ if method == 'ddpm':
156
+ return sampling.sample(cfg_model_fn, x, steps, 1., {})
157
+ if method == 'ddim':
158
+ return sampling.sample(cfg_model_fn, x, steps, eta, {})
159
+ if method == 'prk':
160
+ return sampling.prk_sample(cfg_model_fn, x, steps, {})
161
+ if method == 'plms':
162
+ return sampling.plms_sample(cfg_model_fn, x, steps, {})
163
+ if method == 'pie':
164
+ return sampling.pie_sample(cfg_model_fn, x, steps, {})
165
+ if method == 'plms2':
166
+ return sampling.plms2_sample(cfg_model_fn, x, steps, {})
167
+ assert False
168
+
169
+ batch_size = n
170
+ x = torch.randn([n, n_ch, side_y, side_x], device=device)
171
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
172
+ steps = utils.get_spliced_ddpm_cosine_schedule(t)
173
+ pil_ims = []
174
+ for i in trange(0, n, batch_size):
175
+ cur_batch_size = min(n - i, batch_size)
176
+ out_latents = run(x[i:i+cur_batch_size], steps)
177
+ outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
178
+ for j, out in enumerate(outs):
179
+ pil_ims.append(utils.to_pil_image(out))
180
+
181
+ return pil_ims
182
+
183
+
184
+ import gradio as gr
185
+
186
+ def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
187
+ if seed == None :
188
+ seed = random.randint(0, 10000)
189
+ print( prompt, im_prompt, seed, n_steps)
190
+ prompts = [prompt]
191
+ im_prompts = []
192
+ if im_prompt != None:
193
+ im_prompts = [im_prompt]
194
+ pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
195
+ return pil_ims[0]
196
+
197
+ iface = gr.Interface(fn=gen_ims,
198
+ inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
199
+ #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
200
+ gr.inputs.Textbox(label="Text prompt"),
201
+ gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
202
+ #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
203
+ ],
204
+ outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
205
+ examples=[
206
+ ["Futurism, in the style of Wassily Kandinsky"],
207
+ ["Art Nouveau, in the style of John Singer Sargent"],
208
+ ["Surrealism, in the style of Edgar Degas"],
209
+ ["Expressionism, in the style of Wassily Kandinsky"],
210
+ ["Futurism, in the style of Egon Schiele"],
211
+ ["Neoclassicism, in the style of Gustav Klimt"],
212
+ ["Cubism, in the style of Gustav Klimt"],
213
+ ["Op Art, in the style of Marc Chagall"],
214
+ ["Romanticism, in the style of M.C. Escher"],
215
+ ["Futurism, in the style of M.C. Escher"],
216
+ ["Abstract Art, in the style of M.C. Escher"],
217
+ ["Mannerism, in the style of Paul Klee"],
218
+ ["Romanesque Art, in the style of Leonardo da Vinci"],
219
+ ["High Renaissance, in the style of Rembrandt"],
220
+ ["Magic Realism, in the style of Gustave Dore"],
221
+ ["Realism, in the style of Jean-Michel Basquiat"],
222
+ ["Art Nouveau, in the style of Paul Gauguin"],
223
+ ["Avant-garde, in the style of Pierre-Auguste Renoir"],
224
+ ["Baroque, in the style of Edward Hopper"],
225
+ ["Post-Impressionism, in the style of Wassily Kandinsky"],
226
+ ["Naturalism, in the style of Rene Magritte"],
227
+ ["Constructivism, in the style of Paul Cezanne"],
228
+ ["Abstract Expressionism, in the style of Henri Matisse"],
229
+ ["Pop Art, in the style of Vincent van Gogh"],
230
+ ["Futurism, in the style of Wassily Kandinsky"],
231
+ ["Futurism, in the style of Zdzislaw Beksinski"],
232
+ ['Surrealism, in the style of Salvador Dali'],
233
+ ["Aaron Wacker, oil on canvas"],
234
+ ["abstract"],
235
+ ["landscape"],
236
+ ["portrait"],
237
+ ["sculpture"],
238
+ ["genre painting"],
239
+ ["installation"],
240
+ ["photo"],
241
+ ["figurative"],
242
+ ["illustration"],
243
+ ["still life"],
244
+ ["history painting"],
245
+ ["cityscape"],
246
+ ["marina"],
247
+ ["animal painting"],
248
+ ["design"],
249
+ ["calligraphy"],
250
+ ["symbolic painting"],
251
+ ["graffiti"],
252
+ ["performance"],
253
+ ["mythological painting"],
254
+ ["battle painting"],
255
+ ["self-portrait"],
256
+ ["Impressionism, oil on canvas"]
257
+ ],
258
+ title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia:',
259
+ description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
260
+ article = 'Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa)..'
261
+
262
+ )
263
+ iface.launch(enable_queue=True) # , debug=True for colab debugging