Spaces:
Runtime error
Runtime error
import sys | |
from pathlib import Path | |
sys.path.append(str(Path(__file__).parent.parent)) | |
import importlib | |
import json | |
import os | |
from functools import partial | |
from pprint import pprint | |
from uuid import uuid4 | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler | |
from diffusers.models.embeddings import get_3d_rotary_pos_embed | |
from diffusers.utils.export_utils import export_to_video | |
from einops import rearrange, repeat | |
from pytorch_lightning import seed_everything | |
from torch import Tensor | |
from torchvision import transforms | |
from transformers import AutoTokenizer, T5EncoderModel | |
from torchcodec.decoders import VideoDecoder | |
def relative_pose(rt: Tensor, mode, ref_index) -> Tensor: | |
''' | |
:param rt: F,4,4 | |
:param mode: left or right | |
:return: | |
''' | |
if mode == "left": | |
rt = rt[ref_index].inverse() @ rt | |
elif mode == "right": | |
rt = rt @ rt[ref_index].inverse() | |
return rt | |
def camera_pose_lerp(c2w: Tensor, target_frames: int): | |
weights = torch.linspace(0, c2w.size(0) - 1, target_frames, dtype=c2w.dtype) | |
left_indices = weights.floor().long() | |
right_indices = weights.ceil().long() | |
return torch.lerp(c2w[left_indices], c2w[right_indices], weights.unsqueeze(-1).unsqueeze(-1).frac()) | |
def get_obj_from_str(string, reload=False): | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def _resize_for_rectangle_crop(frames, H, W): | |
''' | |
:param frames: C,F,H,W | |
:param image_size: H,W | |
:return: frames: C,F,crop_H,crop_W; camera_intrinsics: F,3,3 | |
''' | |
ori_H, ori_W = frames.shape[-2:] | |
# if ori_W / ori_H < 1.0: | |
# tmp_H, tmp_W = int(H), int(W) | |
# H, W = tmp_W, tmp_H | |
if ori_W / ori_H > W / H: | |
frames = transforms.functional.resize(frames, size=[H, int(ori_W * H / ori_H)]) | |
else: | |
frames = transforms.functional.resize(frames, size=[int(ori_H * W / ori_W), W]) | |
resized_H, resized_W = frames.shape[2], frames.shape[3] | |
frames = frames.squeeze(0) | |
delta_H = resized_H - H | |
delta_W = resized_W - W | |
top, left = delta_H // 2, delta_W // 2 | |
frames = transforms.functional.crop(frames, top=top, left=left, height=H, width=W) | |
return frames, resized_H, resized_W | |
def _resize(frames, H, W): | |
''' | |
:param frames: C,F,H,W | |
:param image_size: H,W | |
:return: frames: C,F,crop_H,crop_W; camera_intrinsics: F,3,3 | |
''' | |
frames = transforms.functional.resize(frames, size=[H, W]) | |
resized_H, resized_W = frames.shape[2], frames.shape[3] | |
frames = frames.squeeze(0) | |
return frames, resized_H, resized_W | |
class Image2Video: | |
def __init__( | |
self, | |
result_dir: str = "results", | |
model_meta_path: str = "models.json", | |
camera_pose_meta_path: str = "camera_poses.json", | |
save_fps: int = 16, | |
device: str = "cuda", | |
): | |
self.result_dir = result_dir | |
self.model_meta_file = model_meta_path | |
self.camera_pose_meta_path = camera_pose_meta_path | |
self.save_fps = save_fps | |
self.device = torch.device(device) | |
self.pipe = None | |
def init_model(self, model_name): | |
from models.camera_controller.cogvideox_with_controlnetxs import CogVideoXTransformer3DModel | |
from models.camera_controller.controlnetxs import ControlnetXs | |
with open(self.model_meta_file, "r", encoding="utf-8") as f: | |
model_metadata = json.load(f)[model_name] | |
pretrained_model_path = model_metadata["pretrained_model_path"] | |
controlnetxs_model_path = model_metadata["controlnetxs_model_path"] | |
self.transformer = CogVideoXTransformer3DModel.from_pretrained(pretrained_model_path, subfolder="transformer", torch_dtype=torch.bfloat16) | |
self.controlnetxs = ControlnetXs("models/camera_controller/CogVideoX1.5-5B-I2V", self.transformer.config) | |
self.controlnetxs.load_state_dict(torch.load(controlnetxs_model_path)['module'], strict=True) | |
self.controlnetxs.to(torch.bfloat16) | |
# self.controlnetxs.to(torch.float32) | |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
self.text_encoder = T5EncoderModel.from_pretrained(pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.bfloat16) | |
self.vae = AutoencoderKLCogVideoX.from_pretrained(pretrained_model_path, subfolder="vae", torch_dtype=torch.bfloat16) | |
self.vae_scale_factor_spatial = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
self.scheduler = CogVideoXDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
self.controlnetxs.eval() | |
self.text_encoder.eval() | |
self.vae.eval() | |
self.transformer.eval() | |
self.prepare_models() | |
def prepare_models(self) -> None: | |
if self.vae is not None: | |
self.vae.enable_slicing() | |
self.vae.enable_tiling() | |
if self.controlnetxs.vae_encoder is not None: | |
self.controlnetxs.vae_encoder.enable_slicing() | |
self.controlnetxs.vae_encoder.enable_tiling() | |
def init_pipe(self): | |
from models.camera_controller.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline | |
self.pipe = CogVideoXImageToVideoPipeline( | |
tokenizer=self.tokenizer, | |
text_encoder=None, | |
vae=self.vae, | |
transformer=self.transformer, | |
scheduler=self.scheduler | |
) | |
self.pipe.scaling_flag = True | |
self.pipe.to(self.device) | |
def offload_cpu(self): | |
if hasattr(self, "transformer"): | |
self.transformer.cpu() | |
if hasattr(self, "controlnetxs"): | |
self.controlnetxs.cpu() | |
if hasattr(self, "text_encoder"): | |
self.text_encoder.cpu() | |
if hasattr(self, "vae"): | |
self.vae.cpu() | |
torch.cuda.empty_cache() | |
def prepare_rotary_positional_embeddings( | |
self, | |
height: int, | |
width: int, | |
num_frames: int, | |
transformer_config: dict, | |
vae_scale_factor_spatial: int, | |
device: torch.device, | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size) | |
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size) | |
if transformer_config.patch_size_t is None: | |
base_num_frames = num_frames | |
else: | |
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t | |
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
embed_dim=transformer_config.attention_head_dim, | |
crops_coords=None, | |
grid_size=(grid_height, grid_width), | |
temporal_size=base_num_frames, | |
grid_type="slice", | |
max_size=(grid_height, grid_width), | |
device=device, | |
) | |
return freqs_cos, freqs_sin | |
def validation_step(self, input_kwargs: dict[str]) -> torch.Tensor: | |
""" | |
Return the input_kwargs that needs to be saved. For videos, the input_kwargs format is list[PIL], | |
and for images, the input_kwargs format is PIL | |
image: shape=(1,c,h,w), value in [0, 1] | |
""" | |
plucker_embedding = input_kwargs["plucker_embedding"] | |
image = input_kwargs["image"] | |
H, W = image.shape[-2:] | |
# camera | |
plucker_embedding = plucker_embedding.to(self.controlnetxs.vae_encoder.device, dtype=self.controlnetxs.vae_encoder.dtype) # [C=6, F, H, W] | |
latent_plucker_embedding_dist = self.controlnetxs.vae_encoder.encode(plucker_embedding).latent_dist # B,C=6,F,H,W --> B,128,(F-1)//4+1,H//4,W//4 | |
latent_plucker_embedding = latent_plucker_embedding_dist.sample() | |
latent_plucker_embedding = latent_plucker_embedding.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] to [B, F, C, H, W] | |
latent_plucker_embedding = latent_plucker_embedding.repeat(2, 1, 1, 1, 1) # cfg | |
patch_size_t = self.transformer.config.patch_size_t | |
if patch_size_t is not None: | |
ncopy = patch_size_t - latent_plucker_embedding.shape[1] % patch_size_t | |
if ncopy > 0: | |
# Copy the first frame ncopy times to match patch_size_t | |
first_frame = latent_plucker_embedding[:, :1, :, :, :] # Get first frame [B, F, C, H, W] | |
latent_plucker_embedding = torch.cat([first_frame.repeat(1, ncopy, 1, 1, 1), latent_plucker_embedding], dim=1) | |
if 'latent_scene_frames' in input_kwargs: | |
input_kwargs['latent_scene_frames'] = torch.cat([ | |
input_kwargs['latent_scene_frames'][:, :1, :, :, :].repeat(1, ncopy, 1, 1, 1), | |
input_kwargs['latent_scene_frames'] | |
], | |
dim=1 | |
) | |
input_kwargs['latent_scene_mask'] = torch.cat([ | |
input_kwargs['latent_scene_mask'][:, :1, :, :, :].repeat(1, ncopy, 1, 1, 1), | |
input_kwargs['latent_scene_mask'] | |
], | |
dim=1 | |
) | |
assert latent_plucker_embedding.shape[1] % patch_size_t == 0 | |
num_latent_frames = latent_plucker_embedding.shape[1] | |
vae_scale_factor_spatial = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
rotary_emb_for_controlnetxs = ( | |
self.prepare_rotary_positional_embeddings( | |
height=H, | |
width=W, | |
num_frames=num_latent_frames, | |
transformer_config=self.controlnetxs.transformer.config, | |
vae_scale_factor_spatial=vae_scale_factor_spatial, | |
device=self.device, | |
) | |
if self.transformer.config.use_rotary_positional_embeddings | |
else None | |
) | |
self.init_pipe() | |
original_forward = self.pipe.transformer.forward | |
self.pipe.transformer.forward = partial( | |
self.pipe.transformer.forward, | |
controlnetxs=self.controlnetxs, | |
latent_plucker_embedding=latent_plucker_embedding, | |
image_rotary_emb_for_controlnetxs=rotary_emb_for_controlnetxs, | |
) | |
forward_kwargs = dict( | |
num_frames=input_kwargs["video_length"], | |
height=H, | |
width=W, | |
prompt_embeds=input_kwargs["prompt_embedding"], | |
negative_prompt_embeds=input_kwargs["negative_prompt_embedding"], | |
image=image.to(self.device).to_dense(), | |
num_inference_steps=input_kwargs['num_inference_steps'], | |
guidance_scale=input_kwargs['text_cfg'], | |
noise_shaping=input_kwargs['noise_shaping'], | |
noise_shaping_minimum_timesteps = input_kwargs['noise_shaping_minimum_timesteps'], | |
latent_scene_frames = input_kwargs.get('latent_scene_frames', None), # B,F,C,H,W | |
latent_scene_mask = input_kwargs.get('latent_scene_mask', None), | |
generator = input_kwargs['generator'], | |
) | |
video_generate = self.pipe(**forward_kwargs).frames[0] | |
self.pipe.transformer.forward = original_forward | |
return video_generate | |
def encode_text(self, prompt: str) -> torch.Tensor: | |
prompt_token_ids = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.transformer.config.max_text_seq_length, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
prompt_token_ids = prompt_token_ids.input_ids | |
prompt_embedding = self.text_encoder(prompt_token_ids.to(self.device))[0] | |
return prompt_embedding.to(torch.bfloat16).to(self.device) | |
def encode_video(self, video: torch.Tensor) -> torch.Tensor: | |
# shape of input video: [B, C, F, H, W] | |
video = video.to(self.vae.device, dtype=self.vae.dtype) | |
latent_dist = self.vae.encode(video).latent_dist | |
latent = latent_dist.sample() * self.vae.config.scaling_factor | |
return latent.to(torch.bfloat16).to(self.device) | |
def get_image( | |
self, | |
model_name: str, | |
ref_img: np.ndarray, | |
prompt: str, | |
negative_prompt: str, | |
camera_pose_type: str, | |
preview_video: str = None, | |
steps: int = 25, | |
trace_extract_ratio: float = 1.0, | |
trace_scale_factor: float = 1.0, | |
camera_cfg: float = 1.0, | |
text_cfg: float = 6.0, | |
seed: int = 123, | |
noise_shaping: bool = False, | |
noise_shaping_minimum_timesteps: int = 800, | |
video_shape: tuple[int, int, int] = (81, 768, 1360), | |
resize_for_rectangle_crop: bool = True, | |
): | |
if self.pipe is None: | |
self.init_model(model_name) | |
video_length, self.sample_height, self.sample_width = video_shape | |
print(video_length, self.sample_height, self.sample_width) | |
seed_everything(seed) | |
input_kwargs = { | |
'video_length': video_length, | |
'camera_cfg': camera_cfg, | |
'num_inference_steps': steps, | |
'text_cfg': text_cfg, | |
'noise_shaping': noise_shaping, | |
'noise_shaping_minimum_timesteps': noise_shaping_minimum_timesteps, | |
'generator': torch.Generator(device=self.device).manual_seed(seed) | |
} | |
ref_img = rearrange(torch.from_numpy(ref_img), 'h w c -> c 1 h w') | |
if resize_for_rectangle_crop: | |
ref_img, resized_H, resized_W = _resize_for_rectangle_crop( | |
ref_img, self.sample_height, self.sample_width, | |
) | |
else: | |
ref_img, resized_H, resized_W = _resize( | |
ref_img, self.sample_height, self.sample_width, | |
) | |
ref_img = rearrange(ref_img / 255, "c 1 h w -> 1 c h w") | |
H, W = ref_img.shape[-2:] | |
input_kwargs["image"] = ref_img.to(self.device).to(torch.bfloat16) | |
with open(self.camera_pose_meta_path, "r", encoding="utf-8") as f: | |
camera_pose_file_path = json.load(f)[camera_pose_type] | |
camera_data = torch.from_numpy(np.loadtxt(camera_pose_file_path, comments="https")) # t, -1 | |
fx = 0.5 * max(resized_H, resized_W) | |
fy = fx | |
cx = 0.5 * W | |
cy = 0.5 * H | |
intrinsics_matrix = torch.tensor([ | |
[fx, 0, cx], | |
[0, fy, cy], | |
[0, 0, 1.0] | |
]) | |
w2cs_3x4 = camera_data[:, 7:].reshape(-1, 3, 4) # [t, 3, 4] | |
dummy = torch.tensor([[[0, 0, 0, 1]]] * w2cs_3x4.shape[0]) # [t, 1, 4] | |
w2cs_4x4 = torch.cat([w2cs_3x4, dummy], dim=1) # [t, 4, 4] | |
c2ws_4x4 = w2cs_4x4.inverse() # [t, 4, 4] | |
c2ws_lerp_4x4 = camera_pose_lerp(c2ws_4x4, round(video_length / trace_extract_ratio))[: video_length] | |
from utils.camera_utils import get_camera_condition | |
plucker_embedding, relative_c2w_RT_4x4 = get_camera_condition( | |
H, W, intrinsics_matrix[None, None], c2ws_lerp_4x4[None], mode="c2w", | |
cond_frame_index=0, align_factor=trace_scale_factor | |
) # [B=1, C=6, F, H, W] | |
input_kwargs["plucker_embedding"] = plucker_embedding.to(self.device).to(torch.bfloat16) | |
uid = uuid4().fields[0] | |
if noise_shaping: | |
scene_frames = VideoDecoder(preview_video, device=str(self.device))[:] | |
scene_frames = rearrange(scene_frames / 255 * 2 - 1, "t c h w -> c t h w") # c,f,h,w, value in [-1, 1] | |
latent_scene_frames = self.encode_video(scene_frames.unsqueeze(0)) # b=1,c,f,h,w | |
input_kwargs['latent_scene_frames'] = latent_scene_frames.permute(0, 2, 1, 3, 4) # b=1,c,f,h,w --> b=1,f,c,h,w | |
from models.camera_controller.utils import apply_thresholded_conv | |
scene_mask = (scene_frames < 1).float().amax(0, keepdim=True) # c,f,h,w --> 1,f,h,w | |
scene_mask = apply_thresholded_conv(scene_mask, kernel_size=5, threshold=1.0) # 1,f,h,w | |
latent_scene_mask = torch.cat([ | |
F.interpolate(scene_mask[:, :1].unsqueeze(1), (1, H // 8, W // 8), mode="trilinear", align_corners=True), | |
F.interpolate(scene_mask[:, 1:].unsqueeze(1), ((video_length - 1) // 4, H // 8, W // 8), mode="trilinear", align_corners=True) | |
], dim=2).bool() | |
input_kwargs['latent_scene_mask'] = latent_scene_mask.permute(0, 2, 1, 3, 4) | |
self.vae.cpu() | |
self.transformer.cpu() | |
self.controlnetxs.cpu() | |
torch.cuda.empty_cache() | |
self.text_encoder.to(self.device) | |
input_kwargs |= { | |
"prompt_embedding": self.encode_text(prompt), | |
"negative_prompt_embedding": self.encode_text(negative_prompt), | |
} | |
self.text_encoder.cpu() | |
torch.cuda.empty_cache() | |
self.vae.to(self.device) | |
self.transformer.to(self.device) | |
self.controlnetxs.to(self.device) | |
generated_video = self.validation_step(input_kwargs) | |
video_path = f"{self.result_dir}/{model_name}_{uid:08x}.mp4" | |
os.makedirs(self.result_dir, exist_ok=True) | |
export_to_video(generated_video, video_path, fps=self.save_fps) | |
torch.cuda.empty_cache() | |
return video_path | |