kadirnar commited on
Commit
561516d
·
verified ·
1 Parent(s): 2dab131

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ import subprocess
4
+ import tempfile
5
+ import shutil
6
+ import os
7
+ import spaces
8
+
9
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
10
+ import os
11
+
12
+
13
+ def download_t5_model(model_id, save_directory):
14
+ # Modelin tokenizer'ını ve modeli indir
15
+ model = T5ForConditionalGeneration.from_pretrained(model_id)
16
+ tokenizer = T5Tokenizer.from_pretrained(model_id)
17
+
18
+ # Model ve tokenizer'ı belirtilen dizine kaydet
19
+ if not os.path.exists(save_directory):
20
+ os.makedirs(save_directory)
21
+ model.save_pretrained(save_directory)
22
+ tokenizer.save_pretrained(save_directory)
23
+
24
+ # Model ID ve kaydedilecek dizin
25
+ model_id = "DeepFloyd/t5-v1_1-xxl"
26
+ save_directory = "pretrained_models/t5_ckpts/t5-v1_1-xxl"
27
+
28
+ # Modeli indir
29
+ download_t5_model(model_id, save_directory)
30
+
31
+ def download_model(repo_id, model_name):
32
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
33
+ return model_path
34
+
35
+ import glob
36
+
37
+ @spaces.GPU
38
+ def run_inference(model_name, prompt_text):
39
+ repo_id = "hpcai-tech/Open-Sora"
40
+
41
+ # Map model names to their respective configuration files
42
+ config_mapping = {
43
+ "OpenSora-v1-16x256x256.pth": "16x256x256.py",
44
+ "OpenSora-v1-HQ-16x256x256.pth": "16x512x512.py",
45
+ "OpenSora-v1-HQ-16x512x512.pth": "64x512x512.py"
46
+ }
47
+
48
+ config_path = config_mapping[model_name]
49
+ ckpt_path = download_model(repo_id, model_name)
50
+
51
+ # Save prompt_text to a temporary text file
52
+ prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w')
53
+ prompt_file.write(prompt_text)
54
+ prompt_file.close()
55
+
56
+ with open(config_path, 'r') as file:
57
+ config_content = file.read()
58
+ config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"')
59
+
60
+ with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as temp_file:
61
+ temp_file.write(config_content)
62
+ temp_config_path = temp_file.name
63
+
64
+ cmd = [
65
+ "torchrun", "--standalone", "--nproc_per_node", "1",
66
+ "scripts/inference.py", temp_config_path,
67
+ "--ckpt-path", ckpt_path
68
+ ]
69
+ subprocess.run(cmd)
70
+
71
+ save_dir = "./outputs/samples/" # Örneğin, inference.py tarafından kullanılan kayıt dizini
72
+ list_of_files = glob.glob(f'{save_dir}/*')
73
+ if list_of_files:
74
+ latest_file = max(list_of_files, key=os.path.getctime)
75
+ return latest_file
76
+ else:
77
+ print("No files found in the output directory.")
78
+ return None
79
+
80
+ # Clean up the temporary files
81
+ os.remove(temp_file.name)
82
+ os.remove(prompt_file.name)
83
+
84
+ def main():
85
+ gr.Interface(
86
+ fn=run_inference,
87
+ inputs=[
88
+ gr.Dropdown(choices=[
89
+ "OpenSora-v1-16x256x256.pth",
90
+ "OpenSora-v1-HQ-16x256x256.pth",
91
+ "OpenSora-v1-HQ-16x512x512.pth"
92
+ ],
93
+ value="OpenSora-v1-16x256x256.pth",
94
+ label="Model Selection"),
95
+ gr.Textbox(label="Prompt Text", value="Enter prompt text here")
96
+ ],
97
+ outputs=gr.Video(label="Output Video"),
98
+ title="Open-Sora Inference",
99
+ description="Run Open-Sora Inference with Custom Parameters",
100
+ ).launch()
101
+
102
+ if __name__ == "__main__":
103
+ main()