TTV / app.py
LTTEAM's picture
Update app.py
721dfc7 verified
import os
import sys
import time
import torch
# Ensure project roots are on sys.path
current_file_path = os.path.abspath(__file__)
project_roots = [
os.path.dirname(current_file_path),
os.path.dirname(os.path.dirname(current_file_path)),
os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))),
]
for project_root in project_roots:
if project_root not in sys.path:
sys.path.insert(0, project_root)
from cogvideox.api.api import (
infer_forward_api,
update_diffusion_transformer_api,
update_edition_api
)
from cogvideox.ui.controller import flow_scheduler_dict
from cogvideox.ui.wan_fun_ui import ui, ui_eas, ui_modelscope
if __name__ == "__main__":
# --- Configuration ---
# Choose the UI mode: one of "eas", "modelscope", or default
ui_mode = "eas"
# GPU memory mode: choices are
# - "model_cpu_offload"
# - "model_cpu_offload_and_qfloat8"
# - "sequential_cpu_offload"
GPU_memory_mode = "model_cpu_offload"
# Weight dtype: use bfloat16 if supported, otherwise float16
weight_dtype = (
torch.bfloat16
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else torch.float16
)
# Path to your OmegaConf config for WAN2.1
config_path = "config/wan2.1/wan_civitai.yaml"
# Server binding for Gradio
server_name = "0.0.0.0"
server_port = 7860
# Parameters for modelscope mode
model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
model_type = "Inpaint" # or "Control"
savedir_sample = "samples"
# --- Initialize UI & Controller ---
if ui_mode == "modelscope":
demo, controller = ui_modelscope(
model_name,
model_type,
savedir_sample,
GPU_memory_mode,
flow_scheduler_dict,
weight_dtype,
config_path
)
elif ui_mode == "eas":
demo, controller = ui_eas(
model_name,
flow_scheduler_dict,
savedir_sample,
config_path
)
else:
demo, controller = ui(
GPU_memory_mode,
flow_scheduler_dict,
weight_dtype,
config_path
)
# --- Launch Gradio app ---
# share=False for local/Colab use; ssr=False disables experimental SSR to avoid 405 errors
app, _, _ = demo.queue(status_update_rate=1).launch(
share=False,
server_name=server_name,
server_port=server_port,
prevent_thread_lock=True
)
# --- Mount API endpoints ---
infer_forward_api(None, app, controller)
update_diffusion_transformer_api(None, app, controller)
update_edition_api(None, app, controller)
# Keep the script alive
while True:
time.sleep(5)