|
""" |
|
This script demonstrates how to generate a video using the CogVideoX model with the Hugging Face `diffusers` pipeline. |
|
The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v), |
|
and video-to-video (v2v), depending on the input data and different weight. |
|
|
|
- text-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b |
|
- video-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b |
|
- image-to-video: THUDM/CogVideoX-5b-I2V or THUDM/CogVideoX1.5-5b-I2V |
|
|
|
Running the Script: |
|
To run the script, use the following command with appropriate arguments: |
|
|
|
```bash |
|
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v" |
|
``` |
|
|
|
You can change `pipe.enable_sequential_cpu_offload()` to `pipe.enable_model_cpu_offload()` to speed up inference, but this will use more GPU memory |
|
|
|
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths. |
|
|
|
""" |
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple |
|
import argparse |
|
import logging |
|
import os |
|
import sys |
|
from typing import Literal, Optional |
|
from pathlib import Path |
|
import json |
|
from datetime import timedelta |
|
import random |
|
from safetensors.torch import load_file, save_file |
|
from tqdm import tqdm |
|
from einops import rearrange, repeat |
|
import math |
|
|
|
import torch |
|
|
|
from diffusers import ( |
|
CogVideoXDPMScheduler, |
|
CogVideoXImageToVideoPipeline, |
|
CogVideoXPipeline, |
|
CogVideoXVideoToVideoPipeline, |
|
AutoencoderKLCogVideoX |
|
) |
|
from diffusers.utils import export_to_video, load_image, load_video |
|
|
|
sys.path.append(os.path.abspath(os.path.join(sys.path[0], "../"))) |
|
from finetune.pipeline.flovd_FVSM_cogvideox_controlnet_pipeline import FloVDCogVideoXControlnetImageToVideoPipeline |
|
from finetune.schemas import Components, Args |
|
from finetune.modules.cogvideox_controlnet import CogVideoXControlnet |
|
from finetune.modules.cogvideox_custom_model import CustomCogVideoXTransformer3DModel |
|
from transformers import AutoTokenizer, T5EncoderModel |
|
|
|
from finetune.modules.camera_sampler import SampleManualCam |
|
from finetune.modules.camera_flow_generator import CameraFlowGenerator |
|
from finetune.modules.utils import get_camera_flow_generator_input, forward_bilinear_splatting |
|
|
|
from finetune.datasets.utils import ( |
|
preprocess_image_with_resize, |
|
preprocess_video_with_resize, |
|
) |
|
|
|
|
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
|
|
import torch.distributed as dist |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
import pdb |
|
sys.path.append(os.path.abspath(os.path.join(sys.path[-1], 'finetune'))) |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
RESOLUTION_MAP = { |
|
|
|
"cogvideox1.5-5b-i2v": (768, 1360), |
|
"cogvideox1.5-5b": (768, 1360), |
|
|
|
"cogvideox-5b-i2v": (480, 720), |
|
"cogvideox-5b": (480, 720), |
|
"cogvideox-2b": (480, 720), |
|
} |
|
|
|
|
|
|
|
|
|
def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs): |
|
"""Initializes distributed environment.""" |
|
if launcher == 'pytorch': |
|
rank = int(os.environ['RANK']) |
|
num_gpus = torch.cuda.device_count() |
|
local_rank = rank % num_gpus |
|
torch.cuda.set_device(local_rank) |
|
dist.init_process_group(backend=backend, timeout=timedelta(minutes=30), **kwargs) |
|
|
|
elif launcher == 'slurm': |
|
proc_id = int(os.environ['SLURM_PROCID']) |
|
ntasks = int(os.environ['SLURM_NTASKS']) |
|
node_list = os.environ['SLURM_NODELIST'] |
|
num_gpus = torch.cuda.device_count() |
|
local_rank = proc_id % num_gpus |
|
torch.cuda.set_device(local_rank) |
|
addr = subprocess.getoutput( |
|
f'scontrol show hostname {node_list} | head -n1') |
|
os.environ['MASTER_ADDR'] = addr |
|
os.environ['WORLD_SIZE'] = str(ntasks) |
|
os.environ['RANK'] = str(proc_id) |
|
port = os.environ.get('PORT', port) |
|
os.environ['MASTER_PORT'] = str(port) |
|
dist.init_process_group(backend=backend, timeout=timedelta(minutes=30)) |
|
|
|
else: |
|
raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!') |
|
|
|
|
|
|
|
return local_rank |
|
|
|
|
|
def load_cogvideox_flovd_FVSM_controlnet_pipeline(controlnet_path, backbone_path, device, dtype): |
|
controlnet_sd = torch.load(controlnet_path)['module'] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(backbone_path, subfolder="tokenizer") |
|
text_encoder = T5EncoderModel.from_pretrained(backbone_path, subfolder="text_encoder") |
|
transformer = CustomCogVideoXTransformer3DModel.from_pretrained(backbone_path, subfolder="transformer") |
|
vae = AutoencoderKLCogVideoX.from_pretrained(backbone_path, subfolder="vae") |
|
scheduler = CogVideoXDPMScheduler.from_pretrained(backbone_path, subfolder="scheduler") |
|
|
|
additional_kwargs = { |
|
'num_layers': 6, |
|
'out_proj_dim_factor': 64, |
|
'out_proj_dim_zero_init': True, |
|
'notextinflow': True, |
|
} |
|
controlnet = CogVideoXControlnet.from_pretrained(backbone_path, subfolder="transformer", **additional_kwargs) |
|
controlnet.eval() |
|
|
|
missing, unexpected = controlnet.load_state_dict(controlnet_sd) |
|
|
|
if len(missing) != 0 or len(unexpected) != 0: |
|
print(f"Missing keys : {missing}") |
|
print(f"Unexpected keys : {unexpected}") |
|
|
|
pipe = FloVDCogVideoXControlnetImageToVideoPipeline( |
|
tokenizer=tokenizer, |
|
text_encoder=text_encoder, |
|
vae=vae, |
|
transformer=transformer, |
|
controlnet=controlnet, |
|
scheduler=scheduler, |
|
) |
|
|
|
|
|
pipe = pipe.to(device, dtype) |
|
|
|
return pipe |
|
|
|
class I2VFlowDataset_Inference(Dataset): |
|
def __init__( |
|
self, |
|
max_num_frames: int, |
|
height: int, |
|
width: int, |
|
data_root: str, |
|
max_num_videos: int = None, |
|
) -> None: |
|
|
|
self.train_resolution = (int(max_num_frames), int(height), int(width)) |
|
|
|
data_root = Path(data_root) |
|
metadata_path = data_root / "metadata_revised.jsonl" |
|
assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl in the root path" |
|
|
|
metadata = [] |
|
with open(metadata_path, "r") as f: |
|
for line in f: |
|
metadata.append( json.loads(line) ) |
|
|
|
metadata = random.sample(metadata, max_num_videos) |
|
|
|
self.prompts = [x["prompt"] for x in metadata] |
|
self.prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata] |
|
self.videos = [data_root / "video_latent" / "x".join(str(x) for x in self.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata] |
|
self.images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata] |
|
self.flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata] |
|
self.masks = [data_root / "valid_mask" / (x["hash_code"] + '.bin') for x in metadata] |
|
|
|
self.max_num_frames = max_num_frames |
|
self.height = height |
|
self.width = width |
|
|
|
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) |
|
self.__image_transforms = self.__frame_transforms |
|
|
|
self.length = len(self.videos) |
|
|
|
print(f"Dataset size: {self.length}") |
|
|
|
def __len__(self) -> int: |
|
return self.length |
|
|
|
def load_data_pair(self, index): |
|
prompt_embedding_path = self.prompt_embeddings[index] |
|
encoded_video_path = self.videos[index] |
|
encoded_flow_path = self.flows[index] |
|
|
|
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"] |
|
encoded_video = load_file(encoded_video_path)["encoded_video"] |
|
encoded_flow = load_file(encoded_flow_path)["encoded_flow_f"] |
|
|
|
return prompt_embedding, encoded_video, encoded_flow |
|
|
|
def __getitem__(self, index: int) -> Dict[str, Any]: |
|
while True: |
|
try: |
|
prompt_embedding, encoded_video, encoded_flow = self.load_data_pair(index) |
|
break |
|
except Exception as e: |
|
print(f"Error loading {self.prompt_embeddings[index]}: {str(e)}") |
|
index = random.randint(0, self.length - 1) |
|
|
|
image_path = self.images[index] |
|
prompt = self.prompts[index] |
|
|
|
_, image = self.preprocess(None, image_path) |
|
image = self.image_transform(image) |
|
|
|
|
|
|
|
|
|
return { |
|
"image": image, |
|
"prompt": prompt, |
|
"prompt_embedding": prompt_embedding, |
|
"encoded_video": encoded_video, |
|
"encoded_flow": encoded_flow, |
|
"video_metadata": { |
|
"num_frames": encoded_video.shape[1], |
|
"height": encoded_video.shape[2], |
|
"width": encoded_video.shape[3], |
|
}, |
|
} |
|
|
|
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if video_path is not None: |
|
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) |
|
else: |
|
video = None |
|
if image_path is not None: |
|
image = preprocess_image_with_resize(image_path, self.height, self.width) |
|
else: |
|
image = None |
|
return video, image |
|
|
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor: |
|
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0) |
|
|
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor: |
|
return self.__image_transforms(image) |
|
|
|
def initialize_flow_generator(target): |
|
depth_estimator_kwargs = { |
|
"target": target, |
|
"kwargs": { |
|
"ckpt_path": '/workspace/workspace/checkpoints/depth_anything/depth_anything_v2_metric_hypersim_vitb.pth', |
|
"model_config": { |
|
"max_depth": 20, |
|
"encoder": 'vitb', |
|
"features": 128, |
|
"out_channels": [96, 192, 384, 768], |
|
} |
|
|
|
} |
|
} |
|
|
|
return CameraFlowGenerator(depth_estimator_kwargs) |
|
|
|
def generate_video( |
|
|
|
launcher: str, |
|
port: int, |
|
data_root: str, |
|
model_path: str, |
|
num_frames: int = 81, |
|
width: Optional[int] = None, |
|
height: Optional[int] = None, |
|
output_path: str = "./output.mp4", |
|
image_path: str = "", |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 6.0, |
|
num_videos_per_prompt: int = 1, |
|
dtype: torch.dtype = torch.bfloat16, |
|
seed: int = 42, |
|
fps: int = 16, |
|
controlnet_guidance_end: float = 0.4, |
|
max_num_videos: int = None, |
|
use_dynamic_cfg: bool = False, |
|
pose_type: str = "manual", |
|
speed: float = 0.5, |
|
): |
|
""" |
|
Generates a video based on the given prompt and saves it to the specified path. |
|
|
|
Parameters: |
|
- prompt (str): The description of the video to be generated. |
|
- model_path (str): The path of the pre-trained model to be used. |
|
- lora_path (str): The path of the LoRA weights to be used. |
|
- lora_rank (int): The rank of the LoRA weights. |
|
- output_path (str): The path where the generated video will be saved. |
|
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality. |
|
- num_frames (int): Number of frames to generate. CogVideoX1.0 generates 49 frames for 6 seconds at 8 fps, while CogVideoX1.5 produces either 81 or 161 frames, corresponding to 5 seconds or 10 seconds at 16 fps. |
|
- width (int): The width of the generated video, applicable only for CogVideoX1.5-5B-I2V |
|
- height (int): The height of the generated video, applicable only for CogVideoX1.5-5B-I2V |
|
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt. |
|
- num_videos_per_prompt (int): Number of videos to generate per prompt. |
|
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16). |
|
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').· |
|
- seed (int): The seed for reproducibility. |
|
- fps (int): The frames per second for the generated video. |
|
""" |
|
|
|
|
|
local_rank = init_dist(launcher=launcher, port=port) |
|
global_rank = dist.get_rank() |
|
num_processes = dist.get_world_size() |
|
is_main_process = global_rank == 0 |
|
|
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
|
|
if is_main_process: |
|
os.makedirs(os.path.join(output_path, 'generated_videos'), exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
image = None |
|
video = None |
|
|
|
model_name = "cogvideox-5b-i2v".lower() |
|
desired_resolution = RESOLUTION_MAP[model_name] |
|
if width is None or height is None: |
|
height, width = desired_resolution |
|
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m") |
|
elif (height, width) != desired_resolution: |
|
if generate_type == "i2v": |
|
|
|
logging.warning( |
|
f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m" |
|
) |
|
|
|
""" |
|
# Prepare Dataset Class.. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
dataset = I2VFlowDataset_Inference( |
|
max_num_frames=num_frames, |
|
height=height, |
|
width=width, |
|
data_root=data_root, |
|
max_num_videos=max_num_videos, |
|
) |
|
|
|
|
|
distributed_sampler = DistributedSampler( |
|
dataset, |
|
num_replicas=num_processes, |
|
rank=global_rank, |
|
shuffle=False, |
|
seed=seed, |
|
) |
|
|
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
sampler=distributed_sampler, |
|
num_workers=4, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
|
|
|
|
""" |
|
# Prepare Pipeline |
|
""" |
|
print(f'Constructing pipeline') |
|
pipe = load_cogvideox_flovd_FVSM_controlnet_pipeline(model_path, backbone_path="THUDM/CogVideoX-5b-I2V", device=local_rank, dtype=dtype) |
|
|
|
assert pose_type in ['re10k', 'manual'], "Choose other pose_type between ['re10k', 'manual']" |
|
if pose_type == 're10k': |
|
root_path = "./manual_poses_re10k" |
|
else: |
|
root_path = "./manual_poses" |
|
|
|
CameraSampler = SampleManualCam(pose_type=pose_type, root_path=root_path) |
|
camera_flow_generator_target = 'finetune.modules.depth_warping.depth_warping.DepthWarping_wrapper' |
|
camera_flow_generator = initialize_flow_generator(camera_flow_generator_target).to(local_rank) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe.vae.enable_slicing() |
|
pipe.vae.enable_tiling() |
|
|
|
dataloader.sampler.set_epoch(1) |
|
dist.barrier() |
|
|
|
output_path = os.path.join(output_path, 'generated_videos') |
|
|
|
data_iter = iter(dataloader) |
|
for step in tqdm(range(0, len(dataloader))): |
|
batch = next(data_iter) |
|
|
|
prompt = batch["prompt"][0] |
|
image = batch["image"].to(local_rank) |
|
prompt_embedding = batch["prompt_embedding"].to(local_rank) |
|
prompt_short = prompt[:20].strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
camparam, cam_name = CameraSampler.sample() |
|
image_torch = ((image.detach().clone()+1)/2. * 255.).squeeze(0) |
|
camera_flow_generator_input = get_camera_flow_generator_input(image_torch, camparam, device=local_rank, speed=speed) |
|
image_torch = ((image_torch.unsqueeze(0) / 255.) * 2. - 1.).to(local_rank) |
|
|
|
with torch.no_grad(): |
|
with torch.cuda.amp.autocast(enabled=True, dtype=dtype): |
|
camera_flow, log_dict = camera_flow_generator(image_torch, camera_flow_generator_input) |
|
camera_flow = camera_flow.to(local_rank, dtype) |
|
|
|
camera_flow_latent = rearrange(encode_flow(camera_flow, pipe.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(local_rank, dtype) |
|
|
|
|
|
|
|
video_generate = pipe( |
|
num_frames=num_frames, |
|
height=height, |
|
width=width, |
|
prompt=None, |
|
prompt_embeds=prompt_embedding, |
|
image=image, |
|
flow_latent=camera_flow_latent, |
|
valid_mask=None, |
|
generator=torch.Generator().manual_seed(seed), |
|
num_inference_steps=50, |
|
controlnet_guidance_start = 0.0, |
|
controlnet_guidance_end = controlnet_guidance_end, |
|
use_dynamic_cfg=use_dynamic_cfg, |
|
).frames[0] |
|
|
|
|
|
save_path = os.path.join(output_path, f"{prompt_short}_DCFG-{use_dynamic_cfg}_ContGuide-{controlnet_guidance_end}_{cam_name}.mp4") |
|
export_to_video(video_generate, save_path, fps=fps) |
|
|
|
dist.barrier() |
|
|
|
step += 1 |
|
|
|
|
|
|
|
def encode_video(video: torch.Tensor, vae) -> torch.Tensor: |
|
|
|
video = video.to(vae.device, dtype=vae.dtype) |
|
latent_dist = vae.encode(video).latent_dist |
|
latent = latent_dist.sample() * vae.config.scaling_factor |
|
return latent |
|
|
|
def encode_flow(flow, vae, flow_scale_factor): |
|
|
|
|
|
assert flow.ndim == 4 |
|
num_frames, _, height, width = flow.shape |
|
|
|
|
|
|
|
flow = rearrange(flow, '(b f) c h w -> b f c h w', b=1) |
|
flow_norm = adaptive_normalize(flow, flow_scale_factor[0], flow_scale_factor[1]) |
|
|
|
|
|
flow_norm = rearrange(flow_norm, 'b f c h w -> (b f) c h w', b=1) |
|
|
|
|
|
num_frames, _, H, W = flow_norm.shape |
|
flow_norm_extended = torch.empty((num_frames, 3, height, width)).to(flow_norm) |
|
flow_norm_extended[:,:2] = flow_norm |
|
flow_norm_extended[:,-1:] = flow_norm.mean(dim=1, keepdim=True) |
|
flow_norm_extended = rearrange(flow_norm_extended, '(b f) c h w -> b c f h w', f=num_frames) |
|
|
|
return encode_video(flow_norm_extended, vae) |
|
|
|
|
|
def decode_flow(flow_latent, vae, flow_scale_factor): |
|
flow_latent = flow_latent.permute(0, 2, 1, 3, 4) |
|
flow_latent = 1 / vae.config.scaling_factor * flow_latent |
|
|
|
flow = vae.decode(flow_latent).sample |
|
|
|
|
|
flow = flow[:,:2].detach().clone() |
|
|
|
|
|
flow = rearrange(flow, 'b c f h w -> b f c h w') |
|
flow = adaptive_unnormalize(flow, flow_scale_factor[0], flow_scale_factor[1]) |
|
|
|
flow = rearrange(flow, 'b f c h w -> (b f) c h w') |
|
return flow |
|
|
|
def adaptive_normalize(flow, sf_x, sf_y): |
|
|
|
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)' |
|
assert sf_x is not None and sf_y is not None |
|
b, f, c, h, w = flow.shape |
|
|
|
max_clip_x = math.sqrt(w/sf_x) * 1.0 |
|
max_clip_y = math.sqrt(h/sf_y) * 1.0 |
|
|
|
flow_norm = flow.detach().clone() |
|
flow_x = flow[:, :, 0].detach().clone() |
|
flow_y = flow[:, :, 1].detach().clone() |
|
|
|
flow_x_norm = torch.sign(flow_x) * torch.sqrt(torch.abs(flow_x)/sf_x + 1e-7) |
|
flow_y_norm = torch.sign(flow_y) * torch.sqrt(torch.abs(flow_y)/sf_y + 1e-7) |
|
|
|
flow_norm[:, :, 0] = torch.clamp(flow_x_norm, min=-max_clip_x, max=max_clip_x) |
|
flow_norm[:, :, 1] = torch.clamp(flow_y_norm, min=-max_clip_y, max=max_clip_y) |
|
|
|
return flow_norm |
|
|
|
|
|
def adaptive_unnormalize(flow, sf_x, sf_y): |
|
|
|
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)' |
|
assert sf_x is not None and sf_y is not None |
|
|
|
flow_orig = flow.detach().clone() |
|
flow_x = flow[:, :, 0].detach().clone() |
|
flow_y = flow[:, :, 1].detach().clone() |
|
|
|
flow_orig[:, :, 0] = torch.sign(flow_x) * sf_x * (flow_x**2 - 1e-7) |
|
flow_orig[:, :, 1] = torch.sign(flow_y) * sf_y * (flow_y**2 - 1e-7) |
|
|
|
return flow_orig |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") |
|
|
|
parser.add_argument("--image_path", type=str, default=None, help="The path of the image to be used as the background of the video",) |
|
parser.add_argument("--data_root", type=str, required=True, help="The path of the dataset root",) |
|
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX1.5-5B", help="Path of the pre-trained model use") |
|
parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video") |
|
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") |
|
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") |
|
parser.add_argument("--num_frames", type=int, default=49, help="Number of steps for the inference process") |
|
parser.add_argument("--width", type=int, default=None, help="The width of the generated video") |
|
parser.add_argument("--height", type=int, default=None, help="The height of the generated video") |
|
parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video") |
|
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") |
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation") |
|
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") |
|
parser.add_argument("--controlnet_guidance_end", type=float, default=0.4, help="Controlnet guidance end during sampling") |
|
parser.add_argument("--max_num_videos", type=int, default=None, help="# of videos for inference") |
|
parser.add_argument("--use_dynamic_cfg", action='store_true') |
|
parser.add_argument("--pose_type", type=str, default='manual', help="pose type in the inference time") |
|
parser.add_argument("--speed", type=float, default=0.5, help="pose type in the inference time") |
|
|
|
|
|
|
|
parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch") |
|
parser.add_argument("--world_size", default=1, type=int, |
|
help="number of the distributed processes.") |
|
parser.add_argument('--local-rank', type=int, default=-1, |
|
help='Replica rank on the current node. This field is required ' |
|
'by `torch.distributed.launch`.') |
|
parser.add_argument("--global_seed", default=42, type=int, |
|
help="seed") |
|
parser.add_argument("--port", type=int) |
|
parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.") |
|
|
|
|
|
args = parser.parse_args() |
|
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 |
|
|
|
|
|
generate_video( |
|
|
|
launcher=args.launcher, |
|
port=args.port, |
|
data_root=args.data_root, |
|
model_path=args.model_path, |
|
output_path=args.output_path, |
|
num_frames=args.num_frames, |
|
width=args.width, |
|
height=args.height, |
|
image_path=args.image_path, |
|
num_inference_steps=args.num_inference_steps, |
|
guidance_scale=args.guidance_scale, |
|
num_videos_per_prompt=args.num_videos_per_prompt, |
|
dtype=dtype, |
|
seed=args.seed, |
|
fps=args.fps, |
|
controlnet_guidance_end=args.controlnet_guidance_end, |
|
max_num_videos=args.max_num_videos, |
|
use_dynamic_cfg=args.use_dynamic_cfg, |
|
pose_type=args.pose_type, |
|
speed=args.speed, |
|
) |
|
|