Tony Lian
Allow using different schedulers and negative prompts
ec7f11c
raw history blame
No virus
4.23 kB
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, DDIMInverseScheduler, DPMSolverMultistepScheduler
from .unet_2d_condition import UNet2DConditionModel
from easydict import EasyDict
import numpy as np
# For compatibility
from utils.latents import get_unscaled_latents, get_scaled_latents, blend_latents
from utils import torch_device
def load_sd(key="runwayml/stable-diffusion-v1-5", use_fp16=False, load_inverse_scheduler=True):
"""
Keys:
key = "CompVis/stable-diffusion-v1-4"
key = "runwayml/stable-diffusion-v1-5"
key = "stabilityai/stable-diffusion-2-1-base"
Unpack with:
```
model_dict = load_sd(key=key, use_fp16=use_fp16)
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
```
use_fp16: fp16 might have degraded performance
"""
# run final results in fp32
if use_fp16:
dtype = torch.float16
revision = "fp16"
else:
dtype = torch.float
revision = "main"
vae = AutoencoderKL.from_pretrained(key, subfolder="vae", revision=revision, torch_dtype=dtype).to(torch_device)
tokenizer = CLIPTokenizer.from_pretrained(key, subfolder="tokenizer", revision=revision, torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(key, subfolder="text_encoder", revision=revision, torch_dtype=dtype).to(torch_device)
unet = UNet2DConditionModel.from_pretrained(key, subfolder="unet", revision=revision, torch_dtype=dtype).to(torch_device)
dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype)
scheduler = DDIMScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype)
model_dict = EasyDict(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, dpm_scheduler=dpm_scheduler, dtype=dtype)
if load_inverse_scheduler:
inverse_scheduler = DDIMInverseScheduler.from_config(scheduler.config)
model_dict.inverse_scheduler = inverse_scheduler
return model_dict
def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_full_only=False, one_uncond_input_only=False):
if negative_prompt == "":
print("Note that negative_prompt is an empty string")
text_input = tokenizer(
prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
)
max_length = text_input.input_ids.shape[-1]
if one_uncond_input_only:
num_uncond_input = 1
else:
num_uncond_input = len(prompts)
uncond_input = tokenizer([negative_prompt] * num_uncond_input, padding="max_length", max_length=max_length, return_tensors="pt")
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
if one_uncond_input_only:
return uncond_embeddings, cond_embeddings
text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
if return_full_only:
return text_embeddings
return text_embeddings, uncond_embeddings, cond_embeddings
def attn_list_to_tensor(cross_attention_probs):
# timestep, CrossAttnBlock, Transformer2DModel, 1xBasicTransformerBlock
num_cross_attn_block = len(cross_attention_probs[0])
cross_attention_probs_all = []
for i in range(num_cross_attn_block):
# cross_attention_probs_timestep[i]: Transformer2DModel
# 1xBasicTransformerBlock is skipped
cross_attention_probs_current = []
for cross_attention_probs_timestep in cross_attention_probs:
cross_attention_probs_current.append(torch.stack([item for item in cross_attention_probs_timestep[i]], dim=0))
cross_attention_probs_current = torch.stack(cross_attention_probs_current, dim=0)
cross_attention_probs_all.append(cross_attention_probs_current)
return cross_attention_probs_all