aayushmnit commited on
Commit
cb16212
·
1 Parent(s): 5233387

Uploading diffedit app

Browse files
Files changed (7) hide show
  1. .gitattributes +4 -0
  2. Gradio Demo.ipynb +0 -0
  3. app.py +179 -0
  4. fruitbowl.jpg +0 -0
  5. horse.jpg +3 -0
  6. packages.txt +1 -0
  7. requirements.txt +10 -0
.gitattributes CHANGED
@@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ horse.jpg filter=lfs diff=lfs merge=lfs -text
36
+ fruitbowl.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpg filter=lfs diff=lfs merge=lfs -text
Gradio Demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms as tfms
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from transformers import CLIPTextModel, CLIPTokenizer
7
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
8
+ from diffusers import StableDiffusionInpaintPipeline
9
+ import gradio as gr
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ def load_artifacts():
13
+ '''
14
+ A function to load all diffusion artifacts
15
+ '''
16
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16).to(device)
17
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to(device)
18
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
19
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to(device)
20
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
21
+ return vae, unet, tokenizer, text_encoder, scheduler
22
+
23
+ def load_image(p):
24
+ '''
25
+ Function to load images from a defined path
26
+ '''
27
+ return Image.open(p).convert('RGB').resize((512,512))
28
+
29
+ def pil_to_latents(image):
30
+ '''
31
+ Function to convert image to latents
32
+ '''
33
+ init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0
34
+ init_image = init_image.to(device=device, dtype=torch.float16)
35
+ init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215
36
+ return init_latent_dist
37
+
38
+ def latents_to_pil(latents):
39
+ '''
40
+ Function to convert latents to images
41
+ '''
42
+ latents = (1 / 0.18215) * latents
43
+ with torch.no_grad():
44
+ image = vae.decode(latents).sample
45
+ image = (image / 2 + 0.5).clamp(0, 1)
46
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
47
+ images = (image * 255).round().astype("uint8")
48
+ pil_images = [Image.fromarray(image) for image in images]
49
+ return pil_images
50
+
51
+ def text_enc(prompts, maxlen=None):
52
+ '''
53
+ A function to take a texual promt and convert it into embeddings
54
+ '''
55
+ if maxlen is None: maxlen = tokenizer.model_max_length
56
+ inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
57
+ return text_encoder(inp.input_ids.to(device))[0].half()
58
+
59
+ def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):
60
+ """
61
+ Diffusion process to convert prompt to image
62
+ """
63
+ # Converting textual prompts to embedding
64
+ text = text_enc(prompts)
65
+
66
+ # Adding an unconditional prompt , helps in the generation process
67
+ uncond = text_enc([""], text.shape[1])
68
+ emb = torch.cat([uncond, text])
69
+
70
+ # Setting the seed
71
+ if seed: torch.manual_seed(seed)
72
+
73
+ # Setting number of steps in scheduler
74
+ scheduler.set_timesteps(steps)
75
+
76
+ # Convert the seed image to latent
77
+ init_latents = pil_to_latents(init_img)
78
+
79
+ # Figuring initial time step based on strength
80
+ init_timestep = int(steps * strength)
81
+ timesteps = scheduler.timesteps[-init_timestep]
82
+ timesteps = torch.tensor([timesteps], device=device)
83
+
84
+ # Adding noise to the latents
85
+ noise = torch.randn(init_latents.shape, generator=None, device=device, dtype=init_latents.dtype)
86
+ init_latents = scheduler.add_noise(init_latents, noise, timesteps)
87
+ latents = init_latents
88
+
89
+ # We need to scale the i/p latents to match the variance
90
+ inp = scheduler.scale_model_input(torch.cat([latents] * 2), timesteps)
91
+ # Predicting noise residual using U-Net
92
+ with torch.no_grad(): u,t = unet(inp, timesteps, encoder_hidden_states=emb).sample.chunk(2)
93
+
94
+ # Performing Guidance
95
+ pred = u + g*(t-u)
96
+
97
+ # Zero shot prediction
98
+ latents = scheduler.step(pred, timesteps, latents).pred_original_sample
99
+
100
+ # Returning the latent representation to output an array of 4x64x64
101
+ return latents.detach().cpu()
102
+
103
+ def create_mask_fast(init_img, rp, qp, n=20, s=0.5):
104
+ ## Initialize a dictionary to save n iterations
105
+ diff = {}
106
+
107
+ ## Repeating the difference process n times
108
+ for idx in range(n):
109
+ ## Creating denoised sample using reference / original text
110
+ orig_noise = prompt_2_img_i2i_fast(prompts=rp, init_img=init_img, strength=s, seed = 100*idx)[0]
111
+ ## Creating denoised sample using query / target text
112
+ query_noise = prompt_2_img_i2i_fast(prompts=qp, init_img=init_img, strength=s, seed = 100*idx)[0]
113
+ ## Taking the difference
114
+ diff[idx] = (np.array(orig_noise)-np.array(query_noise))
115
+
116
+ ## Creating a mask placeholder
117
+ mask = np.zeros_like(diff[0])
118
+
119
+ ## Taking an average of 10 iterations
120
+ for idx in range(n):
121
+ ## Note np.abs is a key step
122
+ mask += np.abs(diff[idx])
123
+
124
+ ## Averaging multiple channels
125
+ mask = mask.mean(0)
126
+
127
+ ## Normalizing
128
+ mask = (mask - mask.mean()) / np.std(mask)
129
+
130
+ ## Binarizing and returning the mask object
131
+ return (mask > 0).astype("uint8")
132
+
133
+ def improve_mask(mask):
134
+ mask = cv2.GaussianBlur(mask*255,(3,3),1) > 0
135
+ return mask.astype('uint8')
136
+
137
+ vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()
138
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
139
+ "runwayml/stable-diffusion-inpainting",
140
+ revision="fp16",
141
+ torch_dtype=torch.float16,
142
+ ).to(device)
143
+
144
+ def fastDiffEdit(init_img, reference_prompt , query_prompt, g=7.5, seed=100, strength =0.7, steps=20, dim=512):
145
+
146
+ ## Step 1: Create mask
147
+ mask = create_mask_fast(init_img=init_img, rp=reference_prompt, qp=query_prompt, n=20)
148
+
149
+ ## Improve masking using CV trick
150
+ mask = improve_mask(mask)
151
+
152
+ ## Step 2 and 3: Diffusion process using mask
153
+ output = pipe(
154
+ prompt=query_prompt,
155
+ image=init_img,
156
+ mask_image=Image.fromarray(mask*255).resize((512,512)),
157
+ generator=torch.Generator(device).manual_seed(100),
158
+ num_inference_steps = steps
159
+ ).images
160
+ return output[0]
161
+
162
+
163
+
164
+ demo = gr.Interface(
165
+ fn=fastDiffEdit,
166
+ inputs=[
167
+ gr.inputs.Image(shape=(512, 512), type="pil", label = "Upload your image photo"),
168
+ gr.Textbox(label="Describe your image. Ex: a horse image"),
169
+ gr.Textbox(label="Retype the description with target output. Ex: a zebra image")],
170
+ outputs="image",
171
+ title = "DiffEdit demo",
172
+ description = "DiffEdit paper demo. Upload an image, pass reference prompt describing the image, pass query prompt to replace the object with target object",
173
+ examples = [
174
+ ["fruitbowl.jpg", "a bowl of fruit", "a bowl of grapes"],
175
+ ["horse.jpg", "a horse image", "a zebra image"]],
176
+ enable_queue=True
177
+ )
178
+
179
+ demo.launch()
fruitbowl.jpg ADDED
horse.jpg ADDED

Git LFS Details

  • SHA256: a0d8f57d09e54128ede2bed560f28124c23e32b544bcb8414235c736439aea6d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch
3
+ torchvision
4
+
5
+ Pillow
6
+ opencv-python
7
+ ftfy
8
+ transformers==4.23.1
9
+ diffusers==0.6.0
10
+