Spaces:
Runtime error
Runtime error
import os | |
import time | |
import argparse | |
import yaml, math | |
from tqdm import trange | |
import torch | |
import numpy as np | |
from omegaconf import OmegaConf | |
import torch.distributed as dist | |
from pytorch_lightning import seed_everything | |
from lvdm.samplers.ddim import DDIMSampler | |
from lvdm.utils.common_utils import str2bool | |
from lvdm.utils.dist_utils import setup_dist, gather_data | |
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d | |
from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np | |
# ------------------------------------------------------------------------------------------ | |
def get_parser(): | |
parser = argparse.ArgumentParser() | |
# basic args | |
parser.add_argument("--ckpt_path", type=str, help="model checkpoint path") | |
parser.add_argument("--config_path", type=str, help="model config path (a yaml file)") | |
parser.add_argument("--prompt", type=str, help="input text prompts for text2video (a sentence OR a txt file).") | |
parser.add_argument("--save_dir", type=str, help="results saving dir", default="results/") | |
# device args | |
parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False) | |
parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0) | |
parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0) | |
# sampling args | |
parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2) | |
parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1) | |
parser.add_argument("--decode_frame_bs", type=int, help="frame batch size for framewise decoding", default=1) | |
parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddim", choices=["ddpm", "ddim"]) | |
parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50) | |
parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0) | |
parser.add_argument("--cfg_scale", type=float, default=15.0, help="classifier-free guidance scale") | |
parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)") | |
parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",) | |
# lora args | |
parser.add_argument("--lora_path", type=str, help="lora checkpoint path") | |
parser.add_argument("--inject_lora", action='store_true', default=False, help="",) | |
parser.add_argument("--lora_scale", type=float, default=None, help="scale for lora weight") | |
parser.add_argument("--lora_trigger_word", type=str, default="", help="",) | |
# saving args | |
parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"]) | |
parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",) | |
parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",) | |
parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",) | |
parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",) | |
return parser | |
# ------------------------------------------------------------------------------------------ | |
def sample_denoising_batch(model, noise_shape, condition, *args, | |
sample_type="ddim", sampler=None, | |
ddim_steps=None, eta=None, | |
unconditional_guidance_scale=1.0, uc=None, | |
denoising_progress=False, | |
**kwargs, | |
): | |
if sample_type == "ddpm": | |
samples = model.p_sample_loop(cond=condition, shape=noise_shape, | |
return_intermediates=False, | |
verbose=denoising_progress, | |
**kwargs, | |
) | |
elif sample_type == "ddim": | |
assert(sampler is not None) | |
assert(ddim_steps is not None) | |
assert(eta is not None) | |
ddim_sampler = sampler | |
samples, _ = ddim_sampler.sample(S=ddim_steps, | |
conditioning=condition, | |
batch_size=noise_shape[0], | |
shape=noise_shape[1:], | |
verbose=denoising_progress, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_conditioning=uc, | |
eta=eta, | |
**kwargs, | |
) | |
else: | |
raise ValueError | |
return samples | |
# ------------------------------------------------------------------------------------------ | |
def sample_text2video(model, prompt, n_samples, batch_size, | |
sample_type="ddim", sampler=None, | |
ddim_steps=50, eta=1.0, cfg_scale=7.5, | |
decode_frame_bs=1, | |
ddp=False, all_gather=True, | |
batch_progress=True, show_denoising_progress=False, | |
): | |
# get cond vector | |
assert(model.cond_stage_model is not None) | |
cond_embd = get_conditions(prompt, model, batch_size) | |
uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None | |
# sample batches | |
all_videos = [] | |
n_iter = math.ceil(n_samples / batch_size) | |
iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter) | |
for _ in iterator: | |
noise_shape = make_model_input_shape(model, batch_size) | |
samples_latent = sample_denoising_batch(model, noise_shape, cond_embd, | |
sample_type=sample_type, | |
sampler=sampler, | |
ddim_steps=ddim_steps, | |
eta=eta, | |
unconditional_guidance_scale=cfg_scale, | |
uc=uncond_embd, | |
denoising_progress=show_denoising_progress, | |
) | |
samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False) | |
# gather samples from multiple gpus | |
if ddp and all_gather: | |
data_list = gather_data(samples, return_np=False) | |
all_videos.extend([torch_to_np(data) for data in data_list]) | |
else: | |
all_videos.append(torch_to_np(samples)) | |
all_videos = np.concatenate(all_videos, axis=0) | |
assert(all_videos.shape[0] >= n_samples) | |
return all_videos | |
# ------------------------------------------------------------------------------------------ | |
def save_results(videos, save_dir, | |
save_name="results", save_fps=8, save_mp4=True, | |
save_npz=False, save_mp4_sheet=False, save_jpg=False | |
): | |
if save_mp4: | |
save_subdir = os.path.join(save_dir, "videos") | |
os.makedirs(save_subdir, exist_ok=True) | |
for i in range(videos.shape[0]): | |
npz_to_video_grid(videos[i:i+1,...], | |
os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), | |
fps=save_fps) | |
print(f'Successfully saved videos in {save_subdir}') | |
if save_npz: | |
save_path = os.path.join(save_dir, f"{save_name}.npz") | |
np.savez(save_path, videos) | |
print(f'Successfully saved npz in {save_path}') | |
if save_mp4_sheet: | |
save_path = os.path.join(save_dir, f"{save_name}.mp4") | |
npz_to_video_grid(videos, save_path, fps=save_fps) | |
print(f'Successfully saved mp4 sheet in {save_path}') | |
if save_jpg: | |
save_path = os.path.join(save_dir, f"{save_name}.jpg") | |
npz_to_imgsheet_5d(videos, save_path, nrow=videos.shape[1]) | |
print(f'Successfully saved jpg sheet in {save_path}') | |
# ------------------------------------------------------------------------------------------ | |
def main(): | |
""" | |
text-to-video generation | |
""" | |
parser = get_parser() | |
opt, unknown = parser.parse_known_args() | |
os.makedirs(opt.save_dir, exist_ok=True) | |
# set device | |
if opt.ddp: | |
setup_dist(opt.local_rank) | |
opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size()) | |
gpu_id = None | |
else: | |
gpu_id = opt.gpu_id | |
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" | |
# set random seed | |
if opt.seed is not None: | |
if opt.ddp: | |
seed = opt.local_rank + opt.seed | |
else: | |
seed = opt.seed | |
seed_everything(seed) | |
# dump args | |
fpath = os.path.join(opt.save_dir, "sampling_args.yaml") | |
with open(fpath, 'w') as f: | |
yaml.dump(vars(opt), f, default_flow_style=False) | |
# load & merge config | |
config = OmegaConf.load(opt.config_path) | |
cli = OmegaConf.from_dotlist(unknown) | |
config = OmegaConf.merge(config, cli) | |
print("config: \n", config) | |
# get model & sampler | |
model, _, _ = load_model(config, opt.ckpt_path, | |
inject_lora=opt.inject_lora, | |
lora_scale=opt.lora_scale, | |
lora_path=opt.lora_path | |
) | |
ddim_sampler = DDIMSampler(model) if opt.sample_type == "ddim" else None | |
# prepare prompt | |
if opt.prompt.endswith(".txt"): | |
opt.prompt_file = opt.prompt | |
opt.prompt = None | |
else: | |
opt.prompt_file = None | |
if opt.prompt_file is not None: | |
f = open(opt.prompt_file, 'r') | |
prompts, line_idx = [], [] | |
for idx, line in enumerate(f.readlines()): | |
l = line.strip() | |
if len(l) != 0: | |
prompts.append(l) | |
line_idx.append(idx) | |
f.close() | |
cmd = f"cp {opt.prompt_file} {opt.save_dir}" | |
os.system(cmd) | |
else: | |
prompts = [opt.prompt] | |
line_idx = [None] | |
if opt.inject_lora: | |
assert(opt.lora_trigger_word != '') | |
prompts = [p + opt.lora_trigger_word for p in prompts] | |
# go | |
start = time.time() | |
for prompt in prompts: | |
# sample | |
samples = sample_text2video(model, prompt, opt.n_samples, opt.batch_size, | |
sample_type=opt.sample_type, sampler=ddim_sampler, | |
ddim_steps=opt.ddim_steps, eta=opt.eta, | |
cfg_scale=opt.cfg_scale, | |
decode_frame_bs=opt.decode_frame_bs, | |
ddp=opt.ddp, show_denoising_progress=opt.show_denoising_progress, | |
) | |
# save | |
if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp): | |
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt | |
save_name = prompt_str.replace(" ", "_") if " " in prompt else prompt_str | |
if opt.seed is not None: | |
save_name = save_name + f"_seed{seed:05d}" | |
save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps) | |
print("Finish sampling!") | |
print(f"Run time = {(time.time() - start):.2f} seconds") | |
if opt.ddp: | |
dist.destroy_process_group() | |
# ------------------------------------------------------------------------------------------ | |
if __name__ == "__main__": | |
main() |