rahul7star commited on
Commit
191fe1d
·
verified ·
1 Parent(s): 8d3076c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -0
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import sys
4
+ import time
5
+ import subprocess
6
+ import tempfile
7
+ import shutil
8
+ from pathlib import Path
9
+ from huggingface_hub import hf_hub_download
10
+ import gradio as gr
11
+
12
+ # ---------- Helper utilities ----------
13
+
14
+ def sh(cmd, check=True, env=None):
15
+ """Shell helper that streams output to stdout/stderr and returns (returncode, stdout)."""
16
+ print(f"RUN: {cmd}")
17
+ try:
18
+ completed = subprocess.run(cmd, shell=True, check=check, capture_output=True, text=True, env=env)
19
+ print(completed.stdout)
20
+ if completed.stderr:
21
+ print("ERR:", completed.stderr, file=sys.stderr)
22
+ return completed.returncode, completed.stdout
23
+ except subprocess.CalledProcessError as e:
24
+ print("Command failed:", e, file=sys.stderr)
25
+ print(e.stdout)
26
+ print(e.stderr, file=sys.stderr)
27
+ return e.returncode, e.stdout if hasattr(e, "stdout") else ""
28
+
29
+ # ---------- FlashAttention install (best-effort) ----------
30
+ def try_install_flash_attention():
31
+ """
32
+ Attempt to download and install the FlashAttention wheel from HF repo rahul7star/flash-attn-3
33
+ Path in repo: 128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl (as provided).
34
+ This is a best-effort install; failures are non-fatal.
35
+ """
36
+ flash_attention_installed = False
37
+ try:
38
+ print("Attempting to download and install FlashAttention wheel...")
39
+ wheel = hf_hub_download(
40
+ repo_id="rahul7star/flash-attn-3",
41
+ repo_type="model",
42
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
43
+ )
44
+ print("Downloaded wheel:", wheel)
45
+ rc, out = sh(f"pip install {wheel}")
46
+ # refresh site-packages so Python can see newly-installed extension
47
+ try:
48
+ import importlib, site
49
+ # add site-packages dir (first one) and invalidate caches
50
+ site.addsitedir(site.getsitepackages()[0])
51
+ importlib.invalidate_caches()
52
+ except Exception as e:
53
+ print("Could not update site-packages cache:", e)
54
+ flash_attention_installed = True
55
+ print("FlashAttention installed successfully.")
56
+ except Exception as e:
57
+ print(f"⚠️ Could not install FlashAttention: {e}")
58
+ print("Continuing without FlashAttention...")
59
+ return flash_attention_installed
60
+
61
+ # ---------- Model downloader ----------
62
+ def ensure_models_downloaded(marker_file=".models_ready"):
63
+ """
64
+ Run download_models.py if models haven't been downloaded yet.
65
+ This creates a small marker file after success to avoid repeated downloads.
66
+ """
67
+ marker = Path(marker_file)
68
+ if marker.exists():
69
+ print("Models already downloaded (marker found).")
70
+ return True
71
+
72
+ if not Path("download_models.py").exists():
73
+ print("Warning: download_models.py not found in repo. Please add it or run model download manually.")
74
+ return False
75
+
76
+ try:
77
+ print("Running download_models.py to fetch model artifacts...")
78
+ # Try to call the script directly. Use same python executable.
79
+ rc, out = sh(f"{sys.executable} download_models.py", check=True)
80
+ # If it completes without exception, create marker
81
+ marker.write_text("ok")
82
+ print("download_models.py finished. Marker created.")
83
+ return True
84
+ except Exception as e:
85
+ print("Failed to run download_models.py:", e)
86
+ return False
87
+
88
+ # ---------- Inference runner ----------
89
+ def run_inference(prompt: str, image_path: str | None, seed: int | None = None, duration: float | None = None, workdir: str | None = None):
90
+ """
91
+ Run test.py with prompt and optional image. Expect test.py to produce a video file (e.g. output.mp4)
92
+ Returns path to produced video or None on failure.
93
+ """
94
+ workdir = workdir or os.getcwd()
95
+ out_video = Path(workdir) / "output.mp4"
96
+
97
+ # remove old output if present
98
+ if out_video.exists():
99
+ try:
100
+ out_video.unlink()
101
+ except Exception:
102
+ pass
103
+
104
+ if not Path("test.py").exists():
105
+ raise FileNotFoundError("test.py not found in repo. Place the repo's test.py in the same folder as app.py.")
106
+
107
+ cmd = [sys.executable, "test.py", "--prompt", f"\"{prompt}\""]
108
+ if image_path:
109
+ cmd += ["--image_path", f"\"{image_path}\""]
110
+ if seed is not None:
111
+ cmd += ["--seed", str(seed)]
112
+ if duration is not None:
113
+ # If the test.py uses a --duration flag; adapt if your script uses different arg name.
114
+ cmd += ["--duration", str(duration)]
115
+
116
+ # Join to single command string to ensure shell wildcard expansion if needed
117
+ cmd_str = " ".join(cmd)
118
+ print("Inference command:", cmd_str)
119
+
120
+ try:
121
+ # We stream output and check for completion
122
+ proc = subprocess.run(cmd_str, shell=True, check=True, capture_output=True, text=True, env=os.environ)
123
+ print("Inference stdout:", proc.stdout)
124
+ if proc.stderr:
125
+ print("Inference stderr:", proc.stderr, file=sys.stderr)
126
+ except subprocess.CalledProcessError as e:
127
+ print("Inference failed:", e, file=sys.stderr)
128
+ print(e.stdout if hasattr(e, "stdout") else "")
129
+ print(e.stderr if hasattr(e, "stderr") else "", file=sys.stderr)
130
+ return None
131
+
132
+ # locate output video
133
+ if out_video.exists():
134
+ return str(out_video)
135
+ # fallback: find any recent mp4 in workdir
136
+ candidates = sorted(Path(workdir).glob("*.mp4"), key=lambda p: p.stat().st_mtime, reverse=True)
137
+ if candidates:
138
+ return str(candidates[0])
139
+ return None
140
+
141
+ # ---------- Gradio app callbacks ----------
142
+ @spaces.GPU(duration = 50)
143
+ def generate(prompt, image, seed, duration, install_flash, force_download_models):
144
+ """
145
+ Main callback for Gradio "Generate" button.
146
+ - install_flash: boolean, whether to attempt flash-attn install this run
147
+ - force_download_models: boolean to re-run download_models.py even if marker exists
148
+ Returns (video_file, status_text)
149
+ """
150
+ status_msgs = []
151
+ # Convert image (gradio gives a PIL Image or None) to a temp file if provided
152
+ temp_image_path = None
153
+ if image is not None:
154
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
155
+ try:
156
+ image.save(tmp, format="PNG")
157
+ tmp.flush()
158
+ temp_image_path = tmp.name
159
+ tmp.close()
160
+ status_msgs.append(f"Saved input image to {temp_image_path}")
161
+ except Exception as e:
162
+ status_msgs.append(f"Failed to save uploaded image: {e}")
163
+ temp_image_path = None
164
+
165
+ # Optionally install flash attention
166
+ if install_flash:
167
+ ok = try_install_flash_attention()
168
+ status_msgs.append(f"Attempted FlashAttention install: {'OK' if ok else 'FAILED'}")
169
+ else:
170
+ status_msgs.append("Skipped FlashAttention install (checkbox unchecked).")
171
+
172
+ # Ensure models downloaded
173
+ if force_download_models:
174
+ # remove marker if present so we re-download
175
+ marker = Path(".models_ready")
176
+ if marker.exists():
177
+ try:
178
+ marker.unlink()
179
+ status_msgs.append("Removed existing model marker to force re-download.")
180
+ except Exception as e:
181
+ status_msgs.append(f"Could not remove marker file: {e}")
182
+
183
+ ok_models = ensure_models_downloaded()
184
+ status_msgs.append(f"Models ready: {'yes' if ok_models else 'no'}")
185
+ if not ok_models:
186
+ status_msgs.append("Warning: models not ready. Inference will probably fail.")
187
+
188
+ # Run inference
189
+ status_msgs.append("Starting inference (this may take time on GPU).")
190
+ try:
191
+ video_path = run_inference(prompt=prompt, image_path=temp_image_path, seed=seed, duration=duration)
192
+ except Exception as e:
193
+ status_msgs.append(f"Inference runner raised an exception: {e}")
194
+ return None, "\n".join(status_msgs)
195
+
196
+ if video_path:
197
+ status_msgs.append(f"Video created: {video_path}")
198
+ # Move to /tmp or keep in repo for Gradio to serve
199
+ # We'll copy to a stable path that Gradio can serve e.g. ./outputs/output_{timestamp}.mp4
200
+ dest_dir = Path("outputs")
201
+ dest_dir.mkdir(exist_ok=True)
202
+ ts = int(time.time())
203
+ dest = dest_dir / f"t2v_output_{ts}.mp4"
204
+ try:
205
+ shutil.copy(video_path, dest)
206
+ status_msgs.append(f"Video copied to {dest}")
207
+ return str(dest), "\n".join(status_msgs)
208
+ except Exception as e:
209
+ status_msgs.append(f"Could not copy video to outputs/: {e}")
210
+ # still try to return original path
211
+ return str(video_path), "\n".join(status_msgs)
212
+ else:
213
+ status_msgs.append("No video produced by test.py (output not found). Check logs.")
214
+ return None, "\n".join(status_msgs)
215
+
216
+ # ---------- Build Gradio interface ----------
217
+ def build_ui():
218
+ with gr.Blocks(title="Text+Image → Video (Spaces GPU)", css="""
219
+ .output-video { max-width: 800px; }
220
+ """) as demo:
221
+ gr.Markdown("# Text + (Optional) Image → Video\nSimple UI to run Kandinsky/Wan T2V `test.py` in this Space (GPU required).")
222
+
223
+ with gr.Row():
224
+ with gr.Column(scale=3):
225
+ prompt = gr.Textbox(label="Prompt", placeholder="A dog in a red hat, cinematic, 5s", value="A dog in a red hat")
226
+ image_in = gr.Image(label="Optional reference image (still)", type="pil")
227
+ with gr.Row():
228
+ seed = gr.Number(value=42, label="Seed (optional)", precision=0)
229
+ duration = gr.Number(value=5.0, label="Duration (seconds, optional)", precision=2)
230
+ install_flash = gr.Checkbox(label="Attempt FlashAttention install before running (best-effort)", value=False)
231
+ force_download = gr.Checkbox(label="Force run download_models.py (re-download models)", value=False)
232
+ generate_btn = gr.Button("Generate Video", variant="primary")
233
+ status = gr.Textbox(label="Status / Logs", interactive=False, lines=10)
234
+ with gr.Column(scale=2):
235
+ out_video = gr.Video(label="Output video", elem_classes="output-video")
236
+ gr.Markdown("**Notes**:\n- Ensure `download_models.py` and `test.py` are present and compatible.\n- `test.py` should produce an mp4 named `output.mp4` in the repo root or an mp4 somewhere in the working dir.\n- Long-running jobs may hit Space runtime limits if very long.")
237
+
238
+ # wire up
239
+ generate_btn.click(fn=generate,
240
+ inputs=[prompt, image_in, seed, duration, install_flash, force_download],
241
+ outputs=[out_video, status])
242
+
243
+ return demo
244
+
245
+ # ---------- Main entrypoint ----------
246
+ if __name__ == "__main__":
247
+ # Quick environment checks
248
+ print("Starting T2V Gradio app. Python:", sys.executable)
249
+ print("CUDA available?", os.environ.get("CUDA_VISIBLE_DEVICES", "(not set)"))
250
+ # Attempt to install flash-attn automatically? We default to not attempting until user requests in UI.
251
+ # Pre-check models: create marker if download_models.py has already run previously
252
+ if not Path(".models_ready").exists() and Path("download_models.py").exists():
253
+ # we do NOT force downloading on startup automatically to avoid long startup delays on Spaces.
254
+ print("download_models.py exists. Models not yet marked as downloaded. Use the UI to run download (or set force flag).")
255
+
256
+ # Create outputs dir
257
+ Path("outputs").mkdir(exist_ok=True)
258
+
259
+ demo = build_ui()
260
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))