|
|
""" |
|
|
Adapted from: |
|
|
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py |
|
|
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py |
|
|
""" |
|
|
|
|
|
import click |
|
|
import torch |
|
|
import torchvision |
|
|
from pathlib import Path |
|
|
from diffusers import AutoencoderKLMochi, MochiPipeline |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str): |
|
|
T, H, W = [int(s) for s in shape.split("x")] |
|
|
assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6" |
|
|
video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs") |
|
|
fps = metadata["video_fps"] |
|
|
video = video.permute(3, 0, 1, 2) |
|
|
og_shape = video.shape |
|
|
assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}" |
|
|
assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}" |
|
|
assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}" |
|
|
if video.shape[1] > T: |
|
|
video = video[:, :T] |
|
|
print(f"Trimmed video from {og_shape[1]} to first {T} frames") |
|
|
video = video.unsqueeze(0) |
|
|
video = video.float() / 127.5 - 1.0 |
|
|
video = video.to(model.device) |
|
|
|
|
|
assert video.ndim == 5 |
|
|
|
|
|
with torch.inference_mode(): |
|
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
|
ldist = model._encode(video) |
|
|
|
|
|
torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt")) |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path)) |
|
|
@click.option( |
|
|
"--model_id", |
|
|
type=str, |
|
|
help="Repo id. Should be genmo/mochi-1-preview", |
|
|
default="genmo/mochi-1-preview", |
|
|
) |
|
|
@click.option("--shape", default="163x480x848", help="Shape of the video to encode") |
|
|
@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.") |
|
|
def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None: |
|
|
"""Process all videos and captions in a directory using a single GPU.""" |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
|
|
|
video_paths = list(output_dir.glob("**/*.mp4")) |
|
|
if not video_paths: |
|
|
print(f"No MP4 files found in {output_dir}") |
|
|
return |
|
|
|
|
|
text_paths = list(output_dir.glob("**/*.txt")) |
|
|
if not text_paths: |
|
|
print(f"No text files found in {output_dir}") |
|
|
return |
|
|
|
|
|
|
|
|
vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda") |
|
|
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder") |
|
|
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer") |
|
|
pipeline = MochiPipeline.from_pretrained( |
|
|
model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None |
|
|
).to("cuda") |
|
|
|
|
|
for idx, video_path in tqdm(enumerate(sorted(video_paths))): |
|
|
print(f"Processing {video_path}") |
|
|
try: |
|
|
if video_path.with_suffix(".latent.pt").exists() and not overwrite: |
|
|
print(f"Skipping {video_path}") |
|
|
continue |
|
|
|
|
|
|
|
|
encode_videos(vae, vid_path=video_path, shape=shape) |
|
|
|
|
|
|
|
|
prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt") |
|
|
embed_path = prompt_path.with_suffix(".embed.pt") |
|
|
|
|
|
if embed_path.exists() and not overwrite: |
|
|
print(f"Skipping {prompt_path} - embeddings already exist") |
|
|
continue |
|
|
|
|
|
with open(prompt_path) as f: |
|
|
text = f.read().strip() |
|
|
with torch.inference_mode(): |
|
|
conditioning = pipeline.encode_prompt(prompt=[text]) |
|
|
|
|
|
conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]} |
|
|
torch.save(conditioning, embed_path) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
print(f"Error processing {video_path}: {str(e)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
batch_process() |
|
|
|