Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Gradio front-end wrapper for SeedVR2's official inference_cli.py | |
| This script is the user's app.py enhanced to stream subprocess logs in real-time | |
| into the Gradio logs textbox. It runs the CLI as subprocesses and streams | |
| stdout/stderr lines as they arrive using a queue and reader threads. The Gradio | |
| handler `ui_upscale` is implemented as a generator so the frontend receives | |
| incremental updates. | |
| This script provides a simple web UI for single-image upscaling using the | |
| official ComfyUI-SeedVR2_VideoUpscaler `inference_cli.py` script. It calls the | |
| official CLI as a subprocess, and will automatically download model weights | |
| from Hugging Face (numz/SeedVR2_comfyUI) if they are missing. If the | |
| ComfyUI-SeedVR2_VideoUpscaler repository is not present, the script will | |
| attempt to `git clone` it automatically into ./ComfyUI-SeedVR2_VideoUpscaler. | |
| Run: | |
| python app.py | |
| Requirements | |
| - Python 3.10+ | |
| - Gradio (pip install gradio) | |
| - Git available in PATH (for automatic cloning) or clone the repo manually | |
| - PyTorch + CUDA (if using GPU) | |
| Notes | |
| - This wrapper calls the repo's `inference_cli.py` as a subprocess so the CLI's | |
| memory/optimization features (BlockSwap, VAE tiling, etc.) remain available. | |
| - Models will be downloaded to the cloned repo's ./models/SeedVR2 directory if | |
| missing. Use HUGGINGFACE_HUB_TOKEN env var if required for private access. | |
| """ | |
| import os | |
| import sys | |
| import cv2 | |
| import time | |
| import torch | |
| import queue | |
| import shutil | |
| import zipfile | |
| import threading | |
| import subprocess | |
| import numpy as np | |
| import gradio as gr | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Generator, List | |
| # huggingface helper (used for model auto-download) | |
| from huggingface_hub import hf_hub_download | |
| def imreadUTF8(path, flags=cv2.IMREAD_COLOR): | |
| """ | |
| OpenCV's cv2.imread cannot handle non-ASCII paths. | |
| This function reads an image from a path that may contain UTF-8 characters. | |
| """ | |
| try: | |
| # Use NumPy to read from the file, which correctly handles UTF-8 paths | |
| with open(path, "rb") as stream: | |
| bytes_data = bytearray(stream.read()) | |
| numpyarray = np.asarray(bytes_data, dtype=np.uint8) | |
| # Use cv2.imdecode to decode the image from the memory buffer | |
| img = cv2.imdecode(numpyarray, flags) | |
| return img | |
| except Exception as e: | |
| # If reading fails, print the error message and return None | |
| print(f"ERROR: Failed to read image with UTF-8 path: {path}") | |
| print(f" Details: {e}") | |
| return None | |
| def imwriteUTF8(save_path, image): | |
| """ | |
| OpenCV's cv2.imwrite cannot handle non-ASCII paths. | |
| This function writes an image to a path that may contain UTF-8 characters. | |
| """ | |
| try: | |
| img_name = os.path.basename(save_path) | |
| _, extension = os.path.splitext(img_name) | |
| # Encode the image into the specified format (determined by the file extension) | |
| is_success, im_buf_arr = cv2.imencode(extension, image) | |
| if is_success: | |
| # Write the image data from memory to the file | |
| im_buf_arr.tofile(save_path) | |
| return True | |
| else: | |
| print(f"ERROR: Failed to encode image for path: {save_path}") | |
| return False | |
| except Exception as e: | |
| print(f"ERROR: Failed to write image with UTF-8 path: {save_path}") | |
| print(f" Details: {e}") | |
| return False | |
| # Apply Monkey Patch to cv2 (for app.py usage) | |
| print("[SeedVR2 Gradio] Applying UTF-8 patch to OpenCV (Frontend)...") | |
| cv2.imread = imreadUTF8 | |
| cv2.imwrite = imwriteUTF8 | |
| # ---------------- | |
| # Config / paths | |
| # ---------------- | |
| REPO_URL = "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler.git" | |
| CLONE_DIR = Path(__file__).resolve().parent / "ComfyUI-SeedVR2_VideoUpscaler" | |
| INFERENCE_CLI = CLONE_DIR / "inference_cli.py" | |
| PY_EXE = sys.executable # Use same Python executable to run CLI | |
| # Path to the custom improved blockswap file | |
| IMPROVED_BLOCKSWAP_SOURCE = Path(__file__).resolve().parent / "src" / "optimization" / "blockswap.py" | |
| IMPROVED_MEMORY_MANAGER_SOURCE = Path(__file__).resolve().parent / "src" / "optimization" / "memory_manager.py" | |
| # Default HF repo for VAE (VAE is usually static and comes from the official repo) | |
| DEFAULT_VAE_REPO_ID = "numz/SeedVR2_comfyUI" | |
| # Models are now stored in a fixed top-level directory, independent of the clone dir | |
| DEFAULT_MODEL_DIR = Path(__file__).resolve().parent / "models" / "SeedVR2" | |
| # ---------------- | |
| # Model Definitions (RepoID / Filename) | |
| # ---------------- | |
| # Standard Models (Safetensors) | |
| MODEL_CHOICES = [ | |
| "numz/SeedVR2_comfyUI/seedvr2_ema_3b_fp8_e4m3fn.safetensors", | |
| "numz/SeedVR2_comfyUI/seedvr2_ema_3b_fp16.safetensors", | |
| "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_fp8_e4m3fn_mixed_block35_fp16.safetensors", | |
| "numz/SeedVR2_comfyUI/seedvr2_ema_7b_fp16.safetensors", | |
| # sharp variants | |
| "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_sharp_fp8_e4m3fn_mixed_block35_fp16.safetensors", | |
| "numz/SeedVR2_comfyUI/seedvr2_ema_7b_sharp_fp16.safetensors", | |
| ] | |
| # GGUF / alternate model support | |
| GGUF_CHOICES = [ | |
| "AInVFX/SeedVR2_comfyUI/seedvr2_ema_3b-Q4_K_M.gguf", | |
| "AInVFX/SeedVR2_comfyUI/seedvr2_ema_3b-Q8_0.gguf", | |
| "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b-Q4_K_M.gguf", | |
| # sharp variants | |
| "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_sharp-Q4_K_M.gguf", | |
| # custom GGUF from cmeka | |
| "cmeka/SeedVR2-GGUF/seedvr2_ema_7b-Q8_0.gguf", | |
| "cmeka/SeedVR2-GGUF/seedvr2_ema_7b_sharp-Q8_0.gguf", | |
| ] | |
| # # Model registry with metadata | |
| # MODEL_REGISTRY = { | |
| # # 3B models | |
| # "seedvr2_ema_3b-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="3B", precision="Q4_K_M", sha256="e665e3909de1a8c88a69c609bca9d43ff5a134647face2ce4497640cc3597f0e"), | |
| # "seedvr2_ema_3b-Q8_0.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="3B", precision="Q8_0", sha256="be0d60083a2051a265eb4b77f28edf494e6db67ffc250216f32b72292e5cbd96"), | |
| # "seedvr2_ema_3b_fp8_e4m3fn.safetensors": ModelInfo(size="3B", precision="fp8_e4m3fn", sha256="3bf1e43ebedd570e7e7a0b1b60d6a02e105978f505c8128a241cde99a8240cff"), | |
| # "seedvr2_ema_3b_fp16.safetensors": ModelInfo(size="3B", precision="fp16", sha256="2fd0e03a3dad24e07086750360727ca437de4ecd456f769856e960ae93e2b304"), | |
| # # 7B models | |
| # "seedvr2_ema_7b-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="Q4_K_M", sha256="db9cb2ad90ebd40d2e8c29da2b3fc6fd03ba87cd58cbadceccca13ad27162789"), | |
| # "seedvr2_ema_7b_fp8_e4m3fn_mixed_block35_fp16.safetensors": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="fp8_e4m3fn_mixed_block35_fp16", sha256="3d68b5ec0b295ae28092e355c8cad870edd00b817b26587d0cb8f9dd2df19bb2"), | |
| # "seedvr2_ema_7b_fp16.safetensors": ModelInfo(size="7B", precision="fp16", sha256="7b8241aa957606ab6cfb66edabc96d43234f9819c5392b44d2492d9f0b0bbe4a"), | |
| # # 7B sharp variants | |
| # "seedvr2_ema_7b_sharp-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="Q4_K_M", variant="sharp", sha256="7aed800ac4eb8e0d18569a954c0ff35f5a1caa3ed5d920e66cc31405f75b6e69"), | |
| # "seedvr2_ema_7b_sharp_fp8_e4m3fn_mixed_block35_fp16.safetensors": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="fp8_e4m3fn_mixed_block35_fp16", variant="sharp", sha256="0d2c5b8be0fda94351149c5115da26aef4f4932a7a2a928c6f184dda9186e0be"), | |
| # "seedvr2_ema_7b_sharp_fp16.safetensors": ModelInfo(size="7B", precision="fp16", variant="sharp", sha256="20a93e01ff24beaeebc5de4e4e5be924359606c356c9c51509fba245bd2d77dd"), | |
| # # VAE models | |
| # "ema_vae_fp16.safetensors": ModelInfo(category="vae", precision="fp16", sha256="20678548f420d98d26f11442d3528f8b8c94e57ee046ef93dbb7633da8612ca1"), | |
| # } | |
| # Detect Hardware Availability | |
| CUDA_AVAILABLE = torch.cuda.is_available() | |
| # Detect MPS availability (for Apple Silicon) | |
| MPS_AVAILABLE = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built() | |
| # Check for any hardware acceleration | |
| ACCELERATOR_AVAILABLE = CUDA_AVAILABLE or MPS_AVAILABLE | |
| # ----------------- | |
| # Repo / model helpers | |
| # ----------------- | |
| def ensure_repo_cloned( | |
| repo_url: str = REPO_URL, | |
| clone_dir_name: str = "ComfyUI-SeedVR2_VideoUpscaler", | |
| repo_branch: str = "", | |
| force_update: bool = False | |
| ) -> Path: | |
| """ | |
| Ensure the repository is cloned locally into the specified directory name. | |
| Supports specific branches/tags via checkout. | |
| Returns the resolved Path object to the cloned directory. | |
| """ | |
| # Resolve the physical path based on the script's parent location | |
| target_clone_dir = Path(__file__).resolve().parent / clone_dir_name | |
| target_cli = target_clone_dir / "inference_cli.py" | |
| # Helper function to handle detached/orphaned commits | |
| def _smart_checkout(cwd, ref): | |
| print(f"[SeedVR2 Gradio] Checking out '{ref}' in {cwd} ...") | |
| try: | |
| # Try standard checkout first (fastest if ref exists locally) | |
| subprocess.run(["git", "-C", str(cwd), "checkout", ref], check=True) | |
| except subprocess.CalledProcessError: | |
| # Fallback: If ref is not found (e.g. orphaned commit hash), fetch it explicitly | |
| print(f"[SeedVR2 Gradio] Standard checkout failed. Attempting to fetch specific ref '{ref}' from origin...") | |
| try: | |
| subprocess.run(["git", "-C", str(cwd), "fetch", "origin", ref], check=True) | |
| subprocess.run(["git", "-C", str(cwd), "checkout", ref], check=True) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to fetch/checkout specific ref '{ref}': {e}") | |
| if target_clone_dir.exists() and (target_clone_dir / ".git").exists(): | |
| # Repo exists | |
| if force_update: | |
| try: | |
| print(f"[SeedVR2 Gradio] Updating {target_clone_dir} ...") | |
| subprocess.run(["git", "-C", str(target_clone_dir), "fetch", "--all"], check=True) | |
| # If a specific branch/hash is requested | |
| if repo_branch: | |
| _smart_checkout(target_clone_dir, repo_branch) | |
| # If it's a branch name (not a detached hash), we might want to pull latest | |
| # But checking if it's a branch vs hash is complex, generally strictly checking out the ref is safer for reproducibility | |
| else: | |
| subprocess.run(["git", "-C", str(target_clone_dir), "pull"], check=True) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to update repository {target_clone_dir}: {e}") | |
| # If not forcing update, but a branch is specified, ensure we are on it | |
| elif repo_branch: | |
| try: | |
| subprocess.run(["git", "-C", str(target_clone_dir), "fetch", "--all"], check=True) | |
| _smart_checkout(target_clone_dir, repo_branch) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to switch to branch {repo_branch}: {e}") | |
| # Ensure inference_cli present | |
| if not target_cli.exists(): | |
| raise RuntimeError(f"Repository found at {target_clone_dir} but inference_cli.py is missing.") | |
| return target_clone_dir | |
| # Clone repo if not exists | |
| try: | |
| print(f"[SeedVR2 Gradio] Cloning {repo_url} into {target_clone_dir} ...") | |
| # Standard clone (fetches default branch) | |
| subprocess.run(["git", "clone", repo_url, str(target_clone_dir)], check=True) | |
| if repo_branch: | |
| _smart_checkout(target_clone_dir, repo_branch) | |
| except FileNotFoundError: | |
| raise RuntimeError("git not found: please install Git or clone the repository manually.") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to clone repository: {e}") | |
| if not target_cli.exists(): | |
| raise RuntimeError(f"Clone completed but inference_cli.py not found in {target_clone_dir}.") | |
| return target_clone_dir | |
| def apply_inference_cli_patch(cli_path: Path): | |
| """ | |
| Injects UTF-8 compatible imread/imwrite wrappers directly into inference_cli.py. | |
| This modifies the physical file so the subprocess (even on Windows spawn) uses the patch. | |
| """ | |
| if not cli_path.exists(): | |
| return | |
| try: | |
| with open(cli_path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| # Check if already patched to avoid duplicates | |
| if "def imreadUTF8" in content: | |
| return | |
| # The patch content to inject. | |
| # Note: We ensure 'import numpy as np' and 'import os' are available or re-imported. | |
| # inference_cli.py typically has 'import cv2', we inject right after that. | |
| patch_code = r''' | |
| # ============================================================================= | |
| # GRADIO APP PATCH: UTF-8 Support for Windows (Auto-Injected) | |
| # ============================================================================= | |
| import numpy as np | |
| import os | |
| def imreadUTF8(path, flags=cv2.IMREAD_COLOR): | |
| try: | |
| with open(path, "rb") as stream: | |
| bytes_data = bytearray(stream.read()) | |
| numpyarray = np.asarray(bytes_data, dtype=np.uint8) | |
| return cv2.imdecode(numpyarray, flags) | |
| except Exception as e: | |
| print(f"Error reading image {path}: {e}") | |
| return None | |
| def imwriteUTF8(save_path, image): | |
| try: | |
| img_name = os.path.basename(save_path) | |
| _, extension = os.path.splitext(img_name) | |
| is_success, im_buf_arr = cv2.imencode(extension, image) | |
| if is_success: | |
| im_buf_arr.tofile(save_path) | |
| return True | |
| else: | |
| return False | |
| except Exception as e: | |
| print(f"Error writing image {save_path}: {e}") | |
| return False | |
| # Override cv2 methods | |
| cv2.imread = imreadUTF8 | |
| cv2.imwrite = imwriteUTF8 | |
| # ============================================================================= | |
| ''' | |
| # Inject after 'import cv2' | |
| if "import cv2" in content: | |
| print(f"[SeedVR2 Gradio] Patching {cli_path} for UTF-8 subprocess support...") | |
| new_content = content.replace("import cv2", "import cv2" + patch_code, 1) | |
| with open(cli_path, "w", encoding="utf-8") as f: | |
| f.write(new_content) | |
| else: | |
| print("[SeedVR2 Gradio] WARNING: Could not find 'import cv2' in inference_cli.py. UTF-8 patch skipped.") | |
| except Exception as e: | |
| print(f"[SeedVR2 Gradio] ERROR applying UTF-8 patch to inference_cli: {e}") | |
| def patch_model_registry(repo_root: Path): | |
| """ | |
| Appends custom model definitions to src/utils/model_registry.py. | |
| This allows the CLI to recognize new GGUF models that aren't in the official registry. | |
| """ | |
| registry_path = repo_root / "src" / "utils" / "model_registry.py" | |
| if not registry_path.exists(): | |
| print(f"[SeedVR2 Gradio] WARN: Could not find model_registry.py at {registry_path}") | |
| return | |
| try: | |
| with open(registry_path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| # Check if already patched | |
| if "seedvr2_ema_7b-Q8_0.gguf" in content: | |
| return | |
| print(f"[SeedVR2 Gradio] Patching {registry_path} with custom GGUF models...") | |
| # Code to append to the end of the file. | |
| # Since ModelInfo and MODEL_REGISTRY are defined in the file, we can use them directly. | |
| patch_code = r''' | |
| # ============================================================================= | |
| # GRADIO APP PATCH: Custom Model Registry Entries | |
| # ============================================================================= | |
| try: | |
| # Update registry with custom GGUF models requested by user | |
| MODEL_REGISTRY.update({ | |
| "seedvr2_ema_7b-Q8_0.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q8_0", sha256="669788655e8f15f306284f267a444e9766c8a421869577b16a961e43029c737b"), | |
| "seedvr2_ema_7b_sharp-Q8_0.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q8_0", variant="sharp", sha256="b1f81cb5700b0b1f432f2c785528356c952c41c74d03d205c6f14b0bd6da303d"), | |
| }) | |
| print("[Internal] Custom GGUF models injected into MODEL_REGISTRY successfully.") | |
| except Exception as e: | |
| print(f"[Internal] Failed to inject custom models: {e}") | |
| # ============================================================================= | |
| ''' | |
| with open(registry_path, "a", encoding="utf-8") as f: | |
| f.write(patch_code) | |
| except Exception as e: | |
| print(f"[SeedVR2 Gradio] ERROR patching model_registry.py: {e}") | |
| # ----------------- | |
| # BlockSwap Management | |
| # ----------------- | |
| def manage_blockswap_file(use_improved: bool, repo_root: Path) -> str: | |
| """ | |
| Manages the blockswap.py file in the specified cloned repository. | |
| Accepted `repo_root` path to ensure we modify the correct repo. | |
| """ | |
| target_path_blockswap = repo_root / "src" / "optimization" / "blockswap.py" | |
| backup_path_blockswap = repo_root / "src" / "optimization" / "blockswap.py.bak" | |
| target_path_memory_manager = repo_root / "src" / "optimization" / "memory_manager.py" | |
| backup_path_memory_manager = repo_root / "src" / "optimization" / "memory_manager.py.bak" | |
| msg = "" | |
| # Ensure src/optimization exists (some forks might differ in structure) | |
| if not target_path_blockswap.parent.exists(): | |
| return f"[WARN] Optimization folder not found at {target_path_blockswap.parent}. Skipping blockswap patch.\n" | |
| if use_improved: | |
| if not IMPROVED_BLOCKSWAP_SOURCE.exists(): | |
| return f"[WARN] Improved blockswap source not found at {IMPROVED_BLOCKSWAP_SOURCE}. Keeping current version.\n" | |
| # 1. Check if we need to backup the original blockswap (only if backup doesn't exist yet) | |
| if target_path_blockswap.exists() and not backup_path_blockswap.exists(): | |
| try: | |
| shutil.move(str(target_path_blockswap), str(backup_path_blockswap)) | |
| msg += f"[INFO] Backed up original blockswap to {backup_path_blockswap.name}.\n" | |
| except Exception as e: | |
| return f"[ERROR] Failed to backup blockswap: {e}\n" | |
| # 2. Copy the improved file to target blockswap | |
| try: | |
| shutil.copy(str(IMPROVED_BLOCKSWAP_SOURCE), str(target_path_blockswap)) | |
| msg += "[INFO] Switched to Improved BlockSwap (Nunchaku implementation).\n" | |
| except Exception as e: | |
| return f"[ERROR] Failed to install improved blockswap: {e}\n" | |
| # Memory Manager Handling | |
| if not IMPROVED_MEMORY_MANAGER_SOURCE.exists(): | |
| return f"[WARN] Improved memory_manager source not found at {IMPROVED_MEMORY_MANAGER_SOURCE}. Keeping current version.\n" | |
| # 3. Check if we need to backup the original memory_manager (only if backup doesn't exist yet) | |
| if target_path_memory_manager.exists() and not backup_path_memory_manager.exists(): | |
| try: | |
| shutil.move(str(target_path_memory_manager), str(backup_path_memory_manager)) | |
| msg += f"[INFO] Backed up original memory_manager to {backup_path_memory_manager.name}.\n" | |
| except Exception as e: | |
| return f"[ERROR] Failed to backup memory_manager: {e}\n" | |
| # 4. Copy the improved file to target memory_manager | |
| try: | |
| shutil.copy(str(IMPROVED_MEMORY_MANAGER_SOURCE), str(target_path_memory_manager)) | |
| msg += "[INFO] Switched to Improved memory_manager (Nunchaku implementation).\n" | |
| except Exception as e: | |
| return f"[ERROR] Failed to install improved memory_manager: {e}\n" | |
| return msg | |
| else: | |
| # Restore original blockswap if available | |
| if backup_path_blockswap.exists(): | |
| try: | |
| # Remove current target blockswap if it exists (which might be the improved one) | |
| if target_path_blockswap.exists(): | |
| os.remove(target_path_blockswap) | |
| # Restore backup | |
| shutil.move(str(backup_path_blockswap), str(target_path_blockswap)) | |
| msg += "[INFO] Restored Original BlockSwap from backup.\n" | |
| except Exception as e: | |
| return f"[ERROR] Failed to restore original blockswap: {e}\n" | |
| else: | |
| # Backup doesn't exist, assume we are already on original or clean install | |
| msg += "[INFO] Using Original BlockSwap (No backup found/needed).\n" | |
| # Restore original memory_manager if available | |
| if backup_path_memory_manager.exists(): | |
| try: | |
| # Remove current target memory_manager if it exists (which might be the improved one) | |
| if target_path_memory_manager.exists(): | |
| os.remove(target_path_memory_manager) | |
| # Restore backup | |
| shutil.move(str(backup_path_memory_manager), str(target_path_memory_manager)) | |
| msg += "[INFO] Restored Original memory_manager from backup.\n" | |
| except Exception as e: | |
| return f"[ERROR] Failed to restore original memory_manager: {e}\n" | |
| else: | |
| # Backup doesn't exist, assume we are already on original or clean install | |
| msg += "[INFO] Using Original memory_manager (No backup found/needed).\n" | |
| return msg | |
| # ----------------- | |
| # Model download | |
| # ----------------- | |
| def ensure_models_available( | |
| selected_model_filename: str, | |
| model_dir: Optional[Path] = None, | |
| repo_id: str = DEFAULT_VAE_REPO_ID | |
| ) -> None: | |
| """ | |
| Ensure the selected DiT model and the VAE file exist locally. | |
| If missing, download from the specified Hugging Face repo directly into model_dir | |
| using 'local_dir' to avoid nested cache structures. | |
| """ | |
| if model_dir is None: | |
| model_dir = DEFAULT_MODEL_DIR | |
| else: | |
| model_dir = Path(model_dir) | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| # Items to check: VAE + selected DiT model | |
| required = ["ema_vae_fp16.safetensors", selected_model_filename] | |
| # Check if files physically exist at the target location | |
| missing = [_f for _f in required if not (model_dir / _f).exists()] | |
| if not missing: | |
| return | |
| # Optional: silence HF symlink warning if desired | |
| os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") | |
| # Attempt download for each missing file | |
| hf_token = os.environ.get("HF_ACCESS_TOKEN") | |
| for fname in missing: | |
| target_path = model_dir / fname | |
| # If file already somehow exists at target, skip | |
| if target_path.exists(): | |
| continue | |
| # Decide repo for this filename: | |
| # - VAE must always come from the official DEFAULT_VAE_REPO_ID (numz/SeedVR2_comfyUI) | |
| # - Dit model uses the provided repo_id (which comes from the dropdown selection) | |
| repo_for_fname = DEFAULT_VAE_REPO_ID if fname == "ema_vae_fp16.safetensors" else repo_id | |
| try: | |
| print(f"[SeedVR2 Gradio] Downloading {fname} from {repo_for_fname} directly to {model_dir} ...") | |
| # Use local_dir instead of cache_dir. | |
| # This forces the file to be saved exactly at {model_dir}/{fname} | |
| # local_dir_use_symlinks=False ensures we get a real file, not a symlink, | |
| # which prevents issues where the CLI subprocess cannot resolve the path. | |
| downloaded_path = hf_hub_download( | |
| repo_id=repo_for_fname, | |
| filename=fname, | |
| local_dir=str(model_dir), # Download directly to target folder | |
| repo_type="model", | |
| token=hf_token, | |
| ) | |
| print(f"[SeedVR2 Gradio] Download completed: {downloaded_path}") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to download {fname} from Hugging Face repo {repo_for_fname}: {e}") | |
| # ---------------- | |
| # Subprocess streaming helpers | |
| # ---------------- | |
| def _start_process_stream(cmd_args, cwd: str, env: dict) -> Tuple[Optional[subprocess.Popen], queue.Queue, Optional[threading.Thread], Optional[threading.Thread]]: | |
| """Start subprocess and return (proc, q, t_out, t_err). | |
| The returned queue will receive text lines as they arrive. Lines are simple | |
| strings (already newline-terminated). stderr lines are prefixed with "stderr: ". | |
| """ | |
| q = queue.Queue() | |
| try: | |
| proc = subprocess.Popen( | |
| cmd_args, | |
| cwd=cwd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| encoding='utf-8', | |
| errors='replace', | |
| bufsize=1, | |
| env=env | |
| ) | |
| except Exception as e: | |
| # Put error to queue and return a dummy proc | |
| q.put(f"[FAILED TO LAUNCH] {e}\n") | |
| return None, q, None, None | |
| def _reader(fh, prefix: str): | |
| try: | |
| while True: | |
| line = fh.readline() | |
| if not line: | |
| break | |
| if not line.endswith("\n"): | |
| line = line + "\n" | |
| q.put(prefix + line) | |
| except Exception as e: | |
| q.put(f"[reader error] {e}\n") | |
| t_out = threading.Thread(target=_reader, args=(proc.stdout, ""), daemon=True) | |
| t_err = threading.Thread(target=_reader, args=(proc.stderr, "stderr: "), daemon=True) | |
| t_out.start() | |
| t_err.start() | |
| return proc, q, t_out, t_err | |
| # ----------------- | |
| # CLI runner (streaming) | |
| # ----------------- | |
| def expected_upscaled_path(input_path: str, output_format: str = "png") -> str: | |
| """Calculates the expected output path based on the input path and requested format.""" | |
| p = Path(input_path) | |
| stem = p.stem | |
| parent = p.parent | |
| suffix = "_upscaled" | |
| # if output_format == "mp4": | |
| # return str((parent / f"{stem}{suffix}.mp4").resolve()) | |
| # else: | |
| # return str((parent / f"{stem}{suffix}.png").resolve()) | |
| return str((parent / f"{stem}{suffix}.{output_format}").resolve()) | |
| # Single-image/video CLI runner (generator) | |
| def run_cli_upscale_stream( | |
| input_path: str, | |
| resolution: int = 1080, | |
| max_resolution: int = 0, | |
| dit_model_filename: Optional[str] = None, # Receives just the filename | |
| cuda_device: Optional[str] = None, | |
| # Compilation & Performance | |
| compile_dit: bool = False, | |
| compile_vae: bool = False, | |
| compile_backend: str = "inductor", | |
| compile_mode: str = "default", | |
| compile_fullgraph: bool = False, | |
| compile_dynamic: bool = False, | |
| compile_dynamo_cache_size_limit: int = 64, | |
| compile_dynamo_recompile_limit: int = 128, | |
| attention_mode: str = "sdpa", | |
| # Tiling (Split Encode/Decode) | |
| vae_encode_tiled: bool = False, | |
| vae_encode_tile_size: int = 1024, | |
| vae_encode_tile_overlap: int = 128, | |
| vae_decode_tiled: bool = False, | |
| vae_decode_tile_size: int = 1024, | |
| vae_decode_tile_overlap: int = 128, | |
| tile_debug: str = "false", | |
| # Processing | |
| batch_size: int = 1, | |
| uniform_batch_size: bool = False, | |
| seed: int = 42, | |
| skip_first_frames: int = 0, | |
| load_cap: int = 0, | |
| # Quality & Color | |
| color_correction: str = "lab", | |
| input_noise_scale: float = 0.0, | |
| latent_noise_scale: float = 0.0, | |
| # Memory & Offload | |
| blocks_to_swap: int = 0, | |
| swap_io_components: bool = False, | |
| dit_offload_device: str = "none", | |
| vae_offload_device: str = "none", | |
| tensor_offload_device: str = "cpu", | |
| cache_dit: bool = False, | |
| cache_vae: bool = False, | |
| extra_args: str = "", | |
| model_dir: Optional[str] = None, | |
| repo_id: str = DEFAULT_VAE_REPO_ID, # Receives the specific Repo ID for DiT | |
| repo_path: Optional[Path] = None, | |
| timeout: int = 3600, | |
| pre_downscale: bool = False, # for artifact removal | |
| downscale_rate: float = 0.5, | |
| output_format: str = "png", # Can now be "mp4" | |
| use_improved_blockswap: bool = False, # New argument for switching blockswap version | |
| # Video Args | |
| chunk_size: int = 0, | |
| temporal_overlap: int = 0, | |
| prepend_frames: int = 0, | |
| video_backend: str = "opencv", | |
| use_10bit: bool = False, | |
| # Debug Arg | |
| debug: bool = False | |
| ) -> Generator[Tuple[Optional[str], str], None, None]: | |
| """ | |
| Generator yields (out_path_or_None, logs_so_far) while streaming CLI logs. | |
| Includes Phase-Aware Dynamic Fallback logic. | |
| """ | |
| # Defaults | |
| if repo_path is None: | |
| repo_path = CLONE_DIR | |
| current_inference_cli = repo_path / "inference_cli.py" | |
| # 1. Repo Check | |
| if not current_inference_cli.exists(): | |
| yield None, f"[ERROR] inference_cli.py not found in {repo_path}\n" | |
| return | |
| # Patch inference_cli.py with UTF-8 support | |
| try: | |
| apply_inference_cli_patch(current_inference_cli) | |
| except Exception as e: | |
| yield None, f"[WARN] Failed to patch inference_cli: {e}\n" | |
| # Patch model_registry.py with custom models | |
| try: | |
| patch_model_registry(repo_path) | |
| except Exception as e: | |
| yield None, f"[WARN] Failed to patch model_registry: {e}\n" | |
| # Handle BlockSwap File Replacement Logic | |
| try: | |
| swap_log = manage_blockswap_file(use_improved_blockswap, repo_root=repo_path) | |
| # Yield the log about blockswap immediately | |
| yield None, swap_log | |
| except Exception as e: | |
| yield None, f"[ERROR] BlockSwap management failed: {e}\n" | |
| # Use the global default if not provided | |
| if model_dir is None: | |
| model_dir = str(DEFAULT_MODEL_DIR) | |
| # Ensure model files present | |
| if dit_model_filename: | |
| try: | |
| ensure_models_available( | |
| dit_model_filename, | |
| model_dir=Path(model_dir), | |
| repo_id=repo_id, | |
| ) | |
| except Exception as e: | |
| yield None, f"[ERROR] Model download failed: {e}\n" | |
| return | |
| safe_input_path = input_path | |
| temp_copy = None | |
| # Pre-downscale logic (Artifact Removal Trick) - Only applies to Images in this implementation | |
| # We skip this for MP4 files to avoid complex video processing in python before CLI | |
| is_video = input_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')) | |
| # Pre-downscale (Images only) | |
| if pre_downscale and not is_video: | |
| try: | |
| filename = os.path.basename(input_path) | |
| # Load original image using OpenCV | |
| img_obj = cv2.imread(input_path, cv2.IMREAD_UNCHANGED) | |
| if img_obj is None: | |
| raise ValueError(f"Failed to load image: {input_path}") | |
| # Calculate new dimensions (OpenCV shape is [height, width]) | |
| h, w = img_obj.shape[:2] | |
| if (max(w, h) > 250): | |
| new_w = int(w * downscale_rate) | |
| new_h = int(h * downscale_rate) | |
| # Resize | |
| # Use INTER_AREA for downscaling (better quality/less aliasing for shrinking) | |
| # Use INTER_LANCZOS4 if scaling up (though this block is specifically for downscaling) | |
| interpolation_method = cv2.INTER_AREA if downscale_rate < 1.0 else cv2.INTER_LANCZOS4 | |
| img_resized = cv2.resize(img_obj, (new_w, new_h), interpolation=interpolation_method) | |
| # Prepare temp directory | |
| tmp_dir = CLONE_DIR / "tmp_inputs" | |
| tmp_dir.mkdir(parents=True, exist_ok=True) | |
| # Save to a unique temp file (forces .png for intermediate input) | |
| new_name = f"{filename}_downscaled.png" | |
| temp_copy = str(tmp_dir / new_name) | |
| # Use patched cv2.imwrite | |
| cv2.imwrite(temp_copy, img_resized) | |
| # Use this temp file as the input for CLI | |
| safe_input_path = temp_copy | |
| yield None, f"[INFO] Pre-downscaled input by factor {downscale_rate} (Size: {w}x{h} -> {new_w}x{new_h}) to reduce artifacts.\n" | |
| except Exception as e: | |
| yield None, f"[ERROR] Failed to pre-downscale image: {e}\n" | |
| return | |
| # 2. Command Builder | |
| def _build_cmd(curr_tile_size, curr_batch_size): | |
| # Determine strict output format | |
| if is_video: | |
| # If input is video, force mp4 output for CLI unless user explicitly wants png sequence? | |
| # Usually users want mp4 back. | |
| cmd_format = "mp4" | |
| else: | |
| # For images, use png (CLI handles webp/jpg conversion internally if modified, | |
| # but standard CLI outputs png/mp4). We force png here, app.py handles conversion later. | |
| cmd_format = "png" | |
| cmd = [PY_EXE, str(current_inference_cli), safe_input_path, | |
| "--resolution", str(resolution), | |
| "--output_format", cmd_format, | |
| "--batch_size", str(curr_batch_size), | |
| "--color_correction", color_correction, | |
| "--model_dir", str(model_dir), | |
| "--seed", str(seed), | |
| "--attention_mode", str(attention_mode)] | |
| if max_resolution and int(max_resolution) > 0: | |
| cmd += ["--max_resolution", str(int(max_resolution))] | |
| if dit_model_filename: | |
| # CLI just needs the filename relative to --model_dir (or absolute path) | |
| cmd += ["--dit_model", str(dit_model_filename)] | |
| # Only add --cuda_device if CUDA available and user provided a value | |
| if CUDA_AVAILABLE and cuda_device: | |
| cmd += ["--cuda_device", str(cuda_device)] | |
| # --- Compilation Options --- | |
| if compile_dit: | |
| cmd += ["--compile_dit"] | |
| if compile_vae: | |
| cmd += ["--compile_vae"] | |
| if compile_dit or compile_vae: | |
| cmd += [ | |
| "--compile_backend", str(compile_backend), | |
| "--compile_mode", str(compile_mode), | |
| "--compile_dynamo_cache_size_limit", str(compile_dynamo_cache_size_limit), | |
| "--compile_dynamo_recompile_limit", str(compile_dynamo_recompile_limit) | |
| ] | |
| if compile_fullgraph: | |
| cmd += ["--compile_fullgraph"] | |
| if compile_dynamic: | |
| cmd += ["--compile_dynamic"] | |
| # --- Tiling Options --- | |
| # Note: curr_tile_size comes from the loop strategy (Phase Fallback), | |
| # normally we use the user provided vae_encode_tile_size. | |
| if vae_encode_tiled: | |
| cmd += ["--vae_encode_tiled", | |
| "--vae_encode_tile_size", str(curr_tile_size), | |
| "--vae_encode_tile_overlap", str(vae_encode_tile_overlap)] | |
| if vae_decode_tiled: | |
| cmd += ["--vae_decode_tiled", | |
| "--vae_decode_tile_size", str(vae_decode_tile_size), | |
| "--vae_decode_tile_overlap", str(vae_decode_tile_overlap)] | |
| if tile_debug != "false": | |
| cmd += ["--tile_debug", str(tile_debug)] | |
| # --- Processing & Quality --- | |
| if uniform_batch_size: | |
| cmd += ["--uniform_batch_size"] | |
| if skip_first_frames > 0: | |
| cmd += ["--skip_first_frames", str(int(skip_first_frames))] | |
| if load_cap > 0: | |
| cmd += ["--load_cap", str(int(load_cap))] | |
| if input_noise_scale > 0: | |
| cmd += ["--input_noise_scale", str(input_noise_scale)] | |
| if latent_noise_scale > 0: | |
| cmd += ["--latent_noise_scale", str(latent_noise_scale)] | |
| # --- BlockSwap / Offload / Caching --- | |
| if blocks_to_swap and int(blocks_to_swap) > 0: | |
| cmd += ["--blocks_to_swap", str(int(blocks_to_swap))] | |
| if swap_io_components: | |
| cmd += ["--swap_io_components"] | |
| # Offload flags: note these are strings like "none"/"cpu"/"cuda:0" | |
| if dit_offload_device and dit_offload_device != "none": | |
| # Ensure we don't pass a cuda device offload when cuda isn't available | |
| if not (dit_offload_device.startswith("cuda") and not CUDA_AVAILABLE): | |
| cmd += ["--dit_offload_device", str(dit_offload_device)] | |
| if vae_offload_device and vae_offload_device != "none": | |
| if not (vae_offload_device.startswith("cuda") and not CUDA_AVAILABLE): | |
| cmd += ["--vae_offload_device", str(vae_offload_device)] | |
| if tensor_offload_device and tensor_offload_device != "none": | |
| if not (tensor_offload_device.startswith("cuda") and not CUDA_AVAILABLE): | |
| cmd += ["--tensor_offload_device", str(tensor_offload_device)] | |
| if cache_dit: | |
| cmd += ["--cache_dit"] | |
| if cache_vae: | |
| cmd += ["--cache_vae"] | |
| # --- Video Specific Flags --- | |
| if chunk_size > 0: | |
| cmd += ["--chunk_size", str(int(chunk_size))] | |
| if temporal_overlap > 0: | |
| cmd += ["--temporal_overlap", str(int(temporal_overlap))] | |
| if prepend_frames > 0: | |
| cmd += ["--prepend_frames", str(int(prepend_frames))] | |
| if video_backend and video_backend != "opencv": | |
| cmd += ["--video_backend", str(video_backend)] | |
| if use_10bit: | |
| cmd += ["--10bit"] | |
| # Debug Flag | |
| if debug: | |
| cmd += ["--debug"] | |
| if extra_args: | |
| # Allow advanced users to type additional flags (space separated) | |
| cmd += extra_args.split() | |
| return cmd | |
| # 3. Dynamic Strategy Loop | |
| # Use encode tile size as the dynamic variable for fallback | |
| current_tile_size = int(vae_encode_tile_size) | |
| current_batch_size = int(batch_size) | |
| # Initialize log tracking | |
| logs_buf = "" | |
| # Add previous swap logs to buf | |
| logs_buf += swap_log if 'swap_log' in locals() else "" | |
| max_attempts = 5 # Prevent infinite loops | |
| attempt_count = 0 | |
| idx = 0 | |
| while attempt_count < max_attempts: | |
| attempt_count += 1 | |
| note = f"Tile: {current_tile_size}, Batch: {current_batch_size}" | |
| header = f"\n\n=== ATTEMPT {attempt_count}/{max_attempts} ({note}) ===\n" | |
| logs_buf += header | |
| # yield immediate header | |
| yield None, logs_buf | |
| cmd = _build_cmd(current_tile_size, current_batch_size) | |
| logs_buf += f"[CMD] {' '.join(cmd)}\n" | |
| yield None, logs_buf | |
| # start streaming process | |
| env = os.environ.copy() | |
| # Make Python in child process print using UTF-8 (avoids cp950 UnicodeEncodeError on Windows) | |
| env['PYTHONIOENCODING'] = 'utf-8' | |
| env['PYTHONUTF8'] = '1' | |
| # # help fragmentation/alloc issues; user may tune | |
| # env.setdefault('PYTORCH_ALLOC_CONF', os.environ.get('PYTORCH_ALLOC_CONF', 'max_split_size_mb:128')) | |
| proc, q, t_out, t_err = _start_process_stream(cmd, cwd=str(CLONE_DIR), env=env) | |
| if proc is None: | |
| logs_buf += "[ERROR] Failed to launch subprocess.\n" | |
| yield None, logs_buf | |
| break # try next strategy? here treat as fatal | |
| # State tracking for this run | |
| current_phase = "init" # init, vae_enc, dit, vae_dec, post | |
| oom_detected = False | |
| start_time = time.time() | |
| # poll queue | |
| while True: | |
| try: | |
| # wait up to 0.5s for a line | |
| line = q.get(timeout=0.5) | |
| logs_buf += line | |
| yield None, logs_buf | |
| lower_line = line.lower() | |
| # Track Phase | |
| if "phase 1: vae encoding" in lower_line: | |
| current_phase = "vae_enc" | |
| elif "phase 2: dit upscaling" in lower_line: | |
| current_phase = "dit" | |
| elif "phase 3: vae decoding" in lower_line: | |
| current_phase = "vae_dec" | |
| elif "saving" in lower_line or "converting" in lower_line: | |
| current_phase = "post" | |
| # Check for OOM | |
| oom_indicators = ["outofmemory", "out of memory", "allocation on device", "oom", "cuda out of memory"] | |
| if any(k in lower_line for k in oom_indicators): | |
| logs_buf += f"\n[WARN] OOM detected during phase: {current_phase.upper()}\n" | |
| yield None, logs_buf | |
| oom_detected = True | |
| try: | |
| proc.kill() # Kill immediately to recover VRAM | |
| except: pass | |
| break | |
| except queue.Empty: | |
| # no new line - check process status | |
| if proc.poll() is not None: | |
| break | |
| # still running - continue polling | |
| continue | |
| # Flush remaining | |
| while True: | |
| try: | |
| line = q.get_nowait() | |
| logs_buf += line | |
| yield None, logs_buf | |
| except queue.Empty: | |
| break | |
| # Wait for reader threads to exit | |
| try: | |
| if t_out: | |
| t_out.join(timeout=1) | |
| if t_err: | |
| t_err.join(timeout=1) | |
| except Exception: | |
| pass | |
| runtime = time.time() - start_time | |
| logs_buf += f"[Attempt {idx} finished in {runtime:.2f}s] returncode={proc.returncode}\n" | |
| idx += 1 | |
| yield None, logs_buf | |
| # 4. Success Check using safe_input_path | |
| # CLI generates output relative to the actual input file used (which might be the temp one) | |
| # For video, strict output detection logic | |
| out_fmt_check = "mp4" if is_video else "png" | |
| out_path = expected_upscaled_path(safe_input_path, output_format=out_fmt_check) | |
| if Path(out_path).exists(): | |
| logs_buf += f"[SUCCESS] Intermediate Output: {out_path}\n" | |
| # Cleanup temp file if we created one | |
| if temp_copy: | |
| try: | |
| os.remove(temp_copy) | |
| except Exception: | |
| pass | |
| yield out_path, logs_buf | |
| return | |
| # 5. Failure Analysis & Parameter Adjustment | |
| if oom_detected or proc.returncode != 0: | |
| logs_buf += f"\n[INFO] Attempt {attempt_count} failed. Analyzing OOM Phase: {current_phase.upper()}...\n" | |
| # --- INTELLIGENT ADJUSTMENT LOGIC --- | |
| # Case A: VAE OOM (Phase 1 or 3) -> Reduce Tile Size | |
| if current_phase in ["vae_enc", "vae_dec"]: | |
| if current_tile_size > 256: | |
| new_tile = max(256, current_tile_size // 2) | |
| logs_buf += f"[STRATEGY] VAE OOM detected. Reducing Tile Size: {current_tile_size} -> {new_tile}\n" | |
| current_tile_size = new_tile | |
| else: | |
| # Tile size already min, try reducing batch size as a last resort | |
| new_batch = max(1, current_batch_size // 2) | |
| logs_buf += f"[STRATEGY] VAE OOM but Tile Size is min. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" | |
| current_batch_size = new_batch | |
| # Case B: DiT OOM (Phase 2) -> Reduce Batch Size | |
| elif current_phase == "dit": | |
| if current_batch_size > 1: | |
| # For video consistency, try to keep 4n+1 if possible, or just halve it | |
| new_batch = max(1, current_batch_size // 2) | |
| logs_buf += f"[STRATEGY] DiT OOM detected. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" | |
| current_batch_size = new_batch | |
| else: | |
| logs_buf += f"[FAIL] DiT OOM with Batch Size 1. Cannot reduce further.\n" | |
| break | |
| # Case C: Post-Process OOM (Phase 4) -> Reduce Batch Size | |
| elif current_phase == "post": | |
| if current_batch_size > 1: | |
| new_batch = max(1, current_batch_size // 2) | |
| logs_buf += f"[STRATEGY] Post-Process OOM detected. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" | |
| current_batch_size = new_batch | |
| else: | |
| logs_buf += "[FAIL] Post-Process OOM with Batch Size 1.\n" | |
| break | |
| # Case D: Unknown/Init OOM -> Reduce both safely | |
| else: | |
| current_tile_size = max(256, current_tile_size // 2) | |
| current_batch_size = max(1, current_batch_size // 2) | |
| logs_buf += f"[STRATEGY] Early OOM. Reducing both Tile ({current_tile_size}) and Batch ({current_batch_size}).\n" | |
| # Check if we are just retrying same settings (infinite loop prevention) | |
| if attempt_count >= max_attempts: | |
| logs_buf += "[FAIL] Max attempts reached.\n" | |
| break | |
| yield None, logs_buf | |
| # Loop continues with new settings | |
| else: | |
| # Non-OOM fatal error | |
| logs_buf += f"[ERROR] Non-zero return code (not OOM) - stopping.\n" | |
| yield None, logs_buf | |
| return | |
| # all strategies exhausted | |
| logs_buf += "[FAILED] No output produced after all strategies.\n" | |
| if temp_copy: | |
| try: | |
| os.remove(temp_copy) | |
| except: | |
| pass | |
| yield None, logs_buf | |
| return | |
| # --- Preset change handler (considers CUDA & MPS availability) --- | |
| def preset_changed(preset_value): | |
| # Updated Tuple Order: | |
| # 0: compile_dit, 1: compile_vae, | |
| # 2: vae_encode_tiled, 3: vae_encode_tile_size, | |
| # 4: vae_decode_tiled, 5: vae_decode_tile_size, | |
| # 6: max_resolution, 7: blocks_to_swap, 8: swap_io_components | |
| # 9: dit_offload_device, 10: vae_offload_device, 11: tensor_offload_device, | |
| # 12: extra_args, 13: chunk_size, 14: temporal_overlap | |
| if preset_value == "Recommended (low VRAM)": | |
| return ( | |
| False, # compile_dit | |
| False, # compile_vae | |
| True, # vae_encode_tiled | |
| 512, # vae_encode_tile_size | |
| True, # vae_decode_tiled | |
| 512, # vae_decode_tile_size (sync with encode for safety) | |
| 1920, # max_resolution | |
| 32, # blocks_to_swap | |
| True, # swap_io_components | |
| "cpu", # dit_offload_device | |
| "none", # vae_offload_device (Keep VAE on device if possible) | |
| "cpu", # tensor_offload_device (Offload tensors to save VRAM) | |
| "--blocks_to_swap 0", # extra_args | |
| 0, # chunk_size | |
| 0 # temporal_overlap (0=auto/disabled) | |
| ) | |
| elif preset_value == "Offload (very slow)": | |
| return ( | |
| False, # compile_dit | |
| False, # compile_vae | |
| True, # vae_encode_tiled | |
| 256, # vae_encode_tile_size | |
| True, # vae_decode_tiled | |
| 256, # vae_decode_tile_size | |
| 1440, # max_resolution | |
| 99, # blocks_to_swap | |
| True, # swap_io_components | |
| "cpu", # dit_offload_device | |
| "cpu", # vae_offload_device | |
| "cpu", # tensor_offload_device | |
| "--blocks_to_swap 99 --swap_io_components --dit_offload_device cpu --vae_offload_device cpu --tensor_offload_device cpu", | |
| 0, | |
| 0 | |
| ) | |
| elif preset_value == "High quality (fast if lots of VRAM)": | |
| return ( | |
| True, | |
| True, | |
| False, | |
| 512, | |
| False, | |
| 512, | |
| 0, | |
| 0, | |
| False, | |
| "none", | |
| "none", | |
| "none", # Keep tensors on GPU/MPS | |
| "", | |
| 0, | |
| 0 | |
| ) | |
| # fallback | |
| return (False, False, True, 256, True, 256, 1920, 0, False, "none", "none", "cpu", "--blocks_to_swap 0", 0, 0) | |
| # ---------------- Paste JS (attach to gallery elem) ---------------- | |
| paste_js = """ | |
| function initPaste() { | |
| document.addEventListener('paste', function(e) { | |
| const gallery = document.getElementById('input_gallery'); | |
| if (!gallery) return; | |
| if (!gallery.matches(':hover')) return; | |
| const clipboardData = e.clipboardData || e.originalEvent.clipboardData; | |
| if (!clipboardData) return; | |
| const items = clipboardData.items; | |
| const files = []; | |
| for (let i = 0; i < items.length; i++) { | |
| if (items[i].kind === 'file' && items[i].type.startsWith('image/')) { | |
| files.push(items[i].getAsFile()); | |
| } | |
| } | |
| if (files.length === 0 && clipboardData.files.length > 0) { | |
| for (let i = 0; i < clipboardData.files.length; i++) { | |
| if (clipboardData.files[i].type.startsWith('image/')) { | |
| files.push(clipboardData.files[i]); | |
| } | |
| } | |
| } | |
| if (files.length === 0) return; | |
| const uploadInput = gallery.querySelector('input[type="file"]'); | |
| if (uploadInput) { | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| const dataTransfer = new DataTransfer(); | |
| files.forEach(file => dataTransfer.items.add(file)); | |
| uploadInput.files = dataTransfer.files; | |
| uploadInput.dispatchEvent(new Event('change', { bubbles: true })); | |
| } | |
| }); | |
| } | |
| """ | |
| # ---------------- | |
| # Gradio layout | |
| # ---------------- | |
| # Helper function: Generate progress bar HTML | |
| def make_progress_html(current, total, step_desc): | |
| if total == 0: | |
| percent = 0 | |
| else: | |
| percent = min(max(current / total * 100, 0), 100) | |
| # Use Gradio's CSS variables to automatically adapt to dark/light modes | |
| # var(--background-fill-secondary): Container background color | |
| # var(--border-color-primary): Border color | |
| # var(--body-text-color): Main text color | |
| # var(--color-accent): Progress bar color (follows theme accent) | |
| # var(--border-color-primary): Progress bar track color (ensures visibility in dark mode) | |
| return f""" | |
| <div style=" | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: 8px; | |
| padding: 10px; | |
| background: var(--background-fill-secondary); | |
| margin-bottom: 10px; | |
| "> | |
| <div style="display: flex; justify-content: space-between; margin-bottom: 5px; font-family: var(--font); font-size: var(--text-sm);"> | |
| <span style="font-weight: bold; color: var(--body-text-color);">{step_desc}</span> | |
| <span style="color: var(--body-text-color); opacity: 0.8;">{percent:.1f}%</span> | |
| </div> | |
| <div style="width: 100%; background-color: var(--border-color-primary); border-radius: 10px; height: 12px; overflow: hidden;"> | |
| <div style="width: {percent}%; background-color: var(--color-accent); height: 100%; border-radius: 10px; transition: width 0.3s ease-in-out;"></div> | |
| </div> | |
| </div> | |
| """ | |
| # ---------------- UI: main ---------------- | |
| def ui_upscale_main( | |
| gallery_input, # Image list | |
| video_input, # Video path | |
| resolution, max_resolution, preset_mode, dit_model_combo, use_gguf, cuda_device, | |
| # Compile | |
| compile_dit, compile_vae, compile_backend, compile_mode, compile_fullgraph, | |
| compile_dynamic, compile_dynamo_cache_size_limit, compile_dynamo_recompile_limit, attention_mode, | |
| # Tiling | |
| vae_encode_tiled, vae_encode_tile_size, vae_encode_tile_overlap, | |
| vae_decode_tiled, vae_decode_tile_size, vae_decode_tile_overlap, tile_debug, | |
| # Processing | |
| batch_size, uniform_batch_size, seed, skip_first_frames, load_cap, | |
| # Color/Quality | |
| color_correction, input_noise_scale, latent_noise_scale, | |
| # Memory | |
| blocks_to_swap, swap_io_components, dit_offload_device, vae_offload_device, tensor_offload_device, | |
| cache_dit, cache_vae, extra_args, | |
| # General | |
| pre_downscale, downscale_rate, repetition_count, output_format, use_improved_blockswap, | |
| # Video | |
| chunk_size, temporal_overlap, prepend_frames, video_backend, use_10bit, | |
| debug, | |
| # Repo Config Inputs | |
| custom_repo_url, custom_branch, custom_clone_name | |
| ): | |
| # Initialize empty progress bar HTML | |
| empty_progress = make_progress_html(0, 100, "Waiting to start...") | |
| # DETERMINE INPUT SOURCE | |
| target_inputs = [] | |
| is_video_mode = False | |
| if video_input is not None: | |
| # Video takes precedence if provided (or user can clear it) | |
| target_inputs = [video_input] | |
| is_video_mode = True | |
| # Force mp4 format for internal logic if video | |
| output_format = "mp4" | |
| elif gallery_input: | |
| # gallery is expected to be a list; each item may be: | |
| # - str filepath (depending on Gradio version) OR | |
| # - an object/tuple where first item is filepath (some Gradio variants). | |
| # Normalize gallery entries to file paths | |
| for entry in gallery_input: | |
| # Gradio versions vary — entry may be: | |
| # - str (path) | |
| # - list/tuple where first element is path | |
| if isinstance(entry, (list, tuple)): | |
| # sometimes gallery entries are [path, caption...] | |
| path = entry[0] | |
| else: | |
| path = entry | |
| # If path is a dict with 'name' depending on gradio, try common keys | |
| if isinstance(path, dict) and 'name' in path: | |
| path = path['name'] | |
| target_inputs.append(str(path)) | |
| else: | |
| yield None, "No images or video provided.\n", empty_progress | |
| return | |
| # # apply presets | |
| # if preset_mode == "Recommended (low VRAM)": | |
| # compile_dit = False | |
| # compile_vae = False | |
| # vae_encode_tiled = True | |
| # if not vae_tile_size: | |
| # vae_tile_size = 256 | |
| # if max_resolution is None: | |
| # max_resolution = 1920 | |
| # # keep blocks_to_swap = 0 by default | |
| # elif preset_mode == "Offload (very slow)": | |
| # compile_dit = False | |
| # compile_vae = False | |
| # vae_encode_tiled = True | |
| # if not vae_tile_size: | |
| # vae_tile_size = 256 | |
| # if max_resolution is None: | |
| # max_resolution = 1440 | |
| # if not blocks_to_swap: | |
| # blocks_to_swap = 32 | |
| # swap_io_components = True | |
| # dit_offload_device = "cpu" | |
| # vae_offload_device = "cpu" | |
| # tensor_offload_device = "cpu" | |
| # elif preset_mode == "High quality (fast if lots of VRAM)": | |
| # compile_dit = True | |
| # compile_vae = True | |
| # vae_encode_tiled = False | |
| # if max_resolution is None: | |
| # max_resolution = 0 # no limit | |
| # Dynamic Repo Handling | |
| current_repo_path = CLONE_DIR # Fallback | |
| # Model directory is now fixed and independent of the repo location | |
| current_model_dir = DEFAULT_MODEL_DIR | |
| # Default values if empty | |
| target_repo_url = custom_repo_url.strip() if custom_repo_url and custom_repo_url.strip() else REPO_URL | |
| target_clone_name = custom_clone_name.strip() if custom_clone_name and custom_clone_name.strip() else "ComfyUI-SeedVR2_VideoUpscaler" | |
| target_branch = custom_branch.strip() | |
| # Ensure repo and model exist (downloads/clone if missing) | |
| try: | |
| yield None, f"Checking Repository ({target_clone_name})...", make_progress_html(5, 100, "Checking Repo...") | |
| # Call the updated ensure_repo_cloned | |
| current_repo_path = ensure_repo_cloned( | |
| repo_url=target_repo_url, | |
| clone_dir_name=target_clone_name, | |
| repo_branch=target_branch, | |
| force_update=False | |
| ) | |
| except Exception as e: | |
| yield None, f"Repo clone/check failed: {e}\n", make_progress_html(0, 100, "Repo Error") | |
| return | |
| # Parse the selected combo "RepoID/Filename" | |
| selected_repo_id = DEFAULT_VAE_REPO_ID # Default fallback | |
| selected_filename = None | |
| if dit_model_combo: | |
| # Check if the string contains a slash indicating Repo/File structure | |
| if "/" in dit_model_combo: | |
| # Split from the right, as filename is the last part | |
| parts = dit_model_combo.split("/") | |
| selected_filename = parts[-1] | |
| # Join the rest as the repo ID (e.g. "owner/repo" or "owner/sub/repo") | |
| selected_repo_id = "/".join(parts[:-1]) | |
| else: | |
| # Fallback for simple filenames (assumes default repo) | |
| selected_filename = dit_model_combo | |
| if selected_filename: | |
| try: | |
| yield None, f"Checking Model {selected_filename}...", make_progress_html(10, 100, "Checking Models...") | |
| # Pass new paths to ensure download happens in the custom repo folder | |
| ensure_models_available( | |
| selected_filename, | |
| model_dir=current_model_dir, | |
| repo_id=selected_repo_id, | |
| ) | |
| except Exception as e: | |
| yield None, f"Model download failed: {e}\n", make_progress_html(0, 100, "Model Error") | |
| return | |
| # Stores log history for all completed images | |
| full_logs_history = "" | |
| # successful_outputs will now store tuples (physical_path, archive_name) | |
| successful_outputs = [] | |
| total_files = len(target_inputs) | |
| # Ensure repetition is at least 1 | |
| safe_repetition = max(1, int(repetition_count)) | |
| if is_video_mode: | |
| safe_repetition = 1 # Force 1 pass for video to avoid endless waits | |
| total_operations = total_files * safe_repetition | |
| # Process sequentially | |
| for idx, img_path in enumerate(target_inputs, start=1): | |
| filename = os.path.basename(img_path) | |
| original_stem = Path(img_path).stem | |
| # This variable tracks the input for the current pass | |
| # Initially it is the original file, in subsequent loops it becomes the output of the previous pass | |
| current_input_path = img_path | |
| final_output_for_image = None | |
| # Loop for Repetitions | |
| for loop_idx in range(1, safe_repetition + 1): | |
| # Calculate global progress index | |
| # (File 1 Pass 1 = 0, File 1 Pass 2 = 1 ... File 2 Pass 1 = N) | |
| global_op_index = (idx - 1) * safe_repetition + (loop_idx - 1) | |
| # Prepare header | |
| pass_info = f" (Pass {loop_idx}/{safe_repetition})" if safe_repetition > 1 else "" | |
| # Prepare header for this file | |
| header_log = f"\n\n=== FILE {idx}/{total_files}: {filename}{pass_info} ===\n" | |
| # Progress calculation | |
| # Calculate base progress (e.g., 2nd image of 4, base progress is 25%) | |
| # Reserve 10% for preparation, allocate remaining 90% to images | |
| # start_pct = 10 + ((idx - 1) / total_images) * 90 | |
| start_pct = 10 + (global_op_index / total_operations) * 90 | |
| progress_html = make_progress_html(start_pct, 100, f"File {idx}/{total_files} - Pass {loop_idx}: Preparing...") | |
| yield None, full_logs_history + header_log, progress_html | |
| # Call generator | |
| # Note: pre_downscale is passed every time. | |
| # If enabled, it will downscale 'current_input_path' before upscaling. | |
| gen = run_cli_upscale_stream( | |
| input_path=current_input_path, | |
| resolution=int(resolution), | |
| max_resolution=int(max_resolution) if max_resolution is not None else 0, | |
| dit_model_filename=selected_filename if selected_filename else None, | |
| cuda_device=(cuda_device if CUDA_AVAILABLE else None), | |
| # New Compile Args | |
| compile_dit=bool(compile_dit), | |
| compile_vae=bool(compile_vae), | |
| compile_backend=compile_backend, | |
| compile_mode=compile_mode, | |
| compile_fullgraph=bool(compile_fullgraph), | |
| compile_dynamic=bool(compile_dynamic), | |
| compile_dynamo_cache_size_limit=int(compile_dynamo_cache_size_limit), | |
| compile_dynamo_recompile_limit=int(compile_dynamo_recompile_limit), | |
| attention_mode=attention_mode, | |
| # New Tiling Args | |
| vae_encode_tiled=bool(vae_encode_tiled), | |
| vae_encode_tile_size=int(vae_encode_tile_size), | |
| vae_encode_tile_overlap=int(vae_encode_tile_overlap), | |
| vae_decode_tiled=bool(vae_decode_tiled), | |
| vae_decode_tile_size=int(vae_decode_tile_size), | |
| vae_decode_tile_overlap=int(vae_decode_tile_overlap), | |
| tile_debug=tile_debug, | |
| # New Processing Args | |
| batch_size=int(batch_size), | |
| uniform_batch_size=bool(uniform_batch_size), | |
| seed=int(seed), | |
| skip_first_frames=int(skip_first_frames), | |
| load_cap=int(load_cap), | |
| # New Quality Args | |
| color_correction=color_correction, | |
| input_noise_scale=float(input_noise_scale), | |
| latent_noise_scale=float(latent_noise_scale), | |
| # Memory & Caching | |
| blocks_to_swap=int(blocks_to_swap), | |
| swap_io_components=bool(swap_io_components), | |
| dit_offload_device=str(dit_offload_device), | |
| vae_offload_device=str(vae_offload_device), | |
| tensor_offload_device=str(tensor_offload_device), | |
| cache_dit=bool(cache_dit), | |
| cache_vae=bool(cache_vae), | |
| extra_args=extra_args or "", | |
| model_dir=str(current_model_dir), # Use current_model_dir | |
| repo_id=selected_repo_id, # Pass the extracted Repo ID | |
| repo_path=current_repo_path, # Pass dynamic repo path | |
| pre_downscale=pre_downscale, | |
| downscale_rate=downscale_rate, | |
| output_format=output_format, | |
| use_improved_blockswap=use_improved_blockswap, | |
| # Pass Video Args | |
| chunk_size=int(chunk_size), | |
| temporal_overlap=int(temporal_overlap), | |
| prepend_frames=int(prepend_frames), | |
| video_backend=video_backend, | |
| use_10bit=use_10bit, | |
| debug=bool(debug) | |
| ) | |
| out_for_this_pass = None | |
| current_stream_logs = "" | |
| try: | |
| for out_path, logs in gen: | |
| # logs is the complete log from start to now for this image (behavior of run_cli_upscale_stream) | |
| current_stream_logs = logs | |
| # Progress Logic per pass | |
| # Simply parse log content to determine stage, allocating this image's ratio in total progress | |
| # Single image takes up (90 / total_files)% of total progress | |
| per_pass_slice = 90 / total_operations | |
| local_ratio = 0.1 | |
| status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Init" | |
| if "Phase 1: VAE encoding" in logs: | |
| local_ratio = 0.2 | |
| status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Encoding" | |
| if "Phase 2: DiT upscaling" in logs: | |
| local_ratio = 0.45 | |
| status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Upscaling" | |
| if "Phase 3: Decode" in logs: | |
| local_ratio = 0.8 | |
| status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Decoding" | |
| if "Phase 4: Post-process" in logs: | |
| local_ratio = 0.95 | |
| status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Post-proc" | |
| if "Saving" in logs: | |
| local_ratio = 0.98 | |
| status_text = f"File {idx}/{total_files} (Pass {loop_idx}): Saving" | |
| # Calculate current total progress | |
| current_total_pct = start_pct + (per_pass_slice * local_ratio) | |
| progress_html = make_progress_html(current_total_pct, 100, status_text) | |
| # Combine historical log + current image header + current streaming log | |
| yield None, full_logs_history + header_log + current_stream_logs, progress_html | |
| if out_path: | |
| out_for_this_pass = out_path | |
| except Exception as e: | |
| error_msg = f"[ERROR] Exception: {e}\n" | |
| current_stream_logs += error_msg | |
| yield None, full_logs_history + header_log + current_stream_logs, make_progress_html(current_total_pct, 100, "Error") | |
| # If error, break the repetition loop for this image | |
| break | |
| full_logs_history += header_log + current_stream_logs | |
| # After image processing completes | |
| if out_for_this_pass and os.path.exists(out_for_this_pass): | |
| # Success for this pass | |
| current_input_path = out_for_this_pass # Update input for next pass | |
| final_output_for_image = out_for_this_pass | |
| else: | |
| # Failure in this pass, stop repeating | |
| full_logs_history += f"\n[WARN] Pass {loop_idx} failed, stopping.\n" | |
| break | |
| # End Repetitions | |
| if final_output_for_image and os.path.exists(final_output_for_image): | |
| # Rename physical file back to original name + timestamp before adding to ZIP | |
| # Consider output format conversion if necessary | |
| output_dir = Path(final_output_for_image).parent | |
| # Generate timestamp | |
| ts = int(time.time()) | |
| # Logic for Image format conversion vs Video | |
| if is_video_mode: | |
| # Keep as mp4 | |
| target_filename = f"{original_stem}_{ts}.mp4" | |
| target_path = output_dir / target_filename | |
| try: | |
| shutil.move(final_output_for_image, target_path) | |
| final_output_for_image = str(target_path) | |
| full_logs_history += f"[INFO] Renamed output to: {target_filename}\n" | |
| except Exception as e: | |
| full_logs_history += f"[WARN] Rename failed: {e}\n" | |
| target_filename = os.path.basename(final_output_for_image) | |
| else: | |
| # Image Logic (PNG/JPG/WEBP conversion) | |
| # Set restored filename (original_stem_{timestamp}.{ext}) | |
| target_filename = f"{original_stem}_{ts}.{output_format}" | |
| target_path = output_dir / target_filename | |
| # Delete target if exists to avoid collision | |
| if target_path.exists(): | |
| os.remove(target_path) | |
| if output_format == "png": | |
| # Just rename | |
| shutil.move(final_output_for_image, target_path) | |
| full_logs_history += f"[INFO] Renamed output to: {target_filename}\n" | |
| else: | |
| # Convert (jpg, webp, etc.) | |
| try: | |
| # Read the image | |
| img = cv2.imread(final_output_for_image, cv2.IMREAD_UNCHANGED) | |
| if img is None: | |
| raise ValueError("Result image could not be loaded via cv2.") | |
| # Convert BGRA (OpenCV default for alpha) to BGR if saving as JPEG | |
| if output_format in ["jpg", "jpeg"]: | |
| # Check if image has 4 channels | |
| if len(img.shape) == 3 and img.shape[2] == 4: | |
| img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) | |
| # Save to new format with quality control | |
| quality_val = 95 | |
| cv2.imwrite(str(target_path), img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_val]) | |
| else: | |
| # Save directly (handles webp, etc.) | |
| cv2.imwrite(str(target_path), img) | |
| full_logs_history += f"[INFO] Converted PNG to {output_format}: {target_filename}\n" | |
| # Remove original png | |
| os.remove(final_output_for_image) | |
| except Exception as e: | |
| full_logs_history += f"[ERROR] Conversion failed: {e}\n" | |
| # Fallback: if conversion failed, try to keep the original file if possible | |
| target_path = Path(final_output_for_image) # Fallback to png | |
| target_filename = target_path.name | |
| # Update variable to point to the new path | |
| final_output_for_image = str(target_path) | |
| # Store both the physical path and the intended ZIP name | |
| successful_outputs.append((final_output_for_image, target_filename)) | |
| full_logs_history += f"[INFO] Item {idx} completed: {final_output_for_image}\n" | |
| # Update progress bar to completed state for this file | |
| end_pct = 10 + (idx / total_files ) * 90 # Rough estimate for completion of this image block | |
| yield None, full_logs_history, make_progress_html(end_pct, 100, f"Item {idx} Done") | |
| # Final Output Logic | |
| # After all images processed, logic to handle output (Single file vs ZIP) | |
| if successful_outputs: | |
| # Check output count | |
| if len(successful_outputs) == 1: | |
| # If there is only one file, return the image path directly; do not compress. | |
| # successful_outputs stores tuples: (physical_path, archive_name) | |
| single_file_path = successful_outputs[0][0] | |
| full_logs_history += f"\n[DONE] Single image processed. Returning: {single_file_path}\n" | |
| # Yield the single file path directly | |
| yield single_file_path, full_logs_history, make_progress_html(100, 100, "Processing Complete!") | |
| else: | |
| # If more than one file, execute the standard ZIP packaging logic. | |
| # Use ZIP_STORED (store only, no compression) for speed to avoid CPU bottlenecks. | |
| # PNG is already a compressed format; re-compressing via Deflate offers little benefit and is extremely slow. | |
| compression_method = zipfile.ZIP_STORED | |
| # Use parent dir of the first output for the zip location | |
| out_dir = Path(successful_outputs[0][0]).parent | |
| # Derive model tag for filename (fall back to "output" when unknown) | |
| model_tag = selected_filename or "output" | |
| # sanitize model_tag to be filesystem-safe (allow alnum, dot, dash, underscore) | |
| sanitized = "".join(c if (c.isalnum() or c in "._-") else "_" for c in model_tag) | |
| zip_name = f"seedvr2_{sanitized}_{int(time.time())}.zip" | |
| zip_path = out_dir / zip_name | |
| full_logs_history += f"\n[INFO] Zipping {len(successful_outputs)} items...\n" | |
| yield None, full_logs_history, make_progress_html(100, 100, "Packaging...") | |
| # Add allowZip64=True to support files larger than 4GB | |
| with zipfile.ZipFile(zip_path, "w", compression=compression_method, allowZip64=True) as zf: | |
| total_files = len(successful_outputs) | |
| # Yield progress inside the packaging loop to prevent Gradio disconnects due to long periods of unresponsiveness | |
| for i, (physical_path, archive_name) in enumerate(successful_outputs): | |
| try: | |
| zf.write(physical_path, arcname=archive_name) | |
| except Exception as e: | |
| full_logs_history += f"\n[WARN] Failed to pack {archive_name}: {e}\n" | |
| # Regularly yield progress updates | |
| # Although ZIP_STORED is fast, writing 400 images to disk still takes time. | |
| # Update UI every 10 images here to let the frontend know the connection is still alive. | |
| if i % 10 == 0 or i == total_files - 1: | |
| # pct = int((i + 1) / total_files * 100) | |
| status_msg = f"Packaging {i+1}/{total_files}..." | |
| # Here we only update the progress bar, not the full log history, to avoid excessive data transmission | |
| yield None, full_logs_history, make_progress_html(100, 100, status_msg) | |
| final_msg = f"\n[DONE] Successfully packaged {len(successful_outputs)} items into {zip_path}\n" | |
| full_logs_history += final_msg | |
| # Return the ZIP file path | |
| yield str(zip_path), full_logs_history, make_progress_html(100, 100, "Complete!") | |
| else: | |
| full_logs_history += "\n[DONE] No outputs were generated.\n" | |
| yield None, full_logs_history, make_progress_html(100, 100, "Failed / No Output") | |
| # ---------------- UI layout ---------------- | |
| def main(): | |
| css = """ | |
| /* small UI tweaks */ | |
| #input_gallery:hover { border-color: var(--color-accent) !important; box-shadow: 0 0 8px rgba(0,0,0,0.08); } | |
| """ | |
| is_low_vram = False | |
| # Print CUDA/MPS availability, useful when running on CPU-only server | |
| if CUDA_AVAILABLE: | |
| try: | |
| torch.cuda.set_per_process_memory_fraction(0.95, device='cuda:0') | |
| except Exception: | |
| pass | |
| # set torch options to avoid get black image for RTX16xx card | |
| # https://github.com/CompVis/stable-diffusion/issues/69#issuecomment-1260722801 | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = True | |
| print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}") | |
| try: | |
| # Check if VRAM is less than 6.5GB (6 * 1024^3 bytes) | |
| # If so, default to Offload mode to prevent OOM on entry | |
| if torch.cuda.get_device_properties(0).total_memory <= (6.5 * 1024**3): | |
| is_low_vram = True | |
| except Exception: | |
| # Fallback if device property read fails | |
| pass | |
| elif MPS_AVAILABLE: | |
| print("MPS (Apple Silicon) is available. Using Metal Performance Shaders.") | |
| else: | |
| print("Neither CUDA nor MPS detected. GPU-related UI controls hidden.") | |
| # Automatic defaults logic: | |
| # 1. If no accelerator (CPU only) -> Force "Offload" | |
| # 2. If MPS (Mac) -> Default "Recommended" (Unified Memory handles this well), MPS users default to Recommended (usually 8GB+ Unified Memory is sufficient for this preset) | |
| # 3. If CUDA -> Check VRAM, if low use "Offload", else "Recommended" | |
| DEFAULT_PRESET = "Offload (very slow)" if not ACCELERATOR_AVAILABLE or is_low_vram else "Recommended (low VRAM)" | |
| # Unpack default values from the calculated preset immediately. | |
| # This ensures that all sliders and checkboxes match the Dropdown's initial value. | |
| ( | |
| init_compile_dit, init_compile_vae, | |
| init_vae_encode_tiled, init_vae_encode_tile_size, | |
| init_vae_decode_tiled, init_vae_decode_tile_size, | |
| init_max_resolution, init_blocks_to_swap, init_swap_io_components, | |
| init_dit_offload_device, init_vae_offload_device, init_tensor_offload_device, | |
| init_extra_args, init_chunk_size, init_temporal_overlap | |
| ) = preset_changed(DEFAULT_PRESET) | |
| # Default to GGUF only if NO accelerator is found or is low vram. | |
| # MPS/CUDA users usually prefer standard Safetensors unless extremely VRAM constrained. | |
| DEFAULT_USE_GGUF = not ACCELERATOR_AVAILABLE or is_low_vram | |
| # initial model choices depend on DEFAULT_USE_GGUF | |
| initial_model_choices = GGUF_CHOICES if DEFAULT_USE_GGUF else MODEL_CHOICES | |
| initial_model_value = initial_model_choices[0] if initial_model_choices else (MODEL_CHOICES[0] if MODEL_CHOICES else None) | |
| with gr.Blocks(title="SeedVR2 Image/Video Upscaler", css=css) as demo: | |
| gr.Markdown("# SeedVR2 Upscaler — Image & Video\nSupport for single image, batch images, and MP4 video upscaling.") | |
| gr.Markdown("This application utilizes the [ComfyUI-SeedVR2_VideoUpscaler](https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler) backend logic for inference. ") | |
| if MPS_AVAILABLE: | |
| gr.HTML( | |
| """ | |
| <div style=" | |
| padding: 1rem; | |
| border-radius: 8px; | |
| margin-bottom: 10px; | |
| border-left: 5px solid #2196f3; | |
| background-color: rgba(33, 150, 243, 0.1); | |
| color: var(--body-text-color); | |
| "> | |
| <h3 style="margin: 0 0 5px 0; color: var(--body-text-color);">🍎 macOS MPS Detected</h3> | |
| <p style="margin: 0;"> | |
| Running on <b>Metal Performance Shaders (MPS)</b>. | |
| Performance is better than CPU. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| elif not CUDA_AVAILABLE: | |
| gr.HTML( | |
| """ | |
| <div style=" | |
| padding: 1rem; | |
| border-radius: 8px; | |
| margin-bottom: 10px; | |
| border-left: 5px solid #ff9800; | |
| background-color: rgba(255, 152, 0, 0.1); | |
| color: var(--body-text-color); | |
| "> | |
| <h3 style="margin: 0 0 5px 0; color: var(--body-text-color);">⚠️ No GPU Detected (CPU Mode)</h3> | |
| <p style="margin: 0 0 8px 0;"> | |
| Neither CUDA (NVIDIA) nor MPS (macOS) was detected. Processing will be extremely slow. | |
| </p> | |
| <ul style="margin: 0 0 0 20px; padding: 0;"> | |
| <li><b>Recommendation:</b> Clone this repository to a local machine with a GPU for full functionality.</li> | |
| <li><b>If running online (CPU):</b> Please process <b>Images Only</b>.</li> | |
| <li><b>Model Selection:</b> Use <b>GGUF 3B</b> models or <b>7B (Q4_K_M)</b> quantization. Heavier models will likely fail.</li> | |
| </ul> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| submit = gr.Button("Start Upscale Processing", variant="primary", size="lg") | |
| # TABS for Input | |
| with gr.Tabs(): | |
| with gr.TabItem("🖼️ Image Gallery"): | |
| gallery = gr.Gallery( | |
| label="Input Images (Batch Support)", | |
| elem_id="input_gallery", | |
| columns=4, rows=3, show_label=False, interactive=True, height=350 | |
| ) | |
| with gr.TabItem("🎥 Video Input"): | |
| video_input = gr.Video( | |
| label="Input Video (MP4/AVI)", | |
| sources=["upload"], | |
| format="mp4" | |
| ) | |
| # Group 0: Repo Settings (New) | |
| with gr.Accordion("🛠️ Repository Settings (Advanced)", open=False): | |
| gr.Markdown("Configure a custom GitHub repository to test different versions or forks.") | |
| with gr.Row(): | |
| custom_repo_url = gr.Textbox( | |
| label="Repository URL", | |
| value="", | |
| placeholder="https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler.git" | |
| ) | |
| custom_clone_name = gr.Textbox( | |
| label="Clone Directory Name", | |
| value="", | |
| placeholder="ComfyUI-SeedVR2_VideoUpscaler", | |
| info="Folder name inside the app directory. Change this to avoid overwriting default." | |
| ) | |
| custom_branch = gr.Textbox( | |
| label="Branch / Tag / Commit Hash", | |
| value="", | |
| placeholder="e.g. main, dev, or hash like d69b65f...", | |
| info="Leave empty for default branch. If changing repo, use a new directory name." | |
| ) | |
| # Group 1: General Settings (Resolution & Presets) | |
| with gr.Accordion("### ⚙️ General Settings", open=True): | |
| preset_mode = gr.Dropdown( | |
| choices=["Recommended (low VRAM)", "Offload (very slow)", "High quality (fast if lots of VRAM)"], | |
| value=DEFAULT_PRESET, | |
| label="Preset mode", | |
| info="Automatically adjusts compilation, tiling, and offload settings based on your hardware capabilities." | |
| ) | |
| with gr.Row(): | |
| resolution = gr.Slider( | |
| minimum=256, maximum=4096, step=64, value=1920, | |
| label="Target Resolution (Short Side)", | |
| info="Target short-side resolution in pixels. The aspect ratio is preserved." | |
| ) | |
| max_resolution = gr.Number( | |
| value=init_max_resolution, | |
| label="Max resolution (0=unlimited)", | |
| info="Maximum resolution for any edge. Scales down if exceeded. 0 = no limit." | |
| ) | |
| # Output Format selection | |
| with gr.Row(): | |
| output_format = gr.Dropdown( | |
| choices=["webp", "png", "jpg"], | |
| value="webp", | |
| label="Output Format (Default: webp)", | |
| info="Format for saved images. For video input, the CLI produces MP4 (or PNG sequence), and this app converts the final result if needed." | |
| ) | |
| seed = gr.Number( | |
| value=42, | |
| label="Seed", | |
| precision=0, | |
| info="Random seed for reproducibility." | |
| ) | |
| repetition_count = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| step=1, | |
| value=1, | |
| label="Loop Count (Images only, Repeat Upscale)", | |
| info="Run SeedVR2 N times per image. If downscale is checked, it applies before EACH run to progressively upscale/refine." | |
| ) | |
| # Group 2: Video Specific | |
| with gr.Accordion("### 🎥 Video Settings", open=False): | |
| with gr.Row(): | |
| chunk_size = gr.Number( | |
| value=init_chunk_size, | |
| label="Chunk Size (Frames)", | |
| info="Frames per chunk for streaming mode. 0 = load all frames at once. Set to specific amount (e.g. 100) to limit VRAM usage on long videos." | |
| ) | |
| temporal_overlap = gr.Number( | |
| value=init_temporal_overlap, | |
| label="Temporal Overlap", | |
| info="Frames to overlap between chunks/batches for smooth blending and to prevent seams." | |
| ) | |
| with gr.Row(): | |
| prepend_frames = gr.Number( | |
| value=0, | |
| label="Prepend Frames", | |
| info="Prepend N reversed frames to reduce start artifacts. These are automatically removed from the output." | |
| ) | |
| skip_first_frames = gr.Number( | |
| value=0, | |
| label="Skip First Frames", | |
| precision=0, | |
| info="Skip N initial frames of the video." | |
| ) | |
| load_cap = gr.Number( | |
| value=0, | |
| label="Load Cap (Max Frames)", | |
| precision=0, | |
| info="Load maximum N frames from video. 0 = load all." | |
| ) | |
| with gr.Row(): | |
| video_backend = gr.Dropdown( | |
| choices=["opencv", "ffmpeg"], | |
| value="opencv", | |
| label="Video Backend", | |
| info="Video encoder backend. 'ffmpeg' requires ffmpeg in system PATH but supports advanced features like 10-bit." | |
| ) | |
| use_10bit = gr.Checkbox( | |
| label="10-bit Output (ffmpeg only)", | |
| value=False, | |
| info="Use x265 10-bit encoding (reduces banding). Requires ffmpeg backend." | |
| ) | |
| # Group 3: Model & Quality | |
| with gr.Accordion("### 🤖 Model & Quality", open=True): | |
| use_gguf = gr.Checkbox( | |
| label="Use GGUF-quantized models (gguf)", | |
| value=DEFAULT_USE_GGUF, | |
| info="When checked, the DiT model dropdown will show GGUF models from cmeka/SeedVR2-GGUF. Efficient for lower VRAM." | |
| ) | |
| dit_model = gr.Dropdown( | |
| choices=initial_model_choices, | |
| value=initial_model_value, | |
| label="DiT model (Format: RepoID/Filename)", | |
| info="DiT transformer model. 7B models have higher quality but require more memory than 3B models." | |
| ) | |
| # Callback: model choices (gguf <-> safetensors) | |
| def _toggle_model_list(gguf_enabled: bool): | |
| if gguf_enabled: | |
| # set to GGUF list, default the first gguf file | |
| return gr.update(choices=GGUF_CHOICES, value=GGUF_CHOICES[0]) | |
| else: | |
| return gr.update(choices=MODEL_CHOICES, value=MODEL_CHOICES[0]) | |
| use_gguf.change(fn=_toggle_model_list, inputs=[use_gguf], outputs=[dit_model]) | |
| # Show CUDA device textbox only if CUDA available | |
| cuda_device = gr.Textbox( | |
| label="CUDA device", | |
| value="0" if CUDA_AVAILABLE else "", | |
| visible=CUDA_AVAILABLE, | |
| info="CUDA device IDs (e.g. '0' or '0,1'). Leave blank for default." | |
| ) | |
| with gr.Row(): | |
| color_correction = gr.Dropdown( | |
| choices=["lab", "wavelet", "wavelet_adaptive", "hsv", "adain", "none"], | |
| value="lab", | |
| label="Color correction", | |
| info="Method to match colors. 'lab' (perceptual, recommended), 'wavelet' (frequency-based), 'adain' (statistical), etc." | |
| ) | |
| input_noise_scale = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.0, | |
| step=0.01, | |
| label="Input Noise Scale", | |
| info="Input noise injection scale (0.0-1.0). Adds variation to input images." | |
| ) | |
| latent_noise_scale = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.0, | |
| step=0.01, | |
| label="Latent Noise Scale", | |
| info="Latent noise injection scale (0.0-1.0). Adds variation to latent space." | |
| ) | |
| with gr.Row(): | |
| pre_downscale = gr.Checkbox( | |
| label="Pre-downscale image (Images only, removes noise/artifacts)", | |
| value=False, | |
| info="Reduces image size before upscaling. Helps remove JPEG artifacts or noise as noted in community tips." | |
| ) | |
| downscale_rate = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| step=0.1, | |
| value=0.5, | |
| label="Downscale factor", | |
| info="0.5 means the input is resized to 50% size before being upscaled to target resolution." | |
| ) | |
| # Group 4: Performance & Memory (Advanced) | |
| with gr.Accordion("### ⚡ Optimization & Memory", open=True): | |
| with gr.Row(): | |
| batch_size = gr.Slider( | |
| minimum=1, | |
| maximum=65, | |
| step=4, | |
| value=1, | |
| label="Batch size (4n+1 recommended)", | |
| info="Frames per batch. 4n+1 (1, 5, 9, 13...) is optimized for temporal consistency. Higher values use more VRAM." | |
| ) | |
| uniform_batch_size = gr.Checkbox( | |
| label="Uniform Batch Size (Pad final batch)", | |
| value=False, | |
| info="Pad final batch to match batch_size. Prevents temporal artifacts caused by small final batches. Adds extra compute." | |
| ) | |
| with gr.Accordion("Memory & Offload / Caching", open=False): | |
| use_improved_blockswap = gr.Checkbox( | |
| label="Use Improved BlockSwap (Nunchaku Ping-Pong CPUOffload)", | |
| value=False, | |
| info="Replaces the standard blockswap logic with the improved version from Nunchaku. Useful for faster offloading." | |
| ) | |
| with gr.Row(): | |
| blocks_to_swap = gr.Number( | |
| value=init_blocks_to_swap, | |
| label="Blocks to swap", | |
| info="Transformer blocks to swap to RAM. 0=disabled. Use large value like 99 for auto-detection of max blocks. Requires Offload Device." | |
| ) | |
| swap_io_components = gr.Checkbox( | |
| label="Swap I/O components", | |
| value=init_swap_io_components, | |
| info="Offload DiT I/O layers for extra VRAM savings. Requires Offload Device." | |
| ) | |
| # Offload device choices adapt to CUDA availability | |
| offload_choices = ["none", "cpu"] + (["cuda:0"] if CUDA_AVAILABLE else []) | |
| with gr.Row(): | |
| dit_offload_device = gr.Dropdown( | |
| choices=offload_choices, | |
| value=init_dit_offload_device, | |
| label="DiT Offload device", | |
| info="Device to move DiT to when idle. 'cpu' frees VRAM between phases." | |
| ) | |
| vae_offload_device = gr.Dropdown( | |
| choices=offload_choices, | |
| value=init_vae_offload_device, | |
| label="VAE Offload device", | |
| info="Device to move VAE to when idle. 'cpu' frees VRAM between phases." | |
| ) | |
| tensor_offload_device = gr.Dropdown( | |
| choices=offload_choices, | |
| value=init_tensor_offload_device, | |
| label="Tensor Offload device", | |
| info="Where to store intermediate tensors. 'cpu' is recommended to save VRAM." | |
| ) | |
| with gr.Row(): | |
| cache_dit = gr.Checkbox( | |
| label="Cache DiT", | |
| value=False, | |
| info="Keep DiT model in memory between generations. Useful for batch/directory mode or streaming." | |
| ) | |
| cache_vae = gr.Checkbox( | |
| label="Cache VAE", | |
| value=False, | |
| info="Keep VAE model in memory between generations. Useful for batch/directory mode or streaming." | |
| ) | |
| with gr.Accordion("Advanced Tiling (VRAM Saving)", open=False): | |
| with gr.Row(): | |
| vae_encode_tiled = gr.Checkbox( | |
| label="Enable VAE Encode tiling", | |
| value=init_vae_encode_tiled, | |
| info="Process VAE encoding in tiles to reduce VRAM usage (good for large inputs)." | |
| ) | |
| vae_encode_tile_size = gr.Number( | |
| value=init_vae_encode_tile_size, | |
| label="Encode Tile Size", | |
| info="Tile size in pixels for encoding." | |
| ) | |
| vae_encode_tile_overlap = gr.Number( | |
| value=64, | |
| label="Encode Overlap", | |
| info="Overlap in pixels to reduce visible seams." | |
| ) | |
| with gr.Row(): | |
| vae_decode_tiled = gr.Checkbox( | |
| label="Enable Decode Tiling", | |
| value=init_vae_decode_tiled, | |
| info="Process VAE decoding in tiles to reduce VRAM usage." | |
| ) | |
| vae_decode_tile_size = gr.Number( | |
| value=init_vae_decode_tile_size, | |
| label="Decode Tile Size", | |
| info="Tile size in pixels for decoding." | |
| ) | |
| vae_decode_tile_overlap = gr.Number( | |
| value=64, | |
| label="Decode Overlap", | |
| info="Overlap in pixels to reduce visible seams." | |
| ) | |
| tile_debug = gr.Dropdown( | |
| choices=["false", "encode", "decode"], | |
| value="false", | |
| label="Tile Debug Visualization", | |
| info="Visualizes the tiling process for debugging purposes." | |
| ) | |
| with gr.Accordion("Compilation & Backend (Torch 2.0+)", open=False): | |
| with gr.Row(): | |
| compile_dit = gr.Checkbox( | |
| label="Enable torch.compile for DiT", | |
| value=init_compile_dit, | |
| info="20-40% speedup. Requires PyTorch 2.0+ and Triton. May increase memory usage." | |
| ) | |
| compile_vae = gr.Checkbox( | |
| label="Enable torch.compile for VAE", | |
| value=init_compile_vae, | |
| info="15-25% speedup for VAE encoding/decoding." | |
| ) | |
| with gr.Row(): | |
| compile_backend = gr.Dropdown( | |
| choices=["inductor", "cudagraphs"], | |
| value="inductor", | |
| label="Backend", | |
| info="'inductor' (full optimization) or 'cudagraphs' (lightweight)." | |
| ) | |
| compile_mode = gr.Dropdown( | |
| choices=["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], | |
| value="default", | |
| label="Mode", | |
| info="Optimization level: 'default' (fast compile), 'max-autotune' (best speed, slow compile), etc." | |
| ) | |
| with gr.Row(): | |
| attention_mode = gr.Dropdown( | |
| choices=["sdpa", "flash_attn_2", "flash_attn_3", "sageattn_2", "sageattn_3"], | |
| value="sdpa", | |
| label="Attention Mode", | |
| info="Attention backend. 'sdpa' (default), 'flash_attn' (faster), or 'sageattn' (Blackwell)." | |
| ) | |
| compile_fullgraph = gr.Checkbox( | |
| label="Fullgraph", | |
| value=False, | |
| info="Compile entire model as single graph. Faster but less flexible." | |
| ) | |
| compile_dynamic = gr.Checkbox( | |
| label="Dynamic Shapes", | |
| value=False, | |
| info="Handle varying input shapes without recompilation." | |
| ) | |
| with gr.Row(): | |
| compile_dynamo_cache_size_limit = gr.Number( | |
| value=64, | |
| label="Dynamo Cache Limit", | |
| info="Max cached compiled versions per function." | |
| ) | |
| compile_dynamo_recompile_limit = gr.Number( | |
| value=128, | |
| label="Dynamo Recompile Limit", | |
| info="Max recompilation attempts before fallback to eager mode." | |
| ) | |
| with gr.Row(): | |
| debug_mode = gr.Checkbox( | |
| label="Enable Debug Logs", | |
| value=True, | |
| info="Show verbose output in CLI logs." | |
| ) | |
| extra_args = gr.Textbox( | |
| label="Extra CLI args", | |
| value=init_extra_args, | |
| info="Manually pass additional flags to the CLI (e.g. --custom_flag value)." | |
| ) | |
| # Bind the preset change callback (outputs updated to match new UI elements) | |
| preset_mode.change( | |
| fn=preset_changed, | |
| inputs=[preset_mode], | |
| outputs=[ | |
| compile_dit, compile_vae, | |
| vae_encode_tiled, vae_encode_tile_size, | |
| vae_decode_tiled, vae_decode_tile_size, | |
| max_resolution, blocks_to_swap, swap_io_components, | |
| dit_offload_device, vae_offload_device, tensor_offload_device, | |
| extra_args, chunk_size, temporal_overlap | |
| ] | |
| ) | |
| with gr.Column(scale=1, variant="panel"): | |
| # Custom progress bar HTML component | |
| progress_display = gr.HTML(label="Progress", value=make_progress_html(0, 100, "Ready")) | |
| download_zip = gr.File(label="Download Result") | |
| logs = gr.Textbox(label="CLI logs (streaming)", lines=25, autoscroll=True) | |
| clear = gr.ClearButton(components=[gallery, video_input, download_zip, logs, progress_display], variant="secondary") | |
| submit.click( | |
| fn=ui_upscale_main, | |
| inputs=[ | |
| gallery, video_input, | |
| resolution, max_resolution, preset_mode, dit_model, use_gguf, cuda_device, | |
| # Compiled Inputs | |
| compile_dit, compile_vae, compile_backend, compile_mode, compile_fullgraph, | |
| compile_dynamic, compile_dynamo_cache_size_limit, compile_dynamo_recompile_limit, attention_mode, | |
| # Tiling Inputs | |
| vae_encode_tiled, vae_encode_tile_size, vae_encode_tile_overlap, | |
| vae_decode_tiled, vae_decode_tile_size, vae_decode_tile_overlap, tile_debug, | |
| # Processing Inputs | |
| batch_size, uniform_batch_size, seed, skip_first_frames, load_cap, | |
| # Quality Inputs | |
| color_correction, input_noise_scale, latent_noise_scale, | |
| # Memory Inputs | |
| blocks_to_swap, swap_io_components, dit_offload_device, vae_offload_device, tensor_offload_device, | |
| cache_dit, cache_vae, extra_args, | |
| # General Inputs | |
| pre_downscale, downscale_rate, repetition_count, output_format, use_improved_blockswap, | |
| # Video Inputs | |
| chunk_size, temporal_overlap, prepend_frames, video_backend, use_10bit, | |
| debug_mode, | |
| custom_repo_url, custom_branch, custom_clone_name | |
| ], | |
| outputs=[download_zip, logs, progress_display] | |
| ) | |
| # load paste JS | |
| demo.load(None, None, None, js=paste_js) | |
| demo.queue(max_size=1) | |
| demo.launch(inbrowser=True) | |
| if __name__ == "__main__": | |
| main() | |