|
import torch |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from diffusers import AutoencoderKL, DDIMScheduler, DDIMInverseScheduler, DPMSolverMultistepScheduler |
|
from .unet_2d_condition import UNet2DConditionModel |
|
from easydict import EasyDict |
|
import numpy as np |
|
|
|
from utils.latents import get_unscaled_latents, get_scaled_latents, blend_latents |
|
from utils import torch_device |
|
|
|
def load_sd(key="runwayml/stable-diffusion-v1-5", use_fp16=False, load_inverse_scheduler=True): |
|
""" |
|
Keys: |
|
key = "CompVis/stable-diffusion-v1-4" |
|
key = "runwayml/stable-diffusion-v1-5" |
|
key = "stabilityai/stable-diffusion-2-1-base" |
|
|
|
Unpack with: |
|
``` |
|
model_dict = load_sd(key=key, use_fp16=use_fp16) |
|
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype |
|
``` |
|
|
|
use_fp16: fp16 might have degraded performance |
|
""" |
|
|
|
|
|
if use_fp16: |
|
dtype = torch.float16 |
|
revision = "fp16" |
|
else: |
|
dtype = torch.float |
|
revision = "main" |
|
|
|
vae = AutoencoderKL.from_pretrained(key, subfolder="vae", revision=revision, torch_dtype=dtype).to(torch_device) |
|
tokenizer = CLIPTokenizer.from_pretrained(key, subfolder="tokenizer", revision=revision, torch_dtype=dtype) |
|
text_encoder = CLIPTextModel.from_pretrained(key, subfolder="text_encoder", revision=revision, torch_dtype=dtype).to(torch_device) |
|
unet = UNet2DConditionModel.from_pretrained(key, subfolder="unet", revision=revision, torch_dtype=dtype).to(torch_device) |
|
dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype) |
|
scheduler = DDIMScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype) |
|
|
|
model_dict = EasyDict(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, dpm_scheduler=dpm_scheduler, dtype=dtype) |
|
|
|
if load_inverse_scheduler: |
|
inverse_scheduler = DDIMInverseScheduler.from_config(scheduler.config) |
|
model_dict.inverse_scheduler = inverse_scheduler |
|
|
|
return model_dict |
|
|
|
def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_full_only=False, one_uncond_input_only=False): |
|
if negative_prompt == "": |
|
print("Note that negative_prompt is an empty string") |
|
|
|
text_input = tokenizer( |
|
prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" |
|
) |
|
|
|
max_length = text_input.input_ids.shape[-1] |
|
if one_uncond_input_only: |
|
num_uncond_input = 1 |
|
else: |
|
num_uncond_input = len(prompts) |
|
uncond_input = tokenizer([negative_prompt] * num_uncond_input, padding="max_length", max_length=max_length, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] |
|
cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] |
|
|
|
if one_uncond_input_only: |
|
return uncond_embeddings, cond_embeddings |
|
|
|
text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) |
|
|
|
if return_full_only: |
|
return text_embeddings |
|
return text_embeddings, uncond_embeddings, cond_embeddings |
|
|
|
def process_input_embeddings(input_embeddings): |
|
assert isinstance(input_embeddings, (tuple, list)) |
|
if len(input_embeddings) == 3: |
|
|
|
|
|
_, uncond_embeddings, cond_embeddings = input_embeddings |
|
assert uncond_embeddings.shape[0] == cond_embeddings.shape[0], f"{uncond_embeddings.shape[0]} != {cond_embeddings.shape[0]}" |
|
return input_embeddings |
|
elif len(input_embeddings) == 2: |
|
|
|
|
|
uncond_embeddings, cond_embeddings = input_embeddings |
|
if uncond_embeddings.shape[0] == 1: |
|
uncond_embeddings = uncond_embeddings.expand(cond_embeddings.shape) |
|
|
|
text_embeddings = torch.cat((uncond_embeddings, cond_embeddings), dim=0) |
|
return text_embeddings, uncond_embeddings, cond_embeddings |
|
else: |
|
raise ValueError(f"input_embeddings length: {len(input_embeddings)}") |
|
|