omerbartal commited on
Commit
7b90989
1 Parent(s): 2104e5b

Upload region_control.py

Browse files
Files changed (1) hide show
  1. region_control.py +208 -0
region_control.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
2
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
3
+
4
+ # suppress partial model loading warning
5
+ logging.set_verbosity_error()
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as T
10
+ import argparse
11
+ import numpy as np
12
+ from PIL import Image
13
+
14
+
15
+ def seed_everything(seed):
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed(seed)
18
+ # torch.backends.cudnn.deterministic = True
19
+ # torch.backends.cudnn.benchmark = True
20
+
21
+
22
+ def get_views(panorama_height, panorama_width, window_size=64, stride=8):
23
+ panorama_height /= 8
24
+ panorama_width /= 8
25
+ num_blocks_height = (panorama_height - window_size) // stride + 1
26
+ num_blocks_width = (panorama_width - window_size) // stride + 1
27
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
28
+ views = []
29
+ for i in range(total_num_blocks):
30
+ h_start = int((i // num_blocks_width) * stride)
31
+ h_end = h_start + window_size
32
+ w_start = int((i % num_blocks_width) * stride)
33
+ w_end = w_start + window_size
34
+ views.append((h_start, h_end, w_start, w_end))
35
+ return views
36
+
37
+
38
+ class MultiDiffusion(nn.Module):
39
+ def __init__(self, device, sd_version='2.0', hf_key=None):
40
+ super().__init__()
41
+
42
+ self.device = device
43
+ self.sd_version = sd_version
44
+
45
+ print(f'[INFO] loading stable diffusion...')
46
+ if hf_key is not None:
47
+ print(f'[INFO] using hugging face custom model key: {hf_key}')
48
+ model_key = hf_key
49
+ elif self.sd_version == '2.1':
50
+ model_key = "stabilityai/stable-diffusion-2-1-base"
51
+ elif self.sd_version == '2.0':
52
+ model_key = "stabilityai/stable-diffusion-2-base"
53
+ elif self.sd_version == '1.5':
54
+ model_key = "runwayml/stable-diffusion-v1-5"
55
+ else:
56
+ raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
57
+
58
+ # Create model
59
+ self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
60
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
61
+ self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
62
+ self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
63
+
64
+ self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
65
+
66
+ print(f'[INFO] loaded stable diffusion!')
67
+
68
+ @torch.no_grad()
69
+ def get_random_background(self, n_samples):
70
+ # sample random background with a constant rgb value
71
+ backgrounds = torch.rand(n_samples, 3, device=self.device)[:, :, None, None].repeat(1, 1, 512, 512)
72
+ return torch.cat([self.encode_imgs(bg.unsqueeze(0)) for bg in backgrounds])
73
+
74
+ @torch.no_grad()
75
+ def get_text_embeds(self, prompt, negative_prompt):
76
+ # Tokenize text and get embeddings
77
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
78
+ truncation=True, return_tensors='pt')
79
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
80
+
81
+ # Do the same for unconditional embeddings
82
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
83
+ return_tensors='pt')
84
+
85
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
86
+
87
+ # Cat for final embeddings
88
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
89
+ return text_embeddings
90
+
91
+ @torch.no_grad()
92
+ def encode_imgs(self, imgs):
93
+ imgs = 2 * imgs - 1
94
+ posterior = self.vae.encode(imgs).latent_dist
95
+ latents = posterior.sample() * 0.18215
96
+ return latents
97
+
98
+ @torch.no_grad()
99
+ def decode_latents(self, latents):
100
+ latents = 1 / 0.18215 * latents
101
+ imgs = self.vae.decode(latents).sample
102
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
103
+ return imgs
104
+
105
+ @torch.no_grad()
106
+ def generate(self, masks, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50,
107
+ guidance_scale=7.5, bootstrapping=20):
108
+
109
+ # get bootstrapping backgrounds
110
+ # can move this outside of the function to speed up generation. i.e., calculate in init
111
+ bootstrapping_backgrounds = self.get_random_background(bootstrapping)
112
+
113
+ # Prompts -> text embeds
114
+ text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2 * len(prompts), 77, 768]
115
+
116
+ # Define panorama grid and get views
117
+ latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
118
+ noise = latent.clone().repeat(len(prompts) - 1, 1, 1, 1)
119
+ views = get_views(height, width)
120
+ count = torch.zeros_like(latent)
121
+ value = torch.zeros_like(latent)
122
+
123
+ self.scheduler.set_timesteps(num_inference_steps)
124
+
125
+ with torch.autocast('cuda'):
126
+ for i, t in enumerate(self.scheduler.timesteps):
127
+ count.zero_()
128
+ value.zero_()
129
+
130
+ for h_start, h_end, w_start, w_end in views:
131
+ masks_view = masks[:, :, h_start:h_end, w_start:w_end]
132
+ latent_view = latent[:, :, h_start:h_end, w_start:w_end].repeat(len(prompts), 1, 1, 1)
133
+ if i < bootstrapping:
134
+ bg = bootstrapping_backgrounds[torch.randint(0, bootstrapping, (len(prompts) - 1,))]
135
+ bg = self.scheduler.add_noise(bg, noise[:, :, h_start:h_end, w_start:w_end], t)
136
+ latent_view[1:] = latent_view[1:] * masks_view[1:] + bg * (1 - masks_view[1:])
137
+
138
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
139
+ latent_model_input = torch.cat([latent_view] * 2)
140
+
141
+ # predict the noise residual
142
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
143
+
144
+ # perform guidance
145
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
146
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
147
+
148
+ # compute the denoising step with the reference model
149
+ latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
150
+
151
+ value[:, :, h_start:h_end, w_start:w_end] += (latents_view_denoised * masks_view).sum(dim=0,
152
+ keepdims=True)
153
+ count[:, :, h_start:h_end, w_start:w_end] += masks_view.sum(dim=0, keepdims=True)
154
+
155
+ # take the MultiDiffusion step
156
+ latent = torch.where(count > 0, value / count, value)
157
+
158
+ # Img latents -> imgs
159
+ imgs = self.decode_latents(latent) # [1, 3, 512, 512]
160
+ img = T.ToPILImage()(imgs[0].cpu())
161
+ return img
162
+
163
+
164
+ def preprocess_mask(mask_path, h, w, device):
165
+ mask = np.array(Image.open(mask_path).convert("L"))
166
+ mask = mask.astype(np.float32) / 255.0
167
+ mask = mask[None, None]
168
+ mask[mask < 0.5] = 0
169
+ mask[mask >= 0.5] = 1
170
+ mask = torch.from_numpy(mask).to(device)
171
+ mask = torch.nn.functional.interpolate(mask, size=(h, w), mode='nearest')
172
+ return mask
173
+
174
+
175
+ if __name__ == '__main__':
176
+ parser = argparse.ArgumentParser()
177
+ parser.add_argument('--mask_paths', type=list)
178
+ parser.add_argument('--bg_prompt', type=str)
179
+ parser.add_argument('--bg_negative', type=str) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
180
+ parser.add_argument('--fg_prompts', type=list)
181
+ parser.add_argument('--fg_negative', type=list) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
182
+ parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'],
183
+ help="stable diffusion version")
184
+ parser.add_argument('--H', type=int, default=768)
185
+ parser.add_argument('--W', type=int, default=512)
186
+ parser.add_argument('--seed', type=int, default=0)
187
+ parser.add_argument('--steps', type=int, default=50)
188
+ parser.add_argument('--bootstrapping', type=int, default=20)
189
+ opt = parser.parse_args()
190
+
191
+ seed_everything(opt.seed)
192
+
193
+ device = torch.device('cuda')
194
+
195
+ sd = MultiDiffusion(device, opt.sd_version)
196
+
197
+ fg_masks = torch.cat([preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.mask_paths])
198
+ bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
199
+ bg_mask[bg_mask < 0] = 0
200
+ masks = torch.cat([bg_mask, fg_masks])
201
+
202
+ prompts = [opt.bg_prompt] + opt.fg_prompts
203
+ neg_prompts = [opt.bg_negative] + opt.fg_negative
204
+
205
+ img = sd.generate(masks, prompts, neg_prompts, opt.H, opt.W, opt.steps, bootstrapping=opt.bootstrapping)
206
+
207
+ # save image
208
+ img.save('out.png')