gigant commited on
Commit
4fbd61f
1 Parent(s): 12ee892

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
4
+ os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
5
+
6
+ import argparse
7
+ from functools import partial
8
+ from pathlib import Path
9
+ import sys
10
+ sys.path.append('./cloob-latent-diffusion')
11
+ sys.path.append('./cloob-latent-diffusion/cloob-training')
12
+ sys.path.append('./cloob-latent-diffusion/latent-diffusion')
13
+ sys.path.append('./cloob-latent-diffusion/taming-transformers')
14
+ sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
15
+ from omegaconf import OmegaConf
16
+ from PIL import Image
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+ from torchvision import transforms
21
+ from torchvision.transforms import functional as TF
22
+ from tqdm import trange
23
+ from CLIP import clip
24
+ from cloob_training import model_pt, pretrained
25
+ import ldm.models.autoencoder
26
+ from diffusion import sampling, utils
27
+ import train_latent_diffusion as train
28
+ from huggingface_hub import hf_hub_url, cached_download
29
+ import random
30
+
31
+ # Download the model files
32
+ checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
33
+ ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
34
+ ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
35
+
36
+ # Define a few utility functions
37
+
38
+ def parse_prompt(prompt, default_weight=3.):
39
+ if prompt.startswith('http://') or prompt.startswith('https://'):
40
+ vals = prompt.rsplit(':', 2)
41
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
42
+ else:
43
+ vals = prompt.rsplit(':', 1)
44
+ vals = vals + ['', default_weight][len(vals):]
45
+ return vals[0], float(vals[1])
46
+
47
+
48
+ def resize_and_center_crop(image, size):
49
+ fac = max(size[0] / image.size[0], size[1] / image.size[1])
50
+ image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
51
+ return TF.center_crop(image, size[::-1])
52
+
53
+
54
+ # Load the models
55
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
56
+ print('Using device:', device)
57
+ print('loading models')
58
+ # autoencoder
59
+ ae_config = OmegaConf.load(ae_config_path)
60
+ ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
61
+ ae_model.eval().requires_grad_(False).to(device)
62
+ ae_model.load_state_dict(torch.load(ae_model_path))
63
+ n_ch, side_y, side_x = 4, 32, 32
64
+
65
+ # diffusion model
66
+ model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
67
+ model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
68
+ model = model.to(device).eval().requires_grad_(False)
69
+
70
+ # CLOOB
71
+ cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
72
+ cloob = model_pt.get_pt_model(cloob_config)
73
+ checkpoint = pretrained.download_checkpoint(cloob_config)
74
+ cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
75
+ cloob.eval().requires_grad_(False).to(device)
76
+
77
+
78
+ # The key function: returns a list of n PIL images
79
+ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
80
+ method='plms', eta=None):
81
+ zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
82
+ target_embeds, weights = [zero_embed], []
83
+
84
+ for prompt in prompts:
85
+ txt, weight = parse_prompt(prompt)
86
+ target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
87
+ weights.append(weight)
88
+
89
+ for prompt in images:
90
+ path, weight = parse_prompt(prompt)
91
+ img = Image.open(utils.fetch(path)).convert('RGB')
92
+ clip_size = cloob.config['image_encoder']['image_size']
93
+ img = resize_and_center_crop(img, (clip_size, clip_size))
94
+ batch = TF.to_tensor(img)[None].to(device)
95
+ embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
96
+ target_embeds.append(embed)
97
+ weights.append(weight)
98
+
99
+ weights = torch.tensor([1 - sum(weights), *weights], device=device)
100
+
101
+ torch.manual_seed(seed)
102
+
103
+ def cfg_model_fn(x, t):
104
+ n = x.shape[0]
105
+ n_conds = len(target_embeds)
106
+ x_in = x.repeat([n_conds, 1, 1, 1])
107
+ t_in = t.repeat([n_conds])
108
+ clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
109
+ vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
110
+ v = vs.mul(weights[:, None, None, None, None]).sum(0)
111
+ return v
112
+
113
+ def run(x, steps):
114
+ if method == 'ddpm':
115
+ return sampling.sample(cfg_model_fn, x, steps, 1., {})
116
+ if method == 'ddim':
117
+ return sampling.sample(cfg_model_fn, x, steps, eta, {})
118
+ if method == 'prk':
119
+ return sampling.prk_sample(cfg_model_fn, x, steps, {})
120
+ if method == 'plms':
121
+ return sampling.plms_sample(cfg_model_fn, x, steps, {})
122
+ if method == 'pie':
123
+ return sampling.pie_sample(cfg_model_fn, x, steps, {})
124
+ if method == 'plms2':
125
+ return sampling.plms2_sample(cfg_model_fn, x, steps, {})
126
+ assert False
127
+
128
+ batch_size = n
129
+ x = torch.randn([n, n_ch, side_y, side_x], device=device)
130
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
131
+ steps = utils.get_spliced_ddpm_cosine_schedule(t)
132
+ pil_ims = []
133
+ for i in trange(0, n, batch_size):
134
+ cur_batch_size = min(n - i, batch_size)
135
+ out_latents = run(x[i:i+cur_batch_size], steps)
136
+ outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
137
+ for j, out in enumerate(outs):
138
+ pil_ims.append(utils.to_pil_image(out))
139
+
140
+ return pil_ims
141
+
142
+
143
+ import gradio as gr
144
+
145
+ def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
146
+ if seed == None :
147
+ seed = random.randint(0, 10000)
148
+ print( prompt, im_prompt, seed, n_steps)
149
+ prompts = [prompt]
150
+ im_prompts = []
151
+ if im_prompt != None:
152
+ im_prompts = [im_prompt]
153
+ pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
154
+ return pil_ims[0]
155
+
156
+ iface = gr.Interface(fn=gen_ims,
157
+ inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
158
+ #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
159
+ gr.inputs.Textbox(label="Text prompt"),
160
+ gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
161
+ #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
162
+ ],
163
+ outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
164
+ examples=[["An iceberg, oil on canvas"],["A martian landscape, in the style of Monet"], ['A peaceful meadow, pastel crayons'], ["A painting of a vase of flowers"], ["A ship leaving the port in the summer, oil on canvas"]],
165
+ title='Generate art from text prompts :',
166
+ description="By typing a text prompt or providing an image prompt, and pressing submit you can generate images based on this prompt. The model was trained on images from the [WikiArt](https://huggingface.co/datasets/huggan/wikiart) dataset, comprised mostly of paintings.",
167
+ article = 'The model is a distilled version of a cloob-conditioned latent diffusion model fine-tuned on the WikiArt dataset. You can find more information on this model on the [model card](https://huggingface.co/huggan/distill-ccld-wa). According to the [Latent Diffusion paper](https://arxiv.org/abs/2112.10752): \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\".'
168
+
169
+ )
170
+ iface.launch(enable_queue=True) # , debug=True for colab debugging