from __future__ import annotations import argparse import json import os import random import PIL import torch from pytorch_lightning import seed_everything from torchvision import transforms from . import sample_utils VERSION2SPECS = { "vwm": {"config": "configs/inference/vista.yaml", "ckpt": "ckpts/vista.safetensors"} } DATASET2SOURCES = { "NUSCENES": {"data_root": "data/nuscenes", "anno_file": "annos/nuScenes_val.json"}, "IMG": {"data_root": "image_folder"}, } def parse_args(**parser_kwargs): parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument("--version", type=str, default="vwm", help="model version") parser.add_argument("--dataset", type=str, default="NUSCENES", help="dataset name") parser.add_argument( "--save", type=str, default="outputs", help="directory to save samples" ) parser.add_argument( "--action", type=str, default="free", help="action mode for control, such as traj, cmd, steer, goal", ) parser.add_argument( "--n_rounds", type=int, default=1, help="number of sampling rounds" ) parser.add_argument( "--n_frames", type=int, default=25, help="number of frames for each round" ) parser.add_argument( "--n_conds", type=int, default=1, help="number of initial condition frames for the first round", ) parser.add_argument( "--seed", type=int, default=23, help="random seed for seed_everything" ) parser.add_argument( "--height", type=int, default=576, help="target height of the generated video" ) parser.add_argument( "--width", type=int, default=1024, help="target width of the generated video" ) parser.add_argument( "--cfg_scale", type=float, default=2.5, help="scale of the classifier-free guidance", ) parser.add_argument( "--cond_aug", type=float, default=0.0, help="strength of the noise augmentation" ) parser.add_argument( "--n_steps", type=int, default=50, help="number of sampling steps" ) parser.add_argument( "--rand_gen", action="store_false", help="whether to generate samples randomly or sequentially", ) parser.add_argument( "--low_vram", action="store_true", help="whether to save memory or not" ) return parser def get_sample( selected_index=0, dataset_name="NUSCENES", num_frames=25, action_mode="free" ): dataset_dict = DATASET2SOURCES[dataset_name] action_dict = None if dataset_name == "IMG": image_list = os.listdir(dataset_dict["data_root"]) total_length = len(image_list) while selected_index >= total_length: selected_index -= total_length image_file = image_list[selected_index] path_list = [os.path.join(dataset_dict["data_root"], image_file)] * num_frames else: with open(dataset_dict["anno_file"]) as anno_json: all_samples = json.load(anno_json) total_length = len(all_samples) while selected_index >= total_length: selected_index -= total_length sample_dict = all_samples[selected_index] path_list = list() if dataset_name == "NUSCENES": for index in range(num_frames): image_path = os.path.join( dataset_dict["data_root"], sample_dict["frames"][index] ) assert os.path.exists(image_path), image_path path_list.append(image_path) if action_mode != "free": action_dict = dict() if action_mode == "traj" or action_mode == "trajectory": action_dict["trajectory"] = torch.tensor(sample_dict["traj"][2:]) elif action_mode == "cmd" or action_mode == "command": action_dict["command"] = torch.tensor(sample_dict["cmd"]) elif action_mode == "steer": # scene might be empty if sample_dict["speed"]: action_dict["speed"] = torch.tensor(sample_dict["speed"][1:]) # scene might be empty if sample_dict["angle"]: action_dict["angle"] = ( torch.tensor(sample_dict["angle"][1:]) / 780 ) elif action_mode == "goal": # point might be invalid if ( sample_dict["z"] > 0 and 0 < sample_dict["goal"][0] < 1600 and 0 < sample_dict["goal"][1] < 900 ): action_dict["goal"] = torch.tensor( [ sample_dict["goal"][0] / 1600, sample_dict["goal"][1] / 900, ] ) else: raise ValueError(f"Unsupported action mode {action_mode}") else: raise ValueError(f"Invalid dataset {dataset_name}") return path_list, selected_index, total_length, action_dict def load_img(file_name, target_height=320, target_width=576, device="cuda"): if file_name is not None: image = PIL.Image.open(file_name) if not image.mode == "RGB": image = image.convert("RGB") else: raise ValueError(f"Invalid image file {file_name}") ori_w, ori_h = image.size # print(f"Loaded input image of size ({ori_w}, {ori_h})") if ori_w / ori_h > target_width / target_height: tmp_w = int(target_width / target_height * ori_h) left = (ori_w - tmp_w) // 2 right = (ori_w + tmp_w) // 2 image = image.crop((left, 0, right, ori_h)) elif ori_w / ori_h < target_width / target_height: tmp_h = int(target_height / target_width * ori_w) top = (ori_h - tmp_h) // 2 bottom = (ori_h + tmp_h) // 2 image = image.crop((0, top, ori_w, bottom)) image = image.resize((target_width, target_height), resample=PIL.Image.LANCZOS) if not image.mode == "RGB": image = image.convert("RGB") image = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] )(image) return image.to(device) if __name__ == "__main__": parser = parse_args() opt, unknown = parser.parse_known_args() sample_utils.set_lowvram_mode(opt.low_vram) version_dict = VERSION2SPECS[opt.version] model = sample_utils.init_model(version_dict) unique_keys = set([x.input_key for x in model.conditioner.embedders]) sample_index = 0 while sample_index >= 0: seed_everything(opt.seed) frame_list, sample_index, dataset_length, action_dict = get_sample( sample_index, opt.dataset, opt.n_frames, opt.action ) img_seq = list() for each_path in frame_list: img = load_img(each_path, opt.height, opt.width) img_seq.append(img) images = torch.stack(img_seq) value_dict = sample_utils.init_embedder_options(unique_keys) cond_img = img_seq[0][None] value_dict["cond_frames_without_noise"] = cond_img value_dict["cond_aug"] = opt.cond_aug value_dict["cond_frames"] = cond_img + opt.cond_aug * torch.randn_like(cond_img) if action_dict is not None: for key, value in action_dict.items(): value_dict[key] = value if opt.n_rounds > 1: guider = "TrianglePredictionGuider" else: guider = "VanillaCFG" sampler = sample_utils.init_sampling( guider=guider, steps=opt.n_steps, cfg_scale=opt.cfg_scale, num_frames=opt.n_frames, ) uc_keys = [ "cond_frames", "cond_frames_without_noise", "command", "trajectory", "speed", "angle", "goal", ] out = sample_utils.do_sample( images, model, sampler, value_dict, num_rounds=opt.n_rounds, num_frames=opt.n_frames, force_uc_zero_embeddings=uc_keys, initial_cond_indices=[index for index in range(opt.n_conds)], ) if isinstance(out, (tuple, list)): samples, samples_z, inputs = out virtual_path = os.path.join(opt.save, "virtual") real_path = os.path.join(opt.save, "real") sample_utils.perform_save_locally( virtual_path, samples, "videos", opt.dataset, sample_index ) sample_utils.perform_save_locally( virtual_path, samples, "grids", opt.dataset, sample_index ) sample_utils.perform_save_locally( virtual_path, samples, "images", opt.dataset, sample_index ) sample_utils.perform_save_locally( real_path, inputs, "videos", opt.dataset, sample_index ) sample_utils.perform_save_locally( real_path, inputs, "grids", opt.dataset, sample_index ) sample_utils.perform_save_locally( real_path, inputs, "images", opt.dataset, sample_index ) else: raise TypeError if opt.rand_gen: sample_index += random.randint(1, dataset_length - 1) else: sample_index += 1 if dataset_length <= sample_index: sample_index = -1