|
|
""" |
|
|
Compute reference videos for IC-LoRA training. |
|
|
|
|
|
This script provides a command-line interface for generating reference videos to be used for IC-LoRA training. |
|
|
Note that it reads and writes to the same file (the output of caption_videos.py), |
|
|
where it adds the "reference_path" field to the JSON. |
|
|
|
|
|
Basic usage: |
|
|
# Compute reference videos for all videos in a directory |
|
|
compute_reference.py videos_dir/ --output videos_dir/captions.json |
|
|
""" |
|
|
|
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
|
|
|
|
|
|
import cv2 |
|
|
import torch |
|
|
import torchvision.transforms.functional as TF |
|
|
import typer |
|
|
from rich.console import Console |
|
|
from rich.progress import ( |
|
|
BarColumn, |
|
|
MofNCompleteColumn, |
|
|
Progress, |
|
|
SpinnerColumn, |
|
|
TextColumn, |
|
|
TimeElapsedColumn, |
|
|
TimeRemainingColumn, |
|
|
) |
|
|
from transformers.utils.logging import disable_progress_bar |
|
|
|
|
|
|
|
|
from ltx_trainer.video_utils import read_video, save_video |
|
|
|
|
|
|
|
|
console = Console() |
|
|
disable_progress_bar() |
|
|
|
|
|
|
|
|
def compute_reference( |
|
|
images: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Compute Canny edge detection on a batch of images. |
|
|
|
|
|
Args: |
|
|
images: Batch of images tensor of shape [B, C, H, W] |
|
|
|
|
|
Returns: |
|
|
Binary edge masks tensor of shape [B, H, W] |
|
|
""" |
|
|
|
|
|
if images.shape[1] == 3: |
|
|
images = TF.rgb_to_grayscale(images) |
|
|
|
|
|
|
|
|
if images.max() > 1.0: |
|
|
images = images / 255.0 |
|
|
|
|
|
|
|
|
edge_masks = [] |
|
|
for image in images: |
|
|
|
|
|
image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8") |
|
|
|
|
|
|
|
|
edges = cv2.Canny( |
|
|
image_np, |
|
|
threshold1=100, |
|
|
threshold2=200, |
|
|
) |
|
|
|
|
|
|
|
|
edge_mask = torch.from_numpy(edges).float() |
|
|
edge_masks.append(edge_mask) |
|
|
|
|
|
edges = torch.stack(edge_masks) |
|
|
edges = torch.stack([edges] * 3, dim=1) |
|
|
return edges |
|
|
|
|
|
|
|
|
def _get_meta_data( |
|
|
output_path: Path, |
|
|
) -> Dict[str, str]: |
|
|
"""Get set of existing reference video paths without loading the actual files. |
|
|
|
|
|
Args: |
|
|
output_path: Path to the reference video paths file |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping media paths to reference video paths |
|
|
""" |
|
|
if not output_path.exists(): |
|
|
return {} |
|
|
|
|
|
console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]") |
|
|
|
|
|
try: |
|
|
with output_path.open("r", encoding="utf-8") as f: |
|
|
json_data = json.load(f) |
|
|
return json_data |
|
|
|
|
|
except Exception as e: |
|
|
console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]") |
|
|
return {} |
|
|
|
|
|
|
|
|
def _save_dataset_json( |
|
|
reference_paths: Dict[str, str], |
|
|
output_path: Path, |
|
|
) -> None: |
|
|
"""Save dataset json with reference video paths. |
|
|
|
|
|
Args: |
|
|
reference_paths: Dictionary mapping media paths to reference video paths |
|
|
output_path: Path to save the output file |
|
|
""" |
|
|
|
|
|
with output_path.open("r", encoding="utf-8") as f: |
|
|
json_data = json.load(f) |
|
|
new_json_data = json_data.copy() |
|
|
for i, item in enumerate(json_data): |
|
|
media_path = item["media_path"] |
|
|
reference_path = reference_paths[media_path] |
|
|
new_json_data[i]["reference_path"] = reference_path |
|
|
|
|
|
with output_path.open("w", encoding="utf-8") as f: |
|
|
json.dump(new_json_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]") |
|
|
console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:") |
|
|
console.print(" reference_column='[cyan]reference_path[/]'") |
|
|
console.print(" video_column='[cyan]media_path[/]'") |
|
|
|
|
|
|
|
|
def process_media( |
|
|
input_path: Path, |
|
|
output_path: Path, |
|
|
override: bool, |
|
|
batch_size: int = 100, |
|
|
) -> None: |
|
|
"""Process videos and images to compute condition on videos. |
|
|
|
|
|
Args: |
|
|
input_path: Path to input video/image file or directory |
|
|
output_path: Path to output reference video file |
|
|
override: Whether to override existing reference video files |
|
|
""" |
|
|
if not output_path.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"Output file does not exist: {output_path}. This is also the input file for the dataset." |
|
|
) |
|
|
|
|
|
|
|
|
meta_data = _get_meta_data(output_path) |
|
|
|
|
|
base_dir = input_path.resolve() |
|
|
console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths") |
|
|
|
|
|
|
|
|
media_to_process = [] |
|
|
skipped_media = [] |
|
|
|
|
|
def media_path_to_reference_path(media_file: Path) -> Path: |
|
|
return media_file.parent / (media_file.stem + "_reference" + media_file.suffix) |
|
|
|
|
|
media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data] |
|
|
for media_file in media_files: |
|
|
reference_path = media_path_to_reference_path(media_file) |
|
|
media_to_process.append(media_file) |
|
|
|
|
|
console.print(f"Processing [bold]{len(media_to_process)}[/] media.") |
|
|
|
|
|
|
|
|
progress = Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("{task.description}"), |
|
|
BarColumn(bar_width=40), |
|
|
MofNCompleteColumn(), |
|
|
TimeElapsedColumn(), |
|
|
TextColumn("•"), |
|
|
TimeRemainingColumn(), |
|
|
console=console, |
|
|
) |
|
|
|
|
|
|
|
|
media_paths = [item["media_path"] for item in meta_data] |
|
|
reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths} |
|
|
|
|
|
with progress: |
|
|
task = progress.add_task("Computing condition on videos", total=len(media_to_process)) |
|
|
|
|
|
for media_file in media_to_process: |
|
|
progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]") |
|
|
|
|
|
rel_path = str(media_file.resolve().relative_to(base_dir)) |
|
|
reference_path = media_path_to_reference_path(media_file) |
|
|
reference_paths[rel_path] = str(reference_path.relative_to(base_dir)) |
|
|
|
|
|
if not reference_path.resolve().exists() or override: |
|
|
try: |
|
|
video, fps = read_video(media_file) |
|
|
|
|
|
|
|
|
condition_frames = [] |
|
|
|
|
|
for i in range(0, len(video), batch_size): |
|
|
batch = video[i : i + batch_size] |
|
|
condition_batch = compute_reference(batch) |
|
|
condition_frames.append(condition_batch) |
|
|
|
|
|
|
|
|
all_condition = torch.cat(condition_frames, dim=0) |
|
|
|
|
|
|
|
|
save_video(all_condition, reference_path.resolve(), fps=fps) |
|
|
|
|
|
except Exception as e: |
|
|
console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]") |
|
|
reference_paths.pop(rel_path) |
|
|
else: |
|
|
skipped_media.append(media_file) |
|
|
|
|
|
progress.advance(task) |
|
|
|
|
|
|
|
|
_save_dataset_json(reference_paths, output_path) |
|
|
|
|
|
|
|
|
total_to_process = len(media_files) - len(skipped_media) |
|
|
console.print( |
|
|
f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.", |
|
|
) |
|
|
|
|
|
|
|
|
app = typer.Typer( |
|
|
pretty_exceptions_enable=False, |
|
|
no_args_is_help=True, |
|
|
help="Compute reference videos for IC-LoRA training.", |
|
|
) |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def main( |
|
|
input_path: Path = typer.Argument( |
|
|
..., |
|
|
help="Path to input video/image file or directory containing media files", |
|
|
exists=True, |
|
|
), |
|
|
output: Path | None = typer.Option( |
|
|
None, |
|
|
"--output", |
|
|
"-o", |
|
|
help="Path to json output file for reference video paths. " |
|
|
"This is also the input file for the dataset, the output of compute_captions.py.", |
|
|
), |
|
|
override: bool = typer.Option( |
|
|
False, |
|
|
"--override", |
|
|
help="Whether to override existing reference video files", |
|
|
), |
|
|
batch_size: int = typer.Option( |
|
|
100, |
|
|
"--batch-size", |
|
|
help="Batch size for processing videos", |
|
|
), |
|
|
) -> None: |
|
|
"""Compute reference videos for IC-LoRA training. |
|
|
|
|
|
This script generates reference videos (e.g., Canny edge maps) for given videos. |
|
|
The paths in the output file will be relative to the output file's directory. |
|
|
|
|
|
Examples: |
|
|
# Process all videos in a directory |
|
|
compute_reference.py videos_dir/ -o videos_dir/captions.json |
|
|
""" |
|
|
|
|
|
|
|
|
output = Path(output).resolve() |
|
|
console.print(f"Output will be saved to [bold blue]{output}[/]") |
|
|
|
|
|
|
|
|
if not output.exists(): |
|
|
raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.") |
|
|
|
|
|
|
|
|
process_media( |
|
|
input_path=input_path, |
|
|
output_path=output, |
|
|
override=override, |
|
|
batch_size=batch_size, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app() |
|
|
|