Spaces:
Paused
Paused
# Prediction interface for Cog ⚙️ | |
# https://github.com/replicate/cog/blob/main/docs/python.md | |
import os | |
import numpy as np | |
import argparse | |
import imageio | |
import torch | |
from einops import rearrange | |
from diffusers import DDIMScheduler, AutoencoderKL | |
from transformers import CLIPTextModel, CLIPTokenizer | |
import controlnet_aux | |
from controlnet_aux import OpenposeDetector, CannyDetector, MidasDetector | |
from models.pipeline_controlvideo import ControlVideoPipeline | |
from models.util import save_videos_grid, read_video, get_annotation | |
from models.unet import UNet3DConditionModel | |
from models.controlnet import ControlNetModel3D | |
from models.RIFE.IFNet_HDv3 import IFNet | |
from cog import BasePredictor, Input, Path | |
sd_path = "checkpoints/stable-diffusion-v1-5" | |
inter_path = "checkpoints/flownet.pkl" | |
controlnet_dict = { | |
"pose": "checkpoints/sd-controlnet-openpose", | |
"depth": "checkpoints/sd-controlnet-depth", | |
"canny": "checkpoints/sd-controlnet-canny", | |
} | |
controlnet_parser_dict = { | |
"pose": OpenposeDetector, | |
"depth": MidasDetector, | |
"canny": CannyDetector, | |
} | |
POS_PROMPT = " ,best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth" | |
NEG_PROMPT = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" | |
class Predictor(BasePredictor): | |
def setup(self): | |
"""Load the model into memory to make running multiple predictions efficient""" | |
self.tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer") | |
self.text_encoder = CLIPTextModel.from_pretrained( | |
sd_path, subfolder="text_encoder" | |
).to(dtype=torch.float16) | |
self.vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to( | |
dtype=torch.float16 | |
) | |
self.unet = UNet3DConditionModel.from_pretrained_2d( | |
sd_path, subfolder="unet" | |
).to(dtype=torch.float16) | |
self.interpolater = IFNet(ckpt_path=inter_path).to(dtype=torch.float16) | |
self.scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler") | |
self.controlnet = { | |
k: ControlNetModel3D.from_pretrained_2d(controlnet_dict[k]).to( | |
dtype=torch.float16 | |
) | |
for k in ["depth", "canny", "pose"] | |
} | |
self.annotator = {k: controlnet_parser_dict[k]() for k in ["depth", "canny"]} | |
self.annotator["pose"] = controlnet_parser_dict["pose"].from_pretrained( | |
"lllyasviel/ControlNet", cache_dir="checkpoints" | |
) | |
def predict( | |
self, | |
prompt: str = Input( | |
description="Text description of target video", | |
default="A striking mallard floats effortlessly on the sparkling pond.", | |
), | |
video_path: Path = Input(description="source video"), | |
condition: str = Input( | |
default="depth", | |
choices=["depth", "canny", "pose"], | |
description="Condition of structure sequence", | |
), | |
video_length: int = Input( | |
default=15, description="Length of synthesized video" | |
), | |
smoother_steps: str = Input( | |
default="19, 20", | |
description="Timesteps at which using interleaved-frame smoother, separate with comma", | |
), | |
is_long_video: bool = Input( | |
default=False, | |
description="Whether to use hierarchical sampler to produce long video", | |
), | |
num_inference_steps: int = Input( | |
description="Number of denoising steps", default=50 | |
), | |
guidance_scale: float = Input( | |
description="Scale for classifier-free guidance", ge=1, le=20, default=12.5 | |
), | |
seed: str = Input( | |
default=None, description="Random seed. Leave blank to randomize the seed" | |
), | |
) -> Path: | |
"""Run a single prediction on the model""" | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") | |
else: | |
seed = int(seed) | |
print(f"Using seed: {seed}") | |
generator = torch.Generator(device="cuda") | |
generator.manual_seed(seed) | |
pipe = ControlVideoPipeline( | |
vae=self.vae, | |
text_encoder=self.text_encoder, | |
tokenizer=self.tokenizer, | |
unet=self.unet, | |
controlnet=self.controlnet[condition], | |
interpolater=self.interpolater, | |
scheduler=self.scheduler, | |
) | |
pipe.enable_vae_slicing() | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe.to("cuda") | |
# Step 1. Read a video | |
video = read_video(video_path=str(video_path), video_length=video_length) | |
# Step 2. Parse a video to conditional frames | |
pil_annotation = get_annotation(video, self.annotator[condition]) | |
# Step 3. inference | |
smoother_steps = [int(s) for s in smoother_steps.split(",")] | |
if is_long_video: | |
window_size = int(np.sqrt(video_length)) | |
sample = pipe.generate_long_video( | |
prompt + POS_PROMPT, | |
video_length=video_length, | |
frames=pil_annotation, | |
num_inference_steps=num_inference_steps, | |
smooth_steps=smoother_steps, | |
window_size=window_size, | |
generator=generator, | |
guidance_scale=guidance_scale, | |
negative_prompt=NEG_PROMPT, | |
).videos | |
else: | |
sample = pipe( | |
prompt + POS_PROMPT, | |
video_length=video_length, | |
frames=pil_annotation, | |
num_inference_steps=num_inference_steps, | |
smooth_steps=smoother_steps, | |
generator=generator, | |
guidance_scale=guidance_scale, | |
negative_prompt=NEG_PROMPT, | |
).videos | |
out_path = "/tmp/out.mp4" | |
save_videos_grid(sample, out_path) | |
del pipe | |
torch.cuda.empty_cache() | |
return Path(out_path) | |