Upload 2 files
Browse files- app.py +48 -22
- requirements.txt +5 -1
app.py
CHANGED
|
@@ -6,6 +6,7 @@ import shutil
|
|
| 6 |
import subprocess
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Tuple, Optional
|
|
|
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
from huggingface_hub import snapshot_download
|
|
@@ -17,10 +18,10 @@ CACHE_DIR.mkdir(exist_ok=True)
|
|
| 17 |
WEIGHTS_REPO = "JunhaoZhuang/FlashVSR"
|
| 18 |
FLASH_GIT = "https://github.com/OpenImagingLab/FlashVSR.git"
|
| 19 |
|
| 20 |
-
# ----------------------------- helpers -----------------------------
|
| 21 |
|
| 22 |
def run(cmd: list[str], cwd: Optional[Path] = None, env: Optional[dict] = None) -> Tuple[int, str, str]:
|
| 23 |
-
"""Run a command and capture
|
| 24 |
proc = subprocess.run(
|
| 25 |
cmd,
|
| 26 |
cwd=str(cwd) if cwd else None,
|
|
@@ -31,6 +32,37 @@ def run(cmd: list[str], cwd: Optional[Path] = None, env: Optional[dict] = None)
|
|
| 31 |
)
|
| 32 |
return proc.returncode, proc.stdout, proc.stderr
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def ensure_flashvsr_repo() -> Path:
|
| 36 |
"""Clone FlashVSR repo if missing."""
|
|
@@ -39,11 +71,9 @@ def ensure_flashvsr_repo() -> Path:
|
|
| 39 |
code, out, err = run(["git", "clone", "--depth=1", FLASH_GIT, str(repo_dir)])
|
| 40 |
if code != 0:
|
| 41 |
raise RuntimeError(f"Failed to clone FlashVSR.\n{err}")
|
| 42 |
-
# best-effort submodules (harmless if none)
|
| 43 |
run(["git", "submodule", "update", "--init", "--recursive"], cwd=repo_dir)
|
| 44 |
return repo_dir
|
| 45 |
|
| 46 |
-
|
| 47 |
def ensure_weights() -> Path:
|
| 48 |
"""Download weights snapshot locally once."""
|
| 49 |
target = CACHE_DIR / "weights"
|
|
@@ -57,14 +87,14 @@ def ensure_weights() -> Path:
|
|
| 57 |
)
|
| 58 |
return target
|
| 59 |
|
|
|
|
| 60 |
|
| 61 |
def normalize_to_mp4(src_path: Path) -> Path:
|
| 62 |
"""
|
| 63 |
Convert ANY uploaded file to a very browser-friendly MP4:
|
| 64 |
- H.264 (yuv420p), Baseline profile
|
| 65 |
-
- Even dimensions + CFR 30 fps
|
| 66 |
-
-
|
| 67 |
-
- No-audio (avoids codec/container edge cases in browser preview)
|
| 68 |
"""
|
| 69 |
out_path = src_path.with_name(src_path.stem + "_playable.mp4")
|
| 70 |
vf = "scale=trunc(iw/2)*2:trunc(ih/2)*2,fps=30"
|
|
@@ -84,6 +114,7 @@ def normalize_to_mp4(src_path: Path) -> Path:
|
|
| 84 |
raise RuntimeError(f"ffmpeg failed to normalize the video.\n\n{err or out}")
|
| 85 |
return out_path
|
| 86 |
|
|
|
|
| 87 |
|
| 88 |
def ensure_modelscope_stub() -> Path:
|
| 89 |
"""
|
|
@@ -102,23 +133,25 @@ def ensure_modelscope_stub() -> Path:
|
|
| 102 |
)
|
| 103 |
return stub_root
|
| 104 |
|
| 105 |
-
#
|
| 106 |
|
| 107 |
def run_flashvsr_on_video(
|
| 108 |
in_video: Path,
|
| 109 |
scale: int,
|
| 110 |
prefer_sparse: bool,
|
| 111 |
log_file: Path,
|
| 112 |
-
)
|
| 113 |
"""
|
| 114 |
Execute the FlashVSR example script and return (output_path | None, combined_logs).
|
| 115 |
-
Writes full logs to log_file either way.
|
| 116 |
"""
|
| 117 |
logs = []
|
|
|
|
|
|
|
| 118 |
|
|
|
|
| 119 |
repo_dir = ensure_flashvsr_repo()
|
| 120 |
weights_dir = ensure_weights()
|
| 121 |
-
stub_root = ensure_modelscope_stub()
|
| 122 |
|
| 123 |
logs.append(f"Python: {sys.version}")
|
| 124 |
logs.append(f"Repo: {repo_dir}")
|
|
@@ -138,8 +171,7 @@ def run_flashvsr_on_video(
|
|
| 138 |
out_dir.mkdir(exist_ok=True)
|
| 139 |
out_mp4 = out_dir / f"{in_video.stem}_x{scale}.mp4"
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
# And place our 'modelscope' stub at the front so its snapshot_download is used.
|
| 143 |
env = os.environ.copy()
|
| 144 |
extra_paths = [str(stub_root), str(repo_dir), str(repo_dir / "diffsynth")]
|
| 145 |
existing = env.get("PYTHONPATH", "")
|
|
@@ -159,7 +191,7 @@ def run_flashvsr_on_video(
|
|
| 159 |
logs.append(err)
|
| 160 |
return code == 0 and out_mp4.exists()
|
| 161 |
|
| 162 |
-
# Try
|
| 163 |
if _try(["--input", str(in_video), "--output", str(out_mp4), "--scale", str(scale), "--weights", str(weights_dir)]):
|
| 164 |
pass
|
| 165 |
elif _try(["--video", str(in_video), "--outdir", str(out_dir), "--scale", str(scale), "--weights", str(weights_dir)]):
|
|
@@ -174,15 +206,12 @@ def run_flashvsr_on_video(
|
|
| 174 |
combined = "\n".join(logs)
|
| 175 |
log_file.write_text(combined)
|
| 176 |
|
| 177 |
-
# Normalize
|
| 178 |
playable = normalize_to_mp4(out_mp4) if out_mp4.exists() else None
|
| 179 |
return playable, combined
|
| 180 |
|
| 181 |
-
|
| 182 |
def infer(ui_video: str, scale: int, prefer_sparse: bool):
|
| 183 |
-
"""
|
| 184 |
-
Gradio handler: return (video_path | None, diagnostics text, logs file path)
|
| 185 |
-
"""
|
| 186 |
logs_path = CACHE_DIR / "last_run_logs.txt"
|
| 187 |
|
| 188 |
if not ui_video:
|
|
@@ -197,7 +226,6 @@ def infer(ui_video: str, scale: int, prefer_sparse: bool):
|
|
| 197 |
return None, msg, str(logs_path)
|
| 198 |
|
| 199 |
try:
|
| 200 |
-
# Ensure the *input* itself is previewable in the browser.
|
| 201 |
src_playable = normalize_to_mp4(src)
|
| 202 |
except Exception as e:
|
| 203 |
msg = f"Input normalization failed:\n{e}"
|
|
@@ -207,13 +235,11 @@ def infer(ui_video: str, scale: int, prefer_sparse: bool):
|
|
| 207 |
try:
|
| 208 |
out_path, combined = run_flashvsr_on_video(src_playable, scale, prefer_sparse, logs_path)
|
| 209 |
if out_path is None:
|
| 210 |
-
# Show input so player still has something; include reason in diagnostics.
|
| 211 |
return str(src_playable), "FlashVSR failed. See logs below.", str(logs_path)
|
| 212 |
return str(out_path), "Done.", str(logs_path)
|
| 213 |
except Exception as e:
|
| 214 |
msg = f"Pipeline error:\n{e}"
|
| 215 |
logs_path.write_text(msg)
|
| 216 |
-
# Fall back to showing the normalized input so the UI still previews something.
|
| 217 |
return str(src_playable), msg, str(logs_path)
|
| 218 |
|
| 219 |
# ------------------------------ UI -------------------------------
|
|
|
|
| 6 |
import subprocess
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Tuple, Optional
|
| 9 |
+
import importlib.util as _import_spec
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
from huggingface_hub import snapshot_download
|
|
|
|
| 18 |
WEIGHTS_REPO = "JunhaoZhuang/FlashVSR"
|
| 19 |
FLASH_GIT = "https://github.com/OpenImagingLab/FlashVSR.git"
|
| 20 |
|
| 21 |
+
# ----------------------------- shell helpers -----------------------------
|
| 22 |
|
| 23 |
def run(cmd: list[str], cwd: Optional[Path] = None, env: Optional[dict] = None) -> Tuple[int, str, str]:
|
| 24 |
+
"""Run a command and capture stdout/stderr."""
|
| 25 |
proc = subprocess.run(
|
| 26 |
cmd,
|
| 27 |
cwd=str(cwd) if cwd else None,
|
|
|
|
| 32 |
)
|
| 33 |
return proc.returncode, proc.stdout, proc.stderr
|
| 34 |
|
| 35 |
+
# ----------------------------- deps helpers -----------------------------
|
| 36 |
+
|
| 37 |
+
def _has_module(name: str) -> bool:
|
| 38 |
+
return _import_spec.find_spec(name) is not None
|
| 39 |
+
|
| 40 |
+
def ensure_python_deps(logs: list[str]) -> None:
|
| 41 |
+
"""
|
| 42 |
+
Install lightweight runtime deps if missing. Safer than failing the whole run.
|
| 43 |
+
"""
|
| 44 |
+
need = []
|
| 45 |
+
# Minimal set seen in FlashVSR's example imports and typical backbones
|
| 46 |
+
if not _has_module("transformers"):
|
| 47 |
+
need += ["transformers>=4.44"]
|
| 48 |
+
if not _has_module("sentencepiece"):
|
| 49 |
+
need += ["sentencepiece>=0.1.99"]
|
| 50 |
+
if not _has_module("safetensors"):
|
| 51 |
+
need += ["safetensors>=0.4.3"]
|
| 52 |
+
if not _has_module("timm"):
|
| 53 |
+
need += ["timm>=0.9.16"]
|
| 54 |
+
if not _has_module("accelerate"):
|
| 55 |
+
need += ["accelerate>=0.33"]
|
| 56 |
+
|
| 57 |
+
if need:
|
| 58 |
+
logs.append(f"Installing missing deps: {need}")
|
| 59 |
+
code, out, err = run([sys.executable, "-m", "pip", "install", "--no-cache-dir", *need])
|
| 60 |
+
logs.append(out)
|
| 61 |
+
logs.append(err)
|
| 62 |
+
if code != 0:
|
| 63 |
+
raise RuntimeError(f"Pip failed while installing {need}:\n{err or out}")
|
| 64 |
+
|
| 65 |
+
# ----------------------------- repo / weights -----------------------------
|
| 66 |
|
| 67 |
def ensure_flashvsr_repo() -> Path:
|
| 68 |
"""Clone FlashVSR repo if missing."""
|
|
|
|
| 71 |
code, out, err = run(["git", "clone", "--depth=1", FLASH_GIT, str(repo_dir)])
|
| 72 |
if code != 0:
|
| 73 |
raise RuntimeError(f"Failed to clone FlashVSR.\n{err}")
|
|
|
|
| 74 |
run(["git", "submodule", "update", "--init", "--recursive"], cwd=repo_dir)
|
| 75 |
return repo_dir
|
| 76 |
|
|
|
|
| 77 |
def ensure_weights() -> Path:
|
| 78 |
"""Download weights snapshot locally once."""
|
| 79 |
target = CACHE_DIR / "weights"
|
|
|
|
| 87 |
)
|
| 88 |
return target
|
| 89 |
|
| 90 |
+
# ----------------------------- IO utils -----------------------------
|
| 91 |
|
| 92 |
def normalize_to_mp4(src_path: Path) -> Path:
|
| 93 |
"""
|
| 94 |
Convert ANY uploaded file to a very browser-friendly MP4:
|
| 95 |
- H.264 (yuv420p), Baseline profile
|
| 96 |
+
- Even dimensions + CFR 30 fps, +faststart
|
| 97 |
+
- No audio (avoids browser codec edge cases)
|
|
|
|
| 98 |
"""
|
| 99 |
out_path = src_path.with_name(src_path.stem + "_playable.mp4")
|
| 100 |
vf = "scale=trunc(iw/2)*2:trunc(ih/2)*2,fps=30"
|
|
|
|
| 114 |
raise RuntimeError(f"ffmpeg failed to normalize the video.\n\n{err or out}")
|
| 115 |
return out_path
|
| 116 |
|
| 117 |
+
# ----------------------------- ModelScope stub -----------------------------
|
| 118 |
|
| 119 |
def ensure_modelscope_stub() -> Path:
|
| 120 |
"""
|
|
|
|
| 133 |
)
|
| 134 |
return stub_root
|
| 135 |
|
| 136 |
+
# ----------------------------- pipeline -----------------------------
|
| 137 |
|
| 138 |
def run_flashvsr_on_video(
|
| 139 |
in_video: Path,
|
| 140 |
scale: int,
|
| 141 |
prefer_sparse: bool,
|
| 142 |
log_file: Path,
|
| 143 |
+
):
|
| 144 |
"""
|
| 145 |
Execute the FlashVSR example script and return (output_path | None, combined_logs).
|
|
|
|
| 146 |
"""
|
| 147 |
logs = []
|
| 148 |
+
# 1) Make sure python deps exist
|
| 149 |
+
ensure_python_deps(logs)
|
| 150 |
|
| 151 |
+
# 2) Repo + weights
|
| 152 |
repo_dir = ensure_flashvsr_repo()
|
| 153 |
weights_dir = ensure_weights()
|
| 154 |
+
stub_root = ensure_modelscope_stub()
|
| 155 |
|
| 156 |
logs.append(f"Python: {sys.version}")
|
| 157 |
logs.append(f"Repo: {repo_dir}")
|
|
|
|
| 171 |
out_dir.mkdir(exist_ok=True)
|
| 172 |
out_mp4 = out_dir / f"{in_video.stem}_x{scale}.mp4"
|
| 173 |
|
| 174 |
+
# Environment: make repo importable and put our stub first
|
|
|
|
| 175 |
env = os.environ.copy()
|
| 176 |
extra_paths = [str(stub_root), str(repo_dir), str(repo_dir / "diffsynth")]
|
| 177 |
existing = env.get("PYTHONPATH", "")
|
|
|
|
| 191 |
logs.append(err)
|
| 192 |
return code == 0 and out_mp4.exists()
|
| 193 |
|
| 194 |
+
# Try a few CLI shapes used across revisions
|
| 195 |
if _try(["--input", str(in_video), "--output", str(out_mp4), "--scale", str(scale), "--weights", str(weights_dir)]):
|
| 196 |
pass
|
| 197 |
elif _try(["--video", str(in_video), "--outdir", str(out_dir), "--scale", str(scale), "--weights", str(weights_dir)]):
|
|
|
|
| 206 |
combined = "\n".join(logs)
|
| 207 |
log_file.write_text(combined)
|
| 208 |
|
| 209 |
+
# Normalize result for browser playback
|
| 210 |
playable = normalize_to_mp4(out_mp4) if out_mp4.exists() else None
|
| 211 |
return playable, combined
|
| 212 |
|
|
|
|
| 213 |
def infer(ui_video: str, scale: int, prefer_sparse: bool):
|
| 214 |
+
"""Gradio handler: return (video_path | None, diagnostics text, logs file path)"""
|
|
|
|
|
|
|
| 215 |
logs_path = CACHE_DIR / "last_run_logs.txt"
|
| 216 |
|
| 217 |
if not ui_video:
|
|
|
|
| 226 |
return None, msg, str(logs_path)
|
| 227 |
|
| 228 |
try:
|
|
|
|
| 229 |
src_playable = normalize_to_mp4(src)
|
| 230 |
except Exception as e:
|
| 231 |
msg = f"Input normalization failed:\n{e}"
|
|
|
|
| 235 |
try:
|
| 236 |
out_path, combined = run_flashvsr_on_video(src_playable, scale, prefer_sparse, logs_path)
|
| 237 |
if out_path is None:
|
|
|
|
| 238 |
return str(src_playable), "FlashVSR failed. See logs below.", str(logs_path)
|
| 239 |
return str(out_path), "Done.", str(logs_path)
|
| 240 |
except Exception as e:
|
| 241 |
msg = f"Pipeline error:\n{e}"
|
| 242 |
logs_path.write_text(msg)
|
|
|
|
| 243 |
return str(src_playable), msg, str(logs_path)
|
| 244 |
|
| 245 |
# ------------------------------ UI -------------------------------
|
requirements.txt
CHANGED
|
@@ -6,7 +6,11 @@ opencv-python>=4.10
|
|
| 6 |
einops>=0.8.0
|
| 7 |
diffsynth>=1.1.8
|
| 8 |
modelscope>=1.15.0
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# PyTorch (choose the wheel that matches your CUDA setup)
|
| 12 |
# For CUDA 12.4 wheels (edit for your machine or Space):
|
|
|
|
| 6 |
einops>=0.8.0
|
| 7 |
diffsynth>=1.1.8
|
| 8 |
modelscope>=1.15.0
|
| 9 |
+
transformers>=4.44
|
| 10 |
+
sentencepiece>=0.1.99
|
| 11 |
+
safetensors>=0.4.3
|
| 12 |
+
timm>=0.9.16
|
| 13 |
+
accelerate>=0.33
|
| 14 |
|
| 15 |
# PyTorch (choose the wheel that matches your CUDA setup)
|
| 16 |
# For CUDA 12.4 wheels (edit for your machine or Space):
|