Ashton99's picture
Create app.py
3ac36b2 verified
import gradio as gr, subprocess, tempfile, sys, os, shutil
from PIL import Image
from huggingface_hub import snapshot_download
import spaces, torch
MODEL_REPO = "Skywork/Matrix-Game-2.0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
# ----- one-time model + code download -----
@spaces.cached
def setup():
print("‣ downloading weights …")
model_dir = snapshot_download(MODEL_REPO, cache_dir="model_cache")
if not os.path.exists("Matrix-Game"):
subprocess.check_call(["git", "clone",
"https://github.com/SkyworkAI/Matrix-Game.git"])
return model_dir
# -----------------------------------------
@spaces.GPU(duration=120)
def run(img, frames, seed):
if img is None:
return None, "Upload an image first!"
model_dir = setup()
tmp = tempfile.mkdtemp()
inp = os.path.join(tmp, "input.jpg")
outd = os.path.join(tmp, "outputs")
os.makedirs(outd, exist_ok=True)
# down-size to <=512 px to keep VRAM happy
if max(img.size) > 512:
r = 512 / max(img.size)
img = img.resize((int(img.size[0]*r), int(img.size[1]*r)),
Image.Resampling.LANCZOS)
img.save(inp)
m2 = os.path.join("Matrix-Game", "Matrix-Game-2")
cmd = [sys.executable, os.path.join(m2, "inference.py"),
"--img_path", inp,
"--output_folder", outd,
"--num_output_frames", str(frames),
"--seed", str(seed),
"--pretrained_model_path", model_dir]
print("‣ running:", " ".join(cmd))
proc = subprocess.run(cmd, capture_output=True, text=True, cwd=m2)
print(proc.stdout or proc.stderr)
# grab first video file we find
for root, _, files in os.walk(outd):
for f in files:
if f.lower().endswith((".mp4", ".webm", ".mov")):
final = os.path.join(root, f)
shutil.move(final, "result.mp4")
shutil.rmtree(tmp, ignore_errors=True)
return "result.mp4", "✔ Done!"
return None, "Generation failed – see logs"
with gr.Blocks() as demo:
gr.Markdown("# Matrix-Game 2.0 demo")
with gr.Row():
with gr.Column():
img = gr.Image(label="Start frame (jpg/png)", type="pil")
nfrm = gr.Slider(25, 150, 60, step=1, label="Frames")
s = gr.Number(42, label="Seed")
go = gr.Button("Generate")
with gr.Column():
vid = gr.Video(label="Output")
stat = gr.Textbox(label="Status")
go.click(run, [img, nfrm, s], [vid, stat])
if __name__ == "__main__":
demo.launch()