cogvidx / app.py
fantos's picture
Update app.py
4f8a7c2 verified
raw
history blame
6.29 kB
import os
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import uuid
import GPUtil
import gradio as gr
import psutil
import spaces
from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine
def load_model(model_name, enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
pab_config = CogVideoXPABConfig(spatial_threshold=pab_threshold, spatial_range=pab_range)
config = CogVideoXConfig(model_name, enable_pab=enable_video_sys, pab_config=pab_config)
engine = VideoSysEngine(config)
return engine
def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
unique_filename = f"{uuid.uuid4().hex}.mp4"
output_path = os.path.join("./.tmp_outputs", unique_filename)
engine.save_video(video, output_path)
return output_path
def get_server_status():
cpu_percent = psutil.cpu_percent()
memory = psutil.virtual_memory()
disk = psutil.disk_usage("/")
gpus = GPUtil.getGPUs()
gpu_info = []
for gpu in gpus:
gpu_info.append(
{
"id": gpu.id,
"name": gpu.name,
"load": f"{gpu.load*100:.1f}%",
"memory_used": f"{gpu.memoryUsed}MB",
"memory_total": f"{gpu.memoryTotal}MB",
}
)
return {"cpu": f"{cpu_percent}%", "memory": f"{memory.percent}%", "disk": f"{disk.percent}%", "gpu": gpu_info}
@spaces.GPU()
def generate_vanilla(model_name, prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
engine = load_model(model_name)
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
return video_path
@spaces.GPU()
def generate_vs(
model_name,
prompt,
num_inference_steps,
guidance_scale,
threshold_start,
threshold_end,
gap,
progress=gr.Progress(track_tqdm=True),
):
threshold = [int(threshold_end), int(threshold_start)]
gap = int(gap)
engine = load_model(model_name, enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
return video_path
def get_server_status():
cpu_percent = psutil.cpu_percent()
memory = psutil.virtual_memory()
disk = psutil.disk_usage("/")
try:
gpus = GPUtil.getGPUs()
if gpus:
gpu = gpus[0]
gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
else:
gpu_memory = "No GPU found"
except:
gpu_memory = "GPU information unavailable"
return {
"cpu": f"{cpu_percent}%",
"memory": f"{memory.percent}%",
"disk": f"{disk.percent}%",
"gpu_memory": gpu_memory,
}
def update_server_status():
status = get_server_status()
return (status["cpu"], status["memory"], status["disk"], status["gpu_memory"])
css = """
footer {
visibility: hidden;
}
"""
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=3)
with gr.Column():
gr.Markdown("**Generation Parameters**<br>")
with gr.Row():
model_name = gr.Radio(
["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"], label="Model Type", value="THUDM/CogVideoX-2b"
)
with gr.Row():
num_inference_steps = gr.Number(label="Inference Steps", value=50)
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
with gr.Row():
pab_range = gr.Number(
label="PAB Broadcast Range", value=2, precision=0, info="Broadcast timesteps range."
)
pab_threshold_start = gr.Number(label="PAB Start Timestep", value=850, info="Start from step 1000.")
pab_threshold_end = gr.Number(label="PAB End Timestep", value=100, info="End at step 0.")
with gr.Row():
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
generate_button = gr.Button("🎬 Generate Video (Original)")
with gr.Column(elem_classes="server-status"):
gr.Markdown("#### Server Status")
with gr.Row():
cpu_status = gr.Textbox(label="CPU", scale=1)
memory_status = gr.Textbox(label="Memory", scale=1)
with gr.Row():
disk_status = gr.Textbox(label="Disk", scale=1)
gpu_status = gr.Textbox(label="GPU Memory", scale=1)
with gr.Row():
refresh_button = gr.Button("Refresh")
with gr.Column():
with gr.Row():
video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
with gr.Row():
video_output = gr.Video(label="CogVideoX", width=720, height=480)
generate_button.click(
generate_vanilla,
inputs=[model_name, prompt, num_inference_steps, guidance_scale],
outputs=[video_output],
concurrency_id="gen",
concurrency_limit=1,
)
generate_button_vs.click(
generate_vs,
inputs=[
model_name,
prompt,
num_inference_steps,
guidance_scale,
pab_threshold_start,
pab_threshold_end,
pab_range,
],
outputs=[video_output_vs],
concurrency_id="gen",
concurrency_limit=1,
)
refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
if __name__ == "__main__":
demo.queue(max_size=10, default_concurrency_limit=1)
demo.launch()