MStater commited on
Commit
0dbcc5b
·
verified ·
1 Parent(s): dea3006

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +48 -22
  2. requirements.txt +5 -1
app.py CHANGED
@@ -6,6 +6,7 @@ import shutil
6
  import subprocess
7
  from pathlib import Path
8
  from typing import Tuple, Optional
 
9
 
10
  import gradio as gr
11
  from huggingface_hub import snapshot_download
@@ -17,10 +18,10 @@ CACHE_DIR.mkdir(exist_ok=True)
17
  WEIGHTS_REPO = "JunhaoZhuang/FlashVSR"
18
  FLASH_GIT = "https://github.com/OpenImagingLab/FlashVSR.git"
19
 
20
- # ----------------------------- helpers -----------------------------
21
 
22
  def run(cmd: list[str], cwd: Optional[Path] = None, env: Optional[dict] = None) -> Tuple[int, str, str]:
23
- """Run a command and capture full stdout/stderr (so failures are visible in logs)."""
24
  proc = subprocess.run(
25
  cmd,
26
  cwd=str(cwd) if cwd else None,
@@ -31,6 +32,37 @@ def run(cmd: list[str], cwd: Optional[Path] = None, env: Optional[dict] = None)
31
  )
32
  return proc.returncode, proc.stdout, proc.stderr
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def ensure_flashvsr_repo() -> Path:
36
  """Clone FlashVSR repo if missing."""
@@ -39,11 +71,9 @@ def ensure_flashvsr_repo() -> Path:
39
  code, out, err = run(["git", "clone", "--depth=1", FLASH_GIT, str(repo_dir)])
40
  if code != 0:
41
  raise RuntimeError(f"Failed to clone FlashVSR.\n{err}")
42
- # best-effort submodules (harmless if none)
43
  run(["git", "submodule", "update", "--init", "--recursive"], cwd=repo_dir)
44
  return repo_dir
45
 
46
-
47
  def ensure_weights() -> Path:
48
  """Download weights snapshot locally once."""
49
  target = CACHE_DIR / "weights"
@@ -57,14 +87,14 @@ def ensure_weights() -> Path:
57
  )
58
  return target
59
 
 
60
 
61
  def normalize_to_mp4(src_path: Path) -> Path:
62
  """
63
  Convert ANY uploaded file to a very browser-friendly MP4:
64
  - H.264 (yuv420p), Baseline profile
65
- - Even dimensions + CFR 30 fps
66
- - +faststart
67
- - No-audio (avoids codec/container edge cases in browser preview)
68
  """
69
  out_path = src_path.with_name(src_path.stem + "_playable.mp4")
70
  vf = "scale=trunc(iw/2)*2:trunc(ih/2)*2,fps=30"
@@ -84,6 +114,7 @@ def normalize_to_mp4(src_path: Path) -> Path:
84
  raise RuntimeError(f"ffmpeg failed to normalize the video.\n\n{err or out}")
85
  return out_path
86
 
 
87
 
88
  def ensure_modelscope_stub() -> Path:
89
  """
@@ -102,23 +133,25 @@ def ensure_modelscope_stub() -> Path:
102
  )
103
  return stub_root
104
 
105
- # ------------------------- core pipeline ---------------------------
106
 
107
  def run_flashvsr_on_video(
108
  in_video: Path,
109
  scale: int,
110
  prefer_sparse: bool,
111
  log_file: Path,
112
- ) -> Tuple[Optional[Path], str]:
113
  """
114
  Execute the FlashVSR example script and return (output_path | None, combined_logs).
115
- Writes full logs to log_file either way.
116
  """
117
  logs = []
 
 
118
 
 
119
  repo_dir = ensure_flashvsr_repo()
120
  weights_dir = ensure_weights()
121
- stub_root = ensure_modelscope_stub() # <— NEW
122
 
123
  logs.append(f"Python: {sys.version}")
124
  logs.append(f"Repo: {repo_dir}")
@@ -138,8 +171,7 @@ def run_flashvsr_on_video(
138
  out_dir.mkdir(exist_ok=True)
139
  out_mp4 = out_dir / f"{in_video.stem}_x{scale}.mp4"
140
 
141
- # Build env so the repo's modules (e.g., diffsynth) import WITHOUT pip install -e .
142
- # And place our 'modelscope' stub at the front so its snapshot_download is used.
143
  env = os.environ.copy()
144
  extra_paths = [str(stub_root), str(repo_dir), str(repo_dir / "diffsynth")]
145
  existing = env.get("PYTHONPATH", "")
@@ -159,7 +191,7 @@ def run_flashvsr_on_video(
159
  logs.append(err)
160
  return code == 0 and out_mp4.exists()
161
 
162
- # Try several common arg shapes used across revisions.
163
  if _try(["--input", str(in_video), "--output", str(out_mp4), "--scale", str(scale), "--weights", str(weights_dir)]):
164
  pass
165
  elif _try(["--video", str(in_video), "--outdir", str(out_dir), "--scale", str(scale), "--weights", str(weights_dir)]):
@@ -174,15 +206,12 @@ def run_flashvsr_on_video(
174
  combined = "\n".join(logs)
175
  log_file.write_text(combined)
176
 
177
- # Normalize the produced result for browser playback just in case.
178
  playable = normalize_to_mp4(out_mp4) if out_mp4.exists() else None
179
  return playable, combined
180
 
181
-
182
  def infer(ui_video: str, scale: int, prefer_sparse: bool):
183
- """
184
- Gradio handler: return (video_path | None, diagnostics text, logs file path)
185
- """
186
  logs_path = CACHE_DIR / "last_run_logs.txt"
187
 
188
  if not ui_video:
@@ -197,7 +226,6 @@ def infer(ui_video: str, scale: int, prefer_sparse: bool):
197
  return None, msg, str(logs_path)
198
 
199
  try:
200
- # Ensure the *input* itself is previewable in the browser.
201
  src_playable = normalize_to_mp4(src)
202
  except Exception as e:
203
  msg = f"Input normalization failed:\n{e}"
@@ -207,13 +235,11 @@ def infer(ui_video: str, scale: int, prefer_sparse: bool):
207
  try:
208
  out_path, combined = run_flashvsr_on_video(src_playable, scale, prefer_sparse, logs_path)
209
  if out_path is None:
210
- # Show input so player still has something; include reason in diagnostics.
211
  return str(src_playable), "FlashVSR failed. See logs below.", str(logs_path)
212
  return str(out_path), "Done.", str(logs_path)
213
  except Exception as e:
214
  msg = f"Pipeline error:\n{e}"
215
  logs_path.write_text(msg)
216
- # Fall back to showing the normalized input so the UI still previews something.
217
  return str(src_playable), msg, str(logs_path)
218
 
219
  # ------------------------------ UI -------------------------------
 
6
  import subprocess
7
  from pathlib import Path
8
  from typing import Tuple, Optional
9
+ import importlib.util as _import_spec
10
 
11
  import gradio as gr
12
  from huggingface_hub import snapshot_download
 
18
  WEIGHTS_REPO = "JunhaoZhuang/FlashVSR"
19
  FLASH_GIT = "https://github.com/OpenImagingLab/FlashVSR.git"
20
 
21
+ # ----------------------------- shell helpers -----------------------------
22
 
23
  def run(cmd: list[str], cwd: Optional[Path] = None, env: Optional[dict] = None) -> Tuple[int, str, str]:
24
+ """Run a command and capture stdout/stderr."""
25
  proc = subprocess.run(
26
  cmd,
27
  cwd=str(cwd) if cwd else None,
 
32
  )
33
  return proc.returncode, proc.stdout, proc.stderr
34
 
35
+ # ----------------------------- deps helpers -----------------------------
36
+
37
+ def _has_module(name: str) -> bool:
38
+ return _import_spec.find_spec(name) is not None
39
+
40
+ def ensure_python_deps(logs: list[str]) -> None:
41
+ """
42
+ Install lightweight runtime deps if missing. Safer than failing the whole run.
43
+ """
44
+ need = []
45
+ # Minimal set seen in FlashVSR's example imports and typical backbones
46
+ if not _has_module("transformers"):
47
+ need += ["transformers>=4.44"]
48
+ if not _has_module("sentencepiece"):
49
+ need += ["sentencepiece>=0.1.99"]
50
+ if not _has_module("safetensors"):
51
+ need += ["safetensors>=0.4.3"]
52
+ if not _has_module("timm"):
53
+ need += ["timm>=0.9.16"]
54
+ if not _has_module("accelerate"):
55
+ need += ["accelerate>=0.33"]
56
+
57
+ if need:
58
+ logs.append(f"Installing missing deps: {need}")
59
+ code, out, err = run([sys.executable, "-m", "pip", "install", "--no-cache-dir", *need])
60
+ logs.append(out)
61
+ logs.append(err)
62
+ if code != 0:
63
+ raise RuntimeError(f"Pip failed while installing {need}:\n{err or out}")
64
+
65
+ # ----------------------------- repo / weights -----------------------------
66
 
67
  def ensure_flashvsr_repo() -> Path:
68
  """Clone FlashVSR repo if missing."""
 
71
  code, out, err = run(["git", "clone", "--depth=1", FLASH_GIT, str(repo_dir)])
72
  if code != 0:
73
  raise RuntimeError(f"Failed to clone FlashVSR.\n{err}")
 
74
  run(["git", "submodule", "update", "--init", "--recursive"], cwd=repo_dir)
75
  return repo_dir
76
 
 
77
  def ensure_weights() -> Path:
78
  """Download weights snapshot locally once."""
79
  target = CACHE_DIR / "weights"
 
87
  )
88
  return target
89
 
90
+ # ----------------------------- IO utils -----------------------------
91
 
92
  def normalize_to_mp4(src_path: Path) -> Path:
93
  """
94
  Convert ANY uploaded file to a very browser-friendly MP4:
95
  - H.264 (yuv420p), Baseline profile
96
+ - Even dimensions + CFR 30 fps, +faststart
97
+ - No audio (avoids browser codec edge cases)
 
98
  """
99
  out_path = src_path.with_name(src_path.stem + "_playable.mp4")
100
  vf = "scale=trunc(iw/2)*2:trunc(ih/2)*2,fps=30"
 
114
  raise RuntimeError(f"ffmpeg failed to normalize the video.\n\n{err or out}")
115
  return out_path
116
 
117
+ # ----------------------------- ModelScope stub -----------------------------
118
 
119
  def ensure_modelscope_stub() -> Path:
120
  """
 
133
  )
134
  return stub_root
135
 
136
+ # ----------------------------- pipeline -----------------------------
137
 
138
  def run_flashvsr_on_video(
139
  in_video: Path,
140
  scale: int,
141
  prefer_sparse: bool,
142
  log_file: Path,
143
+ ):
144
  """
145
  Execute the FlashVSR example script and return (output_path | None, combined_logs).
 
146
  """
147
  logs = []
148
+ # 1) Make sure python deps exist
149
+ ensure_python_deps(logs)
150
 
151
+ # 2) Repo + weights
152
  repo_dir = ensure_flashvsr_repo()
153
  weights_dir = ensure_weights()
154
+ stub_root = ensure_modelscope_stub()
155
 
156
  logs.append(f"Python: {sys.version}")
157
  logs.append(f"Repo: {repo_dir}")
 
171
  out_dir.mkdir(exist_ok=True)
172
  out_mp4 = out_dir / f"{in_video.stem}_x{scale}.mp4"
173
 
174
+ # Environment: make repo importable and put our stub first
 
175
  env = os.environ.copy()
176
  extra_paths = [str(stub_root), str(repo_dir), str(repo_dir / "diffsynth")]
177
  existing = env.get("PYTHONPATH", "")
 
191
  logs.append(err)
192
  return code == 0 and out_mp4.exists()
193
 
194
+ # Try a few CLI shapes used across revisions
195
  if _try(["--input", str(in_video), "--output", str(out_mp4), "--scale", str(scale), "--weights", str(weights_dir)]):
196
  pass
197
  elif _try(["--video", str(in_video), "--outdir", str(out_dir), "--scale", str(scale), "--weights", str(weights_dir)]):
 
206
  combined = "\n".join(logs)
207
  log_file.write_text(combined)
208
 
209
+ # Normalize result for browser playback
210
  playable = normalize_to_mp4(out_mp4) if out_mp4.exists() else None
211
  return playable, combined
212
 
 
213
  def infer(ui_video: str, scale: int, prefer_sparse: bool):
214
+ """Gradio handler: return (video_path | None, diagnostics text, logs file path)"""
 
 
215
  logs_path = CACHE_DIR / "last_run_logs.txt"
216
 
217
  if not ui_video:
 
226
  return None, msg, str(logs_path)
227
 
228
  try:
 
229
  src_playable = normalize_to_mp4(src)
230
  except Exception as e:
231
  msg = f"Input normalization failed:\n{e}"
 
235
  try:
236
  out_path, combined = run_flashvsr_on_video(src_playable, scale, prefer_sparse, logs_path)
237
  if out_path is None:
 
238
  return str(src_playable), "FlashVSR failed. See logs below.", str(logs_path)
239
  return str(out_path), "Done.", str(logs_path)
240
  except Exception as e:
241
  msg = f"Pipeline error:\n{e}"
242
  logs_path.write_text(msg)
 
243
  return str(src_playable), msg, str(logs_path)
244
 
245
  # ------------------------------ UI -------------------------------
requirements.txt CHANGED
@@ -6,7 +6,11 @@ opencv-python>=4.10
6
  einops>=0.8.0
7
  diffsynth>=1.1.8
8
  modelscope>=1.15.0
9
-
 
 
 
 
10
 
11
  # PyTorch (choose the wheel that matches your CUDA setup)
12
  # For CUDA 12.4 wheels (edit for your machine or Space):
 
6
  einops>=0.8.0
7
  diffsynth>=1.1.8
8
  modelscope>=1.15.0
9
+ transformers>=4.44
10
+ sentencepiece>=0.1.99
11
+ safetensors>=0.4.3
12
+ timm>=0.9.16
13
+ accelerate>=0.33
14
 
15
  # PyTorch (choose the wheel that matches your CUDA setup)
16
  # For CUDA 12.4 wheels (edit for your machine or Space):