孙振宇
Initial HF Spaces deployment
060fbda
import argparse
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, List
def validate_args_and_show_help():
"""
Parse CLI arguments, validate the input folder, and return resolved paths and parsed args.
Parses command-line options for input, output, pattern, quiet, and model; converts input and output to resolved Path objects and validates that the input path exists and is a directory. Exits the process with code 1 if the input path is missing or not a directory.
Returns:
(input_folder, output_folder, args):
input_folder (Path): Resolved Path to the input directory.
output_folder (Path): Resolved Path to the output directory.
args (argparse.Namespace): Parsed command-line arguments.
"""
parser = argparse.ArgumentParser(
description="🎬 Batch process videos to remove Sora watermarks",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Process all .mp4 files in input folder
python batch_process.py -i /path/to/input -o /path/to/output
# Process all .mov files
python batch_process.py -i /path/to/input -o /path/to/output --pattern "*.mov"
# Process all video files (mp4, mov, avi)
python batch_process.py -i /path/to/input -o /path/to/output --pattern "*.{mp4,mov,avi}"
# Without displaying the Tqdm bar inside sorawm procrssing.
python batch_process.py -i /path/to/input -o /path/to/output --quiet
""",
)
parser.add_argument(
"-i",
"--input",
type=str,
required=True,
help="📁 Input folder containing video files",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=True,
help="📁 Output folder for cleaned videos",
)
parser.add_argument(
"-p",
"--pattern",
type=str,
default="*.mp4",
help="🔍 File pattern to match (default: *.mp4)",
)
parser.add_argument(
"--quiet",
action="store_true",
default=False,
help="Run in quiet mode (suppress tqdm and most logs).",
)
parser.add_argument(
"-m",
"--model",
type=str,
default="lama",
choices=["lama", "e2fgvi_hq"],
help="🔧 Model to use for watermark removal (default: lama). Options: lama (fast, may flicker), e2fgvi_hq (time consistent, slower)",
)
args = parser.parse_args()
# Convert to Path objects
input_folder = Path(args.input).expanduser().resolve()
output_folder = Path(args.output).expanduser().resolve()
# Validate input folder
if not input_folder.exists():
print(f"❌ Error: Input folder does not exist: {input_folder}", file=sys.stderr)
sys.exit(1)
if not input_folder.is_dir():
print(
f"❌ Error: Input path is not a directory: {input_folder}", file=sys.stderr
)
sys.exit(1)
return input_folder, output_folder, args
# Classes are now defined inside main() after imports
def main():
# Validate arguments BEFORE loading heavy dependencies (ffmpeg, torch, etc.)
"""
Orchestrate CLI argument validation, lazy-load heavy dependencies, and run the batch video processing workflow.
Validates and processes command-line arguments, imports runtime-only dependencies, selects the watermark removal model, constructs and runs the batch processor, and handles termination: exits with code 130 on user interrupt and with code 1 on other fatal errors.
"""
input_folder, output_folder, args = validate_args_and_show_help()
pattern = args.pattern
# Only NOW import heavy dependencies after validation passed
from rich import box
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
ProgressColumn,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich.text import Text
from rich.text import Text as RichText
from sorawm.core import SoraWM
from sorawm.schemas import CleanerType
# Initialize console after importing rich
console = Console()
# Make SpeedColumn a proper ProgressColumn subclass now that we've imported it
global SpeedColumn
class SpeedColumnImpl(ProgressColumn):
"""Custom column to display processing speed in it/s format (only for video processing)"""
def render(self, task):
"""Render the speed in it/s format, but only for video processing tasks"""
# Only show speed for video processing, not for overall batch progress
if "Overall Progress" in task.description:
return RichText("", style="")
speed = task.finished_speed or task.speed
if speed is None:
return RichText("-- it/s", style="progress.data.speed")
return RichText(f"{speed:.2f} it/s", style="cyan")
SpeedColumn = SpeedColumnImpl
# Define BatchProcessor here to have access to all imports
class BatchProcessorImpl:
"""Batch video processor with progress tracking"""
def __init__(
self,
input_folder: Path,
output_folder: Path,
pattern: str = "*.mp4",
cleaner_type: CleanerType = CleanerType.LAMA,
):
"""
Initialize the batch processor with paths, file-matching pattern, and watermark cleaner selection.
Parameters:
input_folder (Path): Directory containing videos to process.
output_folder (Path): Directory where cleaned videos will be written.
pattern (str): Glob pattern used to find video files in the input folder (default: "*.mp4").
cleaner_type (CleanerType): Cleaner model to use for watermark removal (e.g., CleanerType.LAMA or CleanerType.E2FGVI_HQ).
"""
self.input_folder = input_folder
self.output_folder = output_folder
self.pattern = pattern
self.sora_wm = SoraWM(cleaner_type=cleaner_type)
self.console = console
# Statistics
self.successful: List[str] = []
self.failed: Dict[str, str] = {}
def show_banner(self):
"""Display a colorful welcome banner"""
banner_text = Text()
banner_text.append("🎬 ", style="bold yellow")
banner_text.append("Sora Watermark Remover", style="bold cyan")
banner_text.append(" - Batch Processor", style="bold magenta")
panel = Panel(
banner_text,
box=box.DOUBLE,
border_style="bright_blue",
padding=(1, 2),
)
console.print(panel)
console.print()
def find_videos(self) -> List[Path]:
"""Find all video files matching the pattern"""
video_files = list(self.input_folder.glob(self.pattern))
return sorted(video_files)
def process_batch(self):
"""Process all videos in the batch with progress tracking"""
# Show banner
self.show_banner()
# Find all videos
video_files = self.find_videos()
if not video_files:
console.print(
f"[bold red]❌ No files matching '{self.pattern}' found in {self.input_folder}[/bold red]"
)
return
# Display configuration
config_table = Table(show_header=False, box=box.SIMPLE, padding=(0, 1))
config_table.add_row(
"📁 Input folder:", f"[cyan]{self.input_folder}[/cyan]"
)
config_table.add_row(
"📁 Output folder:", f"[green]{self.output_folder}[/green]"
)
config_table.add_row("🔍 Pattern:", f"[yellow]{self.pattern}[/yellow]")
config_table.add_row(
"🎬 Videos found:", f"[bold magenta]{len(video_files)}[/bold magenta]"
)
console.print(config_table)
console.print()
# Create output folder
self.output_folder.mkdir(parents=True, exist_ok=True)
# Process each video with batch-level progress bar
start_time = datetime.now()
# Create rich progress display
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(bar_width=40),
TaskProgressColumn(),
MofNCompleteColumn(),
SpeedColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
# Batch progress task
batch_task = progress.add_task(
"[cyan]Overall Progress", total=len(video_files)
)
for idx, input_path in enumerate(video_files, 1):
output_path = self.output_folder / f"cleaned_{input_path.name}"
# Update batch task description
progress.update(
batch_task,
description=f"[cyan]Overall Progress ({idx}/{len(video_files)})",
)
# Show current file being processed
console.print(
f"\n[bold blue]📹 [{idx}/{len(video_files)}][/bold blue] "
f"[yellow]{input_path.name}[/yellow]"
)
try:
# Video processing task
video_task = progress.add_task(
f" [green]Processing video", total=100
)
last_progress = [0]
def progress_callback(prog: int):
"""Update the video progress bar"""
if prog > last_progress[0]:
progress.update(
video_task, advance=prog - last_progress[0]
)
last_progress[0] = prog
# Process the video (quiet=True suppresses internal tqdm bars if enabled)
self.sora_wm.run(
input_path, output_path, progress_callback, quiet=args.quiet
)
# Ensure video progress reaches 100%
if last_progress[0] < 100:
progress.update(video_task, advance=100 - last_progress[0])
progress.remove_task(video_task)
self.successful.append(input_path.name)
console.print(
f" [bold green]✅ Completed:[/bold green] {output_path.name}"
)
except Exception as e:
progress.remove_task(video_task)
self.failed[input_path.name] = str(e)
console.print(f" [bold red]❌ Error:[/bold red] {e}")
# Update batch progress
progress.update(batch_task, advance=1)
# Print summary
self._print_summary(start_time)
def _print_summary(self, start_time: datetime):
"""Print processing summary with rich formatting"""
end_time = datetime.now()
duration = end_time - start_time
console.print()
# Create summary statistics table
summary_table = Table(
show_header=False, box=box.ROUNDED, border_style="cyan"
)
summary_table.add_column("Metric", style="bold")
summary_table.add_column("Value")
summary_table.add_row("⏱️ Total Time", f"[yellow]{duration}[/yellow]")
summary_table.add_row(
"✅ Successful", f"[bold green]{len(self.successful)}[/bold green]"
)
summary_table.add_row(
"❌ Failed", f"[bold red]{len(self.failed)}[/bold red]"
)
summary_table.add_row(
"📊 Total",
f"[bold magenta]{len(self.successful) + len(self.failed)}[/bold magenta]",
)
# Success rate
total = len(self.successful) + len(self.failed)
success_rate = (len(self.successful) / total * 100) if total > 0 else 0
summary_table.add_row(
"📈 Success Rate", f"[bold cyan]{success_rate:.1f}%[/bold cyan]"
)
# Wrap in a panel
summary_panel = Panel(
summary_table,
title="[bold white]📋 BATCH PROCESSING SUMMARY[/bold white]",
border_style="bright_cyan",
box=box.DOUBLE,
)
console.print(summary_panel)
# Successful files
if self.successful:
console.print()
success_table = Table(
title="[bold green]✅ Successfully Processed[/bold green]",
box=box.SIMPLE,
show_header=True,
header_style="bold green",
)
success_table.add_column("#", style="dim", width=4)
success_table.add_column("Filename", style="green")
for idx, filename in enumerate(self.successful, 1):
success_table.add_row(str(idx), filename)
console.print(success_table)
# Failed files
if self.failed:
console.print()
failed_table = Table(
title="[bold red]❌ Failed to Process[/bold red]",
box=box.SIMPLE,
show_header=True,
header_style="bold red",
)
failed_table.add_column("#", style="dim", width=4)
failed_table.add_column("Filename", style="red")
failed_table.add_column("Error", style="dim")
for idx, (filename, error) in enumerate(self.failed.items(), 1):
# Truncate long error messages
error_msg = error if len(error) < 60 else error[:57] + "..."
failed_table.add_row(str(idx), filename, error_msg)
console.print(failed_table)
# Final message
console.print()
if len(self.failed) == 0:
console.print(
"[bold green]🎉 All videos processed successfully![/bold green]",
justify="center",
)
else:
console.print(
"[bold yellow]⚠️ Some videos failed to process. Check errors above.[/bold yellow]",
justify="center",
)
console.print()
# Create processor and run
try:
cleaner_type = (
CleanerType.LAMA if args.model == "lama" else CleanerType.E2FGVI_HQ
)
processor = BatchProcessorImpl(
input_folder, output_folder, pattern, cleaner_type
)
processor.process_batch()
except KeyboardInterrupt:
console.print()
console.print(
"[bold yellow]⚠️ Processing interrupted by user[/bold yellow]",
justify="center",
)
sys.exit(130)
except Exception as e:
console.print()
console.print(f"[bold red]❌ Fatal error:[/bold red] {e}")
sys.exit(1)
if __name__ == "__main__":
main()
1