ArianatorQualquer
commited on
Commit
•
e1952b6
1
Parent(s):
fc663f6
Create gui-gradio.py
Browse files- gui-gradio.py +188 -0
gui-gradio.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
import threading
|
5 |
+
import queue
|
6 |
+
import json
|
7 |
+
|
8 |
+
# Função para rodar subprocessos
|
9 |
+
def run_subprocess(cmd, output_queue):
|
10 |
+
try:
|
11 |
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
|
12 |
+
for line in process.stdout:
|
13 |
+
output_queue.put(line)
|
14 |
+
process.wait()
|
15 |
+
if process.returncode == 0:
|
16 |
+
output_queue.put("Process completed successfully!")
|
17 |
+
else:
|
18 |
+
output_queue.put(f"Process failed with return code {process.returncode}")
|
19 |
+
except Exception as e:
|
20 |
+
output_queue.put(f"An error occurred: {str(e)}")
|
21 |
+
|
22 |
+
# Classe para gerenciar as configurações salvas
|
23 |
+
class ConfigManager:
|
24 |
+
def __init__(self, filepath="settings.json"):
|
25 |
+
self.filepath = filepath
|
26 |
+
self.settings = self.load_settings()
|
27 |
+
|
28 |
+
def save_settings(self):
|
29 |
+
with open(self.filepath, "w") as f:
|
30 |
+
json.dump(self.settings, f, indent=2, ensure_ascii=False)
|
31 |
+
|
32 |
+
def load_settings(self):
|
33 |
+
if os.path.exists(self.filepath):
|
34 |
+
with open(self.filepath, "r") as f:
|
35 |
+
return json.load(f)
|
36 |
+
return {"saved_combinations": {}}
|
37 |
+
|
38 |
+
def get_saved_combinations(self):
|
39 |
+
return self.settings.get("saved_combinations", {})
|
40 |
+
|
41 |
+
def add_combination(self, name, combination):
|
42 |
+
self.settings["saved_combinations"][name] = combination
|
43 |
+
self.save_settings()
|
44 |
+
|
45 |
+
config_manager = ConfigManager()
|
46 |
+
|
47 |
+
# Funções de treinamento e inferência
|
48 |
+
def run_training(model_type, config_path, start_checkpoint, results_path, data_paths, valid_paths, num_workers, device_ids):
|
49 |
+
if not (model_type and config_path and results_path and data_paths and valid_paths):
|
50 |
+
return "Error: Missing required inputs for training."
|
51 |
+
|
52 |
+
cmd = [
|
53 |
+
"python", "train.py",
|
54 |
+
"--model_type", model_type,
|
55 |
+
"--config_path", config_path,
|
56 |
+
"--results_path", results_path,
|
57 |
+
"--data_path", *data_paths.split(';'),
|
58 |
+
"--valid_path", *valid_paths.split(';'),
|
59 |
+
"--num_workers", str(num_workers),
|
60 |
+
"--device_ids", device_ids
|
61 |
+
]
|
62 |
+
|
63 |
+
if start_checkpoint:
|
64 |
+
cmd += ["--start_check_point", start_checkpoint]
|
65 |
+
|
66 |
+
output_queue = queue.Queue()
|
67 |
+
threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
|
68 |
+
|
69 |
+
output = []
|
70 |
+
while not output_queue.empty():
|
71 |
+
output.append(output_queue.get())
|
72 |
+
return "\n".join(output)
|
73 |
+
|
74 |
+
def run_inference(model_type, config_path, start_checkpoint, input_folder, store_dir, extract_instrumental):
|
75 |
+
if not (model_type and config_path and input_folder and store_dir):
|
76 |
+
return "Error: Missing required inputs for inference."
|
77 |
+
|
78 |
+
cmd = [
|
79 |
+
"python", "inference.py",
|
80 |
+
"--model_type", model_type,
|
81 |
+
"--config_path", config_path,
|
82 |
+
"--input_folder", input_folder,
|
83 |
+
"--store_dir", store_dir
|
84 |
+
]
|
85 |
+
|
86 |
+
if start_checkpoint:
|
87 |
+
cmd += ["--start_check_point", start_checkpoint]
|
88 |
+
if extract_instrumental:
|
89 |
+
cmd += ["--extract_instrumental"]
|
90 |
+
|
91 |
+
output_queue = queue.Queue()
|
92 |
+
threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
|
93 |
+
|
94 |
+
output = []
|
95 |
+
while not output_queue.empty():
|
96 |
+
output.append(output_queue.get())
|
97 |
+
return "\n".join(output)
|
98 |
+
|
99 |
+
# Interface Gradio
|
100 |
+
def add_preset(name, model_type, config_path, checkpoint):
|
101 |
+
if not name:
|
102 |
+
return "Error: Name is required to save a preset."
|
103 |
+
|
104 |
+
config_manager.add_combination(name, {
|
105 |
+
"model_type": model_type,
|
106 |
+
"config_path": config_path,
|
107 |
+
"checkpoint": checkpoint
|
108 |
+
})
|
109 |
+
return f"Preset '{name}' saved successfully."
|
110 |
+
|
111 |
+
saved_presets = config_manager.get_saved_combinations()
|
112 |
+
preset_names = list(saved_presets.keys())
|
113 |
+
|
114 |
+
def load_preset(name):
|
115 |
+
if name in saved_presets:
|
116 |
+
preset = saved_presets[name]
|
117 |
+
return preset["model_type"], preset["config_path"], preset["checkpoint"]
|
118 |
+
return "", "", ""
|
119 |
+
|
120 |
+
with gr.Blocks() as gui:
|
121 |
+
gr.Markdown("# Music Source Separation Training & Inference GUI")
|
122 |
+
|
123 |
+
# Treinamento
|
124 |
+
with gr.Accordion("Training Configuration", open=False):
|
125 |
+
model_type = gr.Dropdown(
|
126 |
+
choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type"
|
127 |
+
)
|
128 |
+
config_path = gr.Textbox(label="Config File Path")
|
129 |
+
start_checkpoint = gr.Textbox(label="Checkpoint (Optional)")
|
130 |
+
results_path = gr.Textbox(label="Results Path")
|
131 |
+
data_paths = gr.Textbox(label="Data Paths (separated by ';')")
|
132 |
+
valid_paths = gr.Textbox(label="Validation Paths (separated by ';')")
|
133 |
+
num_workers = gr.Number(label="Number of Workers", value=4)
|
134 |
+
device_ids = gr.Textbox(label="Device IDs (comma-separated)", value="0")
|
135 |
+
train_output = gr.Textbox(label="Training Output", interactive=False)
|
136 |
+
|
137 |
+
gr.Button("Run Training").click(
|
138 |
+
run_training,
|
139 |
+
inputs=[
|
140 |
+
model_type, config_path, start_checkpoint, results_path, data_paths,
|
141 |
+
valid_paths, num_workers, device_ids
|
142 |
+
],
|
143 |
+
outputs=train_output
|
144 |
+
)
|
145 |
+
|
146 |
+
# Inferência
|
147 |
+
with gr.Accordion("Inference Configuration", open=False):
|
148 |
+
infer_model_type = gr.Dropdown(
|
149 |
+
choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type"
|
150 |
+
)
|
151 |
+
infer_config_path = gr.Textbox(label="Config File Path")
|
152 |
+
infer_checkpoint = gr.Textbox(label="Checkpoint (Optional)")
|
153 |
+
input_folder = gr.Textbox(label="Input Folder")
|
154 |
+
store_dir = gr.Textbox(label="Output Folder")
|
155 |
+
extract_instrumental = gr.Checkbox(label="Extract Instrumental", value=False)
|
156 |
+
infer_output = gr.Textbox(label="Inference Output", interactive=False)
|
157 |
+
|
158 |
+
gr.Button("Run Inference").click(
|
159 |
+
run_inference,
|
160 |
+
inputs=[
|
161 |
+
infer_model_type, infer_config_path, infer_checkpoint, input_folder,
|
162 |
+
store_dir, extract_instrumental
|
163 |
+
],
|
164 |
+
outputs=infer_output
|
165 |
+
)
|
166 |
+
|
167 |
+
# Gerenciamento de Presets
|
168 |
+
with gr.Accordion("Presets", open=False):
|
169 |
+
preset_name = gr.Textbox(label="Preset Name")
|
170 |
+
preset_model_type = gr.Textbox(label="Model Type")
|
171 |
+
preset_config_path = gr.Textbox(label="Config Path")
|
172 |
+
preset_checkpoint = gr.Textbox(label="Checkpoint")
|
173 |
+
preset_feedback = gr.Textbox(label="Feedback", interactive=False)
|
174 |
+
|
175 |
+
gr.Button("Save Preset").click(
|
176 |
+
add_preset,
|
177 |
+
inputs=[preset_name, preset_model_type, preset_config_path, preset_checkpoint],
|
178 |
+
outputs=preset_feedback
|
179 |
+
)
|
180 |
+
|
181 |
+
preset_dropdown = gr.Dropdown(
|
182 |
+
choices=preset_names, label="Load Preset"
|
183 |
+
)
|
184 |
+
gr.Button("Load Preset").click(
|
185 |
+
load_preset, inputs=preset_dropdown, outputs=[preset_model_type, preset_config_path, preset_checkpoint]
|
186 |
+
)
|
187 |
+
|
188 |
+
gui.launch(share=True)
|