ArianatorQualquer commited on
Commit
e1952b6
1 Parent(s): fc663f6

Create gui-gradio.py

Browse files
Files changed (1) hide show
  1. gui-gradio.py +188 -0
gui-gradio.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import threading
5
+ import queue
6
+ import json
7
+
8
+ # Função para rodar subprocessos
9
+ def run_subprocess(cmd, output_queue):
10
+ try:
11
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
12
+ for line in process.stdout:
13
+ output_queue.put(line)
14
+ process.wait()
15
+ if process.returncode == 0:
16
+ output_queue.put("Process completed successfully!")
17
+ else:
18
+ output_queue.put(f"Process failed with return code {process.returncode}")
19
+ except Exception as e:
20
+ output_queue.put(f"An error occurred: {str(e)}")
21
+
22
+ # Classe para gerenciar as configurações salvas
23
+ class ConfigManager:
24
+ def __init__(self, filepath="settings.json"):
25
+ self.filepath = filepath
26
+ self.settings = self.load_settings()
27
+
28
+ def save_settings(self):
29
+ with open(self.filepath, "w") as f:
30
+ json.dump(self.settings, f, indent=2, ensure_ascii=False)
31
+
32
+ def load_settings(self):
33
+ if os.path.exists(self.filepath):
34
+ with open(self.filepath, "r") as f:
35
+ return json.load(f)
36
+ return {"saved_combinations": {}}
37
+
38
+ def get_saved_combinations(self):
39
+ return self.settings.get("saved_combinations", {})
40
+
41
+ def add_combination(self, name, combination):
42
+ self.settings["saved_combinations"][name] = combination
43
+ self.save_settings()
44
+
45
+ config_manager = ConfigManager()
46
+
47
+ # Funções de treinamento e inferência
48
+ def run_training(model_type, config_path, start_checkpoint, results_path, data_paths, valid_paths, num_workers, device_ids):
49
+ if not (model_type and config_path and results_path and data_paths and valid_paths):
50
+ return "Error: Missing required inputs for training."
51
+
52
+ cmd = [
53
+ "python", "train.py",
54
+ "--model_type", model_type,
55
+ "--config_path", config_path,
56
+ "--results_path", results_path,
57
+ "--data_path", *data_paths.split(';'),
58
+ "--valid_path", *valid_paths.split(';'),
59
+ "--num_workers", str(num_workers),
60
+ "--device_ids", device_ids
61
+ ]
62
+
63
+ if start_checkpoint:
64
+ cmd += ["--start_check_point", start_checkpoint]
65
+
66
+ output_queue = queue.Queue()
67
+ threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
68
+
69
+ output = []
70
+ while not output_queue.empty():
71
+ output.append(output_queue.get())
72
+ return "\n".join(output)
73
+
74
+ def run_inference(model_type, config_path, start_checkpoint, input_folder, store_dir, extract_instrumental):
75
+ if not (model_type and config_path and input_folder and store_dir):
76
+ return "Error: Missing required inputs for inference."
77
+
78
+ cmd = [
79
+ "python", "inference.py",
80
+ "--model_type", model_type,
81
+ "--config_path", config_path,
82
+ "--input_folder", input_folder,
83
+ "--store_dir", store_dir
84
+ ]
85
+
86
+ if start_checkpoint:
87
+ cmd += ["--start_check_point", start_checkpoint]
88
+ if extract_instrumental:
89
+ cmd += ["--extract_instrumental"]
90
+
91
+ output_queue = queue.Queue()
92
+ threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
93
+
94
+ output = []
95
+ while not output_queue.empty():
96
+ output.append(output_queue.get())
97
+ return "\n".join(output)
98
+
99
+ # Interface Gradio
100
+ def add_preset(name, model_type, config_path, checkpoint):
101
+ if not name:
102
+ return "Error: Name is required to save a preset."
103
+
104
+ config_manager.add_combination(name, {
105
+ "model_type": model_type,
106
+ "config_path": config_path,
107
+ "checkpoint": checkpoint
108
+ })
109
+ return f"Preset '{name}' saved successfully."
110
+
111
+ saved_presets = config_manager.get_saved_combinations()
112
+ preset_names = list(saved_presets.keys())
113
+
114
+ def load_preset(name):
115
+ if name in saved_presets:
116
+ preset = saved_presets[name]
117
+ return preset["model_type"], preset["config_path"], preset["checkpoint"]
118
+ return "", "", ""
119
+
120
+ with gr.Blocks() as gui:
121
+ gr.Markdown("# Music Source Separation Training & Inference GUI")
122
+
123
+ # Treinamento
124
+ with gr.Accordion("Training Configuration", open=False):
125
+ model_type = gr.Dropdown(
126
+ choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type"
127
+ )
128
+ config_path = gr.Textbox(label="Config File Path")
129
+ start_checkpoint = gr.Textbox(label="Checkpoint (Optional)")
130
+ results_path = gr.Textbox(label="Results Path")
131
+ data_paths = gr.Textbox(label="Data Paths (separated by ';')")
132
+ valid_paths = gr.Textbox(label="Validation Paths (separated by ';')")
133
+ num_workers = gr.Number(label="Number of Workers", value=4)
134
+ device_ids = gr.Textbox(label="Device IDs (comma-separated)", value="0")
135
+ train_output = gr.Textbox(label="Training Output", interactive=False)
136
+
137
+ gr.Button("Run Training").click(
138
+ run_training,
139
+ inputs=[
140
+ model_type, config_path, start_checkpoint, results_path, data_paths,
141
+ valid_paths, num_workers, device_ids
142
+ ],
143
+ outputs=train_output
144
+ )
145
+
146
+ # Inferência
147
+ with gr.Accordion("Inference Configuration", open=False):
148
+ infer_model_type = gr.Dropdown(
149
+ choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type"
150
+ )
151
+ infer_config_path = gr.Textbox(label="Config File Path")
152
+ infer_checkpoint = gr.Textbox(label="Checkpoint (Optional)")
153
+ input_folder = gr.Textbox(label="Input Folder")
154
+ store_dir = gr.Textbox(label="Output Folder")
155
+ extract_instrumental = gr.Checkbox(label="Extract Instrumental", value=False)
156
+ infer_output = gr.Textbox(label="Inference Output", interactive=False)
157
+
158
+ gr.Button("Run Inference").click(
159
+ run_inference,
160
+ inputs=[
161
+ infer_model_type, infer_config_path, infer_checkpoint, input_folder,
162
+ store_dir, extract_instrumental
163
+ ],
164
+ outputs=infer_output
165
+ )
166
+
167
+ # Gerenciamento de Presets
168
+ with gr.Accordion("Presets", open=False):
169
+ preset_name = gr.Textbox(label="Preset Name")
170
+ preset_model_type = gr.Textbox(label="Model Type")
171
+ preset_config_path = gr.Textbox(label="Config Path")
172
+ preset_checkpoint = gr.Textbox(label="Checkpoint")
173
+ preset_feedback = gr.Textbox(label="Feedback", interactive=False)
174
+
175
+ gr.Button("Save Preset").click(
176
+ add_preset,
177
+ inputs=[preset_name, preset_model_type, preset_config_path, preset_checkpoint],
178
+ outputs=preset_feedback
179
+ )
180
+
181
+ preset_dropdown = gr.Dropdown(
182
+ choices=preset_names, label="Load Preset"
183
+ )
184
+ gr.Button("Load Preset").click(
185
+ load_preset, inputs=preset_dropdown, outputs=[preset_model_type, preset_config_path, preset_checkpoint]
186
+ )
187
+
188
+ gui.launch(share=True)