versae commited on
Commit
28848ba
1 Parent(s): 74e7873

Fix train script for NPSC

Browse files
Files changed (1) hide show
  1. run_speech_recognition_ctc.py +21 -0
run_speech_recognition_ctc.py CHANGED
@@ -391,6 +391,23 @@ def main():
391
  # Set seed before initializing model.
392
  set_seed(training_args.seed)
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  # 1. First, let's load the dataset
395
  raw_datasets = DatasetDict()
396
 
@@ -401,6 +418,8 @@ def main():
401
  split=data_args.train_split_name,
402
  use_auth_token=data_args.use_auth_token,
403
  )
 
 
404
 
405
  if data_args.audio_column_name not in raw_datasets["train"].column_names:
406
  raise ValueError(
@@ -426,6 +445,8 @@ def main():
426
  split=data_args.eval_split_name,
427
  use_auth_token=data_args.use_auth_token,
428
  )
 
 
429
 
430
  if data_args.max_eval_samples is not None:
431
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
 
391
  # Set seed before initializing model.
392
  set_seed(training_args.seed)
393
 
394
+ # Pre-processing dataset
395
+ def preprocess_dataset(entry):
396
+ return (
397
+ "<INAUDIBLE>" not in entry["text"]
398
+ and entry["sentence_language_code"].lower() == "nb-no"
399
+ )
400
+
401
+ def map_dataset(entry):
402
+ return {"text": (entry["text"]
403
+ .lower()
404
+ .replace("<ee>", "eee")
405
+ .replace("<mm>", "mmm")
406
+ .replace("<qq>", "qqq")
407
+ .replace("ó", "o")
408
+ .replace("é", "e")
409
+ )}
410
+
411
  # 1. First, let's load the dataset
412
  raw_datasets = DatasetDict()
413
 
 
418
  split=data_args.train_split_name,
419
  use_auth_token=data_args.use_auth_token,
420
  )
421
+ raw_datasets["train"] = raw_datasets["train"].filter(preprocess_dataset)
422
+ raw_datasets["train"] = raw_datasets["train"].map(map_dataset)
423
 
424
  if data_args.audio_column_name not in raw_datasets["train"].column_names:
425
  raise ValueError(
 
445
  split=data_args.eval_split_name,
446
  use_auth_token=data_args.use_auth_token,
447
  )
448
+ raw_datasets["eval"] = raw_datasets["eval"].filter(preprocess_dataset)
449
+ raw_datasets["eval"] = raw_datasets["eval"].map(map_dataset)
450
 
451
  if data_args.max_eval_samples is not None:
452
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))