Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import time | |
import argparse | |
import yaml, math | |
from tqdm import trange | |
import torch | |
import numpy as np | |
from omegaconf import OmegaConf | |
import torch.distributed as dist | |
from pytorch_lightning import seed_everything | |
from lvdm.samplers.ddim import DDIMSampler | |
from lvdm.utils.common_utils import str2bool | |
from lvdm.utils.dist_utils import setup_dist, gather_data | |
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d | |
from utils import load_model, get_conditions, make_model_input_shape, torch_to_np | |
from huggingface_hub import hf_hub_url, cached_download | |
config_path = "model_config.yaml" | |
config = OmegaConf.load(config_path) | |
REPO_ID = "RamAnanth1/videocrafter-text2video" | |
ckpt_path = cached_download(hf_hub_url(REPO_ID, 'model.ckpt')) | |
# # get model & sampler | |
model, _, _ = load_model(config, ckpt_path, | |
inject_lora=False, | |
lora_scale=None, | |
) | |
ddim_sampler = DDIMSampler(model) | |
def sample_denoising_batch(model, noise_shape, condition, *args, | |
sample_type="ddim", sampler=None, | |
ddim_steps=None, eta=None, | |
unconditional_guidance_scale=1.0, uc=None, | |
denoising_progress=False, | |
**kwargs, | |
): | |
assert(sampler is not None) | |
assert(ddim_steps is not None) | |
assert(eta is not None) | |
ddim_sampler = sampler | |
samples, _ = ddim_sampler.sample(S=ddim_steps, | |
conditioning=condition, | |
batch_size=noise_shape[0], | |
shape=noise_shape[1:], | |
verbose=denoising_progress, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_conditioning=uc, | |
eta=eta, | |
**kwargs, | |
) | |
return samples | |
def sample_text2video(model, prompt, n_samples, batch_size, | |
sample_type="ddim", sampler=None, | |
ddim_steps=50, eta=1.0, cfg_scale=7.5, | |
decode_frame_bs=1, | |
ddp=False, all_gather=True, | |
batch_progress=True, show_denoising_progress=False, | |
): | |
# get cond vector | |
assert(model.cond_stage_model is not None) | |
cond_embd = get_conditions(prompt, model, batch_size) | |
uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None | |
# sample batches | |
all_videos = [] | |
n_iter = math.ceil(n_samples / batch_size) | |
iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter) | |
for _ in iterator: | |
noise_shape = make_model_input_shape(model, batch_size) | |
samples_latent = sample_denoising_batch(model, noise_shape, cond_embd, | |
sample_type=sample_type, | |
sampler=sampler, | |
ddim_steps=ddim_steps, | |
eta=eta, | |
unconditional_guidance_scale=cfg_scale, | |
uc=uncond_embd, | |
denoising_progress=show_denoising_progress, | |
) | |
samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False) | |
# gather samples from multiple gpus | |
if ddp and all_gather: | |
data_list = gather_data(samples, return_np=False) | |
all_videos.extend([torch_to_np(data) for data in data_list]) | |
else: | |
all_videos.append(torch_to_np(samples)) | |
all_videos = np.concatenate(all_videos, axis=0) | |
assert(all_videos.shape[0] >= n_samples) | |
return all_videos | |
def save_results(videos, | |
save_name="results", save_fps=8, save_mp4=True, | |
save_npz=False, save_mp4_sheet=False, save_jpg=False | |
): | |
save_subdir = os.path.join("videos") | |
os.makedirs(save_subdir, exist_ok=True) | |
for i in range(videos.shape[0]): | |
npz_to_video_grid(videos[i:i+1,...], | |
os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), | |
fps=save_fps) | |
return os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4") | |
def get_video(prompt): | |
seed = 1000 | |
seed_everything(seed) | |
samples = sample_text2video(model, prompt, n_samples = 1, batch_size = 1, | |
sampler=ddim_sampler, | |
) | |
return save_results(samples) | |
title = 'Latent Video Diffusion Models' | |
DESCRIPTION = '<p>This model can only be used for non-commercial purposes. To learn more about the model, take a look at the <a href="https://github.com/VideoCrafter/VideoCrafter" style="text-decoration: underline;" target="_blank">model card</a>.</p>' | |
prompt_inp = gr.Textbox(label = "Prompt") | |
result = gr.Video(label='Result') | |
iface = gr.Interface(fn=get_video, | |
inputs=[prompt_inp], | |
outputs=[result], | |
title = title, | |
description = DESCRIPTION, | |
examples = [["An astronaut riding a horse"]]) | |
iface.launch() | |