|
|
import gradio as gr |
|
|
import subprocess |
|
|
import os |
|
|
import glob |
|
|
import time |
|
|
from pathlib import Path |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
MODEL_REPO = "hpcai-tech/Open-Sora-v2" |
|
|
CKPT_DIR = Path("ckpts") |
|
|
SAMPLES_DIR = Path("samples") |
|
|
|
|
|
def ensure_ckpts(): |
|
|
if CKPT_DIR.exists() and any(CKPT_DIR.iterdir()): |
|
|
print("Found existing checkpoints in", CKPT_DIR) |
|
|
return True |
|
|
hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") |
|
|
if not hf_token: |
|
|
print("No HF token found in env. Cannot auto-download. Please add HUGGINGFACE_HUB_TOKEN or download ckpts manually.") |
|
|
return False |
|
|
print("Downloading model weights from HF... (this will take several minutes)") |
|
|
try: |
|
|
snapshot_download(repo_id=MODEL_REPO, local_dir=str(CKPT_DIR), local_dir_use_symlinks=False) |
|
|
print("Download complete.") |
|
|
return True |
|
|
except Exception as e: |
|
|
print("Error downloading checkpoints:", e) |
|
|
return False |
|
|
|
|
|
def find_latest_video(): |
|
|
SAMPLES_DIR.mkdir(exist_ok=True) |
|
|
matches = list(SAMPLES_DIR.glob("*.mp4")) |
|
|
if not matches: |
|
|
return None |
|
|
matches.sort(key=lambda p: p.stat().st_mtime, reverse=True) |
|
|
return str(matches[0]) |
|
|
|
|
|
def run_torch_inference(config, prompt, ref_image=None, aspect_ratio=None, num_frames=None, offload=False): |
|
|
SAMPLES_DIR.mkdir(exist_ok=True) |
|
|
cmd = [ |
|
|
"torchrun", "--nproc_per_node", "1", "--standalone", |
|
|
"scripts/diffusion/inference.py", |
|
|
f"configs/diffusion/inference/{config}.py", |
|
|
"--save-dir", str(SAMPLES_DIR), |
|
|
"--prompt", prompt |
|
|
] |
|
|
if ref_image: |
|
|
cmd += ["--cond_type", "i2v_head", "--ref", ref_image] |
|
|
if aspect_ratio: |
|
|
cmd += ["--aspect_ratio", aspect_ratio] |
|
|
if num_frames: |
|
|
cmd += ["--num_frames", str(num_frames)] |
|
|
if offload: |
|
|
cmd += ["--offload", "True"] |
|
|
print("Running command:", " ".join(cmd)) |
|
|
try: |
|
|
subprocess.run(cmd, check=True, env=os.environ) |
|
|
except subprocess.CalledProcessError as e: |
|
|
print("Inference failed:", e) |
|
|
raise |
|
|
|
|
|
def generate_video(prompt, mode="t2i2v_256px", ref_image_path=None, aspect_ratio="16:9", num_frames=None, offload=False): |
|
|
|
|
|
ok = ensure_ckpts() |
|
|
if not ok: |
|
|
return "Model checkpoints not found and no HF token provided. Upload ckpts to ./ckpts or set HUGGINGFACE_HUB_TOKEN." |
|
|
|
|
|
|
|
|
config_map = { |
|
|
"256 (t2i2v)": "t2i2v_256px", |
|
|
"256 (t2v)": "256px", |
|
|
"768 (t2v)": "768px", |
|
|
"768 (t2i2v)": "t2i2v_768px" |
|
|
} |
|
|
config = config_map.get(mode, "t2i2v_256px") |
|
|
try: |
|
|
run_torch_inference(config, prompt, ref_image=ref_image_path, aspect_ratio=aspect_ratio, num_frames=num_frames, offload=offload) |
|
|
|
|
|
for _ in range(120): |
|
|
latest = find_latest_video() |
|
|
if latest: |
|
|
return latest |
|
|
time.sleep(1) |
|
|
return "No output video detected after inference." |
|
|
except Exception as e: |
|
|
return f"Error during generation: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# π¬ Open-Sora (Open-Sora-v2) β Text/Image to Video") |
|
|
with gr.Row(): |
|
|
prompt = gr.Textbox(lines=3, label="Prompt", placeholder="A cinematic shot of ...") |
|
|
with gr.Row(): |
|
|
mode = gr.Radio(["256 (t2i2v)", "256 (t2v)", "768 (t2v)", "768 (t2i2v)"], value="256 (t2i2v)", label="Generation Mode") |
|
|
aspect_ratio = gr.Dropdown(["16:9","9:16","1:1","2.39:1"], value="16:9", label="Aspect Ratio") |
|
|
num_frames = gr.Number(value=17, label="Frames (use 4k+1 rules)", precision=0) |
|
|
with gr.Row(): |
|
|
ref_image = gr.Image(type="filepath", label="Reference image (optional, for I2V)") |
|
|
offload = gr.Checkbox(label="Memory offload (slower but uses less GPU memory)", value=False) |
|
|
generate_btn = gr.Button("Generate Video") |
|
|
output_video = gr.Video(label="Generated Video") |
|
|
status = gr.Textbox(label="Status/Logs", interactive=False) |
|
|
|
|
|
def on_generate(prompt_text, mode_val, ar, nf, ref_img, off): |
|
|
status_text = "Starting..." |
|
|
status.update(status_text) |
|
|
res = generate_video(prompt_text, mode_val, ref_image_path=ref_img, aspect_ratio=ar, num_frames=int(nf) if nf else None, offload=off) |
|
|
return res, f"Completed: {res}" |
|
|
|
|
|
generate_btn.click(on_generate, inputs=[prompt, mode, aspect_ratio, num_frames, ref_image, offload], outputs=[output_video, status]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |
|
|
|