FreeTraj / scripts /evaluation /inference_freetraj.py
Anonymous
init
2a50f45
import argparse, os, sys, glob, yaml, math, random
import datetime, time
import numpy as np
from omegaconf import OmegaConf
from collections import OrderedDict
from tqdm import trange, tqdm
from einops import repeat
from einops import rearrange, repeat
from functools import partial
import torch
from pytorch_lightning import seed_everything
from funcs import load_model_checkpoint, load_prompts, load_idx, load_traj, load_image_batch, get_filelist, save_videos, save_videos_with_bbox
from funcs import batch_ddim_sampling_freetraj
from utils.utils import instantiate_from_config
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20230211, help="seed for seed_everything")
parser.add_argument("--mode", default="base", type=str, help="which kind of inference mode: {'base', 'i2v'}")
parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
parser.add_argument("--config", type=str, help="config (yaml) path")
parser.add_argument("--prompt_file", type=str, default=None, help="a text file containing many prompts")
parser.add_argument("--savedir", type=str, default=None, help="results saving path")
parser.add_argument("--savefps", type=str, default=10, help="video fps to generate")
parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt")
parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM")
parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)")
parser.add_argument("--bs", type=int, default=1, help="batch size for inference")
parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
parser.add_argument("--frames", type=int, default=-1, help="frames num to inference")
parser.add_argument("--fps", type=int, default=24)
parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, help="temporal consistency guidance")
## for conditional i2v only
parser.add_argument("--cond_input", type=str, default=None, help="data dir of conditional input")
# FreeTraj
parser.add_argument("--ddim_edit", type=int, default=6, help="steps of ddim for edited attention")
parser.add_argument("--idx_file", type=str, default=None, help="a index file containing many prompts")
parser.add_argument("--traj_file", type=str, default=None, help="a path file containing many prompts")
return parser
def run_inference(args, gpu_num, gpu_no, **kwargs):
## step 1: model config
## -----------------------------------------------------------------
config = OmegaConf.load(args.config)
#data_config = config.pop("data", OmegaConf.create())
model_config = config.pop("model", OmegaConf.create())
model = instantiate_from_config(model_config)
model = model.cuda(gpu_no)
assert os.path.exists(args.ckpt_path), f"Error: checkpoint [{args.ckpt_path}] Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
## sample shape
assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels
## saving folders
os.makedirs(args.savedir, exist_ok=True)
bboxdir = os.path.join(args.savedir, "bbox")
os.makedirs(bboxdir, exist_ok=True)
## step 2: load data
## -----------------------------------------------------------------
assert os.path.exists(args.prompt_file), "Error: prompt file NOT Found!"
prompt_list = load_prompts(args.prompt_file)
idx_list_rank = load_idx(args.idx_file)
input_traj = load_traj(args.traj_file)
print(prompt_list)
print(idx_list_rank)
print(input_traj)
num_samples = len(prompt_list)
filename_list = [f"{id+1:04d}" for id in range(num_samples)]
samples_split = num_samples // gpu_num
residual_tail = num_samples % gpu_num
print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.')
indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
if gpu_no == 0 and residual_tail != 0:
indices = indices + list(range(num_samples-residual_tail, num_samples))
prompt_list_rank = [prompt_list[i] for i in indices]
## conditional input
if args.mode == "i2v":
## each video or frames dir per prompt
cond_inputs = get_filelist(args.cond_input, ext='[mpj][pn][4gj]') # '[mpj][pn][4gj]'
assert len(cond_inputs) == num_samples, f"Error: conditional input ({len(cond_inputs)}) NOT match prompt ({num_samples})!"
filename_list = [f"{os.path.split(cond_inputs[id])[-1][:-4]}" for id in range(num_samples)]
cond_inputs_rank = [cond_inputs[i] for i in indices]
filename_list_rank = [filename_list[i] for i in indices]
assert len(idx_list_rank) == len(filename_list_rank), "Error: metas are not paired!"
## step 3: run over samples
## -----------------------------------------------------------------
start = time.time()
n_rounds = len(prompt_list_rank) // args.bs
n_rounds = n_rounds+1 if len(prompt_list_rank) % args.bs != 0 else n_rounds
for idx in range(0, n_rounds):
print(f'[rank:{gpu_no}] batch-{idx+1} ({args.bs})x{args.n_samples} ...', flush=True)
idx_s = idx*args.bs
idx_e = min(idx_s+args.bs, len(prompt_list_rank))
batch_size = idx_e - idx_s
filenames = filename_list_rank[idx_s:idx_e]
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps]*batch_size).to(model.device).long()
idx_list = idx_list_rank[idx_s:idx_e][0]
# print(idx_list)
prompts = prompt_list_rank[idx_s:idx_e]
if isinstance(prompts, str):
prompts = [prompts]
#prompts = batch_size * [""]
text_emb = model.get_learned_conditioning(prompts)
if args.mode == 'base':
cond = {"c_crossattn": [text_emb], "fps": fps}
elif args.mode == 'i2v':
#cond_images = torch.zeros(noise_shape[0],3,224,224).to(model.device)
cond_images = load_image_batch(cond_inputs_rank[idx_s:idx_e], (args.height, args.width))
cond_images = cond_images.to(model.device)
img_emb = model.get_image_embeds(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
cond = {"c_crossattn": [imtext_cond], "fps": fps}
else:
raise NotImplementedError
## inference
batch_samples = batch_ddim_sampling_freetraj(model, cond, noise_shape, args.n_samples, \
args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale, idx_list=idx_list, input_traj=input_traj, args=args, **kwargs)
## b,samples,c,t,h,w
# save_videos(batch_samples, args.savedir, filenames, fps=args.savefps)
save_videos_with_bbox(batch_samples, args.savedir, bboxdir, filenames, fps=args.savefps, input_traj=input_traj)
print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
if __name__ == '__main__':
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
print("@CoLVDM Inference: %s"%now)
parser = get_parser()
args = parser.parse_args()
seed_everything(args.seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)