File size: 4,424 Bytes
153e804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import torch
from PIL import Image

from lvdm.models.modules.lora import net_load_lora
from lvdm.utils.common_utils import instantiate_from_config


# ------------------------------------------------------------------------------------------
def load_model(config, ckpt_path, gpu_id=None, inject_lora=False, lora_scale=1.0, lora_path=''):
    print(f"Loading model from {ckpt_path}")
    
    # load sd
    pl_sd = torch.load(ckpt_path, map_location="cpu")
    try:
        global_step = pl_sd["global_step"]
        epoch = pl_sd["epoch"]
    except:
        global_step = -1
        epoch = -1
    
    # load sd to model
    try:
        sd = pl_sd["state_dict"]
    except:
        sd = pl_sd
    model = instantiate_from_config(config.model)
    model.load_state_dict(sd, strict=True)

    if inject_lora:
        net_load_lora(model, lora_path, alpha=lora_scale)
    
    # move to device & eval
    if gpu_id is not None:
        model.to(f"cuda:{gpu_id}")
    else:
        model.cuda()
    model.eval()

    return model, global_step, epoch


# ------------------------------------------------------------------------------------------
@torch.no_grad()
def get_conditions(prompts, model, batch_size, cond_fps=None,):
    
    if isinstance(prompts, str) or isinstance(prompts, int):
        prompts = [prompts]
    if isinstance(prompts, list):
        if len(prompts) == 1:
            prompts = prompts * batch_size
        elif len(prompts) == batch_size:
            pass
        else:
            raise ValueError(f"invalid prompts length: {len(prompts)}")
    else:
        raise ValueError(f"invalid prompts: {prompts}")
    assert(len(prompts) == batch_size)
    
    # content condition: text / class label
    c = model.get_learned_conditioning(prompts)
    key = 'c_concat' if model.conditioning_key == 'concat' else 'c_crossattn'
    c = {key: [c]}

    # temporal condition: fps
    if getattr(model, 'cond_stage2_config', None) is not None:
        if model.cond_stage2_key == "temporal_context":
            assert(cond_fps is not None)
            batch = {'fps': torch.tensor([cond_fps] * batch_size).long().to(model.device)}
            fps_embd = model.cond_stage2_model(batch)
            c[model.cond_stage2_key] = fps_embd
    
    return c


# ------------------------------------------------------------------------------------------
def make_model_input_shape(model, batch_size, T=None):
    image_size = [model.image_size, model.image_size] if isinstance(model.image_size, int) else model.image_size
    C = model.model.diffusion_model.in_channels
    if T is None:
        T = model.model.diffusion_model.temporal_length
    shape = [batch_size, C, T, *image_size]
    return shape


# ------------------------------------------------------------------------------------------
def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def torch_to_np(x):
    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
    sample = x.detach().cpu()
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    if sample.dim() == 5:
        sample = sample.permute(0, 2, 3, 4, 1)
    else:
        sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    return sample

def make_sample_dir(opt, global_step=None, epoch=None):
    if not getattr(opt, 'not_automatic_logdir', False):
        gs_str = f"globalstep{global_step:09}" if global_step is not None else "None"
        e_str = f"epoch{epoch:06}" if epoch is not None else "None"
        ckpt_dir = os.path.join(opt.logdir, f"{gs_str}_{e_str}")
        
        # subdir name
        if opt.prompt_file is not None:
            subdir = f"prompts_{os.path.splitext(os.path.basename(opt.prompt_file))[0]}"
        else:
            subdir = f"prompt_{opt.prompt[:10]}"
        subdir += "_DDPM" if opt.vanilla_sample else f"_DDIM{opt.custom_steps}steps"
        subdir += f"_CfgScale{opt.scale}"
        if opt.cond_fps is not None:
            subdir += f"_fps{opt.cond_fps}"
        if opt.seed is not None:
            subdir += f"_seed{opt.seed}"

        return os.path.join(ckpt_dir, subdir)
    else:
        return opt.logdir