from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
# suppress partial model loading warning
import torch
import torch.nn as nn
import torchvision.transforms as T
import argparse
import numpy as np
from PIL import Image
def seed_everything(seed):
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
def get_views(panorama_height, panorama_width, window_size=64, stride=8):
panorama_height /= 8
panorama_width /= 8
num_blocks_height = (panorama_height - window_size) // stride + 1
num_blocks_width = (panorama_width - window_size) // stride + 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * stride)
h_end = h_start + window_size
w_start = int((i % num_blocks_width) * stride)
w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end))
return views
class MultiDiffusion(nn.Module):
def __init__(self, device, sd_version='2.0', hf_key=None):
self.device = device
self.sd_version = sd_version
print(f'[INFO] loading stable diffusion...')
if hf_key is not None:
print(f'[INFO] using hugging face custom model key: {hf_key}')
model_key = hf_key
elif self.sd_version == '2.1':
model_key = "stabilityai/stable-diffusion-2-1-base"
elif self.sd_version == '2.0':
model_key = "stabilityai/stable-diffusion-2-base"
elif self.sd_version == '1.5':
model_key = "runwayml/stable-diffusion-v1-5"
model_key = self.sd_version #For custom models or fine-tunes, allow people to use arbitrary versions
#raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
# Create model
self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
print(f'[INFO] loaded stable diffusion!')
def get_random_background(self, n_samples):
# sample random background with a constant rgb value
backgrounds = torch.rand(n_samples, 3, device=self.device)[:, :, None, None].repeat(1, 1, 512, 512)
return[self.encode_imgs(bg.unsqueeze(0)) for bg in backgrounds])
def get_text_embeds(self, prompt, negative_prompt):
# Tokenize text and get embeddings
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors='pt')
text_embeddings = self.text_encoder([0]
# Do the same for unconditional embeddings
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
uncond_embeddings = self.text_encoder([0]
# Cat for final embeddings
text_embeddings =[uncond_embeddings, text_embeddings])
return text_embeddings
def encode_imgs(self, imgs):
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * 0.18215
return latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def generate(self, masks, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50,
guidance_scale=7.5, bootstrapping=20):
# get bootstrapping backgrounds
# can move this outside of the function to speed up generation. i.e., calculate in init
bootstrapping_backgrounds = self.get_random_background(bootstrapping)
# Prompts -> text embeds
text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2 * len(prompts), 77, 768]
# Define panorama grid and get views
latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
noise = latent.clone().repeat(len(prompts) - 1, 1, 1, 1)
views = get_views(height, width)
count = torch.zeros_like(latent)
value = torch.zeros_like(latent)
with torch.autocast('cuda'):
for i, t in enumerate(self.scheduler.timesteps):
for h_start, h_end, w_start, w_end in views:
masks_view = masks[:, :, h_start:h_end, w_start:w_end]
latent_view = latent[:, :, h_start:h_end, w_start:w_end].repeat(len(prompts), 1, 1, 1)
if i < bootstrapping:
bg = bootstrapping_backgrounds[torch.randint(0, bootstrapping, (len(prompts) - 1,))]
bg = self.scheduler.add_noise(bg, noise[:, :, h_start:h_end, w_start:w_end], t)
latent_view[1:] = latent_view[1:] * masks_view[1:] + bg * (1 - masks_view[1:])
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input =[latent_view] * 2)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the denoising step with the reference model
latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
value[:, :, h_start:h_end, w_start:w_end] += (latents_view_denoised * masks_view).sum(dim=0,
count[:, :, h_start:h_end, w_start:w_end] += masks_view.sum(dim=0, keepdims=True)
# take the MultiDiffusion step
latent = torch.where(count > 0, value / count, value)
# Img latents -> imgs
imgs = self.decode_latents(latent) # [1, 3, 512, 512]
img = T.ToPILImage()(imgs[0].cpu())
return img
def preprocess_mask(mask_path, h, w, device):
mask = np.array("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask).to(device)
mask = torch.nn.functional.interpolate(mask, size=(h, w), mode='nearest')
return mask
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mask_paths', type=list)
parser.add_argument('--bg_prompt', type=str)
parser.add_argument('--bg_negative', type=str) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
parser.add_argument('--fg_prompts', type=list)
parser.add_argument('--fg_negative', type=list) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'],
help="stable diffusion version")
parser.add_argument('--H', type=int, default=768)
parser.add_argument('--W', type=int, default=512)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--steps', type=int, default=50)
parser.add_argument('--bootstrapping', type=int, default=20)
opt = parser.parse_args()
device = torch.device('cuda')
sd = MultiDiffusion(device, opt.sd_version)
fg_masks =[preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.mask_paths])
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks =[bg_mask, fg_masks])
prompts = [opt.bg_prompt] + opt.fg_prompts
neg_prompts = [opt.bg_negative] + opt.fg_negative
img = sd.generate(masks, prompts, neg_prompts, opt.H, opt.W, opt.steps, bootstrapping=opt.bootstrapping)
# save image'out.png')