kadirnar commited on
Commit
9f69156
1 Parent(s): 9fbf355

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -1,9 +1,19 @@
1
  import gradio as gr
 
2
  import subprocess
3
  import tempfile
4
  import shutil
5
 
6
- def run_inference(config_path, ckpt_path, prompt_path):
 
 
 
 
 
 
 
 
 
7
  with open(config_path, 'r') as file:
8
  config_content = file.read()
9
  config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_path}"')
@@ -30,9 +40,13 @@ def main():
30
  gr.Interface(
31
  fn=run_inference,
32
  inputs=[
33
- gr.Textbox(label="Configuration Path"),
34
- gr.Dropdown(choices=["./path/to/model1.ckpt", "./path/to/model2.ckpt", "./path/to/model3.ckpt"], label="Checkpoint Path"),
35
- gr.Textbox(label="Prompt Path")
 
 
 
 
36
  ],
37
  outputs=[
38
  gr.Text(label="Status"),
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
  import subprocess
4
  import tempfile
5
  import shutil
6
 
7
+ def download_model(repo_id, model_name):
8
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
9
+ return model_path
10
+
11
+ def run_inference(config_path, model_name, prompt_path):
12
+ repo_id = "hpcai-tech/Open-Sora"
13
+
14
+ # Download the selected model
15
+ ckpt_path = download_model(repo_id, model_name)
16
+
17
  with open(config_path, 'r') as file:
18
  config_content = file.read()
19
  config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_path}"')
 
40
  gr.Interface(
41
  fn=run_inference,
42
  inputs=[
43
+ gr.Textbox(label="Configuration Path", default="configs/opensora/inference/16x256x256.py"),
44
+ gr.Dropdown(choices=[
45
+ "OpenSora-v1-16x256x256.pth",
46
+ "OpenSora-v1-HQ-16x256x256.pth",
47
+ "OpenSora-v1-HQ-16x512x512.pth"
48
+ ], label="Model Selection"),
49
+ gr.Textbox(label="Prompt Path", default="./assets/texts/t2v_samples.txt")
50
  ],
51
  outputs=[
52
  gr.Text(label="Status"),