sanchit-gandhi HF staff commited on
Commit
0353d01
1 Parent(s): e90fbd2
Files changed (1) hide show
  1. run_xtreme_s.py +46 -8
run_xtreme_s.py CHANGED
@@ -136,6 +136,10 @@ class ModelArguments:
136
  metadata={"help": "Length of vector span to mask along the feature axis."},
137
  )
138
  layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
 
 
 
 
139
  ctc_loss_reduction: Optional[str] = field(
140
  default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
141
  )
@@ -166,6 +170,15 @@ class DataTrainingArguments:
166
  default="all",
167
  metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
168
  )
 
 
 
 
 
 
 
 
 
169
  train_split_name: str = field(
170
  default="train",
171
  metadata={
@@ -441,6 +454,11 @@ def main():
441
  "config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
442
  " for multi-lingual fine-tuning."
443
  )
 
 
 
 
 
444
 
445
  if data_args.target_column_name is None:
446
  target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
@@ -502,11 +520,23 @@ def main():
502
  if data_args.max_predict_samples is not None:
503
  raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
504
 
 
505
  if not is_text_target:
506
  label_list = next(iter(raw_datasets.values())).features[target_column_name].names
507
- lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
508
  num_labels = len(label_list)
509
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  # 2. We remove some special characters from the datasets
511
  # that make training complicated and do not help in transcribing the speech
512
  # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
@@ -616,6 +646,7 @@ def main():
616
  "mask_feature_length": model_args.mask_feature_length,
617
  "gradient_checkpointing": training_args.gradient_checkpointing,
618
  "layerdrop": model_args.layerdrop,
 
619
  "ctc_loss_reduction": model_args.ctc_loss_reduction,
620
  "activation_dropout": model_args.activation_dropout,
621
  }
@@ -677,7 +708,6 @@ def main():
677
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
678
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
679
  audio_column_name = data_args.audio_column_name
680
- num_workers = data_args.preprocessing_num_workers
681
 
682
  # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
683
  phoneme_language = data_args.phoneme_language
@@ -742,13 +772,13 @@ def main():
742
  logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
743
  return
744
 
745
- def compute_asr_metric(pred):
746
- pred_logits = pred.predictions
747
- pred_ids = np.argmax(pred_logits, axis=-1)
748
 
 
749
  pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
750
 
751
- pred_str = tokenizer.batch_decode(pred_ids)
752
  # we do not want to group tokens when computing the metrics
753
  label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
754
 
@@ -785,6 +815,7 @@ def main():
785
  model=model,
786
  data_collator=data_collator,
787
  args=training_args,
 
788
  compute_metrics=compute_asr_metric if training_args.predict_with_generate else None,
789
  train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
790
  eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
@@ -795,6 +826,7 @@ def main():
795
  model=model,
796
  data_collator=data_collator,
797
  args=training_args,
 
798
  compute_metrics=compute_asr_metric if is_text_target else compute_classification_metric,
799
  train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
800
  eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
@@ -839,11 +871,17 @@ def main():
839
  average_metrics = defaultdict(list)
840
  for lang_id in range(len(lang_list)):
841
  lang_name = lang_list[lang_id]
842
- lang_dataset = vectorized_datasets["predict"].filter(lambda example: example["lang"] == lang_id)
 
 
 
 
 
843
  lang_metrics = trainer.evaluate(lang_dataset)
 
844
  for metric_name, value in lang_metrics.items():
845
  average_metrics[metric_name].append(value)
846
- if metric_name not in ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second"]:
847
  metrics[f"{metric_name}_{lang_name}"] = value
848
  for metric_name, value in average_metrics.items():
849
  metrics[metric_name] = np.mean(value)
 
136
  metadata={"help": "Length of vector span to mask along the feature axis."},
137
  )
138
  layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
139
+ ctc_zero_infinity: bool = field(
140
+ default=False,
141
+ metadata={"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`."},
142
+ )
143
  ctc_loss_reduction: Optional[str] = field(
144
  default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
145
  )
 
170
  default="all",
171
  metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
172
  )
173
+ language_group: str = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "The language group to select a subset of languages to train on. "
177
+ "This option is only used the 'fleurs-asr' task. Should be one of: "
178
+ "'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', "
179
+ "'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'."
180
+ },
181
+ )
182
  train_split_name: str = field(
183
  default="train",
184
  metadata={
 
454
  "config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
455
  " for multi-lingual fine-tuning."
456
  )
457
+ if data_args.language_group is not None:
458
+ if data_args.task != "fleurs-asr":
459
+ raise ValueError("--language_group should only be used with --task=fleurs-asr")
460
+ if data_args.language != "all":
461
+ raise ValueError("--language_group should only be used with --language=all")
462
 
463
  if data_args.target_column_name is None:
464
  target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
 
520
  if data_args.max_predict_samples is not None:
521
  raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
522
 
523
+ lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
524
  if not is_text_target:
525
  label_list = next(iter(raw_datasets.values())).features[target_column_name].names
 
526
  num_labels = len(label_list)
527
 
528
+ num_workers = data_args.preprocessing_num_workers
529
+
530
+ lang_group = data_args.language_group
531
+ if lang_group is not None:
532
+ with training_args.main_process_first(desc="language group filter"):
533
+ lang_group_id = next(iter(raw_datasets.values())).features["lang_group_id"].str2int(lang_group)
534
+ raw_datasets = raw_datasets.filter(
535
+ lambda lang_group: lang_group == lang_group_id,
536
+ num_proc=num_workers,
537
+ input_columns=["lang_group_id"],
538
+ )
539
+
540
  # 2. We remove some special characters from the datasets
541
  # that make training complicated and do not help in transcribing the speech
542
  # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
 
646
  "mask_feature_length": model_args.mask_feature_length,
647
  "gradient_checkpointing": training_args.gradient_checkpointing,
648
  "layerdrop": model_args.layerdrop,
649
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
650
  "ctc_loss_reduction": model_args.ctc_loss_reduction,
651
  "activation_dropout": model_args.activation_dropout,
652
  }
 
708
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
709
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
710
  audio_column_name = data_args.audio_column_name
 
711
 
712
  # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
713
  phoneme_language = data_args.phoneme_language
 
772
  logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
773
  return
774
 
775
+ def asr_logits_argmax(logits, labels):
776
+ return logits.argmax(dim=-1)
 
777
 
778
+ def compute_asr_metric(pred):
779
  pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
780
 
781
+ pred_str = tokenizer.batch_decode(pred.predictions)
782
  # we do not want to group tokens when computing the metrics
783
  label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
784
 
 
815
  model=model,
816
  data_collator=data_collator,
817
  args=training_args,
818
+ preprocess_logits_for_metrics=asr_logits_argmax if training_args.predict_with_generate else None,
819
  compute_metrics=compute_asr_metric if training_args.predict_with_generate else None,
820
  train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
821
  eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
 
826
  model=model,
827
  data_collator=data_collator,
828
  args=training_args,
829
+ preprocess_logits_for_metrics=asr_logits_argmax if is_text_target else None,
830
  compute_metrics=compute_asr_metric if is_text_target else compute_classification_metric,
831
  train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
832
  eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
 
871
  average_metrics = defaultdict(list)
872
  for lang_id in range(len(lang_list)):
873
  lang_name = lang_list[lang_id]
874
+ with training_args.main_process_first(desc="per-language dataset filter"):
875
+ lang_dataset = vectorized_datasets["predict"].filter(
876
+ lambda lang: lang == lang_id,
877
+ num_proc=num_workers,
878
+ input_columns=["lang"],
879
+ )
880
  lang_metrics = trainer.evaluate(lang_dataset)
881
+ redundant_metrics = ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second", "eval_epoch"]
882
  for metric_name, value in lang_metrics.items():
883
  average_metrics[metric_name].append(value)
884
+ if metric_name not in redundant_metrics:
885
  metrics[f"{metric_name}_{lang_name}"] = value
886
  for metric_name, value in average_metrics.items():
887
  metrics[metric_name] = np.mean(value)