EPiC-LowRes / gradio_batch.py
roll-ai's picture
Update gradio_batch.py
0c89006 verified
import os
import torch
import cv2
import yaml
import shutil
import zipfile
import subprocess
import gradio as gr
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
from gradio_app import get_anchor_video, inference
# -----------------------------
# Environment Setup
# -----------------------------
HF_HOME = "/app/hf_cache"
os.environ["HF_HOME"] = HF_HOME
os.environ["TRANSFORMERS_CACHE"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
PRETRAINED_DIR = "/app/pretrained"
os.makedirs(PRETRAINED_DIR, exist_ok=True)
INPUT_VIDEOS_DIR = "Input_Videos"
CONFIG_FILE = "config.yaml"
FINAL_RESULTS_DIR = "Final_results"
# -----------------------------
# File Upload Handler
# -----------------------------
def handle_uploads(zip_file, config_file):
# Extract zip
with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
zip_ref.extractall(".")
# Write the config file (already a string)
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
f.write(config_file)
# List files
extracted_files = list(Path(".").rglob("*"))
summary = "\n".join(str(f.relative_to(Path("."))) for f in extracted_files if f.is_file())
return f"""βœ… Upload Successful!
πŸ“ Extracted Files:
{summary}
πŸ“ Config file saved to: `{CONFIG_FILE}`
"""
# -----------------------------
# Utility Functions
# -----------------------------
def download_models():
expected_model = os.path.join(PRETRAINED_DIR, "RAFT/raft-things.pth")
if not Path(expected_model).exists():
print("\u2699\ufe0f Downloading pretrained models...")
try:
subprocess.check_call(["bash", "download/download_models.sh"])
print("\u2705 Models downloaded.")
except subprocess.CalledProcessError as e:
print(f"Model download failed: {e}")
else:
print("\u2705 Pretrained models already exist.")
def visualize_depth_npy_as_video(npy_file, fps):
depth_np = np.load(npy_file)
tensor = torch.from_numpy(depth_np)
T, _, H, W = tensor.shape
video_path = "/app/depth_video_preview.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(video_path, fourcc, fps, (W, H))
for i in range(T):
frame = tensor[i, 0].numpy()
norm = (frame - frame.min()) / (frame.max() - frame.min() + 1e-8)
frame_uint8 = (norm * 255).astype(np.uint8)
colored = cv2.applyColorMap(frame_uint8, cv2.COLORMAP_INFERNO)
out.write(colored)
out.release()
return video_path
def zip_dir(dir_path, zip_path):
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(dir_path):
for file in files:
full_path = os.path.join(root, file)
rel_path = os.path.relpath(full_path, dir_path)
zf.write(full_path, rel_path)
# -----------------------------
# Inference Functions
# -----------------------------
from gradio_app import get_anchor_video, inference # Replace with real implementation
def run_batch_process(progress=gr.Progress()):
with open(CONFIG_FILE, 'r') as f:
trajectories = yaml.safe_load(f)
os.makedirs(FINAL_RESULTS_DIR, exist_ok=True)
logs = ""
videos = list(Path(INPUT_VIDEOS_DIR).glob("*.mp4"))
total = len(videos) * len(trajectories)
idx = 0
for video_path in videos:
video_name = video_path.stem
for traj_name, params in trajectories.items():
idx += 1
logs += f"\n---\nRunning {video_name}/{traj_name} ({idx}/{total})\n"
out_dir = Path(FINAL_RESULTS_DIR) / video_name / traj_name
out_dir.mkdir(parents=True, exist_ok=True)
anchor_path, logs1, caption, depth_path = get_anchor_video(
video_path=str(video_path),
fps=params.get("fps",24),
num_frames=params.get("num_frames",49),
target_pose=params["target_pose"],
mode=params.get("mode", "gradual"),
radius_scale=params.get("radius_scale", 1.0),
near_far_estimated=params.get("near_far_estimated", True),
sampler_name=params.get("sampler_name", "DDIM_Origin"),
diffusion_guidance_scale=params.get("diff_guidance", 6.0),
diffusion_inference_steps=params.get("diff_steps", 50),
prompt=params.get("prompt", ""),
negative_prompt=params.get("neg_prompt", ""),
refine_prompt=params.get("refine_prompt", ""),
depth_inference_steps=params.get("depth_steps", 5),
depth_guidance_scale=params.get("depth_guidance", 1.0),
window_size=params.get("window_size", 64),
overlap=params.get("overlap", 25),
max_res=params.get("max_res", 720),
sample_size=params.get("sample_size", "384, 672"),
seed_input=params.get("seed", 43),
height=params.get("height", 480),
width=params.get("width", 720),
aspect_ratio_inputs=params.get("aspect_ratio", "3,4"),
init_dx=params.get("init_dx", 0.0),
init_dy=params.get("init_dy", 0.0),
init_dz=params.get("init_dz", 0.0)
)
if not anchor_path:
logs += f"❌ Failed: {video_name}/{traj_name}\n"
continue
shutil.copy(anchor_path, out_dir / "anchor_video.mp4")
shutil.copy(depth_path, out_dir / "depth.mp4")
with open(out_dir / "captions.txt", "w") as f:
f.write(caption or "")
with open(out_dir / "step1_logs.txt", "w") as f:
f.write(logs1 or "")
final_video, logs2 = inference(
fps=params.get("fps", 24),
num_frames=params.get("num_frames", 49),
controlnet_weights=params.get("controlnet_weights", 0.5),
controlnet_guidance_start=params.get("controlnet_guidance_start", 0.0),
controlnet_guidance_end=params.get("controlnet_guidance_end", 0.5),
guidance_scale=params.get("guidance_scale", 6.0),
num_inference_steps=params.get("inference_steps", 50),
dtype=params.get("dtype", "bfloat16"),
seed=params.get("seed2", 42),
height=params.get("height", 480),
width=params.get("width", 720),
downscale_coef=params.get("downscale_coef", 8),
vae_channels=params.get("vae_channels", 16),
controlnet_input_channels=params.get("controlnet_input_channels", 6),
controlnet_transformer_num_layers=params.get("controlnet_transformer_layers", 8)
)
if final_video:
shutil.copy(final_video, out_dir / "final_video.mp4")
with open(out_dir / "step2_logs.txt", "w") as f:
f.write(logs2 or "")
progress(idx / total)
zip_path = FINAL_RESULTS_DIR + ".zip"
zip_dir(FINAL_RESULTS_DIR, zip_path)
return logs, zip_path
# -----------------------------
# Gradio Interface
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## πŸš€ EPiC Pipeline: Upload Inputs + Run Inference")
with gr.Tab("πŸ“€ Upload Files"):
with gr.Row():
zip_input = gr.File(label="Upload Folder (.zip)", file_types=[".zip"])
config_input = gr.File(label="Upload config.yaml", file_types=[".yaml", ".yml"])
upload_btn = gr.Button("Upload & Extract")
upload_output = gr.Textbox(label="Upload Result", lines=10)
upload_btn.click(handle_uploads, inputs=[zip_input, config_input], outputs=upload_output)
with gr.Tab("πŸ“ Run Experiments"):
with gr.Row():
run_batch_btn = gr.Button("▢️ Run Batch Experiments")
download_btn = gr.Button("⬇️ Download Results")
batch_logs = gr.Textbox(label="Logs", lines=25)
zip_file_output = gr.File(label="Final ZIP", visible=True)
run_batch_btn.click(run_batch_process, outputs=[batch_logs, zip_file_output])
download_btn.click(lambda: FINAL_RESULTS_DIR + ".zip", outputs=zip_file_output)
if __name__ == "__main__":
download_models()
demo.launch(server_name="0.0.0.0", server_port=7860)