File size: 5,269 Bytes
4a40efc
 
 
 
 
 
 
 
 
 
 
 
 
 
cbd5841
4a40efc
 
 
 
 
 
df015e0
 
4a40efc
 
 
 
290c968
4a40efc
 
 
 
 
290c968
 
f28a5b1
 
4a40efc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df015e0
 
 
4a40efc
 
 
 
 
 
 
 
df015e0
4a40efc
 
290c968
 
4a40efc
 
 
 
 
 
 
 
 
 
 
 
 
 
290c968
4a40efc
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import threading
import time

import gradio as gr
import torch
# from diffusers import CogVideoXPipeline
from models.pipeline import VchitectXLPipeline
from diffusers.utils import export_to_video
from datetime import datetime, timedelta
# from openai import OpenAI
import spaces
import moviepy.editor as mp


import os
from huggingface_hub import login
login(token=os.getenv('HF_TOKEN'))

dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = VchitectXLPipeline("Vchitect/Vchitect-XL-2B",device)

os.makedirs("./output", exist_ok=True)
os.makedirs("./gradio_tmp", exist_ok=True)

@spaces.GPU(duration=120)
def infer(prompt: str, progress=gr.Progress(track_tqdm=True)):
    torch.cuda.empty_cache()
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        video = pipe(
            prompt,
            negative_prompt="",
            num_inference_steps=50,
            guidance_scale=7.5,
            width=768,
            height=432, #480x288  624x352 432x240 768x432
            frames=16
        )
    
    return video


def save_video(tensor):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    video_path = f"./output/{timestamp}.mp4"
    os.makedirs(os.path.dirname(video_path), exist_ok=True)
    export_to_video(tensor, video_path)
    return video_path


def convert_to_gif(video_path):
    clip = mp.VideoFileClip(video_path)
    clip = clip.set_fps(8)
    clip = clip.resize(height=240)
    gif_path = video_path.replace(".mp4", ".gif")
    clip.write_gif(gif_path, fps=8)
    return gif_path


def delete_old_files():
    while True:
        now = datetime.now()
        cutoff = now - timedelta(minutes=10)
        directories = ["./output", "./gradio_tmp"]

        for directory in directories:
            for filename in os.listdir(directory):
                file_path = os.path.join(directory, filename)
                if os.path.isfile(file_path):
                    file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
                    if file_mtime < cutoff:
                        os.remove(file_path)
        time.sleep(600)


threading.Thread(target=delete_old_files, daemon=True).start()

with gr.Blocks() as demo:
    gr.Markdown("""
           <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
               Vchitect-XL 2B Huggingface Space🤗
           </div>
           <div style="text-align: center;">
               <a href="https://huggingface.co/Vchitect-XL/Vchitect-XL-2B">🤗 2B Model Hub</a> |
               <a href="https://vchitect.intern-ai.org.cn/">🌐 Website</a> |
           </div>
           <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
            ⚠️ This demo is for academic research and experiential use only. 
            Users should strictly adhere to local laws and ethics.
            </div>
           """)
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=5)

            # with gr.Row():
            #     gr.Markdown(
            #         "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
            #     enhance_button = gr.Button("✨ Enhance Prompt(Optional)")

            with gr.Column():
                # gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
                #             "Increasing the number of inference steps will produce more detailed videos, but it will slow down the process.<br>"
                #             "50 steps are recommended for most cases.<br>"
                #             "For the 5B model, 50 steps will take approximately 350 seconds.")
                # with gr.Row():
                #     num_inference_steps = gr.Number(label="Inference Steps", value=50)
                #     guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
                generate_button = gr.Button("🎬 Generate Video")

        with gr.Column():
            video_output = gr.Video(label="CogVideoX Generate Video", width=768, height=432)
            with gr.Row():
                download_video_button = gr.File(label="📥 Download Video", visible=False)
                download_gif_button = gr.File(label="📥 Download GIF", visible=False)




    def generate(prompt, model_choice, progress=gr.Progress(track_tqdm=True)):
        tensor = infer(prompt, progress=progress)
        video_path = save_video(tensor)
        video_update = gr.update(visible=True, value=video_path)
        gif_path = convert_to_gif(video_path)
        gif_update = gr.update(visible=True, value=gif_path)

        return video_path, video_update, gif_update


    # def enhance_prompt_func(prompt):
    #     return convert_prompt(prompt, retry_times=1)


    generate_button.click(
        generate,
        inputs=[prompt],
        outputs=[video_output, download_video_button, download_gif_button]
    )

    # enhance_button.click(
    #     enhance_prompt_func,
    #     inputs=[prompt],
    #     outputs=[prompt]
    # )

if __name__ == "__main__":
    demo.launch()