|
from transformers import CLIPTextModel, CLIPTokenizer, logging |
|
from diffusers import ( |
|
AutoencoderKL, |
|
UNet2DConditionModel, |
|
PNDMScheduler, |
|
DDIMScheduler, |
|
StableDiffusionPipeline, |
|
) |
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
|
|
|
logging.set_verbosity_error() |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def seed_everything(seed): |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
|
|
|
|
class StableDiffusion(nn.Module): |
|
def __init__( |
|
self, |
|
device, |
|
fp16=True, |
|
vram_O=False, |
|
sd_version="2.1", |
|
hf_key=None, |
|
t_range=[0.02, 0.98], |
|
): |
|
super().__init__() |
|
|
|
self.device = device |
|
self.sd_version = sd_version |
|
|
|
if hf_key is not None: |
|
print(f"[INFO] using hugging face custom model key: {hf_key}") |
|
model_key = hf_key |
|
elif self.sd_version == "2.1": |
|
model_key = "stabilityai/stable-diffusion-2-1-base" |
|
elif self.sd_version == "2.0": |
|
model_key = "stabilityai/stable-diffusion-2-base" |
|
elif self.sd_version == "1.5": |
|
model_key = "runwayml/stable-diffusion-v1-5" |
|
else: |
|
raise ValueError( |
|
f"Stable-diffusion version {self.sd_version} not supported." |
|
) |
|
|
|
self.dtype = torch.float16 if fp16 else torch.float32 |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
model_key, torch_dtype=self.dtype |
|
) |
|
|
|
if vram_O: |
|
pipe.enable_sequential_cpu_offload() |
|
pipe.enable_vae_slicing() |
|
pipe.unet.to(memory_format=torch.channels_last) |
|
pipe.enable_attention_slicing(1) |
|
|
|
else: |
|
pipe.to(device) |
|
|
|
self.vae = pipe.vae |
|
self.tokenizer = pipe.tokenizer |
|
self.text_encoder = pipe.text_encoder |
|
self.unet = pipe.unet |
|
|
|
self.scheduler = DDIMScheduler.from_pretrained( |
|
model_key, subfolder="scheduler", torch_dtype=self.dtype |
|
) |
|
|
|
del pipe |
|
|
|
self.num_train_timesteps = self.scheduler.config.num_train_timesteps |
|
self.min_step = int(self.num_train_timesteps * t_range[0]) |
|
self.max_step = int(self.num_train_timesteps * t_range[1]) |
|
self.alphas = self.scheduler.alphas_cumprod.to(self.device) |
|
|
|
self.embeddings = None |
|
|
|
@torch.no_grad() |
|
def get_text_embeds(self, prompts, negative_prompts): |
|
pos_embeds = self.encode_text(prompts) |
|
neg_embeds = self.encode_text(negative_prompts) |
|
self.embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) |
|
|
|
def encode_text(self, prompt): |
|
|
|
inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
return_tensors="pt", |
|
) |
|
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] |
|
return embeddings |
|
|
|
@torch.no_grad() |
|
def refine(self, pred_rgb, |
|
guidance_scale=100, steps=50, strength=0.8, |
|
): |
|
|
|
batch_size = pred_rgb.shape[0] |
|
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) |
|
latents = self.encode_imgs(pred_rgb_512.to(self.dtype)) |
|
|
|
|
|
self.scheduler.set_timesteps(steps) |
|
init_step = int(steps * strength) |
|
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) |
|
|
|
for i, t in enumerate(self.scheduler.timesteps[init_step:]): |
|
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
|
noise_pred = self.unet( |
|
latent_model_input, t, encoder_hidden_states=self.embeddings, |
|
).sample |
|
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
imgs = self.decode_latents(latents) |
|
return imgs |
|
|
|
def train_step( |
|
self, |
|
pred_rgb, |
|
step_ratio=None, |
|
guidance_scale=100, |
|
as_latent=False, |
|
): |
|
|
|
batch_size = pred_rgb.shape[0] |
|
pred_rgb = pred_rgb.to(self.dtype) |
|
|
|
if as_latent: |
|
latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 |
|
else: |
|
|
|
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) |
|
|
|
latents = self.encode_imgs(pred_rgb_512) |
|
|
|
if step_ratio is not None: |
|
|
|
|
|
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) |
|
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) |
|
else: |
|
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) |
|
|
|
|
|
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
noise = torch.randn_like(latents) |
|
latents_noisy = self.scheduler.add_noise(latents, noise, t) |
|
|
|
latent_model_input = torch.cat([latents_noisy] * 2) |
|
tt = torch.cat([t] * 2) |
|
|
|
noise_pred = self.unet( |
|
latent_model_input, tt, encoder_hidden_states=self.embeddings.repeat(batch_size, 1, 1) |
|
).sample |
|
|
|
|
|
noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_pos - noise_pred_uncond |
|
) |
|
|
|
grad = w * (noise_pred - noise) |
|
grad = torch.nan_to_num(grad) |
|
|
|
|
|
|
|
|
|
target = (latents - grad).detach() |
|
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] |
|
|
|
return loss |
|
|
|
@torch.no_grad() |
|
def produce_latents( |
|
self, |
|
height=512, |
|
width=512, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
latents=None, |
|
): |
|
if latents is None: |
|
latents = torch.randn( |
|
( |
|
self.embeddings.shape[0] // 2, |
|
self.unet.in_channels, |
|
height // 8, |
|
width // 8, |
|
), |
|
device=self.device, |
|
) |
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
|
|
for i, t in enumerate(self.scheduler.timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
|
noise_pred = self.unet( |
|
latent_model_input, t, encoder_hidden_states=self.embeddings |
|
).sample |
|
|
|
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_cond - noise_pred_uncond |
|
) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
return latents |
|
|
|
def decode_latents(self, latents): |
|
latents = 1 / self.vae.config.scaling_factor * latents |
|
|
|
imgs = self.vae.decode(latents).sample |
|
imgs = (imgs / 2 + 0.5).clamp(0, 1) |
|
|
|
return imgs |
|
|
|
def encode_imgs(self, imgs): |
|
|
|
|
|
imgs = 2 * imgs - 1 |
|
|
|
posterior = self.vae.encode(imgs).latent_dist |
|
latents = posterior.sample() * self.vae.config.scaling_factor |
|
|
|
return latents |
|
|
|
def prompt_to_img( |
|
self, |
|
prompts, |
|
negative_prompts="", |
|
height=512, |
|
width=512, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
latents=None, |
|
): |
|
if isinstance(prompts, str): |
|
prompts = [prompts] |
|
|
|
if isinstance(negative_prompts, str): |
|
negative_prompts = [negative_prompts] |
|
|
|
|
|
self.get_text_embeds(prompts, negative_prompts) |
|
|
|
|
|
latents = self.produce_latents( |
|
height=height, |
|
width=width, |
|
latents=latents, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
) |
|
|
|
|
|
imgs = self.decode_latents(latents) |
|
|
|
|
|
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() |
|
imgs = (imgs * 255).round().astype("uint8") |
|
|
|
return imgs |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
import matplotlib.pyplot as plt |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("prompt", type=str) |
|
parser.add_argument("--negative", default="", type=str) |
|
parser.add_argument( |
|
"--sd_version", |
|
type=str, |
|
default="2.1", |
|
choices=["1.5", "2.0", "2.1"], |
|
help="stable diffusion version", |
|
) |
|
parser.add_argument( |
|
"--hf_key", |
|
type=str, |
|
default=None, |
|
help="hugging face Stable diffusion model key", |
|
) |
|
parser.add_argument("--fp16", action="store_true", help="use float16 for training") |
|
parser.add_argument( |
|
"--vram_O", action="store_true", help="optimization for low VRAM usage" |
|
) |
|
parser.add_argument("-H", type=int, default=512) |
|
parser.add_argument("-W", type=int, default=512) |
|
parser.add_argument("--seed", type=int, default=0) |
|
parser.add_argument("--steps", type=int, default=50) |
|
opt = parser.parse_args() |
|
|
|
seed_everything(opt.seed) |
|
|
|
device = torch.device("cuda") |
|
|
|
sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key) |
|
|
|
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) |
|
|
|
|
|
plt.imshow(imgs[0]) |
|
plt.show() |
|
|