Vchitect-2.0 / app.py
WeichenFan
update demo
f28a5b1
raw
history blame
5.27 kB
import os
import threading
import time
import gradio as gr
import torch
# from diffusers import CogVideoXPipeline
from models.pipeline import VchitectXLPipeline
from diffusers.utils import export_to_video
from datetime import datetime, timedelta
# from openai import OpenAI
import spaces
import moviepy.editor as mp
import os
from huggingface_hub import login
login(token=os.getenv('HF_TOKEN'))
dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = VchitectXLPipeline("Vchitect/Vchitect-XL-2B",device)
os.makedirs("./output", exist_ok=True)
os.makedirs("./gradio_tmp", exist_ok=True)
@spaces.GPU(duration=120)
def infer(prompt: str, progress=gr.Progress(track_tqdm=True)):
torch.cuda.empty_cache()
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
video = pipe(
prompt,
negative_prompt="",
num_inference_steps=50,
guidance_scale=7.5,
width=768,
height=432, #480x288 624x352 432x240 768x432
frames=16
)
return video
def save_video(tensor):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_path = f"./output/{timestamp}.mp4"
os.makedirs(os.path.dirname(video_path), exist_ok=True)
export_to_video(tensor, video_path)
return video_path
def convert_to_gif(video_path):
clip = mp.VideoFileClip(video_path)
clip = clip.set_fps(8)
clip = clip.resize(height=240)
gif_path = video_path.replace(".mp4", ".gif")
clip.write_gif(gif_path, fps=8)
return gif_path
def delete_old_files():
while True:
now = datetime.now()
cutoff = now - timedelta(minutes=10)
directories = ["./output", "./gradio_tmp"]
for directory in directories:
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path):
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
if file_mtime < cutoff:
os.remove(file_path)
time.sleep(600)
threading.Thread(target=delete_old_files, daemon=True).start()
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Vchitect-XL 2B Huggingface Space🤗
</div>
<div style="text-align: center;">
<a href="https://huggingface.co/Vchitect-XL/Vchitect-XL-2B">🤗 2B Model Hub</a> |
<a href="https://vchitect.intern-ai.org.cn/">🌐 Website</a> |
</div>
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
⚠️ This demo is for academic research and experiential use only.
Users should strictly adhere to local laws and ethics.
</div>
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=5)
# with gr.Row():
# gr.Markdown(
# "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
# enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
with gr.Column():
# gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
# "Increasing the number of inference steps will produce more detailed videos, but it will slow down the process.<br>"
# "50 steps are recommended for most cases.<br>"
# "For the 5B model, 50 steps will take approximately 350 seconds.")
# with gr.Row():
# num_inference_steps = gr.Number(label="Inference Steps", value=50)
# guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
generate_button = gr.Button("🎬 Generate Video")
with gr.Column():
video_output = gr.Video(label="CogVideoX Generate Video", width=768, height=432)
with gr.Row():
download_video_button = gr.File(label="📥 Download Video", visible=False)
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
def generate(prompt, model_choice, progress=gr.Progress(track_tqdm=True)):
tensor = infer(prompt, progress=progress)
video_path = save_video(tensor)
video_update = gr.update(visible=True, value=video_path)
gif_path = convert_to_gif(video_path)
gif_update = gr.update(visible=True, value=gif_path)
return video_path, video_update, gif_update
# def enhance_prompt_func(prompt):
# return convert_prompt(prompt, retry_times=1)
generate_button.click(
generate,
inputs=[prompt],
outputs=[video_output, download_video_button, download_gif_button]
)
# enhance_button.click(
# enhance_prompt_func,
# inputs=[prompt],
# outputs=[prompt]
# )
if __name__ == "__main__":
demo.launch()