File size: 4,600 Bytes
1f39cf9
 
 
 
 
 
 
 
 
 
ec7f11c
1f39cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec7f11c
 
1f39cf9
ec7f11c
1f39cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61ac46b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, DDIMInverseScheduler, DPMSolverMultistepScheduler
from .unet_2d_condition import UNet2DConditionModel
from easydict import EasyDict
import numpy as np
# For compatibility
from utils.latents import get_unscaled_latents, get_scaled_latents, blend_latents
from utils import torch_device

def load_sd(key="runwayml/stable-diffusion-v1-5", use_fp16=False, load_inverse_scheduler=True):
    """
    Keys:
     key = "CompVis/stable-diffusion-v1-4"
     key = "runwayml/stable-diffusion-v1-5"
     key = "stabilityai/stable-diffusion-2-1-base"
     
    Unpack with:
    ```
    model_dict = load_sd(key=key, use_fp16=use_fp16)
    vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
    ```
    
    use_fp16: fp16 might have degraded performance
    """
    
    # run final results in fp32
    if use_fp16:
        dtype = torch.float16
        revision = "fp16"
    else:
        dtype = torch.float
        revision = "main"
        
    vae = AutoencoderKL.from_pretrained(key, subfolder="vae", revision=revision, torch_dtype=dtype).to(torch_device)
    tokenizer = CLIPTokenizer.from_pretrained(key, subfolder="tokenizer", revision=revision, torch_dtype=dtype)
    text_encoder = CLIPTextModel.from_pretrained(key, subfolder="text_encoder", revision=revision, torch_dtype=dtype).to(torch_device)
    unet = UNet2DConditionModel.from_pretrained(key, subfolder="unet", revision=revision, torch_dtype=dtype).to(torch_device)
    dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype)
    scheduler = DDIMScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype)

    model_dict = EasyDict(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, dpm_scheduler=dpm_scheduler, dtype=dtype)
    
    if load_inverse_scheduler:
        inverse_scheduler = DDIMInverseScheduler.from_config(scheduler.config)
        model_dict.inverse_scheduler = inverse_scheduler
    
    return model_dict

def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_full_only=False, one_uncond_input_only=False):
    if negative_prompt == "":
        print("Note that negative_prompt is an empty string")
    
    text_input = tokenizer(
        prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
    )
    
    max_length = text_input.input_ids.shape[-1]
    if one_uncond_input_only:
        num_uncond_input = 1
    else:
        num_uncond_input = len(prompts)
    uncond_input = tokenizer([negative_prompt] * num_uncond_input, padding="max_length", max_length=max_length, return_tensors="pt")

    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
        cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    
    if one_uncond_input_only:
        return uncond_embeddings, cond_embeddings
    
    text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
    
    if return_full_only:
        return text_embeddings
    return text_embeddings, uncond_embeddings, cond_embeddings

def process_input_embeddings(input_embeddings):
    assert isinstance(input_embeddings, (tuple, list))
    if len(input_embeddings) == 3:
        # input_embeddings: text_embeddings, uncond_embeddings, cond_embeddings
        # Assume `uncond_embeddings` is full (has batch size the same as cond_embeddings)
        _, uncond_embeddings, cond_embeddings = input_embeddings
        assert uncond_embeddings.shape[0] == cond_embeddings.shape[0], f"{uncond_embeddings.shape[0]} != {cond_embeddings.shape[0]}"
        return input_embeddings
    elif len(input_embeddings) == 2:
        # input_embeddings: uncond_embeddings, cond_embeddings
        # uncond_embeddings may have only one item
        uncond_embeddings, cond_embeddings = input_embeddings
        if uncond_embeddings.shape[0] == 1:
            uncond_embeddings = uncond_embeddings.expand(cond_embeddings.shape)
        # We follow the convention: negative (unconditional) prompt comes first
        text_embeddings = torch.cat((uncond_embeddings, cond_embeddings), dim=0)
        return text_embeddings, uncond_embeddings, cond_embeddings
    else:
        raise ValueError(f"input_embeddings length: {len(input_embeddings)}")