Shalmoni commited on
Commit
13a051d
·
verified ·
1 Parent(s): 938e7a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -10
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os, io, time, base64, random, subprocess
2
  from typing import Optional, List
3
- from urllib.parse import quote
4
 
5
  import requests
6
  from PIL import Image
@@ -30,17 +30,42 @@ def _download_to_bytes(url: str) -> bytes:
30
  r.raise_for_status()
31
  return r.content
32
 
33
- def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
 
 
 
 
 
 
 
 
 
34
  """
35
- Calls your Modal backend with two images + prompt + seed and returns a local /tmp video path.
 
36
  """
37
  if start_img is None or end_img is None:
38
  return None
39
 
 
40
  if seed in (None, 0, -1):
41
  seed = random.randint(1, 2**31 - 1)
42
 
43
- url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  files = {
45
  "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
46
  "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
@@ -135,8 +160,10 @@ def collect_choices(*imgs):
135
  choices.append(str(i))
136
  return gr.update(choices=choices), gr.update(choices=choices)
137
 
138
- def stitch_selected(prompt, seed, start_idx_str, end_idx_str, *imgs):
139
- """Run inference for selected start/end indices (1-based strings)."""
 
 
140
  if not start_idx_str or not end_idx_str:
141
  gr.Warning("Please select Start and End frames.")
142
  return None
@@ -157,7 +184,19 @@ def stitch_selected(prompt, seed, start_idx_str, end_idx_str, *imgs):
157
  gr.Warning("Selected slots are empty.")
158
  return None
159
 
160
- vid = stitch_call(start_img, end_img, prompt or "", int(seed or 0))
 
 
 
 
 
 
 
 
 
 
 
 
161
  if not vid:
162
  gr.Warning("Generation failed.")
163
  return None
@@ -246,7 +285,7 @@ with gr.Blocks(css=CSS, title="StitchTool") as demo:
246
  outputs=img_comps
247
  )
248
 
249
- # Seed + Start/End selection + Prompt + Stitch + Preview
250
  seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
251
 
252
  with gr.Row():
@@ -254,12 +293,35 @@ with gr.Blocks(css=CSS, title="StitchTool") as demo:
254
  with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]):
255
  start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True)
256
  end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True)
 
257
  prompt = gr.Textbox(
258
  placeholder="Describe the transition between the selected start and end frames…",
259
  lines=3,
260
  label="Prompt",
261
  elem_classes=["rounded"]
262
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  run_btn = gr.Button("Generate", elem_classes=["pill"])
264
  add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"])
265
 
@@ -278,7 +340,7 @@ with gr.Blocks(css=CSS, title="StitchTool") as demo:
278
  # stitch action → preview
279
  run_btn.click(
280
  fn=stitch_selected,
281
- inputs=[prompt, seed, start_dd, end_dd] + img_comps,
282
  outputs=[preview]
283
  )
284
 
@@ -306,4 +368,4 @@ with gr.Blocks(css=CSS, title="StitchTool") as demo:
306
  )
307
 
308
  if __name__ == "__main__":
309
- demo.queue().launch()
 
1
  import os, io, time, base64, random, subprocess
2
  from typing import Optional, List
3
+ from urllib.parse import urlencode
4
 
5
  import requests
6
  from PIL import Image
 
30
  r.raise_for_status()
31
  return r.content
32
 
33
+ def stitch_call(
34
+ start_img: Image.Image,
35
+ end_img: Image.Image,
36
+ prompt: str,
37
+ seed: Optional[int],
38
+ negative_prompt: Optional[str] = None,
39
+ frames_per_second: int = 24,
40
+ video_length: int = 4,
41
+ num_inference_steps: Optional[int] = None,
42
+ ) -> Optional[str]:
43
  """
44
+ Required (in body): image_bytes (+ image_bytes_end)
45
+ In URL query: prompt, negative_prompt, frames_per_second, video_length, seed, num_inference_steps
46
  """
47
  if start_img is None or end_img is None:
48
  return None
49
 
50
+ # default seed behavior
51
  if seed in (None, 0, -1):
52
  seed = random.randint(1, 2**31 - 1)
53
 
54
+ # Build query string
55
+ q = {
56
+ "prompt": prompt or "",
57
+ "seed": int(seed),
58
+ "frames_per_second": int(frames_per_second),
59
+ "video_length": int(video_length),
60
+ }
61
+ if negative_prompt:
62
+ q["negative_prompt"] = negative_prompt
63
+ if num_inference_steps is not None:
64
+ q["num_inference_steps"] = int(num_inference_steps)
65
+
66
+ url = f"{INFERENCE_URL}?{urlencode(q)}"
67
+
68
+ # Images go in the body
69
  files = {
70
  "image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
71
  "image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
 
160
  choices.append(str(i))
161
  return gr.update(choices=choices), gr.update(choices=choices)
162
 
163
+ def stitch_selected(
164
+ prompt, negative_prompt, fps, length_sec, seed, start_idx_str, end_idx_str, *imgs
165
+ ):
166
+ """Run inference for selected start/end indices (1-based strings) + options."""
167
  if not start_idx_str or not end_idx_str:
168
  gr.Warning("Please select Start and End frames.")
169
  return None
 
184
  gr.Warning("Selected slots are empty.")
185
  return None
186
 
187
+ fps_val = int(str(fps)) if fps else 24
188
+ len_val = int(str(length_sec)) if length_sec else 4
189
+
190
+ vid = stitch_call(
191
+ start_img=start_img,
192
+ end_img=end_img,
193
+ prompt=prompt or "",
194
+ seed=int(seed or 0),
195
+ negative_prompt=(negative_prompt or "").strip() or None,
196
+ frames_per_second=fps_val,
197
+ video_length=len_val,
198
+ num_inference_steps=None,
199
+ )
200
  if not vid:
201
  gr.Warning("Generation failed.")
202
  return None
 
285
  outputs=img_comps
286
  )
287
 
288
+ # Seed + Start/End selection + Prompt + options + Stitch + Preview
289
  seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
290
 
291
  with gr.Row():
 
293
  with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]):
294
  start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True)
295
  end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True)
296
+
297
  prompt = gr.Textbox(
298
  placeholder="Describe the transition between the selected start and end frames…",
299
  lines=3,
300
  label="Prompt",
301
  elem_classes=["rounded"]
302
  )
303
+
304
+ negative = gr.Textbox(
305
+ placeholder="Optional: things to avoid (e.g., 'no cuts, no angle switch, no text overlays')",
306
+ lines=2,
307
+ label="Negative prompt",
308
+ elem_classes=["rounded"]
309
+ )
310
+
311
+ with gr.Row():
312
+ fps = gr.Dropdown(
313
+ label="Frame rate",
314
+ choices=["16", "24", "32"],
315
+ value="24",
316
+ interactive=True,
317
+ )
318
+ length_sec = gr.Dropdown(
319
+ label="Video length (sec)",
320
+ choices=["2", "4"],
321
+ value="4",
322
+ interactive=True,
323
+ )
324
+
325
  run_btn = gr.Button("Generate", elem_classes=["pill"])
326
  add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"])
327
 
 
340
  # stitch action → preview
341
  run_btn.click(
342
  fn=stitch_selected,
343
+ inputs=[prompt, negative, fps, length_sec, seed, start_dd, end_dd] + img_comps,
344
  outputs=[preview]
345
  )
346
 
 
368
  )
369
 
370
  if __name__ == "__main__":
371
+ demo.queue().launch()