nkanungo commited on
Commit
cd1d66e
·
1 Parent(s): 7d59329

Upload 2 files

Browse files
Files changed (2) hide show
  1. Era_s20_updt.py +373 -0
  2. app.py +42 -0
Era_s20_updt.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers as t
2
+ assert t.__version__=='4.25.1', "Transformers version should be as specified"
3
+
4
+
5
+ import torch
6
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
7
+ from huggingface_hub import notebook_login
8
+
9
+ # For video display:
10
+ from IPython.display import HTML
11
+ from matplotlib import pyplot as plt
12
+ from pathlib import Path
13
+ from PIL import Image
14
+ from torch import autocast
15
+ from torchvision import transforms as tfms
16
+ from tqdm.auto import tqdm
17
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
18
+ import os
19
+ import io
20
+ import base64
21
+ import torch.nn.functional as F
22
+ #from pytorch_grad_cam.utils.image import show_cam_on_image
23
+
24
+
25
+ torch.manual_seed(1)
26
+
27
+ if not (Path.home()/'.cache/huggingface'/'token').exists(): notebook_login()
28
+
29
+ # Supress some unnecessary warnings when loading the CLIPTextModel
30
+ logging.set_verbosity_error()
31
+
32
+ # Set device
33
+ torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
34
+ if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
35
+
36
+ import sys,gc,traceback
37
+ import fastcore.all as fc
38
+
39
+ # %% ../nbs/11_initializing.ipynb 11
40
+ def clean_ipython_hist():
41
+ # Code in this function mainly copied from IPython source
42
+ if not 'get_ipython' in globals(): return
43
+ ip = get_ipython()
44
+ user_ns = ip.user_ns
45
+ ip.displayhook.flush()
46
+ pc = ip.displayhook.prompt_count + 1
47
+ for n in range(1, pc): user_ns.pop('_i'+repr(n),None)
48
+ user_ns.update(dict(_i='',_ii='',_iii=''))
49
+ hm = ip.history_manager
50
+ hm.input_hist_parsed[:] = [''] * pc
51
+ hm.input_hist_raw[:] = [''] * pc
52
+ hm._i = hm._ii = hm._iii = hm._i00 = ''
53
+
54
+ # %% ../nbs/11_initializing.ipynb 12
55
+ def clean_tb():
56
+ # h/t Piotr Czapla
57
+ if hasattr(sys, 'last_traceback'):
58
+ traceback.clear_frames(sys.last_traceback)
59
+ delattr(sys, 'last_traceback')
60
+ if hasattr(sys, 'last_type'): delattr(sys, 'last_type')
61
+ if hasattr(sys, 'last_value'): delattr(sys, 'last_value')
62
+
63
+ # %% ../nbs/11_initializing.ipynb 13
64
+ def clean_mem():
65
+ clean_tb()
66
+ clean_ipython_hist()
67
+ gc.collect()
68
+ torch.cuda.empty_cache()
69
+
70
+ clean_mem()
71
+
72
+ # Load the autoencoder model which will be used to decode the latents into image space.
73
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
74
+
75
+ # Load the tokenizer and text encoder to tokenize and encode the text.
76
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
77
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
78
+
79
+ # The UNet model for generating the latents.
80
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
81
+
82
+ # The noise scheduler
83
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
84
+
85
+ # To the GPU we go!
86
+ vae = vae.to(torch_device)
87
+ text_encoder = text_encoder.to(torch_device)
88
+ unet = unet.to(torch_device);
89
+
90
+ embeds_folder = Path('C:/Users/shivs/Downloads/paintings_embed')
91
+ file_names = [path.name for path in embeds_folder.glob('*') if path.is_file()]
92
+ print(file_names)
93
+
94
+ style_names = [list(torch.load(embeds_folder/file).keys())[0] for file in file_names]
95
+ style_names
96
+ num_added_tokens = tokenizer.add_tokens(style_names)
97
+
98
+ added_tokens = list(map(tokenizer.added_tokens_encoder.get,style_names))
99
+ added_tokens,style_names
100
+
101
+
102
+ text_encoder.resize_token_embeddings(len(tokenizer))
103
+ text_encoder.text_model.embeddings.token_embedding
104
+
105
+
106
+ style_dict = {}
107
+
108
+ list_styles = [torch.load(embeds_folder/file) for file in file_names]
109
+
110
+
111
+ for k,v in list_styles[0].items():
112
+ print(k,v.shape)
113
+
114
+ style_dict = {style:embedding for each_style in list_styles for style,embedding in each_style.items()}
115
+
116
+ list(style_dict)
117
+
118
+ for token,style in zip(added_tokens,style_names):
119
+ text_encoder.text_model.embeddings.token_embedding.weight.data[token] = style_dict[style]
120
+
121
+ # #checking if we added the embeddings properly to text_encoder
122
+ # ft_dict = torch.load(embeds_folder/'fairy-tale-painting_embeds.bin')
123
+
124
+ # list(ft_dict.keys())[0]
125
+
126
+ # ft_dict['<fairy-tale-painting-style>'][:10]
127
+
128
+ clean_mem()
129
+
130
+ # text_encoder.get_input_embeddings()(torch.tensor(49408, device=torch_device))[:10]
131
+
132
+
133
+ # Prep Scheduler
134
+ def set_timesteps(scheduler, num_inference_steps):
135
+ scheduler.set_timesteps(num_inference_steps)
136
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
137
+
138
+ def pil_to_latent(input_im):
139
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
140
+ with torch.no_grad():
141
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
142
+ return 0.18215 * latent.latent_dist.sample()
143
+
144
+ def latents_to_pil(latents):
145
+ # bath of latents -> list of images
146
+ latents = (1 / 0.18215) * latents
147
+ with torch.no_grad():
148
+ image = vae.decode(latents).sample
149
+ image = (image / 2 + 0.5).clamp(0, 1)
150
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
151
+ images = (image * 255).round().astype("uint8")
152
+ pil_images = [Image.fromarray(image) for image in images]
153
+ return pil_images
154
+
155
+ # Access the embedding layer
156
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
157
+ token_emb_layer # Vocab size 49408, emb_dim 768
158
+
159
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
160
+
161
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
162
+ position_embeddings = pos_emb_layer(position_ids)
163
+ print(position_embeddings.shape)
164
+
165
+ def get_output_embeds(input_embeddings):
166
+ # CLIP's text model uses causal mask, so we prepare it here:
167
+ bsz, seq_len = input_embeddings.shape[:2]
168
+ causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
169
+
170
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
171
+ # so that it doesn't just return the pooled final predictions:
172
+ encoder_outputs = text_encoder.text_model.encoder(
173
+ inputs_embeds=input_embeddings,
174
+ attention_mask=None, # We aren't using an attention mask so that can be None
175
+ causal_attention_mask=causal_attention_mask.to(torch_device),
176
+ output_attentions=None,
177
+ output_hidden_states=True, # We want the output embs not the final output
178
+ return_dict=None,
179
+ )
180
+
181
+ # We're interested in the output hidden state only
182
+ output = encoder_outputs[0]
183
+
184
+ # There is a final layer norm we need to pass these through
185
+ output = text_encoder.text_model.final_layer_norm(output)
186
+
187
+ # And now they're ready!
188
+ return output
189
+
190
+ #Generating an image with these modified embeddings
191
+
192
+ def generate_with_embs_custom(text_embeddings,seed):
193
+ height = 512 # default height of Stable Diffusion
194
+ width = 512 # default width of Stable Diffusion
195
+ num_inference_steps = 1 # Number of denoising steps
196
+ guidance_scale = 7.5 # Scale for classifier-free guidance
197
+ generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
198
+ batch_size = 1
199
+
200
+ max_length = text_embeddings.shape[1]
201
+ uncond_input = tokenizer(
202
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
203
+ )
204
+ with torch.no_grad():
205
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
206
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
207
+
208
+ # Prep Scheduler
209
+ set_timesteps(scheduler, num_inference_steps)
210
+
211
+ # Prep latents
212
+ latents = torch.randn(
213
+ (batch_size, unet.in_channels, height // 8, width // 8),
214
+ generator=generator,
215
+ )
216
+ latents = latents.to(torch_device)
217
+ latents = latents * scheduler.init_noise_sigma
218
+
219
+ # Loop
220
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
221
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
222
+ latent_model_input = torch.cat([latents] * 2)
223
+ sigma = scheduler.sigmas[i]
224
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
225
+
226
+ # predict the noise residual
227
+ with torch.no_grad():
228
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
229
+
230
+ # perform guidance
231
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
232
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
233
+
234
+ # compute the previous noisy sample x_t -> x_t-1
235
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
236
+
237
+ return latents_to_pil(latents)[0]
238
+
239
+
240
+ # ref_image = Image.open('C:/Users/shivs/Downloads/lg.jpg').resize((512,512))
241
+ # ref_latent = pil_to_latent(ref_image)
242
+
243
+ ## Guidance through Custom Loss Function
244
+ def custom_loss(latent):
245
+ error = F.mse_loss(0.5*latent,0.8*ref_latent)
246
+ return error
247
+
248
+
249
+ class Styles_paintings():
250
+ def __init__(self,prompt):
251
+ self.output_styles = []
252
+ self.prompt = prompt
253
+ self.style_names = list(style_dict)
254
+ self.seeds = [1024+i for i in range(len(self.style_names))]
255
+
256
+ def generate_styles(self):
257
+ #print('The Values are ', list(style_dict)[0])
258
+
259
+ for seed,style_name in zip(self.seeds,self.style_names):
260
+ # Tokenize
261
+ prompt = f'{self.prompt} in the style of {style_name}'
262
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
263
+ input_ids = text_input.input_ids.to(torch_device)
264
+
265
+ # Get token embeddings
266
+ token_embeddings = token_emb_layer(input_ids)
267
+
268
+
269
+ # Combine with pos embs
270
+ input_embeddings = token_embeddings + position_embeddings
271
+
272
+ # Feed through to get final output embs
273
+ modified_output_embeddings = get_output_embeds(input_embeddings)
274
+
275
+ # And generate an image with this:
276
+ self.output_styles.append(generate_with_embs_custom(modified_output_embeddings,seed))
277
+
278
+ def generate_styles_with_custom_loss(self, image):
279
+ height = 512 # default height of Stable Diffusion
280
+ width = 512 # default width of Stable Diffusion
281
+ num_inference_steps = 1 #@param # Number of denoising steps
282
+ guidance_scale = 8 #@param # Scale for classifier-free guidance
283
+ batch_size = 1
284
+ custom_loss_scale = 200 #@param
285
+ #print('image shape there is',image.size)
286
+ self.output_styles_with_custom_loss = []
287
+ #ref_image = Image.open('C:/Users/shivs/Downloads/ig.jpg').resize((512,512))
288
+ ref_latent = pil_to_latent(ref_image)
289
+ for seed,style_name in zip(self.seeds,self.style_names):
290
+ # Tokenize
291
+ prompt = f'{self.prompt} in the style of {style_name}'
292
+ generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
293
+ print(f' the prompt is : {prompt} with seed value :{seed}')
294
+ # Prep text
295
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
296
+ with torch.no_grad():
297
+ text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
298
+
299
+ # And the uncond. input as before:
300
+ max_length = text_input.input_ids.shape[-1]
301
+ uncond_input = tokenizer(
302
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
303
+ )
304
+ with torch.no_grad():
305
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
306
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
307
+
308
+ # Prep Scheduler
309
+ set_timesteps(scheduler, num_inference_steps)
310
+
311
+ # Prep latents
312
+ latents = torch.randn(
313
+ (batch_size, unet.in_channels, height // 8, width // 8),
314
+ generator=generator,)
315
+ latents = latents.to(torch_device)
316
+ latents = latents * scheduler.init_noise_sigma
317
+
318
+ # Loop
319
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
320
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
321
+ latent_model_input = torch.cat([latents] * 2)
322
+ sigma = scheduler.sigmas[i]
323
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
324
+
325
+ # predict the noise residual
326
+ with torch.no_grad():
327
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
328
+
329
+ # perform CFG
330
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
331
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
332
+
333
+ #### ADDITIONAL GUIDANCE ###
334
+ if i%5 == 0:
335
+ # Requires grad on the latents
336
+ latents = latents.detach().requires_grad_()
337
+
338
+ # Get the predicted x0:
339
+ latents_x0 = latents - sigma * noise_pred
340
+ #latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
341
+
342
+ # Decode to image space
343
+ #denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
344
+
345
+ # Calculate loss
346
+ loss = custom_loss(latents_x0) * custom_loss_scale
347
+ #loss = blue_loss(denoised_images) * blue_loss_scale
348
+
349
+ # Occasionally print it out
350
+ if i%10==0:
351
+ print(i, 'loss:', loss.item())
352
+
353
+ # Get gradient
354
+ cond_grad = torch.autograd.grad(loss, latents)[0]
355
+
356
+ # Modify the latents based on this gradient
357
+ latents = latents.detach() - cond_grad * sigma**2
358
+
359
+ # Now step with scheduler
360
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
361
+
362
+ self.output_styles_with_custom_loss.append(latents_to_pil(latents)[0])
363
+
364
+ def generate_final_image(im1,in_prompt):
365
+ paintings = Styles_paintings(in_prompt)
366
+ paintings.generate_styles()
367
+ r_image = im1.resize((512,512))
368
+ print('image shape is',r_image.size)
369
+ paintings.generate_styles_with_custom_loss(r_image)
370
+
371
+ #print(len(paintings.output_styles))
372
+
373
+ return [paintings.output_styles[0]], [paintings.output_styles[1]],[paintings.output_styles[2]],[paintings.output_styles[3]],[paintings.output_styles[4]],[paintings.output_styles_with_custom_loss[0]],[paintings.output_styles_with_custom_loss[1]],[paintings.output_styles_with_custom_loss[2]],[paintings.output_styles_with_custom_loss[3]],[paintings.output_styles_with_custom_loss[4]]
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[5]:
5
+
6
+
7
+ import numpy as np
8
+ import gradio as gr
9
+ from Era_s20_updt import generate_final_image
10
+
11
+
12
+ gr.Interface(
13
+
14
+ generate_final_image,
15
+ inputs=[
16
+ #gr.Image(label="Input Image"),
17
+ gr.Image(type='pil', label="Guided Image for Loss"),
18
+ gr.Text(label="Input Prompt")
19
+
20
+ #gr.Slider(0, 1, value=0.5, label="IOU Threshold"),
21
+ #gr.Slider(0, 1, value=0.4, label="Threshold"),
22
+ #gr.Checkbox(label="Show Grad Cam"),
23
+ #gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"),
24
+ ],
25
+ outputs =
26
+ [
27
+ gr.Gallery(rows=2, columns=1"),
28
+ gr.Gallery(rows=2, columns=1),
29
+ gr.Gallery(rows=2, columns=1),
30
+ gr.Gallery(rows=2, columns=1),
31
+ gr.Gallery(rows=2, columns=1),
32
+ gr.Gallery(rows=2, columns=1),
33
+ gr.Gallery(rows=2, columns=1),
34
+ gr.Gallery(rows=2, columns=1),
35
+ gr.Gallery(rows=2, columns=1),
36
+ gr.Gallery(rows=2, columns=1)
37
+
38
+ ],
39
+ title="Stable Diffusion",
40
+ layout="Vertical"
41
+ ).launch()
42
+