Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import threading | |
import time | |
import gradio as gr | |
import torch | |
# from diffusers import CogVideoXPipeline | |
from models.pipeline import VchitectXLPipeline | |
from diffusers.utils import export_to_video | |
from datetime import datetime, timedelta | |
# from openai import OpenAI | |
import spaces | |
import moviepy.editor as mp | |
import os | |
from huggingface_hub import login | |
login(token=os.getenv('HF_TOKEN')) | |
dtype = torch.float16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = VchitectXLPipeline("Vchitect/Vchitect-XL-2B",device) | |
os.makedirs("./output", exist_ok=True) | |
os.makedirs("./gradio_tmp", exist_ok=True) | |
def infer(prompt: str, progress=gr.Progress(track_tqdm=True)): | |
torch.cuda.empty_cache() | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
video = pipe( | |
prompt, | |
negative_prompt="", | |
num_inference_steps=50, | |
guidance_scale=7.5, | |
width=768, | |
height=432, #480x288 624x352 432x240 768x432 | |
frames=16 | |
) | |
return video | |
def save_video(tensor): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
video_path = f"./output/{timestamp}.mp4" | |
os.makedirs(os.path.dirname(video_path), exist_ok=True) | |
export_to_video(tensor, video_path) | |
return video_path | |
def convert_to_gif(video_path): | |
clip = mp.VideoFileClip(video_path) | |
clip = clip.set_fps(8) | |
clip = clip.resize(height=240) | |
gif_path = video_path.replace(".mp4", ".gif") | |
clip.write_gif(gif_path, fps=8) | |
return gif_path | |
def delete_old_files(): | |
while True: | |
now = datetime.now() | |
cutoff = now - timedelta(minutes=10) | |
directories = ["./output", "./gradio_tmp"] | |
for directory in directories: | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
if os.path.isfile(file_path): | |
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
if file_mtime < cutoff: | |
os.remove(file_path) | |
time.sleep(600) | |
threading.Thread(target=delete_old_files, daemon=True).start() | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
Vchitect-XL 2B Huggingface Space🤗 | |
</div> | |
<div style="text-align: center;"> | |
<a href="https://huggingface.co/Vchitect-XL/Vchitect-XL-2B">🤗 2B Model Hub</a> | | |
<a href="https://vchitect.intern-ai.org.cn/">🌐 Website</a> | | |
</div> | |
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;"> | |
⚠️ This demo is for academic research and experiential use only. | |
Users should strictly adhere to local laws and ethics. | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=5) | |
# with gr.Row(): | |
# gr.Markdown( | |
# "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.") | |
# enhance_button = gr.Button("✨ Enhance Prompt(Optional)") | |
with gr.Column(): | |
# gr.Markdown("**Optional Parameters** (default values are recommended)<br>" | |
# "Increasing the number of inference steps will produce more detailed videos, but it will slow down the process.<br>" | |
# "50 steps are recommended for most cases.<br>" | |
# "For the 5B model, 50 steps will take approximately 350 seconds.") | |
# with gr.Row(): | |
# num_inference_steps = gr.Number(label="Inference Steps", value=50) | |
# guidance_scale = gr.Number(label="Guidance Scale", value=7.5) | |
generate_button = gr.Button("🎬 Generate Video") | |
with gr.Column(): | |
video_output = gr.Video(label="CogVideoX Generate Video", width=768, height=432) | |
with gr.Row(): | |
download_video_button = gr.File(label="📥 Download Video", visible=False) | |
download_gif_button = gr.File(label="📥 Download GIF", visible=False) | |
def generate(prompt, model_choice, progress=gr.Progress(track_tqdm=True)): | |
tensor = infer(prompt, progress=progress) | |
video_path = save_video(tensor) | |
video_update = gr.update(visible=True, value=video_path) | |
gif_path = convert_to_gif(video_path) | |
gif_update = gr.update(visible=True, value=gif_path) | |
return video_path, video_update, gif_update | |
# def enhance_prompt_func(prompt): | |
# return convert_prompt(prompt, retry_times=1) | |
generate_button.click( | |
generate, | |
inputs=[prompt], | |
outputs=[video_output, download_video_button, download_gif_button] | |
) | |
# enhance_button.click( | |
# enhance_prompt_func, | |
# inputs=[prompt], | |
# outputs=[prompt] | |
# ) | |
if __name__ == "__main__": | |
demo.launch() | |