TIMBOVILL commited on
Commit
5cf4082
1 Parent(s): ea76d52

Create tabs/inference/inference.py

Browse files
Files changed (1) hide show
  1. tabs/inference/inference.py +437 -0
tabs/inference/inference.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import gradio as gr
3
+ import regex as re
4
+ import shutil
5
+ import datetime
6
+ import random
7
+
8
+ from core import (
9
+ run_infer_script,
10
+ run_batch_infer_script,
11
+ )
12
+
13
+ from assets.i18n.i18n import I18nAuto
14
+
15
+ i18n = I18nAuto()
16
+
17
+ now_dir = os.getcwd()
18
+ sys.path.append(now_dir)
19
+
20
+ model_root = os.path.join(now_dir, "logs")
21
+ audio_root = os.path.join(now_dir, "assets", "audios")
22
+ sup_audioext = {
23
+ "wav",
24
+ "mp3",
25
+ "flac",
26
+ "ogg",
27
+ "opus",
28
+ "m4a",
29
+ "mp4",
30
+ "aac",
31
+ "alac",
32
+ "wma",
33
+ "aiff",
34
+ "webm",
35
+ "ac3",
36
+ }
37
+
38
+ names = [
39
+ os.path.join(root, file)
40
+ for root, _, files in os.walk(model_root, topdown=False)
41
+ for file in files
42
+ if file.endswith((".pth", ".onnx"))
43
+ ]
44
+
45
+ indexes_list = [
46
+ os.path.join(root, name)
47
+ for root, _, files in os.walk(model_root, topdown=False)
48
+ for name in files
49
+ if name.endswith(".index") and "trained" not in name
50
+ ]
51
+
52
+ audio_paths = [
53
+ os.path.join(root, name)
54
+ for root, _, files in os.walk(audio_root, topdown=False)
55
+ for name in files
56
+ if name.endswith(tuple(sup_audioext))
57
+ and root == audio_root
58
+ and "_output" not in name
59
+ ]
60
+
61
+
62
+ def output_path_fn(input_audio_path):
63
+ original_name_without_extension = os.path.basename(input_audio_path).rsplit(".", 1)[
64
+ 0
65
+ ]
66
+ new_name = original_name_without_extension + "_output.wav"
67
+ output_path = os.path.join(os.path.dirname(input_audio_path), new_name)
68
+ return output_path
69
+
70
+
71
+ def change_choices():
72
+ names = [
73
+ os.path.join(root, file)
74
+ for root, _, files in os.walk(model_root, topdown=False)
75
+ for file in files
76
+ if file.endswith((".pth", ".onnx"))
77
+ ]
78
+
79
+ indexes_list = [
80
+ os.path.join(root, name)
81
+ for root, _, files in os.walk(model_root, topdown=False)
82
+ for name in files
83
+ if name.endswith(".index") and "trained" not in name
84
+ ]
85
+
86
+ audio_paths = [
87
+ os.path.join(root, name)
88
+ for root, _, files in os.walk(audio_root, topdown=False)
89
+ for name in files
90
+ if name.endswith(tuple(sup_audioext))
91
+ and root == audio_root
92
+ and "_output" not in name
93
+ ]
94
+
95
+ return (
96
+ {"choices": sorted(names), "__type__": "update"},
97
+ {"choices": sorted(indexes_list), "__type__": "update"},
98
+ {"choices": sorted(audio_paths), "__type__": "update"},
99
+ )
100
+
101
+
102
+ def get_indexes():
103
+ indexes_list = [
104
+ os.path.join(dirpath, filename)
105
+ for dirpath, _, filenames in os.walk(model_root)
106
+ for filename in filenames
107
+ if filename.endswith(".index") and "trained" not in filename
108
+ ]
109
+
110
+ return indexes_list if indexes_list else ""
111
+
112
+
113
+ def match_index(model_file: str) -> tuple:
114
+ model_files_trip = re.sub(r"\.pth|\.onnx$", "", model_file)
115
+ model_file_name = os.path.split(model_files_trip)[
116
+ -1
117
+ ] # Extract only the name, not the directory
118
+
119
+ # Check if the sid0strip has the specific ending format _eXXX_sXXX
120
+ if re.match(r".+_e\d+_s\d+$", model_file_name):
121
+ base_model_name = model_file_name.rsplit("_", 2)[0]
122
+ else:
123
+ base_model_name = model_file_name
124
+
125
+ sid_directory = os.path.join(model_root, base_model_name)
126
+ directories_to_search = [sid_directory] if os.path.exists(sid_directory) else []
127
+ directories_to_search.append(model_root)
128
+
129
+ matching_index_files = []
130
+
131
+ for directory in directories_to_search:
132
+ for filename in os.listdir(directory):
133
+ if filename.endswith(".index") and "trained" not in filename:
134
+ # Condition to match the name
135
+ name_match = any(
136
+ name.lower() in filename.lower()
137
+ for name in [model_file_name, base_model_name]
138
+ )
139
+
140
+ # If in the specific directory, it's automatically a match
141
+ folder_match = directory == sid_directory
142
+
143
+ if name_match or folder_match:
144
+ index_path = os.path.join(directory, filename)
145
+ if index_path in indexes_list:
146
+ matching_index_files.append(
147
+ (
148
+ index_path,
149
+ os.path.getsize(index_path),
150
+ " " not in filename,
151
+ )
152
+ )
153
+
154
+ if matching_index_files:
155
+ # Sort by favoring files without spaces and by size (largest size first)
156
+ matching_index_files.sort(key=lambda x: (-x[2], -x[1]))
157
+ best_match_index_path = matching_index_files[0][0]
158
+ return best_match_index_path
159
+
160
+ return ""
161
+
162
+
163
+ def save_to_wav(record_button):
164
+ if record_button is None:
165
+ pass
166
+ else:
167
+ path_to_file = record_button
168
+ new_name = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".wav"
169
+ target_path = os.path.join(audio_root, os.path.basename(new_name))
170
+
171
+ shutil.move(path_to_file, target_path)
172
+ return target_path, output_path_fn(target_path)
173
+
174
+
175
+ def save_to_wav2(upload_audio):
176
+ file_path = upload_audio
177
+ target_path = os.path.join(audio_root, os.path.basename(file_path))
178
+
179
+ if os.path.exists(target_path):
180
+ os.remove(target_path)
181
+
182
+ shutil.copy(file_path, target_path)
183
+ return target_path, output_path_fn(target_path)
184
+
185
+
186
+ def delete_outputs():
187
+ for root, _, files in os.walk(audio_root, topdown=False):
188
+ for name in files:
189
+ if name.endswith(tuple(sup_audioext)) and name.__contains__("_output"):
190
+ os.remove(os.path.join(root, name))
191
+ gr.Info(f"Outputs cleared!")
192
+
193
+
194
+ # Inference tab
195
+ def inference_tab():
196
+ default_weight = random.choice(names) if names else ""
197
+ with gr.Row():
198
+ with gr.Row():
199
+ model_file = gr.Dropdown(
200
+ label=i18n("Voice Model"),
201
+ choices=sorted(names),
202
+ interactive=True,
203
+ value=default_weight,
204
+ allow_custom_value=True,
205
+ )
206
+ best_default_index_path = match_index(model_file.value)
207
+ index_file = gr.Dropdown(
208
+ label=i18n("Index File"),
209
+ choices=get_indexes(),
210
+ value=best_default_index_path,
211
+ interactive=True,
212
+ allow_custom_value=True,
213
+ )
214
+ with gr.Column():
215
+ refresh_button = gr.Button(i18n("Refresh"))
216
+ unload_button = gr.Button(i18n("Unload Voice"))
217
+
218
+ unload_button.click(
219
+ fn=lambda: ({"value": "", "__type__": "update"}),
220
+ inputs=[],
221
+ outputs=[model_file],
222
+ )
223
+
224
+ model_file.select(
225
+ fn=match_index,
226
+ inputs=[model_file],
227
+ outputs=[index_file],
228
+ )
229
+
230
+ # Single inference tab
231
+ with gr.Tab(i18n("Single")):
232
+ with gr.Row():
233
+ with gr.Column():
234
+ upload_audio = gr.Audio(
235
+ label=i18n("Upload Audio"), type="filepath", editable=False
236
+ )
237
+ with gr.Row():
238
+ audio = gr.Dropdown(
239
+ label=i18n("Select Audio"),
240
+ choices=sorted(audio_paths),
241
+ value=audio_paths[0] if audio_paths else "",
242
+ interactive=True,
243
+ allow_custom_value=True,
244
+ )
245
+
246
+ with gr.Accordion(i18n("Advanced Settings"), open=False):
247
+ with gr.Column():
248
+ clear_outputs = gr.Button(
249
+ i18n("Clear Outputs (Deletes all audios in assets/audios)")
250
+ )
251
+ output_path = gr.Textbox(
252
+ label=i18n("Output Path"),
253
+ placeholder=i18n("Enter output path"),
254
+ value=output_path_fn(audio_paths[0])
255
+ if audio_paths
256
+ else os.path.join(now_dir, "assets", "audios", "output.wav"),
257
+ interactive=True,
258
+ )
259
+ split_audio = gr.Checkbox(
260
+ label=i18n("Split Audio"),
261
+ visible=True,
262
+ value=False,
263
+ interactive=True,
264
+ )
265
+ pitch = gr.Slider(-12, 12, 0, label=i18n("Pitch"))
266
+ filter_radius = gr.Slider(
267
+ minimum=0,
268
+ maximum=7,
269
+ label=i18n(
270
+ "If >=3: apply median filtering to the harvested pitch results. The value represents the filter radius and can reduce breathiness"
271
+ ),
272
+ value=3,
273
+ step=1,
274
+ interactive=True,
275
+ )
276
+ index_rate = gr.Slider(
277
+ minimum=0,
278
+ maximum=1,
279
+ label=i18n("Search Feature Ratio"),
280
+ value=0.75,
281
+ interactive=True,
282
+ )
283
+ hop_length = gr.Slider(
284
+ minimum=1,
285
+ maximum=512,
286
+ step=1,
287
+ label=i18n("Hop Length"),
288
+ value=128,
289
+ interactive=True,
290
+ )
291
+ with gr.Column():
292
+ f0method = gr.Radio(
293
+ label=i18n("Pitch extraction algorithm"),
294
+ choices=[
295
+ "pm",
296
+ "harvest",
297
+ "dio",
298
+ "crepe",
299
+ "crepe-tiny",
300
+ "rmvpe",
301
+ ],
302
+ value="rmvpe",
303
+ interactive=True,
304
+ )
305
+
306
+ convert_button1 = gr.Button(i18n("Convert"))
307
+
308
+ with gr.Row(): # Defines output info + output audio download after conversion
309
+ vc_output1 = gr.Textbox(label=i18n("Output Information"))
310
+ vc_output2 = gr.Audio(label=i18n("Export Audio"))
311
+
312
+ # Batch inference tab
313
+ with gr.Tab(i18n("Batch")):
314
+ with gr.Row():
315
+ with gr.Column():
316
+ input_folder_batch = gr.Textbox(
317
+ label=i18n("Input Folder"),
318
+ placeholder=i18n("Enter input path"),
319
+ value=os.path.join(now_dir, "assets", "audios"),
320
+ interactive=True,
321
+ )
322
+ output_folder_batch = gr.Textbox(
323
+ label=i18n("Output Folder"),
324
+ placeholder=i18n("Enter output path"),
325
+ value=os.path.join(now_dir, "assets", "audios"),
326
+ interactive=True,
327
+ )
328
+ with gr.Accordion(i18n("Advanced Settings"), open=False):
329
+ with gr.Column():
330
+ clear_outputs = gr.Button(
331
+ i18n("Clear Outputs (Deletes all audios in assets/audios)")
332
+ )
333
+ pitch_batch = gr.Slider(-12, 12, 0, label=i18n("Pitch"))
334
+ filter_radius_batch = gr.Slider(
335
+ minimum=0,
336
+ maximum=7,
337
+ label=i18n(
338
+ "If >=3: apply median filtering to the harvested pitch results. The value represents the filter radius and can reduce breathiness"
339
+ ),
340
+ value=3,
341
+ step=1,
342
+ interactive=True,
343
+ )
344
+ index_rate_batch = gr.Slider(
345
+ minimum=0,
346
+ maximum=1,
347
+ label=i18n("Search Feature Ratio"),
348
+ value=0.75,
349
+ interactive=True,
350
+ )
351
+ hop_length_batch = gr.Slider(
352
+ minimum=1,
353
+ maximum=512,
354
+ step=1,
355
+ label=i18n("Hop Length"),
356
+ value=128,
357
+ interactive=True,
358
+ )
359
+ with gr.Column():
360
+ f0method_batch = gr.Radio(
361
+ label=i18n("Pitch extraction algorithm"),
362
+ choices=[
363
+ "pm",
364
+ "harvest",
365
+ "dio",
366
+ "crepe",
367
+ "crepe-tiny",
368
+ "rmvpe",
369
+ ],
370
+ value="rmvpe",
371
+ interactive=True,
372
+ )
373
+
374
+ convert_button2 = gr.Button(i18n("Convert"))
375
+
376
+ with gr.Row(): # Defines output info + output audio download after conversion
377
+ vc_output3 = gr.Textbox(label=i18n("Output Information"))
378
+
379
+ def toggle_visible(checkbox):
380
+ return {"visible": checkbox, "__type__": "update"}
381
+
382
+ refresh_button.click(
383
+ fn=change_choices,
384
+ inputs=[],
385
+ outputs=[model_file, index_file, audio],
386
+ )
387
+ audio.change(
388
+ fn=output_path_fn,
389
+ inputs=[audio],
390
+ outputs=[output_path],
391
+ )
392
+ upload_audio.upload(
393
+ fn=save_to_wav2,
394
+ inputs=[upload_audio],
395
+ outputs=[audio, output_path],
396
+ )
397
+ upload_audio.stop_recording(
398
+ fn=save_to_wav,
399
+ inputs=[upload_audio],
400
+ outputs=[audio, output_path],
401
+ )
402
+ clear_outputs.click(
403
+ fn=delete_outputs,
404
+ inputs=[],
405
+ outputs=[],
406
+ )
407
+ convert_button1.click(
408
+ fn=run_infer_script,
409
+ inputs=[
410
+ pitch,
411
+ filter_radius,
412
+ index_rate,
413
+ hop_length,
414
+ f0method,
415
+ audio,
416
+ output_path,
417
+ model_file,
418
+ index_file,
419
+ split_audio,
420
+ ],
421
+ outputs=[vc_output1, vc_output2],
422
+ )
423
+ convert_button2.click(
424
+ fn=run_batch_infer_script,
425
+ inputs=[
426
+ pitch_batch,
427
+ filter_radius_batch,
428
+ index_rate_batch,
429
+ hop_length_batch,
430
+ f0method_batch,
431
+ input_folder_batch,
432
+ output_folder_batch,
433
+ model_file,
434
+ index_file,
435
+ ],
436
+ outputs=[vc_output3],
437
+ )