Spaces:
Build error
Build error
mrq
commited on
Commit
·
d8b9969
1
Parent(s):
b785192
a bunch of shit i had uncommited over the past while pertaining to VALL-E
Browse files- modules/tortoise-tts +1 -1
- src/utils.py +35 -15
- src/webui.py +58 -5
modules/tortoise-tts
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit f025470d60fd18993caaa651e6faa585bcc420f0
|
src/utils.py
CHANGED
@@ -75,6 +75,7 @@ try:
|
|
75 |
|
76 |
VALLE_ENABLED = True
|
77 |
except Exception as e:
|
|
|
78 |
pass
|
79 |
|
80 |
if VALLE_ENABLED:
|
@@ -156,10 +157,12 @@ def generate_valle(**kwargs):
|
|
156 |
|
157 |
voice_cache = {}
|
158 |
def fetch_voice( voice ):
|
159 |
-
voice_dir = f'./
|
|
|
|
|
160 |
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
161 |
-
return files
|
162 |
-
|
163 |
|
164 |
def get_settings( override=None ):
|
165 |
settings = {
|
@@ -1089,13 +1092,13 @@ class TrainingState():
|
|
1089 |
'ar-quarter.lr', 'nar-quarter.lr',
|
1090 |
]
|
1091 |
keys['losses'] = [
|
1092 |
-
'ar.loss', 'nar.loss',
|
1093 |
-
'ar-half.loss', 'nar-half.loss',
|
1094 |
-
'ar-quarter.loss', 'nar-quarter.loss',
|
1095 |
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
]
|
1100 |
|
1101 |
keys['accuracies'] = [
|
@@ -1123,7 +1126,7 @@ class TrainingState():
|
|
1123 |
|
1124 |
prefix = ""
|
1125 |
|
1126 |
-
if
|
1127 |
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
|
1128 |
|
1129 |
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
|
@@ -1231,6 +1234,7 @@ class TrainingState():
|
|
1231 |
|
1232 |
unq = {}
|
1233 |
averager = None
|
|
|
1234 |
|
1235 |
for log in logs:
|
1236 |
with open(log, 'r', encoding="utf-8") as f:
|
@@ -1250,6 +1254,7 @@ class TrainingState():
|
|
1250 |
|
1251 |
name = "train"
|
1252 |
mode = "training"
|
|
|
1253 |
elif line.find('Validation Metrics:') >= 0:
|
1254 |
data = json.loads(line.split("Validation Metrics:")[-1])
|
1255 |
if "it" not in data:
|
@@ -1257,8 +1262,15 @@ class TrainingState():
|
|
1257 |
if "epoch" not in data:
|
1258 |
data['epoch'] = epoch
|
1259 |
|
1260 |
-
name = data['name'] if 'name' in data else "val"
|
1261 |
mode = "validation"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1262 |
else:
|
1263 |
continue
|
1264 |
|
@@ -1272,6 +1284,7 @@ class TrainingState():
|
|
1272 |
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
|
1273 |
averager = {
|
1274 |
'key': f'{it}_{name}',
|
|
|
1275 |
'mode': mode,
|
1276 |
"metrics": {}
|
1277 |
}
|
@@ -1292,11 +1305,13 @@ class TrainingState():
|
|
1292 |
if update and it <= self.last_info_check_at:
|
1293 |
continue
|
1294 |
|
|
|
1295 |
for it in unq:
|
1296 |
if args.tts_backend == "vall-e":
|
1297 |
stats = unq[it]
|
1298 |
-
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()}
|
1299 |
-
data['
|
|
|
1300 |
data['steps'] = len(stats['metrics']['it'])
|
1301 |
else:
|
1302 |
data = unq[it]
|
@@ -1633,6 +1648,7 @@ def whisper_transcribe( file, language=None ):
|
|
1633 |
|
1634 |
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
1635 |
if whisper_vad:
|
|
|
1636 |
"""
|
1637 |
if args.whisper_batchsize > 1:
|
1638 |
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
@@ -1778,7 +1794,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
|
1778 |
messages = []
|
1779 |
|
1780 |
if not os.path.exists(infile):
|
1781 |
-
|
|
|
|
|
1782 |
|
1783 |
if results is None:
|
1784 |
results = json.load(open(infile, 'r', encoding="utf-8"))
|
@@ -1903,7 +1921,9 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|
1903 |
indir = f'./training/{voice}/'
|
1904 |
infile = f'{indir}/whisper.json'
|
1905 |
if not os.path.exists(infile):
|
1906 |
-
|
|
|
|
|
1907 |
|
1908 |
results = json.load(open(infile, 'r', encoding="utf-8"))
|
1909 |
|
|
|
75 |
|
76 |
VALLE_ENABLED = True
|
77 |
except Exception as e:
|
78 |
+
print(e)
|
79 |
pass
|
80 |
|
81 |
if VALLE_ENABLED:
|
|
|
157 |
|
158 |
voice_cache = {}
|
159 |
def fetch_voice( voice ):
|
160 |
+
voice_dir = f'./training/{voice}/audio/'
|
161 |
+
if not os.path.isdir(voice_dir):
|
162 |
+
voice_dir = f'./voices/{voice}/'
|
163 |
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
164 |
+
# return files
|
165 |
+
return random.choice(files)
|
166 |
|
167 |
def get_settings( override=None ):
|
168 |
settings = {
|
|
|
1092 |
'ar-quarter.lr', 'nar-quarter.lr',
|
1093 |
]
|
1094 |
keys['losses'] = [
|
1095 |
+
'ar.loss', 'nar.loss', 'ar+nar.loss',
|
1096 |
+
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
|
1097 |
+
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
|
1098 |
|
1099 |
+
# 'ar.loss.nll', 'nar.loss.nll',
|
1100 |
+
# 'ar-half.loss.nll', 'nar-half.loss.nll',
|
1101 |
+
# 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
1102 |
]
|
1103 |
|
1104 |
keys['accuracies'] = [
|
|
|
1126 |
|
1127 |
prefix = ""
|
1128 |
|
1129 |
+
if "mode" in self.info and self.info["mode"] == "validation":
|
1130 |
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
|
1131 |
|
1132 |
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
|
|
|
1234 |
|
1235 |
unq = {}
|
1236 |
averager = None
|
1237 |
+
prev_state = 0
|
1238 |
|
1239 |
for log in logs:
|
1240 |
with open(log, 'r', encoding="utf-8") as f:
|
|
|
1254 |
|
1255 |
name = "train"
|
1256 |
mode = "training"
|
1257 |
+
prev_state = 0
|
1258 |
elif line.find('Validation Metrics:') >= 0:
|
1259 |
data = json.loads(line.split("Validation Metrics:")[-1])
|
1260 |
if "it" not in data:
|
|
|
1262 |
if "epoch" not in data:
|
1263 |
data['epoch'] = epoch
|
1264 |
|
1265 |
+
# name = data['name'] if 'name' in data else "val"
|
1266 |
mode = "validation"
|
1267 |
+
|
1268 |
+
if prev_state == 0:
|
1269 |
+
name = "subtrain"
|
1270 |
+
else:
|
1271 |
+
name = "val"
|
1272 |
+
|
1273 |
+
prev_state += 1
|
1274 |
else:
|
1275 |
continue
|
1276 |
|
|
|
1284 |
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
|
1285 |
averager = {
|
1286 |
'key': f'{it}_{name}',
|
1287 |
+
'name': name,
|
1288 |
'mode': mode,
|
1289 |
"metrics": {}
|
1290 |
}
|
|
|
1305 |
if update and it <= self.last_info_check_at:
|
1306 |
continue
|
1307 |
|
1308 |
+
blacklist = [ "batch", "eval" ]
|
1309 |
for it in unq:
|
1310 |
if args.tts_backend == "vall-e":
|
1311 |
stats = unq[it]
|
1312 |
+
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist }
|
1313 |
+
data['name'] = stats['name']
|
1314 |
+
data['mode'] = stats['mode']
|
1315 |
data['steps'] = len(stats['metrics']['it'])
|
1316 |
else:
|
1317 |
data = unq[it]
|
|
|
1648 |
|
1649 |
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
1650 |
if whisper_vad:
|
1651 |
+
# omits a considerable amount of the end
|
1652 |
"""
|
1653 |
if args.whisper_batchsize > 1:
|
1654 |
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
|
|
1794 |
messages = []
|
1795 |
|
1796 |
if not os.path.exists(infile):
|
1797 |
+
message = f"Missing dataset: {infile}"
|
1798 |
+
print(message)
|
1799 |
+
return message
|
1800 |
|
1801 |
if results is None:
|
1802 |
results = json.load(open(infile, 'r', encoding="utf-8"))
|
|
|
1921 |
indir = f'./training/{voice}/'
|
1922 |
infile = f'{indir}/whisper.json'
|
1923 |
if not os.path.exists(infile):
|
1924 |
+
message = f"Missing dataset: {infile}"
|
1925 |
+
print(message)
|
1926 |
+
return message
|
1927 |
|
1928 |
results = json.load(open(infile, 'r', encoding="utf-8"))
|
1929 |
|
src/webui.py
CHANGED
@@ -196,6 +196,50 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
|
196 |
def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
|
197 |
return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
|
200 |
messages = []
|
201 |
|
@@ -468,6 +512,8 @@ def setup_gradio():
|
|
468 |
DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
|
469 |
|
470 |
transcribe_button = gr.Button(value="Transcribe and Process")
|
|
|
|
|
471 |
|
472 |
with gr.Row():
|
473 |
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
|
@@ -579,7 +625,7 @@ def setup_gradio():
|
|
579 |
tooltip=['epoch', 'it', 'value', 'type'],
|
580 |
width=500,
|
581 |
height=350,
|
582 |
-
visible=args.tts_backend=="vall-e"
|
583 |
)
|
584 |
view_losses = gr.Button(value="View Losses")
|
585 |
|
@@ -611,10 +657,7 @@ def setup_gradio():
|
|
611 |
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
|
612 |
|
613 |
with gr.Column(visible=args.tts_backend=="vall-e"):
|
614 |
-
|
615 |
-
if len(valle_models):
|
616 |
-
default_valle_model_choice = valle_models[0]
|
617 |
-
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else default_valle_model_choice)
|
618 |
|
619 |
with gr.Column(visible=args.tts_backend=="tortoise"):
|
620 |
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
|
@@ -859,6 +902,16 @@ def setup_gradio():
|
|
859 |
inputs=dataset_settings,
|
860 |
outputs=prepare_dataset_output #console_output
|
861 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
862 |
prepare_dataset_button.click(
|
863 |
prepare_dataset,
|
864 |
inputs=[
|
|
|
196 |
def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
|
197 |
return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
|
198 |
|
199 |
+
def diarize_dataset( voice, progress=gr.Progress(track_tqdm=False) ):
|
200 |
+
from pyannote.audio import Pipeline
|
201 |
+
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=args.hf_token)
|
202 |
+
|
203 |
+
messages = []
|
204 |
+
files = sorted( get_voices(load_latents=False)[voice] )
|
205 |
+
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
206 |
+
diarization = pipeline(file)
|
207 |
+
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
208 |
+
message = f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}"
|
209 |
+
print(message)
|
210 |
+
messages.append(message)
|
211 |
+
|
212 |
+
return "\n".join(messages)
|
213 |
+
|
214 |
+
def prepare_all_datasets( language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
|
215 |
+
kwargs = locals()
|
216 |
+
|
217 |
+
messages = []
|
218 |
+
voices = get_voice_list()
|
219 |
+
|
220 |
+
"""
|
221 |
+
for voice in voices:
|
222 |
+
message = prepare_dataset_proxy(voice, **kwargs)
|
223 |
+
messages.append(message)
|
224 |
+
"""
|
225 |
+
for voice in voices:
|
226 |
+
print("Processing:", voice)
|
227 |
+
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
|
228 |
+
messages.append(message)
|
229 |
+
|
230 |
+
if slice_audio:
|
231 |
+
for voice in voices:
|
232 |
+
print("Processing:", voice)
|
233 |
+
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
|
234 |
+
messages.append(message)
|
235 |
+
|
236 |
+
for voice in voices:
|
237 |
+
print("Processing:", voice)
|
238 |
+
message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length, progress=progress )
|
239 |
+
messages.append(message)
|
240 |
+
|
241 |
+
return "\n".join(messages)
|
242 |
+
|
243 |
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
|
244 |
messages = []
|
245 |
|
|
|
512 |
DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
|
513 |
|
514 |
transcribe_button = gr.Button(value="Transcribe and Process")
|
515 |
+
transcribe_all_button = gr.Button(value="Transcribe All")
|
516 |
+
diarize_button = gr.Button(value="Diarize")
|
517 |
|
518 |
with gr.Row():
|
519 |
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
|
|
|
625 |
tooltip=['epoch', 'it', 'value', 'type'],
|
626 |
width=500,
|
627 |
height=350,
|
628 |
+
visible=False, # args.tts_backend=="vall-e"
|
629 |
)
|
630 |
view_losses = gr.Button(value="View Losses")
|
631 |
|
|
|
657 |
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
|
658 |
|
659 |
with gr.Column(visible=args.tts_backend=="vall-e"):
|
660 |
+
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
|
|
|
|
|
|
|
661 |
|
662 |
with gr.Column(visible=args.tts_backend=="tortoise"):
|
663 |
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
|
|
|
902 |
inputs=dataset_settings,
|
903 |
outputs=prepare_dataset_output #console_output
|
904 |
)
|
905 |
+
transcribe_all_button.click(
|
906 |
+
prepare_all_datasets,
|
907 |
+
inputs=dataset_settings[1:],
|
908 |
+
outputs=prepare_dataset_output #console_output
|
909 |
+
)
|
910 |
+
diarize_button.click(
|
911 |
+
diarize_dataset,
|
912 |
+
inputs=dataset_settings[0],
|
913 |
+
outputs=prepare_dataset_output #console_output
|
914 |
+
)
|
915 |
prepare_dataset_button.click(
|
916 |
prepare_dataset,
|
917 |
inputs=[
|