Open-Sora / app.py
kadirnar's picture
Update app.py
b4c555f verified
raw
history blame
4.85 kB
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
import subprocess
import tempfile
import shutil
import os
import spaces
import importlib
from transformers import T5ForConditionalGeneration, T5Tokenizer
import os
def download_t5_model(model_id, save_directory):
# Modelin tokenizer'ını ve modeli indir
if not os.path.exists(save_directory):
os.makedirs(save_directory)
snapshot_download(repo_id="DeepFloyd/t5-v1_1-xxl",local_dir=save_directory, local_dir_use_symlinks=False)
# Model ID ve kaydedilecek dizin
model_id = "DeepFloyd/t5-v1_1-xxl"
save_directory = "pretrained_models/t5_ckpts/t5-v1_1-xxl"
# Modeli indir
download_t5_model(model_id, save_directory)
def download_model(repo_id, model_name):
model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
return model_path
import glob
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@spaces.GPU(duration=180)
def run_inference(prompt_text):
repo_id = "hpcai-tech/Open-Sora"
# Map model names to their respective configuration files
model_name = "OpenSora-v1-HQ-16x512x512.pth"
config_mapping = {
"OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py",
"OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py",
"OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py"
}
config_path = config_mapping[model_name]
ckpt_path = download_model(repo_id, model_name)
# Save prompt_text to a temporary text file
prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w')
prompt_file.write(prompt_text)
prompt_file.close()
with open(config_path, 'r') as file:
config_content = file.read()
config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"')
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as temp_file:
temp_file.write(config_content)
temp_config_path = temp_file.name
cmd = [
"torchrun", "--standalone", "--nproc_per_node", "1",
"scripts/inference.py", temp_config_path,
"--ckpt-path", ckpt_path
]
subprocess.run(cmd)
save_dir = "./outputs/samples/" # Örneğin, inference.py tarafından kullanılan kayıt dizini
list_of_files = glob.glob(f'{save_dir}/*')
if list_of_files:
latest_file = max(list_of_files, key=os.path.getctime)
return latest_file
else:
print("No files found in the output directory.")
return None
# Clean up the temporary files
os.remove(temp_file.name)
os.remove(prompt_file.name)
def main():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.HTML(
"""
<h1 style='text-align: center'>
Open-Sora: Democratizing Efficient Video Production for All
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
Follow me for more!
<a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a>
</h3>
"""
)
with gr.Row():
with gr.Column():
prompt_text = gr.Textbox(show_label=False, placeholder="Enter prompt text here", lines=4)
submit_button = gr.Button("Run Inference")
with gr.Column():
output_video = gr.Video()
submit_button.click(
fn=run_inference,
inputs=[prompt_text],
outputs=output_video
)
gr.Examples(
examples=[
[
"Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle. The art style is 3D and realistic, with a focus on lighting and texture. The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image.",
],
],
fn=run_inference,
inputs=[prompt_text,],
outputs=[output_video],
cache_examples=True,
)
demo.launch(debug=True)
if __name__ == "__main__":
main()