RamAnanth1 commited on
Commit
514015e
β€’
1 Parent(s): b6b5d48

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +129 -0
utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+
5
+ from lvdm.models.modules.lora import net_load_lora
6
+ from lvdm.utils.common_utils import instantiate_from_config
7
+
8
+
9
+ # ------------------------------------------------------------------------------------------
10
+ def load_model(config, ckpt_path, gpu_id=None, inject_lora=False, lora_scale=1.0, lora_path=''):
11
+ print(f"Loading model from {ckpt_path}")
12
+
13
+ # load sd
14
+ pl_sd = torch.load(ckpt_path, map_location="cpu")
15
+ try:
16
+ global_step = pl_sd["global_step"]
17
+ epoch = pl_sd["epoch"]
18
+ except:
19
+ global_step = -1
20
+ epoch = -1
21
+
22
+ # load sd to model
23
+ try:
24
+ sd = pl_sd["state_dict"]
25
+ except:
26
+ sd = pl_sd
27
+ model = instantiate_from_config(config.model)
28
+ model.load_state_dict(sd, strict=True)
29
+
30
+ if inject_lora:
31
+ net_load_lora(model, lora_path, alpha=lora_scale)
32
+
33
+ # move to device & eval
34
+ if gpu_id is not None:
35
+ model.to(f"cuda:{gpu_id}")
36
+ else:
37
+ model.cuda()
38
+ model.eval()
39
+
40
+ return model, global_step, epoch
41
+
42
+
43
+ # ------------------------------------------------------------------------------------------
44
+ @torch.no_grad()
45
+ def get_conditions(prompts, model, batch_size, cond_fps=None,):
46
+
47
+ if isinstance(prompts, str) or isinstance(prompts, int):
48
+ prompts = [prompts]
49
+ if isinstance(prompts, list):
50
+ if len(prompts) == 1:
51
+ prompts = prompts * batch_size
52
+ elif len(prompts) == batch_size:
53
+ pass
54
+ else:
55
+ raise ValueError(f"invalid prompts length: {len(prompts)}")
56
+ else:
57
+ raise ValueError(f"invalid prompts: {prompts}")
58
+ assert(len(prompts) == batch_size)
59
+
60
+ # content condition: text / class label
61
+ c = model.get_learned_conditioning(prompts)
62
+ key = 'c_concat' if model.conditioning_key == 'concat' else 'c_crossattn'
63
+ c = {key: [c]}
64
+
65
+ # temporal condition: fps
66
+ if getattr(model, 'cond_stage2_config', None) is not None:
67
+ if model.cond_stage2_key == "temporal_context":
68
+ assert(cond_fps is not None)
69
+ batch = {'fps': torch.tensor([cond_fps] * batch_size).long().to(model.device)}
70
+ fps_embd = model.cond_stage2_model(batch)
71
+ c[model.cond_stage2_key] = fps_embd
72
+
73
+ return c
74
+
75
+
76
+ # ------------------------------------------------------------------------------------------
77
+ def make_model_input_shape(model, batch_size, T=None):
78
+ image_size = [model.image_size, model.image_size] if isinstance(model.image_size, int) else model.image_size
79
+ C = model.model.diffusion_model.in_channels
80
+ if T is None:
81
+ T = model.model.diffusion_model.temporal_length
82
+ shape = [batch_size, C, T, *image_size]
83
+ return shape
84
+
85
+
86
+ # ------------------------------------------------------------------------------------------
87
+ def custom_to_pil(x):
88
+ x = x.detach().cpu()
89
+ x = torch.clamp(x, -1., 1.)
90
+ x = (x + 1.) / 2.
91
+ x = x.permute(1, 2, 0).numpy()
92
+ x = (255 * x).astype(np.uint8)
93
+ x = Image.fromarray(x)
94
+ if not x.mode == "RGB":
95
+ x = x.convert("RGB")
96
+ return x
97
+
98
+ def torch_to_np(x):
99
+ # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
100
+ sample = x.detach().cpu()
101
+ sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
102
+ if sample.dim() == 5:
103
+ sample = sample.permute(0, 2, 3, 4, 1)
104
+ else:
105
+ sample = sample.permute(0, 2, 3, 1)
106
+ sample = sample.contiguous()
107
+ return sample
108
+
109
+ def make_sample_dir(opt, global_step=None, epoch=None):
110
+ if not getattr(opt, 'not_automatic_logdir', False):
111
+ gs_str = f"globalstep{global_step:09}" if global_step is not None else "None"
112
+ e_str = f"epoch{epoch:06}" if epoch is not None else "None"
113
+ ckpt_dir = os.path.join(opt.logdir, f"{gs_str}_{e_str}")
114
+
115
+ # subdir name
116
+ if opt.prompt_file is not None:
117
+ subdir = f"prompts_{os.path.splitext(os.path.basename(opt.prompt_file))[0]}"
118
+ else:
119
+ subdir = f"prompt_{opt.prompt[:10]}"
120
+ subdir += "_DDPM" if opt.vanilla_sample else f"_DDIM{opt.custom_steps}steps"
121
+ subdir += f"_CfgScale{opt.scale}"
122
+ if opt.cond_fps is not None:
123
+ subdir += f"_fps{opt.cond_fps}"
124
+ if opt.seed is not None:
125
+ subdir += f"_seed{opt.seed}"
126
+
127
+ return os.path.join(ckpt_dir, subdir)
128
+ else:
129
+ return opt.logdir