anilbhatt1 commited on
Commit
52702ae
1 Parent(s): dc38bce

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +199 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision
6
+ from torchvision import transforms as tfms
7
+ import torchvision.models as models
8
+
9
+ from PIL import Image
10
+ import numpy as np
11
+ from diffusers import LMSDiscreteScheduler, DiffusionPipeline
12
+
13
+ import random
14
+ import os
15
+ import subprocess
16
+
17
+ from matplotlib import pyplot as plt
18
+ from pathlib import Path
19
+ from torch import autocast
20
+ from tqdm.auto import tqdm
21
+
22
+ # Set device
23
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # Load a pre-trained VGG model (you can use other models as well)
26
+ vgg_model = models.vgg16(pretrained=True).features
27
+ vgg_model = vgg_model.to(torch_device)
28
+
29
+ # Create a new model that extracts features from the chosen layers
30
+ feature_extractor = nn.Sequential()
31
+ for name, layer in vgg_model._modules.items():
32
+ if name == '0': # Stop at the 0th layer
33
+ break
34
+ feature_extractor.add_module(name, layer)
35
+ feature_extractor = feature_extractor.to(torch_device)
36
+
37
+ pretrained_model_name_or_path = "segmind/tiny-sd"
38
+ pipe = DiffusionPipeline.from_pretrained(
39
+ pretrained_model_name_or_path,
40
+ torch_dtype=torch.float32
41
+ ).to(torch_device)
42
+
43
+ # The noise scheduler
44
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
45
+
46
+ concept_dict={'anime_bg_v2':('sd-concepts-library/anime-background-style-v2','<anime-background-style-v2>',31),
47
+ 'birb':('sd-concepts-library/birb-style','<birb-style>',32),
48
+ 'depthmap':('sd-concepts-library/depthmap','<depthmap>',33),
49
+ 'gta5_artwork':('sd-concepts-library/gta5-artwork','<gta5_artwork>',34),
50
+ 'midjourney':('sd-concepts-library/midjourney-style','<midjourney-style>',35),
51
+ 'beetlejuice':('sd-concepts-library/beetlejuice-cartoon-style','<beetlejuice-cartoon>',36)}
52
+
53
+ cache_style_list = []
54
+
55
+ def transform_pattern_image(pattern_image):
56
+ preprocess = tfms.Compose([
57
+ tfms.Resize((320, 320)),
58
+ tfms.ToTensor(),
59
+ ])
60
+ tfms_pattern_image = preprocess(pattern_image).unsqueeze(0)
61
+ return tfms_pattern_image
62
+
63
+ def load_required_style(style):
64
+ for concept, value in concept_dict.items():
65
+ if style in concept:
66
+ concept_key = value[1]
67
+ concept_seed = value[2]
68
+ if style not in cache_style_list:
69
+ pipe.load_textual_inversion(value[0])
70
+ cache_style_list.append(style)
71
+ break
72
+ return concept_key, concept_seed
73
+
74
+ def pil_to_latent(input_im):
75
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
76
+ with torch.no_grad():
77
+ latent = pipe.vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
78
+ return 0.18215 * latent.latent_dist.sample() # [1, 4, 64, 64]
79
+
80
+ def latents_to_pil(latents):
81
+ # bath of latents -> list of images
82
+ latents = (1 / 0.18215) * latents
83
+ with torch.no_grad():
84
+ image = pipe.vae.decode(latents).sample
85
+ image = (image / 2 + 0.5).clamp(0, 1)
86
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
87
+ images = (image * 255).round().astype("uint8")
88
+ pil_images = [Image.fromarray(image) for image in images]
89
+ return pil_images
90
+
91
+ def perceptual_loss(images, pattern):
92
+ """
93
+ This function calculates the perceptual loss between the output image and the target image.
94
+
95
+ Parameters:
96
+ """
97
+ criterion = nn.MSELoss()
98
+ mse_loss = criterion(images, pattern)
99
+ return mse_loss
100
+
101
+ #Generating image with the modified embeddings with pattern loss guidance and saving the images to steps/{concept} folder
102
+ def generate_with_embs_pattern_loss(prompt, concept_seed, tfm_pattern_image, num_inf_steps):
103
+ height = 320 # default height of Stable Diffusion
104
+ width = 320 # default width of Stable Diffusion
105
+ num_inference_steps = num_inf_steps # Number of denoising steps
106
+ guidance_scale = 8 # Scale for classifier-free guidance
107
+ generator = torch.manual_seed(concept_seed) # Seed generator to create the inital latent noise
108
+ batch_size = 1
109
+ pattern_loss_scale = 20
110
+
111
+ text_input = pipe.tokenizer(prompt, padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt")
112
+ input_ids = text_input.input_ids.to(torch_device)
113
+ with torch.no_grad():
114
+ text_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0]
115
+
116
+ max_length = text_input.input_ids.shape[-1]
117
+ uncond_input = pipe.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
118
+ with torch.no_grad():
119
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]
120
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
121
+
122
+ # Prep Scheduler
123
+ scheduler.set_timesteps(num_inference_steps)
124
+
125
+ # Prep latents
126
+ latents = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8),
127
+ generator=generator,)
128
+ latents = latents.to(torch_device)
129
+ latents = latents * scheduler.init_noise_sigma
130
+
131
+ # Loop
132
+ for i, t in tqdm(enumerate(scheduler.timesteps)):
133
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
134
+ latent_model_input = torch.cat([latents] * 2)
135
+ sigma = scheduler.sigmas[i]
136
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
137
+
138
+ # predict the noise residual
139
+ with torch.no_grad():
140
+ noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
141
+
142
+ # perform CFG (Classifier Free Guidance)
143
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
144
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
145
+ #### ADDITIONAL GUIDANCE ###
146
+ if (i%3 == 0):
147
+ # Requires grad on the latents
148
+ latents = latents.detach().requires_grad_()
149
+
150
+ # Get the predicted x0:
151
+ latents_x0 = latents - sigma * noise_pred
152
+ # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
153
+
154
+ # Decode to image space
155
+ denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
156
+ # Calculate loss
157
+ denoised_images_extr = feature_extractor(denoised_images)
158
+ reference_img_extr = feature_extractor(tfm_pattern_image)
159
+ loss = perceptual_loss(denoised_images_extr, reference_img_extr) * pattern_loss_scale
160
+ # Get gradient
161
+ cond_grad = torch.autograd.grad(loss, latents)[0]
162
+
163
+ # Modify the latents based on this gradient
164
+ latents = latents.detach() - cond_grad * sigma**2
165
+
166
+ # Now step with scheduler. compute the previous noisy sample x_t -> x_t-1
167
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
168
+
169
+ return latents
170
+
171
+ def generate_image(prompt, pattern_image, style, num_inf_steps):
172
+ tfm_pattern_image = transform_pattern_image(pattern_image) # Transform the pattern image to be fed to feature extractor
173
+ tfm_pattern_image = tfm_pattern_image.to(torch_device)
174
+ if style == "no-style":
175
+ concept_seed = 40
176
+ main_prompt = str(prompt)
177
+ else:
178
+ concept_key, concept_seed = load_required_style(style)
179
+ main_prompt = f"{str(prompt)} in the style of {concept_key}"
180
+ latents = generate_with_embs_pattern_loss(main_prompt, concept_seed, tfm_pattern_image, num_inf_steps)
181
+ generated_image = latents_to_pil(latents)[0]
182
+ return generated_image
183
+
184
+ def gradio_fn(prompt, pattern_image, style, num_inf_steps):
185
+ output_pil_image = generate_image(prompt, pattern_image, style, num_inf_steps)
186
+ return output_pil_image
187
+
188
+ demo = gr.Interface(fn=gradio_fn,
189
+ inputs=[gr.Textbox(info="Example prompt: 'A toddler gazing at sky'"),
190
+ gr.Image(type="pil", height=224, width=224, info='Sample image to emulate the pattern'),
191
+ gr.Radio(["anime","birb","depthmap","gta5","midjourney","beetlejuice","no-style"], label="Style",
192
+ info="Choose the style in which image to be made"),
193
+ gr.Slider(50, 100, value=50, label="Num_inference_steps", info="Choose between 50 & 100")],
194
+ outputs=gr.Image(height=320, width=320),
195
+ title="ImageAlchemy using Stable Diffusion",
196
+ description="- Stable Diffusion model that generates single image to fit \
197
+ (a) given text prompt (b) given reference image and (c) selected style.")
198
+
199
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.34.1
2
+ diffusers==0.21.4
3
+ ftfy==6.1.1
4
+ accelerate==0.23.0
5
+ scipy
6
+ torch==2.1.0
7
+ torchvision==0.16.0