from __future__ import annotations import math import os import queue from typing import Optional, Union import numpy as np import rerun as rr import torch import torchvision from einops import rearrange, repeat from omegaconf import ListConfig, OmegaConf from PIL import Image from safetensors.torch import load_file as load_safetensors from torch import autocast from tqdm import tqdm from .vwm.modules.diffusionmodules.sampling import EulerEDMSampler from .vwm.util import default, instantiate_from_config def init_model(version_dict, load_ckpt=True): config = OmegaConf.load(version_dict["config"]) model = load_model_from_config(config, version_dict["ckpt"] if load_ckpt else None) return model lowvram_mode = True def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode def initial_model_load(model): global lowvram_mode if lowvram_mode: model.model.half() else: model.cuda() return model def load_model(model): model.cuda() def unload_model(model): global lowvram_mode print(lowvram_mode) if lowvram_mode: model.cpu() torch.cuda.empty_cache() torch.cuda.synchronize() def load_model_from_config(config, ckpt=None): model = instantiate_from_config(config.model) print(ckpt) if ckpt is not None: print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_svd = torch.load(ckpt, map_location="cpu") # dict contains: # "epoch", "global_step", "pytorch-lightning_version", # "state_dict", "loops", "callbacks", "optimizer_states", "lr_schedulers" if "global_step" in pl_svd: print(f"Global step: {pl_svd['global_step']}") svd = pl_svd["state_dict"] else: svd = load_safetensors(ckpt) missing, unexpected = model.load_state_dict(svd, strict=False) if len(missing) > 0: print(f"Missing keys: {missing}") if len(unexpected) > 0: print(f"Unexpected keys: {unexpected}") model = initial_model_load(model) model.eval() return model def init_embedder_options(keys): # hardcoded demo settings, might undergo some changes in the future value_dict = dict() for key in keys: if key in ["fps_id", "fps"]: fps = 10 value_dict["fps"] = fps value_dict["fps_id"] = fps - 1 elif key == "motion_bucket_id": value_dict["motion_bucket_id"] = 127 # [0, 511] return value_dict def perform_save_locally(save_path, samples, mode, dataset_name, sample_index): assert mode in ["images", "grids", "videos"] merged_path = os.path.join(save_path, mode) os.makedirs(merged_path, exist_ok=True) samples = samples.cpu() if mode == "images": frame_count = 0 for sample in samples: sample = rearrange(sample.numpy(), "c h w -> h w c") if "real" in save_path: sample = 255.0 * (sample + 1.0) / 2.0 else: sample = 255.0 * sample image_save_path = os.path.join( merged_path, f"{dataset_name}_{sample_index:06}_{frame_count:04}.png" ) # if os.path.exists(image_save_path): # return Image.fromarray(sample.astype(np.uint8)).save(image_save_path) frame_count += 1 elif mode == "grids": grid = torchvision.utils.make_grid(samples, nrow=int(samples.shape[0] ** 0.5)) grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy() if "real" in save_path: grid = 255.0 * (grid + 1.0) / 2.0 else: grid = 255.0 * grid grid_save_path = os.path.join( merged_path, f"{dataset_name}_{sample_index:06}.png" ) # if os.path.exists(grid_save_path): # return Image.fromarray(grid.astype(np.uint8)).save(grid_save_path) elif mode == "videos": img_seq = rearrange(samples.numpy(), "t c h w -> t h w c") if "real" in save_path: img_seq = 255.0 * (img_seq + 1.0) / 2.0 else: img_seq = 255.0 * img_seq video_save_path = os.path.join( merged_path, f"{dataset_name}_{sample_index:06}.mp4" ) # if os.path.exists(video_save_path): # return save_img_seq_to_video(video_save_path, img_seq.astype(np.uint8), 10) else: raise NotImplementedError def init_sampling( sampler="EulerEDMSampler", guider="VanillaCFG", discretization="EDMDiscretization", steps=50, cfg_scale=2.5, num_frames=25, ): discretization_config = get_discretization(discretization) guider_config = get_guider(guider, cfg_scale, num_frames) sampler = get_sampler(sampler, steps, discretization_config, guider_config) return sampler def get_discretization(discretization): if discretization == "LegacyDDPMDiscretization": discretization_config = { "target": "vista.vwm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization" } elif discretization == "EDMDiscretization": discretization_config = { "target": "vista.vwm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": {"sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0}, } else: raise NotImplementedError return discretization_config def get_guider(guider="LinearPredictionGuider", cfg_scale=2.5, num_frames=25): if guider == "IdentityGuider": guider_config = { "target": "vista.vwm.modules.diffusionmodules.guiders.IdentityGuider" } elif guider == "VanillaCFG": scale = cfg_scale guider_config = { "target": "vista.vwm.modules.diffusionmodules.guiders.VanillaCFG", "params": {"scale": scale}, } elif guider == "LinearPredictionGuider": max_scale = cfg_scale min_scale = 1.0 guider_config = { "target": "vista.vwm.modules.diffusionmodules.guiders.LinearPredictionGuider", "params": { "max_scale": max_scale, "min_scale": min_scale, "num_frames": num_frames, }, } elif guider == "TrianglePredictionGuider": max_scale = cfg_scale min_scale = 1.0 guider_config = { "target": "vista.vwm.modules.diffusionmodules.guiders.TrianglePredictionGuider", "params": { "max_scale": max_scale, "min_scale": min_scale, "num_frames": num_frames, }, } else: raise NotImplementedError return guider_config def get_sampler(sampler, steps, discretization_config, guider_config): if sampler == "EulerEDMSampler": s_churn = 0.0 s_tmin = 0.0 s_tmax = 999.0 s_noise = 1.0 sampler = EulerEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=False, ) else: raise ValueError(f"Unknown sampler {sampler}") return sampler def get_batch(keys, value_dict, N: Union[list, ListConfig], device="cuda"): # hardcoded demo setups, might undergo some changes in the future batch = dict() batch_uc = dict() for key in keys: if key in value_dict: if key in ["fps", "fps_id", "motion_bucket_id", "cond_aug"]: batch[key] = repeat( torch.tensor([value_dict[key]]).to(device), "1 -> b", b=math.prod(N) ) elif key in ["command", "trajectory", "speed", "angle", "goal"]: batch[key] = repeat( value_dict[key][None].to(device), "1 ... -> b ...", b=N[0] ) elif key in ["cond_frames", "cond_frames_without_noise"]: batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) else: # batch[key] = value_dict[key] raise NotImplementedError for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def get_condition(model, value_dict, num_samples, force_uc_zero_embeddings, device): load_model(model.conditioner) batch, batch_uc = get_batch( list(set([x.input_key for x in model.conditioner.embedders])), value_dict, [num_samples], ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings ) unload_model(model.conditioner) for k in c: if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) if c[k].shape[0] < num_samples: c[k] = c[k][[0]] if uc[k].shape[0] < num_samples: uc[k] = uc[k][[0]] return c, uc def fill_latent(cond, length, cond_indices, device): latent = torch.zeros(length, *cond.shape[1:]).to(device) latent[cond_indices] = cond return latent @torch.no_grad() def do_sample( images, model, sampler, value_dict, num_rounds, num_frames, force_uc_zero_embeddings: Optional[list] = None, initial_cond_indices: Optional[list] = None, device="cuda", log_queue: queue.SimpleQueue = None, ): if initial_cond_indices is None: initial_cond_indices = [0] force_uc_zero_embeddings = default(force_uc_zero_embeddings, list()) precision_scope = autocast with torch.no_grad(), precision_scope(device), model.ema_scope("Sampling"): c, uc = get_condition( model, value_dict, num_frames, force_uc_zero_embeddings, device ) load_model(model.first_stage_model) z = model.encode_first_stage(images) unload_model(model.first_stage_model) samples_z = torch.zeros((num_rounds * (num_frames - 3) + 3, *z.shape[1:])).to( device ) sampling_progress = tqdm(total=num_rounds, desc="Compute sequences") def denoiser(x, sigma, cond, cond_mask): return model.denoiser(model.model, x, sigma, cond, cond_mask) load_model(model.denoiser) load_model(model.model) initial_cond_mask = torch.zeros(num_frames).to(device) prediction_cond_mask = torch.zeros(num_frames).to(device) initial_cond_mask[initial_cond_indices] = 1 prediction_cond_mask[[0, 1, 2]] = 1 generated_images = [] noise = torch.randn_like(z) sample = sampler( denoiser, noise, cond=c, uc=uc, cond_frame=z, # cond_frame will be rescaled when calling the sampler cond_mask=initial_cond_mask, num_sequence=0, log_queue=log_queue, ) sampling_progress.update(1) sample[0] = z[0] samples_z[:num_frames] = sample generated_images.append(decode_samples(sample[:num_frames], model)) for i, generated_image in enumerate(generated_images[-1]): log_queue.put( ( "generated_image", rr.Image(generated_image.cpu().permute(1, 2, 0)), [ ("frame_id", i), ("diffusion", 0), ( "combined", 1 + 2 * 0 + (i * 1.0 / len(generated_images[-1])), ), ], ) ) for n in range(num_rounds - 1): load_model(model.first_stage_model) samples_x_for_guidance = model.decode_first_stage(sample[-14:]) unload_model(model.first_stage_model) value_dict["cond_frames_without_noise"] = samples_x_for_guidance[[-3]] value_dict["cond_frames"] = sample[[-3]] / model.scale_factor for embedder in model.conditioner.embedders: if hasattr(embedder, "skip_encode"): embedder.skip_encode = True c, uc = get_condition( model, value_dict, num_frames, force_uc_zero_embeddings, device ) for embedder in model.conditioner.embedders: if hasattr(embedder, "skip_encode"): embedder.skip_encode = False filled_latent = fill_latent(sample[-3:], num_frames, [0, 1, 2], device) noise = torch.randn_like(filled_latent) sample = sampler( denoiser, noise, cond=c, uc=uc, cond_frame=filled_latent, # cond_frame will be rescaled when calling the sampler cond_mask=prediction_cond_mask, num_sequence=n + 1, log_queue=log_queue, ) sampling_progress.update(1) first_frame_id = (n + 1) * (num_frames - 3) + 3 last_frame_id = (n + 1) * (num_frames - 3) + num_frames samples_z[first_frame_id:last_frame_id] = sample[3:] generated_images.append(decode_samples(sample[3:], model)) for i, generated_image in enumerate(generated_images[-1]): log_queue.put( ( "generated_image", rr.Image(generated_image.cpu().permute(1, 2, 0)), [ ("frame_id", first_frame_id + i), ("diffusion", 0), ( "combined", 1 + 2 * (n + 1) + (i * 1.0 / len(generated_images[-1])), ), ], ) ) unload_model(model.model) unload_model(model.denoiser) generated_images = torch.concat(generated_images, dim=0) return generated_images, samples_z, images def decode_samples(samples, model): load_model(model.first_stage_model) samples_x = model.decode_first_stage(samples) unload_model(model.first_stage_model) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) return samples