|
from transformers import CLIPTextModel, CLIPTokenizer, logging |
|
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler |
|
|
|
|
|
logging.set_verbosity_error() |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import time |
|
|
|
class StableDiffusion(nn.Module): |
|
def __init__(self, device): |
|
super().__init__() |
|
|
|
try: |
|
with open('./TOKEN', 'r') as f: |
|
self.token = f.read() |
|
print(f'[INFO] successfully loaded hugging face user token!') |
|
except FileNotFoundError as e: |
|
print(e) |
|
print(f'[INFO] Please first create a file called TOKEN and copy your hugging face access token into it to download stable diffusion checkpoints.') |
|
|
|
self.device = device |
|
self.num_train_timesteps = 1000 |
|
self.min_step = int(self.num_train_timesteps * 0.02) |
|
self.max_step = int(self.num_train_timesteps * 0.98) |
|
|
|
print(f'[INFO] loading stable diffusion...') |
|
|
|
|
|
self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=self.token).to(self.device) |
|
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device) |
|
|
|
|
|
self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=self.token).to(self.device) |
|
|
|
|
|
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=self.num_train_timesteps) |
|
|
|
print(f'[INFO] loaded stable diffusion!') |
|
|
|
def get_text_embeds(self, prompt): |
|
|
|
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] |
|
|
|
|
|
uncond_input = self.tokenizer([''] * len(prompt), padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] |
|
|
|
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
return text_embeddings |
|
|
|
|
|
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100): |
|
|
|
|
|
|
|
|
|
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=self.device) |
|
|
|
|
|
|
|
latents = self.encode_imgs(pred_rgb_512) |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
noise = torch.randn_like(latents) |
|
latents_noisy = self.scheduler.add_noise(latents, noise, t) |
|
|
|
latent_model_input = torch.cat([latents_noisy] * 2) |
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
|
|
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
w = (1 - self.scheduler.alphas_cumprod[t]).to(self.device) |
|
grad = w * (noise_pred - noise) |
|
|
|
|
|
|
|
|
|
|
|
|
|
latents.backward(gradient=grad, retain_graph=True) |
|
|
|
|
|
return 0 |
|
|
|
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): |
|
|
|
if latents is None: |
|
latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device) |
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
|
|
with torch.autocast('cuda'): |
|
for i, t in enumerate(self.scheduler.timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
|
|
|
with torch.no_grad(): |
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] |
|
|
|
return latents |
|
|
|
def decode_latents(self, latents): |
|
|
|
latents = 1 / 0.18215 * latents |
|
|
|
with torch.no_grad(): |
|
imgs = self.vae.decode(latents).sample |
|
|
|
imgs = (imgs / 2 + 0.5).clamp(0, 1) |
|
|
|
return imgs |
|
|
|
def encode_imgs(self, imgs): |
|
|
|
|
|
imgs = 2 * imgs - 1 |
|
|
|
posterior = self.vae.encode(imgs).latent_dist |
|
latents = posterior.sample() * 0.18215 |
|
|
|
return latents |
|
|
|
def prompt_to_img(self, prompts, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): |
|
|
|
if isinstance(prompts, str): |
|
prompts = [prompts] |
|
|
|
|
|
text_embeds = self.get_text_embeds(prompts) |
|
|
|
|
|
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) |
|
|
|
|
|
imgs = self.decode_latents(latents) |
|
|
|
|
|
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() |
|
imgs = (imgs * 255).round().astype('uint8') |
|
|
|
return imgs |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
import argparse |
|
import matplotlib.pyplot as plt |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('prompt', type=str) |
|
parser.add_argument('-H', type=int, default=512) |
|
parser.add_argument('-W', type=int, default=512) |
|
parser.add_argument('--steps', type=int, default=50) |
|
opt = parser.parse_args() |
|
|
|
device = torch.device('cuda') |
|
|
|
sd = StableDiffusion(device) |
|
|
|
imgs = sd.prompt_to_img(opt.prompt, opt.H, opt.W, opt.steps) |
|
|
|
|
|
plt.imshow(imgs[0]) |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|