Upload app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ GITHUB_REPO = "https://github.com/OpenImagingLab/FlashVSR"
|
|
| 10 |
ROOT = Path(__file__).parent.resolve()
|
| 11 |
CODE_DIR = ROOT / "FlashVSR"
|
| 12 |
EXAMPLES_DIR = CODE_DIR / "examples" / "WanVSR"
|
| 13 |
-
WEIGHTS_DIR = EXAMPLES_DIR / "FlashVSR" #
|
| 14 |
OUT_DIR = ROOT / "outputs"
|
| 15 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
| 16 |
|
|
@@ -18,7 +18,7 @@ OUT_DIR.mkdir(exist_ok=True, parents=True)
|
|
| 18 |
|
| 19 |
def _run(cmd, cwd=None, env=None, allow_fail=False):
|
| 20 |
"""
|
| 21 |
-
Run a command; if it fails, include stdout/stderr in the error so
|
| 22 |
the real cause in Space logs and the Gradio toast.
|
| 23 |
"""
|
| 24 |
print(">>", " ".join(map(str, cmd)))
|
|
@@ -44,6 +44,9 @@ def _ffmpeg():
|
|
| 44 |
return get_ffmpeg_exe()
|
| 45 |
|
| 46 |
def _coerce_to_path(video_input) -> Path:
|
|
|
|
|
|
|
|
|
|
| 47 |
if not video_input:
|
| 48 |
raise gr.Error("Upload a video first.")
|
| 49 |
if isinstance(video_input, str) and Path(video_input).exists():
|
|
@@ -58,11 +61,24 @@ def _coerce_to_path(video_input) -> Path:
|
|
| 58 |
return Path(name)
|
| 59 |
raise gr.Error("Could not read uploaded video path. Please re-upload as a file (mp4/avi/mkv).")
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# -------------------- io normalization --------------------
|
| 62 |
|
| 63 |
def normalize_to_mp4(src_path: Path) -> Path:
|
| 64 |
"""
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
"""
|
| 67 |
dst_dir = Path(tempfile.mkdtemp(prefix="norm_"))
|
| 68 |
dst = dst_dir / "input_norm.mp4"
|
|
@@ -71,7 +87,7 @@ def normalize_to_mp4(src_path: Path) -> Path:
|
|
| 71 |
_run([
|
| 72 |
_ffmpeg(), "-y",
|
| 73 |
"-i", str(src_path),
|
| 74 |
-
"-map", "0:v:0", "-map", "0:a:0?",
|
| 75 |
"-vf", vf,
|
| 76 |
"-c:v", "libx264", "-pix_fmt", "yuv420p",
|
| 77 |
"-profile:v", "high", "-level", "4.1",
|
|
@@ -102,7 +118,9 @@ def normalize_to_mp4(src_path: Path) -> Path:
|
|
| 102 |
|
| 103 |
def make_web_playable(src: Path, dst: Path):
|
| 104 |
"""
|
| 105 |
-
Transcode the model output to a browser-friendly MP4
|
|
|
|
|
|
|
| 106 |
"""
|
| 107 |
vf = "scale=trunc(iw/2)*2:trunc(ih/2)*2,fps=30"
|
| 108 |
try:
|
|
@@ -138,6 +156,9 @@ def make_web_playable(src: Path, dst: Path):
|
|
| 138 |
# ------------------------ setup ------------------------
|
| 139 |
|
| 140 |
def _pip_install_requirements():
|
|
|
|
|
|
|
|
|
|
| 141 |
any_found = False
|
| 142 |
for req in [
|
| 143 |
CODE_DIR / "requirements.txt",
|
|
@@ -158,19 +179,17 @@ def _git_submodules():
|
|
| 158 |
def _verify_weights(tiny: bool | None = None):
|
| 159 |
"""
|
| 160 |
Ensure at least some weights exist under WEIGHTS_DIR.
|
| 161 |
-
If tiny/full is provided, try to find a matching file name.
|
| 162 |
"""
|
| 163 |
if not WEIGHTS_DIR.exists():
|
| 164 |
raise RuntimeError(f"Weights folder missing: {WEIGHTS_DIR}")
|
| 165 |
files = list(WEIGHTS_DIR.rglob("*"))
|
| 166 |
if not files:
|
| 167 |
raise RuntimeError(
|
| 168 |
-
f"No files found in {WEIGHTS_DIR}
|
| 169 |
-
f"Expected weights snapshot from {HF_MODEL_ID} to be here."
|
| 170 |
)
|
| 171 |
if tiny is None:
|
| 172 |
return
|
| 173 |
-
# Heuristic check: look for 'tiny'/'full' keywords or common extensions
|
| 174 |
names = [p.name.lower() for p in files]
|
| 175 |
if tiny and not any(("tiny" in n or n.endswith((".pth", ".pt", ".safetensors"))) for n in names):
|
| 176 |
print("[warn] Could not confirm Tiny weights by name; proceeding anyway.")
|
|
@@ -224,9 +243,10 @@ def prepare_once(build_sparse=False, progress: gr.Progress | None = None):
|
|
| 224 |
def _try_runner_variants(runner: Path, inp: Path, raw_out: Path, scale: int) -> None:
|
| 225 |
"""
|
| 226 |
Try several CLI shapes in case upstream changed arg names.
|
|
|
|
| 227 |
"""
|
| 228 |
variants = [
|
| 229 |
-
# 1) Basic (
|
| 230 |
[sys.executable, str(runner), "--input", str(inp), "--output", str(raw_out), "--scale", str(scale)],
|
| 231 |
# 2) With explicit weights dir
|
| 232 |
[sys.executable, str(runner), "--input", str(inp), "--output", str(raw_out), "--scale", str(scale),
|
|
@@ -243,19 +263,18 @@ def _try_runner_variants(runner: Path, inp: Path, raw_out: Path, scale: int) ->
|
|
| 243 |
for i, cmd in enumerate(variants, 1):
|
| 244 |
try:
|
| 245 |
print(f"[runner] Trying variant {i}/{len(variants)}: {' '.join(cmd)}")
|
| 246 |
-
_run(cmd, cwd=str(EXAMPLES_DIR))
|
| 247 |
return
|
| 248 |
except Exception as e:
|
| 249 |
last_err = e
|
| 250 |
print(f"[runner] Variant {i} failed: {e}")
|
| 251 |
-
# If all variants failed, re-raise the last error so it shows up in the UI.
|
| 252 |
raise last_err or RuntimeError("All runner variants failed.")
|
| 253 |
|
| 254 |
def flashvsr_infer(video_input, scale: int = 4, tiny: bool = True, use_sparse: bool = False,
|
| 255 |
progress: gr.Progress | None = None) -> str:
|
| 256 |
prepare_once(build_sparse=use_sparse, progress=progress)
|
| 257 |
|
| 258 |
-
#
|
| 259 |
runner = EXAMPLES_DIR / ("infer_flashvsr_tiny.py" if tiny else "infer_flashvsr_full.py")
|
| 260 |
if not runner.exists():
|
| 261 |
raise FileNotFoundError(
|
|
@@ -263,7 +282,6 @@ def flashvsr_infer(video_input, scale: int = 4, tiny: bool = True, use_sparse: b
|
|
| 263 |
f"Check repo structure; if they renamed scripts, update this app."
|
| 264 |
)
|
| 265 |
|
| 266 |
-
# Re-check weights (with size hint)
|
| 267 |
_verify_weights(tiny=tiny)
|
| 268 |
|
| 269 |
# 1) Normalize input
|
|
@@ -321,6 +339,7 @@ with gr.Blocks(title="FlashVSR (Unofficial)", fill_height=True) as demo:
|
|
| 321 |
"- Optional: Enable **Block-Sparse-Attention** (GPU) for potential speedups on some GPUs."
|
| 322 |
)
|
| 323 |
with gr.Row():
|
|
|
|
| 324 |
inp = gr.Video(label="Input video (MP4/AVI/MKV, etc.)", sources=["upload"])
|
| 325 |
with gr.Row():
|
| 326 |
tiny = gr.Checkbox(label="Use Tiny model (faster)", value=True)
|
|
@@ -330,6 +349,7 @@ with gr.Blocks(title="FlashVSR (Unofficial)", fill_height=True) as demo:
|
|
| 330 |
out = gr.Video(label="Upscaled output", format="mp4")
|
| 331 |
btn.click(ui_run, inputs=[inp, tiny, scale, sparse], outputs=out)
|
| 332 |
|
|
|
|
| 333 |
demo.queue(max_size=8)
|
| 334 |
|
| 335 |
if __name__ == "__main__":
|
|
|
|
| 10 |
ROOT = Path(__file__).parent.resolve()
|
| 11 |
CODE_DIR = ROOT / "FlashVSR"
|
| 12 |
EXAMPLES_DIR = CODE_DIR / "examples" / "WanVSR"
|
| 13 |
+
WEIGHTS_DIR = EXAMPLES_DIR / "FlashVSR" # where example scripts usually look
|
| 14 |
OUT_DIR = ROOT / "outputs"
|
| 15 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
| 16 |
|
|
|
|
| 18 |
|
| 19 |
def _run(cmd, cwd=None, env=None, allow_fail=False):
|
| 20 |
"""
|
| 21 |
+
Run a command; if it fails, include stdout/stderr in the error so we see
|
| 22 |
the real cause in Space logs and the Gradio toast.
|
| 23 |
"""
|
| 24 |
print(">>", " ".join(map(str, cmd)))
|
|
|
|
| 44 |
return get_ffmpeg_exe()
|
| 45 |
|
| 46 |
def _coerce_to_path(video_input) -> Path:
|
| 47 |
+
"""
|
| 48 |
+
Gradio 5 may pass: str path, dict with 'path'/'name', or obj with .name
|
| 49 |
+
"""
|
| 50 |
if not video_input:
|
| 51 |
raise gr.Error("Upload a video first.")
|
| 52 |
if isinstance(video_input, str) and Path(video_input).exists():
|
|
|
|
| 61 |
return Path(name)
|
| 62 |
raise gr.Error("Could not read uploaded video path. Please re-upload as a file (mp4/avi/mkv).")
|
| 63 |
|
| 64 |
+
def _python_env():
|
| 65 |
+
"""
|
| 66 |
+
Build an env where the repo root is importable without 'pip install -e .'
|
| 67 |
+
"""
|
| 68 |
+
env = os.environ.copy()
|
| 69 |
+
extra = [str(CODE_DIR), str(CODE_DIR / "diffsynth")]
|
| 70 |
+
existing = env.get("PYTHONPATH", "")
|
| 71 |
+
env["PYTHONPATH"] = os.pathsep.join([p for p in (os.pathsep.join(extra), existing) if p])
|
| 72 |
+
env.setdefault("HF_HOME", str(ROOT / ".hf_home"))
|
| 73 |
+
return env
|
| 74 |
+
|
| 75 |
# -------------------- io normalization --------------------
|
| 76 |
|
| 77 |
def normalize_to_mp4(src_path: Path) -> Path:
|
| 78 |
"""
|
| 79 |
+
Convert ANY uploaded file to a safe MP4 for the FlashVSR runner:
|
| 80 |
+
- H.264 (yuv420p), High@4.1, even dimensions, CFR 30fps, +faststart
|
| 81 |
+
- Keep audio if present (AAC), otherwise no-audio.
|
| 82 |
"""
|
| 83 |
dst_dir = Path(tempfile.mkdtemp(prefix="norm_"))
|
| 84 |
dst = dst_dir / "input_norm.mp4"
|
|
|
|
| 87 |
_run([
|
| 88 |
_ffmpeg(), "-y",
|
| 89 |
"-i", str(src_path),
|
| 90 |
+
"-map", "0:v:0", "-map", "0:a:0?", # audio only if exists
|
| 91 |
"-vf", vf,
|
| 92 |
"-c:v", "libx264", "-pix_fmt", "yuv420p",
|
| 93 |
"-profile:v", "high", "-level", "4.1",
|
|
|
|
| 118 |
|
| 119 |
def make_web_playable(src: Path, dst: Path):
|
| 120 |
"""
|
| 121 |
+
Transcode the model output to a browser-friendly MP4:
|
| 122 |
+
- H.264 (yuv420p), High@4.1, even dimensions, CFR 30fps, +faststart
|
| 123 |
+
- Keep/convert audio to AAC if present; otherwise no audio.
|
| 124 |
"""
|
| 125 |
vf = "scale=trunc(iw/2)*2:trunc(ih/2)*2,fps=30"
|
| 126 |
try:
|
|
|
|
| 156 |
# ------------------------ setup ------------------------
|
| 157 |
|
| 158 |
def _pip_install_requirements():
|
| 159 |
+
"""
|
| 160 |
+
Install any requirements files present in the repo tree, without editable install.
|
| 161 |
+
"""
|
| 162 |
any_found = False
|
| 163 |
for req in [
|
| 164 |
CODE_DIR / "requirements.txt",
|
|
|
|
| 179 |
def _verify_weights(tiny: bool | None = None):
|
| 180 |
"""
|
| 181 |
Ensure at least some weights exist under WEIGHTS_DIR.
|
| 182 |
+
If tiny/full is provided, try to find a matching file name (best-effort).
|
| 183 |
"""
|
| 184 |
if not WEIGHTS_DIR.exists():
|
| 185 |
raise RuntimeError(f"Weights folder missing: {WEIGHTS_DIR}")
|
| 186 |
files = list(WEIGHTS_DIR.rglob("*"))
|
| 187 |
if not files:
|
| 188 |
raise RuntimeError(
|
| 189 |
+
f"No files found in {WEIGHTS_DIR}. Expected weights snapshot from {HF_MODEL_ID}."
|
|
|
|
| 190 |
)
|
| 191 |
if tiny is None:
|
| 192 |
return
|
|
|
|
| 193 |
names = [p.name.lower() for p in files]
|
| 194 |
if tiny and not any(("tiny" in n or n.endswith((".pth", ".pt", ".safetensors"))) for n in names):
|
| 195 |
print("[warn] Could not confirm Tiny weights by name; proceeding anyway.")
|
|
|
|
| 243 |
def _try_runner_variants(runner: Path, inp: Path, raw_out: Path, scale: int) -> None:
|
| 244 |
"""
|
| 245 |
Try several CLI shapes in case upstream changed arg names.
|
| 246 |
+
Always pass PYTHONPATH so 'import' works without pip install -e .
|
| 247 |
"""
|
| 248 |
variants = [
|
| 249 |
+
# 1) Basic (expected)
|
| 250 |
[sys.executable, str(runner), "--input", str(inp), "--output", str(raw_out), "--scale", str(scale)],
|
| 251 |
# 2) With explicit weights dir
|
| 252 |
[sys.executable, str(runner), "--input", str(inp), "--output", str(raw_out), "--scale", str(scale),
|
|
|
|
| 263 |
for i, cmd in enumerate(variants, 1):
|
| 264 |
try:
|
| 265 |
print(f"[runner] Trying variant {i}/{len(variants)}: {' '.join(cmd)}")
|
| 266 |
+
_run(cmd, cwd=str(EXAMPLES_DIR), env=_python_env())
|
| 267 |
return
|
| 268 |
except Exception as e:
|
| 269 |
last_err = e
|
| 270 |
print(f"[runner] Variant {i} failed: {e}")
|
|
|
|
| 271 |
raise last_err or RuntimeError("All runner variants failed.")
|
| 272 |
|
| 273 |
def flashvsr_infer(video_input, scale: int = 4, tiny: bool = True, use_sparse: bool = False,
|
| 274 |
progress: gr.Progress | None = None) -> str:
|
| 275 |
prepare_once(build_sparse=use_sparse, progress=progress)
|
| 276 |
|
| 277 |
+
# Choose runner
|
| 278 |
runner = EXAMPLES_DIR / ("infer_flashvsr_tiny.py" if tiny else "infer_flashvsr_full.py")
|
| 279 |
if not runner.exists():
|
| 280 |
raise FileNotFoundError(
|
|
|
|
| 282 |
f"Check repo structure; if they renamed scripts, update this app."
|
| 283 |
)
|
| 284 |
|
|
|
|
| 285 |
_verify_weights(tiny=tiny)
|
| 286 |
|
| 287 |
# 1) Normalize input
|
|
|
|
| 339 |
"- Optional: Enable **Block-Sparse-Attention** (GPU) for potential speedups on some GPUs."
|
| 340 |
)
|
| 341 |
with gr.Row():
|
| 342 |
+
# No `type` kwarg in Gradio 5; backend coerces path
|
| 343 |
inp = gr.Video(label="Input video (MP4/AVI/MKV, etc.)", sources=["upload"])
|
| 344 |
with gr.Row():
|
| 345 |
tiny = gr.Checkbox(label="Use Tiny model (faster)", value=True)
|
|
|
|
| 349 |
out = gr.Video(label="Upscaled output", format="mp4")
|
| 350 |
btn.click(ui_run, inputs=[inp, tiny, scale, sparse], outputs=out)
|
| 351 |
|
| 352 |
+
# Queue (defaults only; Gradio 5 removed concurrency_count arg)
|
| 353 |
demo.queue(max_size=8)
|
| 354 |
|
| 355 |
if __name__ == "__main__":
|