StitchTool / app.py
Shalmoni's picture
Update app.py
26b0caf verified
raw
history blame
9.62 kB
import os
import io
import time
import random
import base64
from urllib.parse import quote_plus
from typing import Optional, Tuple
import requests
import gradio as gr
# -----------------------------
# Config
# -----------------------------
# You can set your Modal endpoint via env var MM_I2V_URL
DEFAULT_API_URL = os.getenv(
"MM_I2V_URL",
"https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run",
)
SAVE_DIR = "outputs"
os.makedirs(SAVE_DIR, exist_ok=True)
# -----------------------------
# Helpers
# -----------------------------
def _save_bytes_to_mp4(buf: bytes, name_prefix: str) -> str:
ts = int(time.time() * 1000)
path = os.path.join(SAVE_DIR, f"{name_prefix}-{ts}.mp4")
with open(path, "wb") as f:
f.write(buf)
return path
def _download(url: str) -> bytes:
r = requests.get(url, timeout=600)
r.raise_for_status()
return r.content
def call_i2v(
image_path: str,
prompt: str,
seed: Optional[int],
api_url: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
\"\"\"
Call the image->video backend and return (video_path, error_message).
Tries to handle several common response types:
1) raw mp4 bytes
2) JSON with {\"video\": \"<base64>\"} (mp4 base64)
3) JSON with {\"video_url\": \"https://...\"} (or \"result_url\")
\"\"\"
api = (api_url or DEFAULT_API_URL).strip().rstrip(\"/\")
used_seed = seed if (seed is not None and str(seed).strip() != \"\") else random.randint(0, 2**31 - 1)
url = f\"{api}?prompt={quote_plus(prompt)}&seed={used_seed}\"
files = {
\"image_bytes\": (os.path.basename(image_path), open(image_path, \"rb\"), \"application/octet-stream\")
}
headers = {\"accept\": \"application/json\"}
try:
resp = requests.post(url, headers=headers, files=files, timeout=1200)
# Try to accommodate various backends
ctype = resp.headers.get(\"Content-Type\", \"\")
if \"application/json\" in ctype:
data = resp.json()
# base64 payload
if \"video\" in data and isinstance(data[\"video\"], str) and len(data[\"video\"]) > 50:
try:
raw = base64.b64decode(data[\"video\"], validate=True)
return _save_bytes_to_mp4(raw, \"clip\"), None
except Exception as e:
return None, f\"Could not decode base64 video: {e}\"
# url payload
for key in (\"video_url\", \"result_url\", \"url\"):
if key in data and isinstance(data[key], str) and data[key].startswith(\"http\"):
raw = _download(data[key])
return _save_bytes_to_mp4(raw, \"clip\"), None
return None, 'JSON response did not include \"video\" (base64) or a known url key.'
# Raw bytes (ideally mp4)
elif \"video\" in ctype or \"octet-stream\" in ctype:
return _save_bytes_to_mp4(resp.content, \"clip\"), None
else:
# Some backends still reply bytes with missing/odd content-type
if resp.content and len(resp.content) > 1024:
return _save_bytes_to_mp4(resp.content, \"clip\"), None
return None, f\"Unexpected content type: {ctype}\"
except requests.RequestException as e:
return None, f\"Request failed: {e}\"
def stitch_pair(
image_a: str,
image_b: str,
prompt: str,
seed: Optional[int],
api_url: Optional[str],
crossfade: float,
) -> Tuple[Optional[str], str]:
\"\"\"
Strategy:
- Generate a short clip from image A
- Generate a short clip from image B (same prompt/seed unless user changes)
- Concatenate with a short crossfade in Python (moviepy)
If you already have a backend endpoint that does stitching directly,
replace this function body with a single backend call.
\"\"\"
if not image_a or not image_b:
return None, \"Please upload both images.\"
# First generate both clips
clip1_path, err1 = call_i2v(image_a, prompt, seed, api_url)
if err1:
return None, f\"Clip 1 failed: {err1}\"
clip2_path, err2 = call_i2v(image_b, prompt, seed, api_url)
if err2:
return None, f\"Clip 2 failed: {err2}\"
# If crossfade is 0, just concatenate directly
try:
from moviepy.editor import VideoFileClip, concatenate_videoclips
except Exception as e:
return None, f\"MoviePy import failed. Add moviepy & imageio-ffmpeg to requirements.txt. Error: {e}\"
try:
c1 = VideoFileClip(clip1_path)
c2 = VideoFileClip(clip2_path)
# Enforce same size/fps (compose handles mismatches)
if crossfade and crossfade > 0:
# Apply crossfade (second clip fades in)
c2 = c2.crossfadein(crossfade)
c1 = c1.crossfadeout(crossfade)
merged = concatenate_videoclips([c1, c2], method=\"compose\", padding=-crossfade)
else:
merged = concatenate_videoclips([c1, c2], method=\"compose\")
out_path = os.path.join(SAVE_DIR, f\"stitched-{int(time.time()*1000)}.mp4\")
merged.write_videofile(out_path, codec=\"libx264\", audio_codec=\"aac\", verbose=False, logger=None)
c1.close(); c2.close(); merged.close()
return out_path, \"\"
except Exception as e:
return None, f\"Stitching failed: {e}\"
# -----------------------------
# UI
# -----------------------------
with gr.Blocks(title=\"Image Stitch to Video\", css=\"\"\"
/* Rounded tiles like the sketch */
.rounded { border-radius: 24px; }
.tile { background: #f7f7ff; padding: 12px; }
.tile-blue { background: #e8f0ff; }
.tile-yellow { background: #fff7d6; }
.small-btn button { padding: 6px 10px; border-radius: 999px; }
.label-center label { text-align:center; width: 100%; }
\"\"\") as demo:
gr.Markdown(\"### Image → Video (Stitch Adjacent Pairs)\\nUpload 3 images, enter prompts for each stitch, then click the stitch buttons.\")
with gr.Row(equal_height=True):
# Left column: images + add image
with gr.Column(scale=1):
gr.Markdown(\"**Images**\")
img1 = gr.Image(type=\"filepath\", label=\"Image 1\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
img2 = gr.Image(type=\"filepath\", label=\"Image 2\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
img3 = gr.Image(type=\"filepath\", label=\"Image 3\", height=220, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
# Optional extra slots (hidden until added)
extra_imgs = []
for i in range(4, 9):
comp = gr.Image(type=\"filepath\", label=f\"Image {i}\", height=220, visible=False, elem_classes=[\"rounded\", \"tile\", \"tile-blue\"])
extra_imgs.append(comp)
add_btn = gr.Button(\"Add Image\", variant=\"secondary\")
# Middle column: prompts + stitch buttons
with gr.Column(scale=1):
gr.Markdown(\"**Prompts**\")
prompt12 = gr.Textbox(label=\"Prompt for Stitch 1 & 2\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"])
seed12 = gr.Number(label=\"Seed (optional)\", value=None, precision=0)
stitch12 = gr.Button(\"Stitch 1 & 2\", elem_classes=[\"small-btn\"])
prompt23 = gr.Textbox(label=\"Prompt for Stitch 2 & 3\", lines=3, placeholder=\"Describe motion/style/etc.\", elem_classes=[\"rounded\", \"tile\"])
seed23 = gr.Number(label=\"Seed (optional)\", value=None, precision=0)
stitch23 = gr.Button(\"Stitch 2 & 3\", elem_classes=[\"small-btn\"])
with gr.Accordion(\"Advanced (API & Stitch)\", open=False):
api_url = gr.Textbox(label=\"Backend API URL\", value=DEFAULT_API_URL)
crossfade = gr.Slider(0.0, 1.5, value=0.4, step=0.1, label=\"Crossfade seconds\")
clear_btn = gr.Button(\"Clear All\")
# Right column: video outputs
with gr.Column(scale=1):
gr.Markdown(\"**Outputs**\")
vid12 = gr.Video(label=\"Video (image 1 + 2) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"])
vid23 = gr.Video(label=\"Video (image 2 + 3) output\", elem_classes=[\"rounded\", \"tile\", \"tile-yellow\"])
# Wire up actions
def _on_add(*imgs):
# Reveal the next hidden uploader
for comp in extra_imgs:
if comp.visible is False:
comp.visible = True
break
return [gr.update(visible=comp.visible) for comp in extra_imgs]
add_btn.click(
_on_add,
inputs=extra_imgs,
outputs=extra_imgs,
)
stitch12.click(
stitch_pair,
inputs=[img1, img2, prompt12, seed12, api_url, crossfade],
outputs=[vid12, gr.Textbox(visible=False)],
)
stitch23.click(
stitch_pair,
inputs=[img2, img3, prompt23, seed23, api_url, crossfade],
outputs=[vid23, gr.Textbox(visible=False)],
)
def _on_clear():
updates = []
for comp in [img1, img2, img3, *extra_imgs]:
updates.append(gr.update(value=None, visible=True if comp in [img1, img2, img3] else False))
return updates + [None, None, \"\", \"\", gr.update(value=DEFAULT_API_URL), 0.4]
clear_btn.click(
_on_clear,
inputs=None,
outputs=[img1, img2, img3, *extra_imgs, vid12, vid23, prompt12, prompt23, api_url, crossfade],
)
if __name__ == \"__main__\":
demo.launch()