Spaces:
Paused
Paused
import gradio as gr, subprocess, tempfile, sys, os, shutil | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
import spaces, torch | |
MODEL_REPO = "Skywork/Matrix-Game-2.0" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Device:", DEVICE) | |
# ----- one-time model + code download ----- | |
def setup(): | |
print("‣ downloading weights …") | |
model_dir = snapshot_download(MODEL_REPO, cache_dir="model_cache") | |
if not os.path.exists("Matrix-Game"): | |
subprocess.check_call(["git", "clone", | |
"https://github.com/SkyworkAI/Matrix-Game.git"]) | |
return model_dir | |
# ----------------------------------------- | |
def run(img, frames, seed): | |
if img is None: | |
return None, "Upload an image first!" | |
model_dir = setup() | |
tmp = tempfile.mkdtemp() | |
inp = os.path.join(tmp, "input.jpg") | |
outd = os.path.join(tmp, "outputs") | |
os.makedirs(outd, exist_ok=True) | |
# down-size to <=512 px to keep VRAM happy | |
if max(img.size) > 512: | |
r = 512 / max(img.size) | |
img = img.resize((int(img.size[0]*r), int(img.size[1]*r)), | |
Image.Resampling.LANCZOS) | |
img.save(inp) | |
m2 = os.path.join("Matrix-Game", "Matrix-Game-2") | |
cmd = [sys.executable, os.path.join(m2, "inference.py"), | |
"--img_path", inp, | |
"--output_folder", outd, | |
"--num_output_frames", str(frames), | |
"--seed", str(seed), | |
"--pretrained_model_path", model_dir] | |
print("‣ running:", " ".join(cmd)) | |
proc = subprocess.run(cmd, capture_output=True, text=True, cwd=m2) | |
print(proc.stdout or proc.stderr) | |
# grab first video file we find | |
for root, _, files in os.walk(outd): | |
for f in files: | |
if f.lower().endswith((".mp4", ".webm", ".mov")): | |
final = os.path.join(root, f) | |
shutil.move(final, "result.mp4") | |
shutil.rmtree(tmp, ignore_errors=True) | |
return "result.mp4", "✔ Done!" | |
return None, "Generation failed – see logs" | |
with gr.Blocks() as demo: | |
gr.Markdown("# Matrix-Game 2.0 demo") | |
with gr.Row(): | |
with gr.Column(): | |
img = gr.Image(label="Start frame (jpg/png)", type="pil") | |
nfrm = gr.Slider(25, 150, 60, step=1, label="Frames") | |
s = gr.Number(42, label="Seed") | |
go = gr.Button("Generate") | |
with gr.Column(): | |
vid = gr.Video(label="Output") | |
stat = gr.Textbox(label="Status") | |
go.click(run, [img, nfrm, s], [vid, stat]) | |
if __name__ == "__main__": | |
demo.launch() | |