daspartho commited on
Commit
d7b9c73
1 Parent(s): 26ac83f

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +22 -0
  2. magic_mix.py +202 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from magic_mix import magic_mix
3
+
4
+ iface = gr.Interface(
5
+ description = "Implementation of MagicMix: Semantic Mixing with Diffusion Models paper",
6
+ article = "<p style='text-align: center'><a href='https://github.com/daspartho/MagicMix' target='_blank'>Github</a></p>",
7
+ fn=magic_mix,
8
+ inputs=[
9
+ gr.Image(shape=(512,512), type="pil"),
10
+ gr.Text(),
11
+ gr.Slider(value=0.3,minimum=0, maximum=1, step=0.1),
12
+ gr.Slider(value=0.5,minimum=0, maximum=1, step=0.1),
13
+ gr.Slider(value=0.5,minimum=0, maximum=1, step=0.1),
14
+ gr.Number(value=42, maximum=2**64-1),
15
+ gr.Slider(value=50),
16
+ gr.Slider(value=7.5, minimum=1, maximum=15, step=0.1),
17
+ ],
18
+ outputs=gr.Image(),
19
+ title="MagicMix"
20
+ )
21
+
22
+ iface.launch()
magic_mix.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
2
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
3
+ import torch
4
+ from torchvision import transforms as tfms
5
+ from tqdm.auto import tqdm
6
+ from PIL import Image
7
+
8
+ # Supress some unnecessary warnings when loading the CLIPTextModel
9
+ logging.set_verbosity_error()
10
+
11
+ # Set device
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Loading components we'll use
15
+
16
+ tokenizer = CLIPTokenizer.from_pretrained(
17
+ "openai/clip-vit-large-patch14",
18
+ )
19
+
20
+ text_encoder = CLIPTextModel.from_pretrained(
21
+ "openai/clip-vit-large-patch14",
22
+ ).to(device)
23
+
24
+ vae = AutoencoderKL.from_pretrained(
25
+ "CompVis/stable-diffusion-v1-4",
26
+ subfolder = "vae",
27
+ ).to(device)
28
+
29
+ unet = UNet2DConditionModel.from_pretrained(
30
+ "CompVis/stable-diffusion-v1-4",
31
+ subfolder = "unet",
32
+ ).to(device)
33
+
34
+ beta_start,beta_end = 0.00085,0.012
35
+ scheduler = DDIMScheduler(
36
+ beta_start=beta_start,
37
+ beta_end=beta_end,
38
+ beta_schedule="scaled_linear",
39
+ num_train_timesteps=1000,
40
+ clip_sample=False,
41
+ set_alpha_to_one=False,
42
+ )
43
+
44
+
45
+ # convert PIL image to latents
46
+ def encode(img):
47
+ with torch.no_grad():
48
+ latent = vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(device)*2-1)
49
+ latent = 0.18215 * latent.latent_dist.sample()
50
+ return latent
51
+
52
+
53
+ # convert latents to PIL image
54
+ def decode(latent):
55
+ latent = (1 / 0.18215) * latent
56
+ with torch.no_grad():
57
+ img = vae.decode(latent).sample
58
+ img = (img / 2 + 0.5).clamp(0, 1)
59
+ img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
60
+ img = (img * 255).round().astype("uint8")
61
+ return Image.fromarray(img[0])
62
+
63
+
64
+ # convert prompt into text embeddings, also unconditional embeddings
65
+ def prep_text(prompt):
66
+
67
+ text_input = tokenizer(
68
+ prompt,
69
+ padding="max_length",
70
+ max_length=tokenizer.model_max_length,
71
+ truncation=True,
72
+ return_tensors="pt",
73
+ )
74
+
75
+ text_embedding = text_encoder(
76
+ text_input.input_ids.to(device)
77
+ )[0]
78
+
79
+ uncond_input = tokenizer(
80
+ "",
81
+ padding="max_length",
82
+ max_length=tokenizer.model_max_length,
83
+ truncation=True,
84
+ return_tensors="pt",
85
+ )
86
+
87
+ uncond_embedding = text_encoder(
88
+ uncond_input.input_ids.to(device)
89
+ )[0]
90
+
91
+ return torch.cat([uncond_embedding, text_embedding])
92
+
93
+
94
+ def magic_mix(
95
+ img, # specifies the layout semantics
96
+ prompt, # specifies the content semantics
97
+ kmin=0.3,
98
+ kmax=0.6,
99
+ v=0.5, # interpolation constant
100
+ seed=42,
101
+ steps=50,
102
+ guidance_scale=7.5,
103
+ ):
104
+
105
+ tmin = steps- int(kmin*steps)
106
+ tmax = steps- int(kmax*steps)
107
+
108
+ text_embeddings = prep_text(prompt)
109
+
110
+ scheduler.set_timesteps(steps)
111
+
112
+ width, height = img.size
113
+ encoded = encode(img)
114
+
115
+ torch.manual_seed(seed)
116
+ noise = torch.randn(
117
+ (1,unet.in_channels,height // 8,width // 8),
118
+ ).to(device)
119
+
120
+ latents = scheduler.add_noise(
121
+ encoded,
122
+ noise,
123
+ timesteps=scheduler.timesteps[tmax]
124
+ )
125
+
126
+ input = torch.cat([latents]*2)
127
+
128
+ input = scheduler.scale_model_input(input, scheduler.timesteps[tmax])
129
+
130
+ with torch.no_grad():
131
+ pred = unet(
132
+ input,
133
+ scheduler.timesteps[tmax],
134
+ encoder_hidden_states=text_embeddings,
135
+ ).sample
136
+
137
+ pred_uncond, pred_text = pred.chunk(2)
138
+ pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
139
+
140
+ latents = scheduler.step(pred, scheduler.timesteps[tmax], latents).prev_sample
141
+
142
+ for i, t in enumerate(tqdm(scheduler.timesteps)):
143
+ if i > tmax:
144
+ if i < tmin: # layout generation phase
145
+ orig_latents = scheduler.add_noise(
146
+ encoded,
147
+ noise,
148
+ timesteps=t
149
+ )
150
+
151
+ input = (v*latents) + (1-v)*orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics
152
+ input = torch.cat([input]*2)
153
+
154
+ else: # content generation phase
155
+ input = torch.cat([latents]*2)
156
+
157
+ input = scheduler.scale_model_input(input, t)
158
+
159
+ with torch.no_grad():
160
+ pred = unet(
161
+ input,
162
+ t,
163
+ encoder_hidden_states=text_embeddings,
164
+ ).sample
165
+
166
+ pred_uncond, pred_text = pred.chunk(2)
167
+ pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
168
+
169
+ latents = scheduler.step(pred, t, latents).prev_sample
170
+
171
+ return decode(latents)
172
+
173
+ if __name__ == "__main__":
174
+
175
+ import argparse
176
+
177
+ parser = argparse.ArgumentParser()
178
+
179
+ parser.add_argument("img_file", type=str, help="image file to provide the layout semantics for the mixing process")
180
+ parser.add_argument("prompt", type=str, help="prompt to provide the content semantics for the mixing process")
181
+ parser.add_argument("out_file", type=str, help="filename to save the generation to")
182
+ parser.add_argument("--kmin", type=float, default=0.3)
183
+ parser.add_argument("--kmax", type=float, default=0.6)
184
+ parser.add_argument("--v", type=float, default=0.5)
185
+ parser.add_argument("--seed", type=int, default=42)
186
+ parser.add_argument("--steps", type=int, default=50)
187
+ parser.add_argument("--guidance_scale", type=float, default=7.5)
188
+
189
+ args = parser.parse_args()
190
+
191
+ img = Image.open(args.img_file)
192
+ out_img = magic_mix(
193
+ img,
194
+ args.prompt,
195
+ args.kmin,
196
+ args.kmax,
197
+ args.v,
198
+ args.seed,
199
+ args.steps,
200
+ args.guidance_scale
201
+ )
202
+ out_img.save(args.out_file)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ diffusers
4
+ transformers
5
+ accelerate
6
+ tqdm
7
+ pillow
8
+ gradio