|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import argparse |
|
import random |
|
from omegaconf import OmegaConf |
|
from einops import rearrange, repeat |
|
import torch |
|
import torchvision |
|
from pytorch_lightning import seed_everything |
|
from cog import BasePredictor, Input, Path |
|
|
|
sys.path.insert(0, "scripts/evaluation") |
|
from funcs import ( |
|
batch_ddim_sampling, |
|
load_model_checkpoint, |
|
load_image_batch, |
|
get_filelist, |
|
) |
|
from utils.utils import instantiate_from_config |
|
|
|
|
|
class Predictor(BasePredictor): |
|
def setup(self) -> None: |
|
"""Load the model into memory to make running multiple predictions efficient""" |
|
|
|
ckpt_path_base = "checkpoints/base_1024_v1/model.ckpt" |
|
config_base = "configs/inference_t2v_1024_v1.0.yaml" |
|
ckpt_path_i2v = "checkpoints/i2v_512_v1/model.ckpt" |
|
config_i2v = "configs/inference_i2v_512_v1.0.yaml" |
|
|
|
config_base = OmegaConf.load(config_base) |
|
model_config_base = config_base.pop("model", OmegaConf.create()) |
|
self.model_base = instantiate_from_config(model_config_base) |
|
self.model_base = self.model_base.cuda() |
|
self.model_base = load_model_checkpoint(self.model_base, ckpt_path_base) |
|
self.model_base.eval() |
|
|
|
config_i2v = OmegaConf.load(config_i2v) |
|
model_config_i2v = config_i2v.pop("model", OmegaConf.create()) |
|
self.model_i2v = instantiate_from_config(model_config_i2v) |
|
self.model_i2v = self.model_i2v.cuda() |
|
self.model_i2v = load_model_checkpoint(self.model_i2v, ckpt_path_i2v) |
|
self.model_i2v.eval() |
|
|
|
def predict( |
|
self, |
|
task: str = Input( |
|
description="Choose the task.", |
|
choices=["text2video", "image2video"], |
|
default="text2video", |
|
), |
|
prompt: str = Input( |
|
description="Prompt for video generation.", |
|
default="A tiger walks in the forest, photorealistic, 4k, high definition.", |
|
), |
|
image: Path = Input( |
|
description="Input image for image2video task.", default=None |
|
), |
|
ddim_steps: int = Input(description="Number of denoising steps.", default=50), |
|
unconditional_guidance_scale: float = Input( |
|
description="Classifier-free guidance scale.", default=12.0 |
|
), |
|
seed: int = Input( |
|
description="Random seed. Leave blank to randomize the seed", default=None |
|
), |
|
save_fps: int = Input( |
|
description="Frame per second for the generated video.", default=10 |
|
), |
|
) -> Path: |
|
|
|
width = 1024 if task == "text2video" else 512 |
|
height = 576 if task == "text2video" else 320 |
|
model = self.model_base if task == "text2video" else self.model_i2v |
|
|
|
if task == "image2video": |
|
assert image is not None, "Please provide image for image2video generation." |
|
|
|
if seed is None: |
|
seed = int.from_bytes(os.urandom(2), "big") |
|
print(f"Using seed: {seed}") |
|
seed_everything(seed) |
|
|
|
args = argparse.Namespace( |
|
mode="base" if task == "text2video" else "i2v", |
|
savefps=save_fps, |
|
n_samples=1, |
|
ddim_steps=ddim_steps, |
|
ddim_eta=1.0, |
|
bs=1, |
|
height=height, |
|
width=width, |
|
frames=-1, |
|
fps=28 if task == "text2video" else 8, |
|
unconditional_guidance_scale=unconditional_guidance_scale, |
|
unconditional_guidance_scale_temporal=None, |
|
) |
|
|
|
|
|
h, w = args.height // 8, args.width // 8 |
|
frames = model.temporal_length if args.frames < 0 else args.frames |
|
channels = model.channels |
|
|
|
batch_size = 1 |
|
noise_shape = [batch_size, channels, frames, h, w] |
|
fps = torch.tensor([args.fps] * batch_size).to(model.device).long() |
|
prompts = [prompt] |
|
text_emb = model.get_learned_conditioning(prompts) |
|
|
|
if args.mode == "base": |
|
cond = {"c_crossattn": [text_emb], "fps": fps} |
|
elif args.mode == "i2v": |
|
cond_images = load_image_batch([str(image)], (args.height, args.width)) |
|
cond_images = cond_images.to(model.device) |
|
img_emb = model.get_image_embeds(cond_images) |
|
imtext_cond = torch.cat([text_emb, img_emb], dim=1) |
|
cond = {"c_crossattn": [imtext_cond], "fps": fps} |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
batch_samples = batch_ddim_sampling( |
|
model, |
|
cond, |
|
noise_shape, |
|
args.n_samples, |
|
args.ddim_steps, |
|
args.ddim_eta, |
|
args.unconditional_guidance_scale, |
|
) |
|
|
|
out_path = "/tmp/output.mp4" |
|
vid_tensor = batch_samples[0] |
|
video = vid_tensor.detach().cpu() |
|
video = torch.clamp(video.float(), -1.0, 1.0) |
|
video = video.permute(2, 0, 1, 3, 4) |
|
|
|
frame_grids = [ |
|
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples)) |
|
for framesheet in video |
|
] |
|
grid = torch.stack(frame_grids, dim=0) |
|
grid = (grid + 1.0) / 2.0 |
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) |
|
torchvision.io.write_video( |
|
out_path, |
|
grid, |
|
fps=args.savefps, |
|
video_codec="h264", |
|
options={"crf": "10"}, |
|
) |
|
return Path(out_path) |
|
|