mrq commited on
Commit
d8b9969
·
1 Parent(s): b785192

a bunch of shit i had uncommited over the past while pertaining to VALL-E

Browse files
Files changed (3) hide show
  1. modules/tortoise-tts +1 -1
  2. src/utils.py +35 -15
  3. src/webui.py +58 -5
modules/tortoise-tts CHANGED
@@ -1 +1 @@
1
- Subproject commit 0bcdf81d0444218b4dedaefa5c546d42f36b8130
 
1
+ Subproject commit f025470d60fd18993caaa651e6faa585bcc420f0
src/utils.py CHANGED
@@ -75,6 +75,7 @@ try:
75
 
76
  VALLE_ENABLED = True
77
  except Exception as e:
 
78
  pass
79
 
80
  if VALLE_ENABLED:
@@ -156,10 +157,12 @@ def generate_valle(**kwargs):
156
 
157
  voice_cache = {}
158
  def fetch_voice( voice ):
159
- voice_dir = f'./voices/{voice}/'
 
 
160
  files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
161
- return files
162
- # return random.choice(files)
163
 
164
  def get_settings( override=None ):
165
  settings = {
@@ -1089,13 +1092,13 @@ class TrainingState():
1089
  'ar-quarter.lr', 'nar-quarter.lr',
1090
  ]
1091
  keys['losses'] = [
1092
- 'ar.loss', 'nar.loss',
1093
- 'ar-half.loss', 'nar-half.loss',
1094
- 'ar-quarter.loss', 'nar-quarter.loss',
1095
 
1096
- 'ar.loss.nll', 'nar.loss.nll',
1097
- 'ar-half.loss.nll', 'nar-half.loss.nll',
1098
- 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
1099
  ]
1100
 
1101
  keys['accuracies'] = [
@@ -1123,7 +1126,7 @@ class TrainingState():
1123
 
1124
  prefix = ""
1125
 
1126
- if data["mode"] == "validation":
1127
  prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
1128
 
1129
  self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
@@ -1231,6 +1234,7 @@ class TrainingState():
1231
 
1232
  unq = {}
1233
  averager = None
 
1234
 
1235
  for log in logs:
1236
  with open(log, 'r', encoding="utf-8") as f:
@@ -1250,6 +1254,7 @@ class TrainingState():
1250
 
1251
  name = "train"
1252
  mode = "training"
 
1253
  elif line.find('Validation Metrics:') >= 0:
1254
  data = json.loads(line.split("Validation Metrics:")[-1])
1255
  if "it" not in data:
@@ -1257,8 +1262,15 @@ class TrainingState():
1257
  if "epoch" not in data:
1258
  data['epoch'] = epoch
1259
 
1260
- name = data['name'] if 'name' in data else "val"
1261
  mode = "validation"
 
 
 
 
 
 
 
1262
  else:
1263
  continue
1264
 
@@ -1272,6 +1284,7 @@ class TrainingState():
1272
  if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
1273
  averager = {
1274
  'key': f'{it}_{name}',
 
1275
  'mode': mode,
1276
  "metrics": {}
1277
  }
@@ -1292,11 +1305,13 @@ class TrainingState():
1292
  if update and it <= self.last_info_check_at:
1293
  continue
1294
 
 
1295
  for it in unq:
1296
  if args.tts_backend == "vall-e":
1297
  stats = unq[it]
1298
- data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()}
1299
- data['mode'] = stats
 
1300
  data['steps'] = len(stats['metrics']['it'])
1301
  else:
1302
  data = unq[it]
@@ -1633,6 +1648,7 @@ def whisper_transcribe( file, language=None ):
1633
 
1634
  device = "cuda" if get_device_name() == "cuda" else "cpu"
1635
  if whisper_vad:
 
1636
  """
1637
  if args.whisper_batchsize > 1:
1638
  result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
@@ -1778,7 +1794,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
1778
  messages = []
1779
 
1780
  if not os.path.exists(infile):
1781
- raise Exception(f"Missing dataset: {infile}")
 
 
1782
 
1783
  if results is None:
1784
  results = json.load(open(infile, 'r', encoding="utf-8"))
@@ -1903,7 +1921,9 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
1903
  indir = f'./training/{voice}/'
1904
  infile = f'{indir}/whisper.json'
1905
  if not os.path.exists(infile):
1906
- raise Exception(f"Missing dataset: {infile}")
 
 
1907
 
1908
  results = json.load(open(infile, 'r', encoding="utf-8"))
1909
 
 
75
 
76
  VALLE_ENABLED = True
77
  except Exception as e:
78
+ print(e)
79
  pass
80
 
81
  if VALLE_ENABLED:
 
157
 
158
  voice_cache = {}
159
  def fetch_voice( voice ):
160
+ voice_dir = f'./training/{voice}/audio/'
161
+ if not os.path.isdir(voice_dir):
162
+ voice_dir = f'./voices/{voice}/'
163
  files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
164
+ # return files
165
+ return random.choice(files)
166
 
167
  def get_settings( override=None ):
168
  settings = {
 
1092
  'ar-quarter.lr', 'nar-quarter.lr',
1093
  ]
1094
  keys['losses'] = [
1095
+ 'ar.loss', 'nar.loss', 'ar+nar.loss',
1096
+ 'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
1097
+ 'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
1098
 
1099
+ # 'ar.loss.nll', 'nar.loss.nll',
1100
+ # 'ar-half.loss.nll', 'nar-half.loss.nll',
1101
+ # 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
1102
  ]
1103
 
1104
  keys['accuracies'] = [
 
1126
 
1127
  prefix = ""
1128
 
1129
+ if "mode" in self.info and self.info["mode"] == "validation":
1130
  prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
1131
 
1132
  self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
 
1234
 
1235
  unq = {}
1236
  averager = None
1237
+ prev_state = 0
1238
 
1239
  for log in logs:
1240
  with open(log, 'r', encoding="utf-8") as f:
 
1254
 
1255
  name = "train"
1256
  mode = "training"
1257
+ prev_state = 0
1258
  elif line.find('Validation Metrics:') >= 0:
1259
  data = json.loads(line.split("Validation Metrics:")[-1])
1260
  if "it" not in data:
 
1262
  if "epoch" not in data:
1263
  data['epoch'] = epoch
1264
 
1265
+ # name = data['name'] if 'name' in data else "val"
1266
  mode = "validation"
1267
+
1268
+ if prev_state == 0:
1269
+ name = "subtrain"
1270
+ else:
1271
+ name = "val"
1272
+
1273
+ prev_state += 1
1274
  else:
1275
  continue
1276
 
 
1284
  if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
1285
  averager = {
1286
  'key': f'{it}_{name}',
1287
+ 'name': name,
1288
  'mode': mode,
1289
  "metrics": {}
1290
  }
 
1305
  if update and it <= self.last_info_check_at:
1306
  continue
1307
 
1308
+ blacklist = [ "batch", "eval" ]
1309
  for it in unq:
1310
  if args.tts_backend == "vall-e":
1311
  stats = unq[it]
1312
+ data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist }
1313
+ data['name'] = stats['name']
1314
+ data['mode'] = stats['mode']
1315
  data['steps'] = len(stats['metrics']['it'])
1316
  else:
1317
  data = unq[it]
 
1648
 
1649
  device = "cuda" if get_device_name() == "cuda" else "cpu"
1650
  if whisper_vad:
1651
+ # omits a considerable amount of the end
1652
  """
1653
  if args.whisper_batchsize > 1:
1654
  result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
 
1794
  messages = []
1795
 
1796
  if not os.path.exists(infile):
1797
+ message = f"Missing dataset: {infile}"
1798
+ print(message)
1799
+ return message
1800
 
1801
  if results is None:
1802
  results = json.load(open(infile, 'r', encoding="utf-8"))
 
1921
  indir = f'./training/{voice}/'
1922
  infile = f'{indir}/whisper.json'
1923
  if not os.path.exists(infile):
1924
+ message = f"Missing dataset: {infile}"
1925
+ print(message)
1926
+ return message
1927
 
1928
  results = json.load(open(infile, 'r', encoding="utf-8"))
1929
 
src/webui.py CHANGED
@@ -196,6 +196,50 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
196
  def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
197
  return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
200
  messages = []
201
 
@@ -468,6 +512,8 @@ def setup_gradio():
468
  DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
469
 
470
  transcribe_button = gr.Button(value="Transcribe and Process")
 
 
471
 
472
  with gr.Row():
473
  slice_dataset_button = gr.Button(value="(Re)Slice Audio")
@@ -579,7 +625,7 @@ def setup_gradio():
579
  tooltip=['epoch', 'it', 'value', 'type'],
580
  width=500,
581
  height=350,
582
- visible=args.tts_backend=="vall-e"
583
  )
584
  view_losses = gr.Button(value="View Losses")
585
 
@@ -611,10 +657,7 @@ def setup_gradio():
611
  # EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
612
 
613
  with gr.Column(visible=args.tts_backend=="vall-e"):
614
- default_valle_model_choice = ""
615
- if len(valle_models):
616
- default_valle_model_choice = valle_models[0]
617
- EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else default_valle_model_choice)
618
 
619
  with gr.Column(visible=args.tts_backend=="tortoise"):
620
  EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
@@ -859,6 +902,16 @@ def setup_gradio():
859
  inputs=dataset_settings,
860
  outputs=prepare_dataset_output #console_output
861
  )
 
 
 
 
 
 
 
 
 
 
862
  prepare_dataset_button.click(
863
  prepare_dataset,
864
  inputs=[
 
196
  def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
197
  return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
198
 
199
+ def diarize_dataset( voice, progress=gr.Progress(track_tqdm=False) ):
200
+ from pyannote.audio import Pipeline
201
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=args.hf_token)
202
+
203
+ messages = []
204
+ files = sorted( get_voices(load_latents=False)[voice] )
205
+ for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
206
+ diarization = pipeline(file)
207
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
208
+ message = f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}"
209
+ print(message)
210
+ messages.append(message)
211
+
212
+ return "\n".join(messages)
213
+
214
+ def prepare_all_datasets( language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
215
+ kwargs = locals()
216
+
217
+ messages = []
218
+ voices = get_voice_list()
219
+
220
+ """
221
+ for voice in voices:
222
+ message = prepare_dataset_proxy(voice, **kwargs)
223
+ messages.append(message)
224
+ """
225
+ for voice in voices:
226
+ print("Processing:", voice)
227
+ message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
228
+ messages.append(message)
229
+
230
+ if slice_audio:
231
+ for voice in voices:
232
+ print("Processing:", voice)
233
+ message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
234
+ messages.append(message)
235
+
236
+ for voice in voices:
237
+ print("Processing:", voice)
238
+ message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length, progress=progress )
239
+ messages.append(message)
240
+
241
+ return "\n".join(messages)
242
+
243
  def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
244
  messages = []
245
 
 
512
  DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
513
 
514
  transcribe_button = gr.Button(value="Transcribe and Process")
515
+ transcribe_all_button = gr.Button(value="Transcribe All")
516
+ diarize_button = gr.Button(value="Diarize")
517
 
518
  with gr.Row():
519
  slice_dataset_button = gr.Button(value="(Re)Slice Audio")
 
625
  tooltip=['epoch', 'it', 'value', 'type'],
626
  width=500,
627
  height=350,
628
+ visible=False, # args.tts_backend=="vall-e"
629
  )
630
  view_losses = gr.Button(value="View Losses")
631
 
 
657
  # EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
658
 
659
  with gr.Column(visible=args.tts_backend=="vall-e"):
660
+ EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
 
 
 
661
 
662
  with gr.Column(visible=args.tts_backend=="tortoise"):
663
  EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
 
902
  inputs=dataset_settings,
903
  outputs=prepare_dataset_output #console_output
904
  )
905
+ transcribe_all_button.click(
906
+ prepare_all_datasets,
907
+ inputs=dataset_settings[1:],
908
+ outputs=prepare_dataset_output #console_output
909
+ )
910
+ diarize_button.click(
911
+ diarize_dataset,
912
+ inputs=dataset_settings[0],
913
+ outputs=prepare_dataset_output #console_output
914
+ )
915
  prepare_dataset_button.click(
916
  prepare_dataset,
917
  inputs=[