awacke1 commited on
Commit
4849edf
1 Parent(s): 5c3b8f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -109
app.py CHANGED
@@ -1,17 +1,10 @@
 
1
  import os
2
-
3
- os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
4
- os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
5
-
6
  import argparse
7
  from functools import partial
8
  from pathlib import Path
9
  import sys
10
- sys.path.append('./cloob-latent-diffusion')
11
- sys.path.append('./cloob-latent-diffusion/cloob-training')
12
- sys.path.append('./cloob-latent-diffusion/latent-diffusion')
13
- sys.path.append('./cloob-latent-diffusion/taming-transformers')
14
- sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
15
  from omegaconf import OmegaConf
16
  from PIL import Image
17
  import torch
@@ -26,17 +19,21 @@ import ldm.models.autoencoder
26
  from diffusion import sampling, utils
27
  import train_latent_diffusion as train
28
  from huggingface_hub import hf_hub_url, cached_download
29
- import random
30
 
31
- # Download the model files
 
32
  checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
33
  ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
34
  ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
35
 
36
- # Define a few utility functions
37
-
38
 
39
  def parse_prompt(prompt, default_weight=3.):
 
 
 
40
  if prompt.startswith('http://') or prompt.startswith('https://'):
41
  vals = prompt.rsplit(':', 2)
42
  vals = [vals[0] + ':' + vals[1], *vals[2:]]
@@ -45,31 +42,35 @@ def parse_prompt(prompt, default_weight=3.):
45
  vals = vals + ['', default_weight][len(vals):]
46
  return vals[0], float(vals[1])
47
 
48
-
49
  def resize_and_center_crop(image, size):
 
 
 
50
  fac = max(size[0] / image.size[0], size[1] / image.size[1])
51
  image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
52
  return TF.center_crop(image, size[::-1])
53
 
54
 
55
- # Load the models
 
 
56
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
57
  print('Using device:', device)
58
- print('loading models')
59
 
60
- # autoencoder
61
  ae_config = OmegaConf.load(ae_config_path)
62
  ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
63
  ae_model.eval().requires_grad_(False).to(device)
64
  ae_model.load_state_dict(torch.load(ae_model_path))
65
  n_ch, side_y, side_x = 4, 32, 32
66
 
67
- # diffusion model
68
  model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
69
  model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
70
  model = model.to(device).eval().requires_grad_(False)
71
 
72
- # CLOOB
73
  cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
74
  cloob = model_pt.get_pt_model(cloob_config)
75
  checkpoint = pretrained.download_checkpoint(cloob_config)
@@ -77,93 +78,98 @@ cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
77
  cloob.eval().requires_grad_(False).to(device)
78
 
79
 
80
- # The key function: returns a list of n PIL images
81
- def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
82
- method='plms', eta=None):
83
- zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
84
- target_embeds, weights = [zero_embed], []
85
-
86
- for prompt in prompts:
87
- txt, weight = parse_prompt(prompt)
88
- target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
89
- weights.append(weight)
90
-
91
- for prompt in images:
92
- path, weight = parse_prompt(prompt)
93
- img = Image.open(utils.fetch(path)).convert('RGB')
94
- clip_size = cloob.config['image_encoder']['image_size']
95
- img = resize_and_center_crop(img, (clip_size, clip_size))
96
- batch = TF.to_tensor(img)[None].to(device)
97
- embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
98
- target_embeds.append(embed)
99
- weights.append(weight)
100
-
101
- weights = torch.tensor([1 - sum(weights), *weights], device=device)
102
-
103
- torch.manual_seed(seed)
104
-
105
- def cfg_model_fn(x, t):
106
- n = x.shape[0]
107
- n_conds = len(target_embeds)
108
- x_in = x.repeat([n_conds, 1, 1, 1])
109
- t_in = t.repeat([n_conds])
110
- clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
111
- vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
112
- v = vs.mul(weights[:, None, None, None, None]).sum(0)
113
- return v
114
-
115
- def run(x, steps):
116
- if method == 'ddpm':
117
- return sampling.sample(cfg_model_fn, x, steps, 1., {})
118
- if method == 'ddim':
119
- return sampling.sample(cfg_model_fn, x, steps, eta, {})
120
- if method == 'prk':
121
- return sampling.prk_sample(cfg_model_fn, x, steps, {})
122
- if method == 'plms':
123
- return sampling.plms_sample(cfg_model_fn, x, steps, {})
124
- if method == 'pie':
125
- return sampling.pie_sample(cfg_model_fn, x, steps, {})
126
- if method == 'plms2':
127
- return sampling.plms2_sample(cfg_model_fn, x, steps, {})
128
- assert False
129
-
130
- batch_size = n
131
- x = torch.randn([n, n_ch, side_y, side_x], device=device)
132
- t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
133
- steps = utils.get_spliced_ddpm_cosine_schedule(t)
134
- pil_ims = []
135
- for i in trange(0, n, batch_size):
136
- cur_batch_size = min(n - i, batch_size)
137
- out_latents = run(x[i:i+cur_batch_size], steps)
138
- outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
139
- for j, out in enumerate(outs):
140
- pil_ims.append(utils.to_pil_image(out))
141
-
142
- return pil_ims
143
-
144
-
145
- import gradio as gr
 
 
 
146
 
147
  def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
148
- if seed == None :
149
- seed = random.randint(0, 10000)
150
- print( prompt, im_prompt, seed, n_steps)
151
- prompts = [prompt]
152
- im_prompts = []
153
- if im_prompt != None:
154
- im_prompts = [im_prompt]
155
- pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
156
- return pil_ims[0]
157
-
158
- iface = gr.Interface(fn=gen_ims,
159
- inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
160
- #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
161
- gr.inputs.Textbox(label="Text prompt"),
162
- gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
163
- #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
164
- ],
165
- outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
166
- examples=[
 
 
167
  ["Virgin and Child, in the style of Jacopo Bellini"],
168
  ["Art Nouveau, in the style of John Singer Sargent"],
169
  ["Neoclassicism, in the style of Gustav Klimt"],
@@ -212,9 +218,10 @@ iface = gr.Interface(fn=gen_ims,
212
  ["Futurism, in the style of Zdzislaw Beksinski"],
213
  ["Aaron Wacker, oil on canvas"],
214
  ],
215
- title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia:',
216
- description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
217
- article = 'Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa)..'
218
-
219
  )
220
- iface.launch(enable_queue=True) # , debug=True for colab debugging
 
 
 
1
+ # 🚀 Import all necessary libraries
2
  import os
 
 
 
 
3
  import argparse
4
  from functools import partial
5
  from pathlib import Path
6
  import sys
7
+ import random
 
 
 
 
8
  from omegaconf import OmegaConf
9
  from PIL import Image
10
  import torch
 
19
  from diffusion import sampling, utils
20
  import train_latent_diffusion as train
21
  from huggingface_hub import hf_hub_url, cached_download
22
+ import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
23
 
24
+ # 🖼️ Download the necessary model files
25
+ # These files are loaded from HuggingFace's repository
26
  checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
27
  ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
28
  ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
29
 
30
+ # 📐 Utility Functions: Math and images, what could go wrong?
31
+ # These functions help parse prompts and resize/crop images to fit nicely
32
 
33
  def parse_prompt(prompt, default_weight=3.):
34
+ """
35
+ 🎯 Parses a prompt into text and weight.
36
+ """
37
  if prompt.startswith('http://') or prompt.startswith('https://'):
38
  vals = prompt.rsplit(':', 2)
39
  vals = [vals[0] + ':' + vals[1], *vals[2:]]
 
42
  vals = vals + ['', default_weight][len(vals):]
43
  return vals[0], float(vals[1])
44
 
 
45
  def resize_and_center_crop(image, size):
46
+ """
47
+ ✂️ Resize and crop image to center it beautifully.
48
+ """
49
  fac = max(size[0] / image.size[0], size[1] / image.size[1])
50
  image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
51
  return TF.center_crop(image, size[::-1])
52
 
53
 
54
+ # 🧠 Model loading: the brain of our operation! 🔥
55
+ # Load all the models: autoencoder, diffusion, and CLOOB
56
+
57
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
58
  print('Using device:', device)
59
+ print('loading models... 🛠️')
60
 
61
+ # 🔧 Autoencoder Setup: Let’s decode the madness into images
62
  ae_config = OmegaConf.load(ae_config_path)
63
  ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
64
  ae_model.eval().requires_grad_(False).to(device)
65
  ae_model.load_state_dict(torch.load(ae_model_path))
66
  n_ch, side_y, side_x = 4, 32, 32
67
 
68
+ # 🌀 Diffusion Model Setup: The artist behind the scenes
69
  model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
70
  model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
71
  model = model.to(device).eval().requires_grad_(False)
72
 
73
+ # 👁️ CLOOB Setup: Our vision model to understand art in human style
74
  cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
75
  cloob = model_pt.get_pt_model(cloob_config)
76
  checkpoint = pretrained.download_checkpoint(cloob_config)
 
78
  cloob.eval().requires_grad_(False).to(device)
79
 
80
 
81
+ # 🎨 The key function: Where the magic happens!
82
+ # This is where we generate images based on text and image prompts
83
+
84
+ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='plms', eta=None):
85
+ """
86
+ 🖼️ Generates a list of PIL images based on given text and image prompts.
87
+ """
88
+ zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
89
+ target_embeds, weights = [zero_embed], []
90
+
91
+ # Parse text prompts
92
+ for prompt in prompts:
93
+ txt, weight = parse_prompt(prompt)
94
+ target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
95
+ weights.append(weight)
96
+
97
+ # Parse image prompts
98
+ for prompt in images:
99
+ path, weight = parse_prompt(prompt)
100
+ img = Image.open(utils.fetch(path)).convert('RGB')
101
+ clip_size = cloob.config['image_encoder']['image_size']
102
+ img = resize_and_center_crop(img, (clip_size, clip_size))
103
+ batch = TF.to_tensor(img)[None].to(device)
104
+ embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
105
+ target_embeds.append(embed)
106
+ weights.append(weight)
107
+
108
+ # Adjust weights and set seed
109
+ weights = torch.tensor([1 - sum(weights), *weights], device=device)
110
+ torch.manual_seed(seed)
111
+
112
+ # 💡 Model function with classifier-free guidance
113
+ def cfg_model_fn(x, t):
114
+ n = x.shape[0]
115
+ n_conds = len(target_embeds)
116
+ x_in = x.repeat([n_conds, 1, 1, 1])
117
+ t_in = t.repeat([n_conds])
118
+ clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
119
+ vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
120
+ v = vs.mul(weights[:, None, None, None, None]).sum(0)
121
+ return v
122
+
123
+ # 🎞️ Run the sampler to generate images
124
+ def run(x, steps):
125
+ if method == 'ddpm':
126
+ return sampling.sample(cfg_model_fn, x, steps, 1., {})
127
+ if method == 'ddim':
128
+ return sampling.sample(cfg_model_fn, x, steps, eta, {})
129
+ if method == 'plms':
130
+ return sampling.plms_sample(cfg_model_fn, x, steps, {})
131
+ assert False
132
+
133
+ # 🏃‍♂️ Generate the output images
134
+ batch_size = n
135
+ x = torch.randn([n, n_ch, side_y, side_x], device=device)
136
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
137
+ pil_ims = []
138
+ for i in trange(0, n, batch_size):
139
+ cur_batch_size = min(n - i, batch_size)
140
+ out_latents = run(x[i:i + cur_batch_size], steps)
141
+ outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
142
+ for j, out in enumerate(outs):
143
+ pil_ims.append(utils.to_pil_image(out))
144
+
145
+ return pil_ims
146
+
147
+
148
+ # 🖌️ Interface: Gradio's brush to paint the UI
149
+ # Gradio is used here to create a user-friendly interface for art generation.
150
 
151
  def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
152
+ """
153
+ 💡 Gradio function to wrap image generation.
154
+ """
155
+ if seed is None:
156
+ seed = random.randint(0, 10000)
157
+ prompts = [prompt]
158
+ im_prompts = []
159
+ if im_prompt is not None:
160
+ im_prompts = [im_prompt]
161
+ pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
162
+ return pil_ims[0]
163
+
164
+ # 🖼️ Gradio UI: The interface where users can input text or image prompts
165
+ iface = gr.Interface(
166
+ fn=gen_ims,
167
+ inputs=[
168
+ gr.Textbox(label="Text prompt"),
169
+ gr.Image(optional=True, label="Image prompt", type='filepath')
170
+ ],
171
+ outputs=gr.Image(type="pil", label="Generated Image"),
172
+ examples=[
173
  ["Virgin and Child, in the style of Jacopo Bellini"],
174
  ["Art Nouveau, in the style of John Singer Sargent"],
175
  ["Neoclassicism, in the style of Gustav Klimt"],
 
218
  ["Futurism, in the style of Zdzislaw Beksinski"],
219
  ["Aaron Wacker, oil on canvas"],
220
  ],
221
+ title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia',
222
+ description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
223
+ article='Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa).'
 
224
  )
225
+
226
+ # 🚀 Launch the Gradio interface
227
+ iface.launch(enable_queue=True)