Spaces:
Paused
Paused
| #! /usr/bin/env python | |
| import json | |
| import os | |
| import time | |
| import click | |
| import numpy as np | |
| import torch | |
| from genmo.lib.progress import progress_bar | |
| from genmo.lib.utils import save_video | |
| from genmo.mochi_preview.pipelines_ti2v_release import ( | |
| DecoderModelFactory, | |
| EncoderModelFactory, | |
| DitModelFactory, | |
| MochiMultiGPUPipeline, | |
| MochiSingleGPUPipeline, | |
| T5ModelFactory, | |
| linear_quadratic_schedule, | |
| ) | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import random | |
| import string | |
| from lightning.pytorch import LightningDataModule | |
| from genmo.mochi_preview.vae.models import Encoder, add_fourier_features | |
| from genmo.mochi_preview.vae.latent_dist import LatentDistribution | |
| import torchvision | |
| from einops import rearrange | |
| from safetensors.torch import load_file | |
| from genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial, decode_latents, decode_latents_tiled_full | |
| from genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents | |
| pipeline = None | |
| model_dir_path = None | |
| num_gpus = torch.cuda.device_count() | |
| cpu_offload = False | |
| dit_path = None | |
| def configure_model(model_dir_path_, dit_path_, cpu_offload_): | |
| global model_dir_path, dit_path, cpu_offload | |
| model_dir_path = model_dir_path_ | |
| dit_path = dit_path_ | |
| cpu_offload = cpu_offload_ | |
| def load_model(): | |
| global num_gpus, pipeline, model_dir_path, dit_path | |
| if pipeline is None: | |
| MOCHI_DIR = model_dir_path | |
| print(f"Launching with {num_gpus} GPUs. If you want to force single GPU mode use CUDA_VISIBLE_DEVICES=0.") | |
| klass = MochiSingleGPUPipeline if num_gpus == 1 else MochiMultiGPUPipeline | |
| kwargs = dict( | |
| text_encoder_factory=T5ModelFactory(), | |
| dit_factory=DitModelFactory( | |
| model_path=dit_path, | |
| model_dtype="bf16" | |
| ), | |
| decoder_factory=DecoderModelFactory( | |
| model_path=f"{MOCHI_DIR}/decoder.safetensors", | |
| ), | |
| encoder_factory=EncoderModelFactory( | |
| model_path=f"{MOCHI_DIR}/encoder.safetensors", | |
| ), | |
| ) | |
| if num_gpus > 1: | |
| assert not cpu_offload, "CPU offload not supported in multi-GPU mode" | |
| kwargs["world_size"] = num_gpus | |
| else: | |
| kwargs["cpu_offload"] = cpu_offload | |
| # kwargs["decode_type"] = "tiled_full" | |
| kwargs["decode_type"] = "tiled_spatial" | |
| pipeline = klass(**kwargs) | |
| def generate_video( | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| num_frames, | |
| seed, | |
| cfg_scale, | |
| num_inference_steps, | |
| data_path, | |
| input_image=None, | |
| noise_multiplier=0, | |
| ): | |
| load_model() | |
| global dit_path | |
| # sigma_schedule should be a list of floats of length (num_inference_steps + 1), | |
| # such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing. | |
| sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025) | |
| # cfg_schedule should be a list of floats of length num_inference_steps. | |
| # For simplicity, we just use the same cfg scale at all timesteps, | |
| # but more optimal schedules may use varying cfg, e.g: | |
| # [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2) | |
| cfg_schedule = [cfg_scale] * num_inference_steps | |
| args = { | |
| "height": height, | |
| "width": width, | |
| "num_frames": num_frames, | |
| "sigma_schedule": sigma_schedule, | |
| "cfg_schedule": cfg_schedule, | |
| "num_inference_steps": num_inference_steps, | |
| # We *need* flash attention to batch cfg | |
| # and it's only worth doing in a high-memory regime (assume multiple GPUs) | |
| "batch_cfg": False, | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "seed": seed, | |
| "data_path": data_path, | |
| "noise_multiplier": noise_multiplier, | |
| } | |
| # Handle different input types | |
| if input_image is not None: | |
| # if "tensor" in input_image: | |
| # Check if this is an image tensor (for image conditioning) or latent tensor | |
| # if len(input_image["tensor"].shape) == 4: # [B, C, H, W] - image tensor | |
| # This is an image tensor, prepare it for conditioning | |
| # cond_position = input_image.get("cond_position", 0) | |
| args["condition_image"] = input_image["tensor"] | |
| args["condition_frame_idx"] = input_image["cond_position"] | |
| # else: # Latent tensor | |
| # args["input_image"] = input_image["tensor"] | |
| # print(args) | |
| with progress_bar(type="tqdm"): | |
| final_frames = pipeline(**args) | |
| final_frames = final_frames[0] | |
| assert isinstance(final_frames, np.ndarray) | |
| assert final_frames.dtype == np.float32 | |
| # Create a results directory based on model name and timestamp | |
| model_name = os.path.basename(dit_path.split('/')[-2]) | |
| checkpoint_name = dit_path.split('/')[-1].split('train_loss')[0] | |
| # Use datetime format for timestamp_dir | |
| from datetime import datetime | |
| timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| if input_image is not None: | |
| cond_position = input_image["cond_position"] | |
| else: | |
| cond_position = "" | |
| results_base_dir = "/home/dyvm6xra/dyvm6xrauser02/raphael/mochi-1-preview/models/video_test_demos_results" | |
| results_dir = os.path.join(results_base_dir, f"{model_name}_{checkpoint_name}_github_user_demo_{cond_position}pos_{num_inference_steps}steps_crop_{noise_multiplier}sigma") | |
| os.makedirs(results_dir, exist_ok=True) | |
| # Extract filename from input_image if available | |
| filename_prefix = "" | |
| if isinstance(input_image, dict) and "filename" in input_image: | |
| filename_prefix = f"{os.path.basename(input_image['filename']).split('.')[0]}_" | |
| output_path = os.path.join( | |
| results_dir, | |
| f"{filename_prefix}{timestamp_str}.mp4" | |
| ) | |
| save_video(final_frames, output_path) | |
| json_path = os.path.splitext(output_path)[0] + ".json" | |
| # Save args to JSON but remove input_image tensor and convert non-serializable objects | |
| json_args = args.copy() | |
| # Handle input_image for JSON serialization | |
| if "input_image" in json_args: | |
| json_args["input_image"] = None | |
| # Handle condition_image for JSON serialization | |
| if "condition_image" in json_args: | |
| json_args["condition_image"] = "Image tensor (removed for JSON)" | |
| if isinstance(input_image, dict): | |
| json_args["input_filename"] = input_image.get("filename", None) | |
| if "cond_position" in input_image: | |
| json_args["condition_frame_idx"] = input_image["cond_position"] | |
| # Convert sigma_schedule and cfg_schedule from tensors to lists if needed | |
| if isinstance(json_args["sigma_schedule"], torch.Tensor): | |
| json_args["sigma_schedule"] = json_args["sigma_schedule"].tolist() | |
| if isinstance(json_args["cfg_schedule"], torch.Tensor): | |
| json_args["cfg_schedule"] = json_args["cfg_schedule"].tolist() | |
| # Handle prompt if it's a tensor or other non-serializable object | |
| if not isinstance(json_args["prompt"], (str, type(None))): | |
| if hasattr(json_args["prompt"], "tolist"): | |
| json_args["prompt"] = "Tensor prompt (converted to string for JSON)" | |
| else: | |
| json_args["prompt"] = str(json_args["prompt"]) | |
| # Handle negative_prompt if it's a tensor | |
| if not isinstance(json_args["negative_prompt"], (str, type(None))): | |
| if hasattr(json_args["negative_prompt"], "tolist"): | |
| json_args["negative_prompt"] = "Tensor negative prompt (converted to string for JSON)" | |
| else: | |
| json_args["negative_prompt"] = str(json_args["negative_prompt"]) | |
| json.dump(json_args, open(json_path, "w"), indent=4) | |
| return output_path | |
| from textwrap import dedent | |
| def generate_cli( | |
| prompt, negative_prompt, width, height, num_frames, seed, cfg_scale, num_steps, model_dir, | |
| dit_path, cpu_offload, data_path, image_dir, prompt_dir, cond_position, noise_multiplier | |
| ): | |
| configure_model(model_dir, dit_path, cpu_offload) | |
| # Case 1: Text to video generation | |
| if image_dir is None: | |
| # Check if we have a prompt directory to process multiple prompts | |
| if prompt_dir is not None and os.path.isdir(prompt_dir): | |
| prompt_files = [f for f in os.listdir(prompt_dir) if f.lower().endswith('.txt')] | |
| prompt_files = prompt_files[210:] # TODO: Remove this | |
| if not prompt_files: | |
| click.echo(f"No prompt files found in {prompt_dir}") | |
| return | |
| click.echo(f"Found {len(prompt_files)} prompt files to process") | |
| for i, prompt_file in enumerate(prompt_files): | |
| file_path = os.path.join(prompt_dir, prompt_file) | |
| click.echo(f"Processing prompt file {i+1}/{len(prompt_files)}: {file_path}") | |
| # Read prompt from file | |
| with open(file_path, 'r') as f: | |
| file_prompt = f.read().strip() | |
| click.echo(f"Using prompt: {file_prompt}") | |
| with torch.inference_mode(): | |
| output = generate_video( | |
| file_prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| num_frames, | |
| seed, | |
| cfg_scale, | |
| num_steps, | |
| data_path, | |
| input_image=None, | |
| noise_multiplier=noise_multiplier, | |
| ) | |
| click.echo(f"Video generated at: {output}") | |
| else: | |
| # Process single prompt as before | |
| click.echo("Running text-to-video generation with provided prompt") | |
| with torch.inference_mode(): | |
| output = generate_video( | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| num_frames, | |
| seed, | |
| cfg_scale, | |
| num_steps, | |
| data_path, | |
| input_image=None, | |
| ) | |
| click.echo(f"Video generated at: {output}") | |
| return | |
| config = dict( | |
| prune_bottlenecks=[False, False, False, False, False], | |
| has_attentions=[False, True, True, True, True], | |
| affine=True, | |
| bias=True, | |
| input_is_conv_1x1=True, | |
| padding_mode="replicate", | |
| ) | |
| # Create VAE encoder | |
| encoder = Encoder( | |
| in_channels=15, | |
| base_channels=64, | |
| channel_multipliers=[1, 2, 4, 6], | |
| num_res_blocks=[3, 3, 4, 6, 3], | |
| latent_dim=12, | |
| temporal_reductions=[1, 2, 3], | |
| spatial_reductions=[2, 2, 2], | |
| **config, | |
| ) | |
| device = torch.device("cuda:0") | |
| encoder = encoder.to(device, memory_format=torch.channels_last_3d) | |
| encoder.load_state_dict(load_file(f"{model_dir}/encoder.safetensors")) | |
| encoder.eval() | |
| # Case 2: Image-to-video, image_dir is a single file | |
| if image_dir is not None and os.path.isfile(image_dir) and image_dir.lower().endswith(('.jpg', '.jpeg', '.png')): | |
| click.echo(f"Processing single image: {image_dir}") | |
| # Load the image | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| image = Image.open(image_dir) | |
| # Crop and resize the image to the target dimensions rather than directly resize | |
| # Calculate crop dimensions to maintain aspect ratio | |
| target_ratio = width / height | |
| current_ratio = image.width / image.height | |
| if current_ratio > target_ratio: | |
| # Image is wider than target ratio - crop width | |
| new_width = int(image.height * target_ratio) | |
| x1 = (image.width - new_width) // 2 | |
| image = image.crop((x1, 0, x1 + new_width, image.height)) | |
| else: | |
| # Image is taller than target ratio - crop height | |
| new_height = int(image.width / target_ratio) | |
| y1 = (image.height - new_height) // 2 | |
| image = image.crop((0, y1, image.width, y1 + new_height)) | |
| # Now resize the cropped image | |
| transform = transforms.Compose([ | |
| transforms.Resize((height, width)), | |
| transforms.ToTensor(), | |
| ]) | |
| image_tensor = (transform(image)* 2 - 1).unsqueeze(1).unsqueeze(0) | |
| print("image_tensor.shape", image_tensor.shape) | |
| image_tensor = add_fourier_features(image_tensor.to(device)) | |
| # Encode image to latent | |
| with torch.inference_mode(): | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| t0 = time.time() | |
| encoder = encoder.to(device) | |
| ldist = encoder(image_tensor) | |
| image_tensor = ldist.sample() | |
| torch.cuda.empty_cache() | |
| encoder = encoder.to("cpu") | |
| del ldist | |
| # Package input for generate_video | |
| input_image = { | |
| "tensor": image_tensor, | |
| "filename": os.path.basename(image_dir), | |
| "cond_position": cond_position | |
| } | |
| with torch.inference_mode(): | |
| output = generate_video( | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| num_frames, | |
| seed, | |
| cfg_scale, | |
| num_steps, | |
| data_path, | |
| input_image, | |
| noise_multiplier=noise_multiplier, | |
| ) | |
| click.echo(f"Video generated at: {output}") | |
| return | |
| # Case 3: image_dir is a directory of images | |
| if image_dir is not None and os.path.isdir(image_dir): | |
| # Get all image files in the directory | |
| image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
| if not image_files: | |
| click.echo(f"No image files found in {image_dir}") | |
| return | |
| click.echo(f"Found {len(image_files)} image files to process") | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| with torch.inference_mode(): | |
| for i, image_file in enumerate(image_files): | |
| file_path = os.path.join(image_dir, image_file) | |
| click.echo(f"Processing file {i+1}/{len(image_files)}: {file_path}") | |
| # Load image | |
| image = Image.open(file_path) | |
| # Calculate crop dimensions to maintain aspect ratio | |
| target_ratio = width / height | |
| current_ratio = image.width / image.height | |
| if current_ratio > target_ratio: | |
| # Image is wider than target ratio - crop width | |
| new_width = int(image.height * target_ratio) | |
| x1 = (image.width - new_width) // 2 | |
| image = image.crop((x1, 0, x1 + new_width, image.height)) | |
| else: | |
| # Image is taller than target ratio - crop height | |
| new_height = int(image.width / target_ratio) | |
| y1 = (image.height - new_height) // 2 | |
| image = image.crop((0, y1, image.width, y1 + new_height)) | |
| # Now resize the cropped image | |
| image = image.resize((width, height)) | |
| image_tensor = (transform(image)* 2 - 1).unsqueeze(1).unsqueeze(0) | |
| print("image_tensor.shape", image_tensor.shape) | |
| image_tensor = add_fourier_features(image_tensor.to(device)) | |
| # Encode image to latent | |
| with torch.inference_mode(): | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| t0 = time.time() | |
| encoder = encoder.to(device) | |
| ldist = encoder(image_tensor) | |
| image_tensor = ldist.sample() | |
| torch.cuda.empty_cache() | |
| encoder = encoder.to("cpu") | |
| del ldist | |
| # Get corresponding prompt | |
| img_basename = os.path.basename(file_path).split('.')[0] | |
| prompt_file = os.path.join(prompt_dir, f"{img_basename}.txt") | |
| if os.path.exists(prompt_file): | |
| with open(prompt_file, 'r') as f: | |
| file_prompt = f.read().strip() | |
| click.echo(f"Using prompt from file: {file_prompt}") | |
| else: | |
| click.echo(f"Warning: Prompt file not found for {file_path}. Using default prompt.") | |
| # Package input for generate_video | |
| input_image = { | |
| "tensor": image_tensor, | |
| "filename": os.path.basename(file_path), | |
| "cond_position": cond_position | |
| } | |
| output = generate_video( | |
| file_prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| num_frames, | |
| seed, | |
| cfg_scale, | |
| num_steps, | |
| data_path, | |
| input_image, | |
| noise_multiplier=noise_multiplier, | |
| ) | |
| click.echo(f"Video generated at: {output}") | |
| return | |
| if __name__ == "__main__": | |
| generate_cli() | |