Spaces:
Running
Running
| import argparse | |
| import sys | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, List | |
| def validate_args_and_show_help(): | |
| """ | |
| Parse CLI arguments, validate the input folder, and return resolved paths and parsed args. | |
| Parses command-line options for input, output, pattern, quiet, and model; converts input and output to resolved Path objects and validates that the input path exists and is a directory. Exits the process with code 1 if the input path is missing or not a directory. | |
| Returns: | |
| (input_folder, output_folder, args): | |
| input_folder (Path): Resolved Path to the input directory. | |
| output_folder (Path): Resolved Path to the output directory. | |
| args (argparse.Namespace): Parsed command-line arguments. | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="🎬 Batch process videos to remove Sora watermarks", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Process all .mp4 files in input folder | |
| python batch_process.py -i /path/to/input -o /path/to/output | |
| # Process all .mov files | |
| python batch_process.py -i /path/to/input -o /path/to/output --pattern "*.mov" | |
| # Process all video files (mp4, mov, avi) | |
| python batch_process.py -i /path/to/input -o /path/to/output --pattern "*.{mp4,mov,avi}" | |
| # Without displaying the Tqdm bar inside sorawm procrssing. | |
| python batch_process.py -i /path/to/input -o /path/to/output --quiet | |
| """, | |
| ) | |
| parser.add_argument( | |
| "-i", | |
| "--input", | |
| type=str, | |
| required=True, | |
| help="📁 Input folder containing video files", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| type=str, | |
| required=True, | |
| help="📁 Output folder for cleaned videos", | |
| ) | |
| parser.add_argument( | |
| "-p", | |
| "--pattern", | |
| type=str, | |
| default="*.mp4", | |
| help="🔍 File pattern to match (default: *.mp4)", | |
| ) | |
| parser.add_argument( | |
| "--quiet", | |
| action="store_true", | |
| default=False, | |
| help="Run in quiet mode (suppress tqdm and most logs).", | |
| ) | |
| parser.add_argument( | |
| "-m", | |
| "--model", | |
| type=str, | |
| default="lama", | |
| choices=["lama", "e2fgvi_hq"], | |
| help="🔧 Model to use for watermark removal (default: lama). Options: lama (fast, may flicker), e2fgvi_hq (time consistent, slower)", | |
| ) | |
| args = parser.parse_args() | |
| # Convert to Path objects | |
| input_folder = Path(args.input).expanduser().resolve() | |
| output_folder = Path(args.output).expanduser().resolve() | |
| # Validate input folder | |
| if not input_folder.exists(): | |
| print(f"❌ Error: Input folder does not exist: {input_folder}", file=sys.stderr) | |
| sys.exit(1) | |
| if not input_folder.is_dir(): | |
| print( | |
| f"❌ Error: Input path is not a directory: {input_folder}", file=sys.stderr | |
| ) | |
| sys.exit(1) | |
| return input_folder, output_folder, args | |
| # Classes are now defined inside main() after imports | |
| def main(): | |
| # Validate arguments BEFORE loading heavy dependencies (ffmpeg, torch, etc.) | |
| """ | |
| Orchestrate CLI argument validation, lazy-load heavy dependencies, and run the batch video processing workflow. | |
| Validates and processes command-line arguments, imports runtime-only dependencies, selects the watermark removal model, constructs and runs the batch processor, and handles termination: exits with code 130 on user interrupt and with code 1 on other fatal errors. | |
| """ | |
| input_folder, output_folder, args = validate_args_and_show_help() | |
| pattern = args.pattern | |
| # Only NOW import heavy dependencies after validation passed | |
| from rich import box | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.progress import ( | |
| BarColumn, | |
| MofNCompleteColumn, | |
| Progress, | |
| ProgressColumn, | |
| SpinnerColumn, | |
| TaskProgressColumn, | |
| TextColumn, | |
| TimeElapsedColumn, | |
| TimeRemainingColumn, | |
| ) | |
| from rich.table import Table | |
| from rich.text import Text | |
| from rich.text import Text as RichText | |
| from sorawm.core import SoraWM | |
| from sorawm.schemas import CleanerType | |
| # Initialize console after importing rich | |
| console = Console() | |
| # Make SpeedColumn a proper ProgressColumn subclass now that we've imported it | |
| global SpeedColumn | |
| class SpeedColumnImpl(ProgressColumn): | |
| """Custom column to display processing speed in it/s format (only for video processing)""" | |
| def render(self, task): | |
| """Render the speed in it/s format, but only for video processing tasks""" | |
| # Only show speed for video processing, not for overall batch progress | |
| if "Overall Progress" in task.description: | |
| return RichText("", style="") | |
| speed = task.finished_speed or task.speed | |
| if speed is None: | |
| return RichText("-- it/s", style="progress.data.speed") | |
| return RichText(f"{speed:.2f} it/s", style="cyan") | |
| SpeedColumn = SpeedColumnImpl | |
| # Define BatchProcessor here to have access to all imports | |
| class BatchProcessorImpl: | |
| """Batch video processor with progress tracking""" | |
| def __init__( | |
| self, | |
| input_folder: Path, | |
| output_folder: Path, | |
| pattern: str = "*.mp4", | |
| cleaner_type: CleanerType = CleanerType.LAMA, | |
| ): | |
| """ | |
| Initialize the batch processor with paths, file-matching pattern, and watermark cleaner selection. | |
| Parameters: | |
| input_folder (Path): Directory containing videos to process. | |
| output_folder (Path): Directory where cleaned videos will be written. | |
| pattern (str): Glob pattern used to find video files in the input folder (default: "*.mp4"). | |
| cleaner_type (CleanerType): Cleaner model to use for watermark removal (e.g., CleanerType.LAMA or CleanerType.E2FGVI_HQ). | |
| """ | |
| self.input_folder = input_folder | |
| self.output_folder = output_folder | |
| self.pattern = pattern | |
| self.sora_wm = SoraWM(cleaner_type=cleaner_type) | |
| self.console = console | |
| # Statistics | |
| self.successful: List[str] = [] | |
| self.failed: Dict[str, str] = {} | |
| def show_banner(self): | |
| """Display a colorful welcome banner""" | |
| banner_text = Text() | |
| banner_text.append("🎬 ", style="bold yellow") | |
| banner_text.append("Sora Watermark Remover", style="bold cyan") | |
| banner_text.append(" - Batch Processor", style="bold magenta") | |
| panel = Panel( | |
| banner_text, | |
| box=box.DOUBLE, | |
| border_style="bright_blue", | |
| padding=(1, 2), | |
| ) | |
| console.print(panel) | |
| console.print() | |
| def find_videos(self) -> List[Path]: | |
| """Find all video files matching the pattern""" | |
| video_files = list(self.input_folder.glob(self.pattern)) | |
| return sorted(video_files) | |
| def process_batch(self): | |
| """Process all videos in the batch with progress tracking""" | |
| # Show banner | |
| self.show_banner() | |
| # Find all videos | |
| video_files = self.find_videos() | |
| if not video_files: | |
| console.print( | |
| f"[bold red]❌ No files matching '{self.pattern}' found in {self.input_folder}[/bold red]" | |
| ) | |
| return | |
| # Display configuration | |
| config_table = Table(show_header=False, box=box.SIMPLE, padding=(0, 1)) | |
| config_table.add_row( | |
| "📁 Input folder:", f"[cyan]{self.input_folder}[/cyan]" | |
| ) | |
| config_table.add_row( | |
| "📁 Output folder:", f"[green]{self.output_folder}[/green]" | |
| ) | |
| config_table.add_row("🔍 Pattern:", f"[yellow]{self.pattern}[/yellow]") | |
| config_table.add_row( | |
| "🎬 Videos found:", f"[bold magenta]{len(video_files)}[/bold magenta]" | |
| ) | |
| console.print(config_table) | |
| console.print() | |
| # Create output folder | |
| self.output_folder.mkdir(parents=True, exist_ok=True) | |
| # Process each video with batch-level progress bar | |
| start_time = datetime.now() | |
| # Create rich progress display | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(bar_width=40), | |
| TaskProgressColumn(), | |
| MofNCompleteColumn(), | |
| SpeedColumn(), | |
| TimeElapsedColumn(), | |
| TimeRemainingColumn(), | |
| console=console, | |
| ) as progress: | |
| # Batch progress task | |
| batch_task = progress.add_task( | |
| "[cyan]Overall Progress", total=len(video_files) | |
| ) | |
| for idx, input_path in enumerate(video_files, 1): | |
| output_path = self.output_folder / f"cleaned_{input_path.name}" | |
| # Update batch task description | |
| progress.update( | |
| batch_task, | |
| description=f"[cyan]Overall Progress ({idx}/{len(video_files)})", | |
| ) | |
| # Show current file being processed | |
| console.print( | |
| f"\n[bold blue]📹 [{idx}/{len(video_files)}][/bold blue] " | |
| f"[yellow]{input_path.name}[/yellow]" | |
| ) | |
| try: | |
| # Video processing task | |
| video_task = progress.add_task( | |
| f" [green]Processing video", total=100 | |
| ) | |
| last_progress = [0] | |
| def progress_callback(prog: int): | |
| """Update the video progress bar""" | |
| if prog > last_progress[0]: | |
| progress.update( | |
| video_task, advance=prog - last_progress[0] | |
| ) | |
| last_progress[0] = prog | |
| # Process the video (quiet=True suppresses internal tqdm bars if enabled) | |
| self.sora_wm.run( | |
| input_path, output_path, progress_callback, quiet=args.quiet | |
| ) | |
| # Ensure video progress reaches 100% | |
| if last_progress[0] < 100: | |
| progress.update(video_task, advance=100 - last_progress[0]) | |
| progress.remove_task(video_task) | |
| self.successful.append(input_path.name) | |
| console.print( | |
| f" [bold green]✅ Completed:[/bold green] {output_path.name}" | |
| ) | |
| except Exception as e: | |
| progress.remove_task(video_task) | |
| self.failed[input_path.name] = str(e) | |
| console.print(f" [bold red]❌ Error:[/bold red] {e}") | |
| # Update batch progress | |
| progress.update(batch_task, advance=1) | |
| # Print summary | |
| self._print_summary(start_time) | |
| def _print_summary(self, start_time: datetime): | |
| """Print processing summary with rich formatting""" | |
| end_time = datetime.now() | |
| duration = end_time - start_time | |
| console.print() | |
| # Create summary statistics table | |
| summary_table = Table( | |
| show_header=False, box=box.ROUNDED, border_style="cyan" | |
| ) | |
| summary_table.add_column("Metric", style="bold") | |
| summary_table.add_column("Value") | |
| summary_table.add_row("⏱️ Total Time", f"[yellow]{duration}[/yellow]") | |
| summary_table.add_row( | |
| "✅ Successful", f"[bold green]{len(self.successful)}[/bold green]" | |
| ) | |
| summary_table.add_row( | |
| "❌ Failed", f"[bold red]{len(self.failed)}[/bold red]" | |
| ) | |
| summary_table.add_row( | |
| "📊 Total", | |
| f"[bold magenta]{len(self.successful) + len(self.failed)}[/bold magenta]", | |
| ) | |
| # Success rate | |
| total = len(self.successful) + len(self.failed) | |
| success_rate = (len(self.successful) / total * 100) if total > 0 else 0 | |
| summary_table.add_row( | |
| "📈 Success Rate", f"[bold cyan]{success_rate:.1f}%[/bold cyan]" | |
| ) | |
| # Wrap in a panel | |
| summary_panel = Panel( | |
| summary_table, | |
| title="[bold white]📋 BATCH PROCESSING SUMMARY[/bold white]", | |
| border_style="bright_cyan", | |
| box=box.DOUBLE, | |
| ) | |
| console.print(summary_panel) | |
| # Successful files | |
| if self.successful: | |
| console.print() | |
| success_table = Table( | |
| title="[bold green]✅ Successfully Processed[/bold green]", | |
| box=box.SIMPLE, | |
| show_header=True, | |
| header_style="bold green", | |
| ) | |
| success_table.add_column("#", style="dim", width=4) | |
| success_table.add_column("Filename", style="green") | |
| for idx, filename in enumerate(self.successful, 1): | |
| success_table.add_row(str(idx), filename) | |
| console.print(success_table) | |
| # Failed files | |
| if self.failed: | |
| console.print() | |
| failed_table = Table( | |
| title="[bold red]❌ Failed to Process[/bold red]", | |
| box=box.SIMPLE, | |
| show_header=True, | |
| header_style="bold red", | |
| ) | |
| failed_table.add_column("#", style="dim", width=4) | |
| failed_table.add_column("Filename", style="red") | |
| failed_table.add_column("Error", style="dim") | |
| for idx, (filename, error) in enumerate(self.failed.items(), 1): | |
| # Truncate long error messages | |
| error_msg = error if len(error) < 60 else error[:57] + "..." | |
| failed_table.add_row(str(idx), filename, error_msg) | |
| console.print(failed_table) | |
| # Final message | |
| console.print() | |
| if len(self.failed) == 0: | |
| console.print( | |
| "[bold green]🎉 All videos processed successfully![/bold green]", | |
| justify="center", | |
| ) | |
| else: | |
| console.print( | |
| "[bold yellow]⚠️ Some videos failed to process. Check errors above.[/bold yellow]", | |
| justify="center", | |
| ) | |
| console.print() | |
| # Create processor and run | |
| try: | |
| cleaner_type = ( | |
| CleanerType.LAMA if args.model == "lama" else CleanerType.E2FGVI_HQ | |
| ) | |
| processor = BatchProcessorImpl( | |
| input_folder, output_folder, pattern, cleaner_type | |
| ) | |
| processor.process_batch() | |
| except KeyboardInterrupt: | |
| console.print() | |
| console.print( | |
| "[bold yellow]⚠️ Processing interrupted by user[/bold yellow]", | |
| justify="center", | |
| ) | |
| sys.exit(130) | |
| except Exception as e: | |
| console.print() | |
| console.print(f"[bold red]❌ Fatal error:[/bold red] {e}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |
| 1 | |