|
import spaces |
|
import os |
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from pathlib import Path |
|
import io |
|
import sys |
|
import traceback |
|
from huggingface_hub import hf_hub_download |
|
|
|
import psutil |
|
import GPUtil |
|
|
|
|
|
|
|
|
|
|
|
HF_DATASET_REPO = "roll-ai/FloVD-weights" |
|
|
|
WEIGHT_FILES = { |
|
"ckpt/FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt", |
|
"ckpt/OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors", |
|
"ckpt/OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors", |
|
"ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth" |
|
} |
|
|
|
print("\nDownloading model...", flush=True) |
|
|
|
def download_weights(): |
|
print("๐ Downloading model weights via huggingface_hub...") |
|
for hf_path, local_rel_path in WEIGHT_FILES.items(): |
|
local_path = Path("ckpt") / local_rel_path |
|
if not local_path.exists(): |
|
print(f"๐ฅ Downloading {hf_path}") |
|
hf_hub_download( |
|
repo_id=HF_DATASET_REPO, |
|
repo_type="dataset", |
|
filename=hf_path, |
|
local_dir="./" |
|
) |
|
else: |
|
print(f"โ
Already exists: {local_path}") |
|
|
|
download_weights() |
|
|
|
def print_ckpt_structure(base_path="ckpt"): |
|
print(f"๐ Listing structure of: {base_path}", flush=True) |
|
for root, dirs, files in os.walk(base_path): |
|
level = root.replace(base_path, '').count(os.sep) |
|
indent = ' ' * 2 * level |
|
print(f"{indent}๐ {os.path.basename(root)}/", flush=True) |
|
sub_indent = ' ' * 2 * (level + 1) |
|
for f in files: |
|
print(f"{sub_indent}๐ {f}", flush=True) |
|
|
|
print_ckpt_structure() |
|
|
|
|
|
|
|
|
|
|
|
from inference.flovd_demo import generate_video |
|
|
|
def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name): |
|
log_buffer = io.StringIO() |
|
sys_stdout = sys.stdout |
|
sys.stdout = log_buffer |
|
|
|
video_path = None |
|
try: |
|
print("๐ Starting inference...", flush=True) |
|
os.makedirs("input_images", exist_ok=True) |
|
image_path = "input_images/input_image.png" |
|
|
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(image.astype("uint8")) |
|
|
|
image.save(image_path) |
|
print(f"๐ธ Saved input image to {image_path}", flush=True) |
|
|
|
generate_video( |
|
prompt=prompt, |
|
image_path=image_path, |
|
fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt", |
|
omsm_path="./ckpt/OMSM", |
|
output_path="./outputs", |
|
num_frames=49, |
|
fps=16, |
|
width=None, |
|
height=None, |
|
seed=42, |
|
guidance_scale=6.0, |
|
dtype=torch.float16, |
|
controlnet_guidance_end=0.4, |
|
use_dynamic_cfg=False, |
|
pose_type=pose_type, |
|
speed=float(speed), |
|
use_flow_integration=use_flow_integration, |
|
cam_pose_name=cam_pose_name, |
|
depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth" |
|
) |
|
|
|
video_name = f"{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4" |
|
video_path = f"./outputs/generated_videos/{video_name}" |
|
print(f"โ
Inference complete. Video saved to {video_path}") |
|
|
|
except Exception: |
|
print("๐ฅ Inference failed with exception:") |
|
traceback.print_exc() |
|
|
|
sys.stdout = sys_stdout |
|
logs = log_buffer.getvalue() |
|
log_buffer.close() |
|
|
|
return (video_path if video_path and os.path.exists(video_path) else None), logs |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as video_tab: |
|
gr.Markdown("## ๐ฅ FloVD: Optical Flow + CogVideoX Video Generation") |
|
|
|
prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.") |
|
image = gr.Image(label="Input Image") |
|
pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type") |
|
speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Camera Speed") |
|
use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False) |
|
cam_pose_name = gr.Textbox(label="Camera Trajectory", placeholder="e.g., zoom_in, custom_motion, etc.", lines=1) |
|
|
|
generate_btn = gr.Button("๐ฌ Generate Video") |
|
|
|
video_output = gr.Video(label="Generated Video") |
|
log_output = gr.Textbox(label="Logs", lines=20, interactive=False) |
|
|
|
generate_btn.click( |
|
fn=run_inference, |
|
inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], |
|
outputs=[video_output, log_output] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_system_stats(): |
|
cpu = psutil.cpu_percent() |
|
mem = psutil.virtual_memory() |
|
disk = psutil.disk_usage('/') |
|
try: |
|
gpus = GPUtil.getGPUs() |
|
gpu_info = "\n".join([ |
|
f"GPU {i}: {gpu.name}, {gpu.memoryUsed}MB / {gpu.memoryTotal}MB, Util: {gpu.load * 100:.1f}%" |
|
for i, gpu in enumerate(gpus) |
|
]) if gpus else "No GPU detected" |
|
except Exception as e: |
|
gpu_info = f"GPU info error: {e}" |
|
|
|
return ( |
|
f"๐ง CPU Usage: {cpu}%\n" |
|
f"๐พ RAM: {mem.used / 1e9:.2f} GB / {mem.total / 1e9:.2f} GB ({mem.percent}%)\n" |
|
f"๐๏ธ Disk: {disk.used / 1e9:.2f} GB / {disk.total / 1e9:.2f} GB ({disk.percent}%)\n" |
|
f"๐ฎ {gpu_info}" |
|
) |
|
|
|
with gr.Blocks() as monitor_tab: |
|
gr.Markdown("## ๐ Live System Resource Monitor") |
|
stats_box = gr.Textbox(label="Live Stats", lines=10, interactive=False) |
|
|
|
def update_stats(): |
|
return gr.update(value=get_system_stats()) |
|
|
|
stats_btn = gr.Button("๐ Refresh Stats") |
|
stats_btn.click(fn=update_stats, outputs=stats_box) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
with gr.Tab("๐ฅ Video Generator"): |
|
video_tab.render() |
|
with gr.Tab("๐ System Monitor"): |
|
monitor_tab.render() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True) |
|
|