File size: 4,994 Bytes
69acc93
e588377
b38913b
69acc93
 
 
 
538d96c
69acc93
 
 
 
 
 
 
e588377
69acc93
 
 
 
 
 
 
 
 
 
 
 
 
 
974ed71
7bf8cc1
496445e
ff65188
69acc93
 
 
f5f1704
69acc93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93d7b49
e18b82d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93d7b49
7d4a56a
 
93d7b49
 
7d4a56a
e18b82d
93d7b49
 
 
ff65188
93d7b49
 
3a8867d
ae4aad6
 
cf589b1
ae4aad6
3a8867d
 
fbbf5a1
3a8867d
 
 
93d7b49
3a8867d
69acc93
 
e18b82d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
import subprocess
import tempfile
import shutil
import os
import spaces
import importlib
from transformers import T5ForConditionalGeneration, T5Tokenizer
import os

def download_t5_model(model_id, save_directory):
    # Modelin tokenizer'ını ve modeli indir
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    snapshot_download(repo_id="DeepFloyd/t5-v1_1-xxl",local_dir=save_directory, local_dir_use_symlinks=False)

# Model ID ve kaydedilecek dizin
model_id = "DeepFloyd/t5-v1_1-xxl"
save_directory = "pretrained_models/t5_ckpts/t5-v1_1-xxl"

# Modeli indir
download_t5_model(model_id, save_directory)

def download_model(repo_id, model_name):
    model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
    return model_path

import glob

subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

@spaces.GPU(duration=200)
def run_inference(prompt_text):
    repo_id = "hpcai-tech/Open-Sora"
    
    # Map model names to their respective configuration files
    model_name = "OpenSora-v1-HQ-16x512x512.pth"
    config_mapping = {
        "OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py",
        "OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py",
        "OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py"
    }
    
    config_path = config_mapping[model_name]
    ckpt_path = download_model(repo_id, model_name)

    # Save prompt_text to a temporary text file
    prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w')
    prompt_file.write(prompt_text)
    prompt_file.close()

    with open(config_path, 'r') as file:
        config_content = file.read()
    config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"')
    
    with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as temp_file:
        temp_file.write(config_content)
        temp_config_path = temp_file.name

    cmd = [
        "torchrun", "--standalone", "--nproc_per_node", "1",
        "scripts/inference.py", temp_config_path,
        "--ckpt-path", ckpt_path
    ]
    subprocess.run(cmd)

    save_dir = "./outputs/samples/"  # Örneğin, inference.py tarafından kullanılan kayıt dizini
    list_of_files = glob.glob(f'{save_dir}/*')
    if list_of_files:
        latest_file = max(list_of_files, key=os.path.getctime)
        return latest_file
    else:
        print("No files found in the output directory.")
        return None

    # Clean up the temporary files
    os.remove(temp_file.name)
    os.remove(prompt_file.name)

def main():
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                gr.HTML(
                """
                <h1 style='text-align: center'>
               Open-Sora: Democratizing Efficient Video Production for All
                </h1>
                """
            )
                gr.HTML(
                    """
                    <h3 style='text-align: center'>
                    Follow me for more! 
                    <a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a>
                    </h3>
                    """
            )

        with gr.Row():
            with gr.Column():
                prompt_text = gr.Textbox(show_label=False, placeholder="Enter prompt text here", lines=4)
                submit_button = gr.Button("Run Inference")

            with gr.Column():
                output_video = gr.Video()

        submit_button.click(
            fn=run_inference, 
            inputs=[prompt_text], 
            outputs=output_video
        )
        gr.Examples(
            examples=[
                [
                    "A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene. The video is shot from a slightly elevated angle, providing a comprehensive view of the turtle's surroundings. The overall style of the video is calm and peaceful, capturing the beauty and tranquility of the underwater world.",
                ],       
            ],
            fn=run_inference,
            inputs=[prompt_text,],
            outputs=[output_video],
            cache_examples=True,
        )

    demo.launch(debug=True)

if __name__ == "__main__":
    main()