Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from PIL import Image | |
from lvdm.models.modules.lora import net_load_lora | |
from lvdm.utils.common_utils import instantiate_from_config | |
# ------------------------------------------------------------------------------------------ | |
def load_model(config, ckpt_path, gpu_id=None, inject_lora=False, lora_scale=1.0, lora_path=''): | |
print(f"Loading model from {ckpt_path}") | |
# load sd | |
pl_sd = torch.load(ckpt_path, map_location="cpu") | |
try: | |
global_step = pl_sd["global_step"] | |
epoch = pl_sd["epoch"] | |
except: | |
global_step = -1 | |
epoch = -1 | |
# load sd to model | |
try: | |
sd = pl_sd["state_dict"] | |
except: | |
sd = pl_sd | |
model = instantiate_from_config(config.model) | |
model.load_state_dict(sd, strict=True) | |
if inject_lora: | |
net_load_lora(model, lora_path, alpha=lora_scale) | |
# move to device & eval | |
if gpu_id is not None: | |
model.to(f"cuda:{gpu_id}") | |
else: | |
model.cuda() | |
model.eval() | |
return model, global_step, epoch | |
# ------------------------------------------------------------------------------------------ | |
def get_conditions(prompts, model, batch_size, cond_fps=None,): | |
if isinstance(prompts, str) or isinstance(prompts, int): | |
prompts = [prompts] | |
if isinstance(prompts, list): | |
if len(prompts) == 1: | |
prompts = prompts * batch_size | |
elif len(prompts) == batch_size: | |
pass | |
else: | |
raise ValueError(f"invalid prompts length: {len(prompts)}") | |
else: | |
raise ValueError(f"invalid prompts: {prompts}") | |
assert(len(prompts) == batch_size) | |
# content condition: text / class label | |
c = model.get_learned_conditioning(prompts) | |
key = 'c_concat' if model.conditioning_key == 'concat' else 'c_crossattn' | |
c = {key: [c]} | |
# temporal condition: fps | |
if getattr(model, 'cond_stage2_config', None) is not None: | |
if model.cond_stage2_key == "temporal_context": | |
assert(cond_fps is not None) | |
batch = {'fps': torch.tensor([cond_fps] * batch_size).long().to(model.device)} | |
fps_embd = model.cond_stage2_model(batch) | |
c[model.cond_stage2_key] = fps_embd | |
return c | |
# ------------------------------------------------------------------------------------------ | |
def make_model_input_shape(model, batch_size, T=None): | |
image_size = [model.image_size, model.image_size] if isinstance(model.image_size, int) else model.image_size | |
C = model.model.diffusion_model.in_channels | |
if T is None: | |
T = model.model.diffusion_model.temporal_length | |
shape = [batch_size, C, T, *image_size] | |
return shape | |
# ------------------------------------------------------------------------------------------ | |
def custom_to_pil(x): | |
x = x.detach().cpu() | |
x = torch.clamp(x, -1., 1.) | |
x = (x + 1.) / 2. | |
x = x.permute(1, 2, 0).numpy() | |
x = (255 * x).astype(np.uint8) | |
x = Image.fromarray(x) | |
if not x.mode == "RGB": | |
x = x.convert("RGB") | |
return x | |
def torch_to_np(x): | |
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py | |
sample = x.detach().cpu() | |
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) | |
if sample.dim() == 5: | |
sample = sample.permute(0, 2, 3, 4, 1) | |
else: | |
sample = sample.permute(0, 2, 3, 1) | |
sample = sample.contiguous() | |
return sample | |
def make_sample_dir(opt, global_step=None, epoch=None): | |
if not getattr(opt, 'not_automatic_logdir', False): | |
gs_str = f"globalstep{global_step:09}" if global_step is not None else "None" | |
e_str = f"epoch{epoch:06}" if epoch is not None else "None" | |
ckpt_dir = os.path.join(opt.logdir, f"{gs_str}_{e_str}") | |
# subdir name | |
if opt.prompt_file is not None: | |
subdir = f"prompts_{os.path.splitext(os.path.basename(opt.prompt_file))[0]}" | |
else: | |
subdir = f"prompt_{opt.prompt[:10]}" | |
subdir += "_DDPM" if opt.vanilla_sample else f"_DDIM{opt.custom_steps}steps" | |
subdir += f"_CfgScale{opt.scale}" | |
if opt.cond_fps is not None: | |
subdir += f"_fps{opt.cond_fps}" | |
if opt.seed is not None: | |
subdir += f"_seed{opt.seed}" | |
return os.path.join(ckpt_dir, subdir) | |
else: | |
return opt.logdir | |