# Prediction interface for Cog ⚙️ # https://github.com/replicate/cog/blob/main/docs/python.md import os import sys import argparse import random from omegaconf import OmegaConf from einops import rearrange, repeat import torch import torchvision from pytorch_lightning import seed_everything from cog import BasePredictor, Input, Path sys.path.insert(0, "scripts/evaluation") from funcs import ( batch_ddim_sampling, load_model_checkpoint, load_image_batch, get_filelist, ) from utils.utils import instantiate_from_config class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" ckpt_path_base = "checkpoints/base_1024_v1/model.ckpt" config_base = "configs/inference_t2v_1024_v1.0.yaml" ckpt_path_i2v = "checkpoints/i2v_512_v1/model.ckpt" config_i2v = "configs/inference_i2v_512_v1.0.yaml" config_base = OmegaConf.load(config_base) model_config_base = config_base.pop("model", OmegaConf.create()) self.model_base = instantiate_from_config(model_config_base) self.model_base = self.model_base.cuda() self.model_base = load_model_checkpoint(self.model_base, ckpt_path_base) self.model_base.eval() config_i2v = OmegaConf.load(config_i2v) model_config_i2v = config_i2v.pop("model", OmegaConf.create()) self.model_i2v = instantiate_from_config(model_config_i2v) self.model_i2v = self.model_i2v.cuda() self.model_i2v = load_model_checkpoint(self.model_i2v, ckpt_path_i2v) self.model_i2v.eval() def predict( self, task: str = Input( description="Choose the task.", choices=["text2video", "image2video"], default="text2video", ), prompt: str = Input( description="Prompt for video generation.", default="A tiger walks in the forest, photorealistic, 4k, high definition.", ), image: Path = Input( description="Input image for image2video task.", default=None ), ddim_steps: int = Input(description="Number of denoising steps.", default=50), unconditional_guidance_scale: float = Input( description="Classifier-free guidance scale.", default=12.0 ), seed: int = Input( description="Random seed. Leave blank to randomize the seed", default=None ), save_fps: int = Input( description="Frame per second for the generated video.", default=10 ), ) -> Path: width = 1024 if task == "text2video" else 512 height = 576 if task == "text2video" else 320 model = self.model_base if task == "text2video" else self.model_i2v if task == "image2video": assert image is not None, "Please provide image for image2video generation." if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") seed_everything(seed) args = argparse.Namespace( mode="base" if task == "text2video" else "i2v", savefps=save_fps, n_samples=1, ddim_steps=ddim_steps, ddim_eta=1.0, bs=1, height=height, width=width, frames=-1, fps=28 if task == "text2video" else 8, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale_temporal=None, ) ## latent noise shape h, w = args.height // 8, args.width // 8 frames = model.temporal_length if args.frames < 0 else args.frames channels = model.channels batch_size = 1 noise_shape = [batch_size, channels, frames, h, w] fps = torch.tensor([args.fps] * batch_size).to(model.device).long() prompts = [prompt] text_emb = model.get_learned_conditioning(prompts) if args.mode == "base": cond = {"c_crossattn": [text_emb], "fps": fps} elif args.mode == "i2v": cond_images = load_image_batch([str(image)], (args.height, args.width)) cond_images = cond_images.to(model.device) img_emb = model.get_image_embeds(cond_images) imtext_cond = torch.cat([text_emb, img_emb], dim=1) cond = {"c_crossattn": [imtext_cond], "fps": fps} else: raise NotImplementedError ## inference batch_samples = batch_ddim_sampling( model, cond, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale, ) out_path = "/tmp/output.mp4" vid_tensor = batch_samples[0] video = vid_tensor.detach().cpu() video = torch.clamp(video.float(), -1.0, 1.0) video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples)) for framesheet in video ] # [3, 1*h, n*w] grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video( out_path, grid, fps=args.savefps, video_codec="h264", options={"crf": "10"}, ) return Path(out_path)