marinone94 commited on
Commit
e023bb4
1 Parent(s): b9bdaf4

debug dataset

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -312,8 +312,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
312
 
313
  # if bos token is appended in previous tokenization step,
314
  # cut bos token here as it's append later anyways
315
- if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
316
- labels = labels[:, 1:]
317
  # lang_token_ids = self.processor.tokenizer(lang_features).input_ids
318
  # # Replace language and task if they are in the beginning, otherwise add them
319
  # if (labels[:, 1] == self.task_id).all().cpu().item():
@@ -325,8 +325,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
325
  # labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
326
 
327
  # Set language and task to pad token
328
- labels[:, 0] = torch.full_like(labels[:, 0], -100)
329
- labels[:, 1] = torch.full_like(labels[:, 1], -100)
330
 
331
  batch["labels"] = labels
332
 
@@ -458,6 +458,15 @@ def load_maybe_streaming_dataset(
458
  return dataset
459
 
460
 
 
 
 
 
 
 
 
 
 
461
  def main():
462
  # 1. Parse input arguments
463
  # See all possible arguments in src/transformers/training_args.py
@@ -592,6 +601,7 @@ def main():
592
  use_auth_token=hf_token if model_args.use_auth_token else None
593
  )
594
 
 
595
  config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
596
 
597
  if training_args.gradient_checkpointing:
@@ -758,6 +768,7 @@ def main():
758
  elif isinstance(train_dataloader.dataset, IterableDataset):
759
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
760
 
 
761
  # Initialize Trainer
762
  logger.info("*** Init trainer ***")
763
  trainer = Seq2SeqTrainer(
@@ -775,6 +786,7 @@ def main():
775
  # 12. Training
776
  if training_args.do_train:
777
  logger.info("*** Train ***")
 
778
  checkpoint = None
779
  if training_args.resume_from_checkpoint is not None:
780
  checkpoint = training_args.resume_from_checkpoint
@@ -817,7 +829,8 @@ def main():
817
  pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
818
  logger.info("Before setting language and task")
819
  logger.info(f"{pred_labels}")
820
- trainer.data_collator.processor.tokenizer.set_prefix_tokens(language=data_args.language_eval, task=data_args.task)
 
821
  preds = processor.batch_decode(predictions.predictions)
822
  labels = processor.batch_decode(predictions.label_ids)
823
  pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
@@ -828,7 +841,7 @@ def main():
828
  results = {}
829
  if training_args.do_eval:
830
  logger.info("*** Evaluate ***")
831
-
832
  metrics = trainer.evaluate(
833
  metric_key_prefix="eval",
834
  max_length=training_args.generation_max_length,
 
312
 
313
  # if bos token is appended in previous tokenization step,
314
  # cut bos token here as it's append later anyways
315
+ # if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
316
+ # labels = labels[:, 1:]
317
  # lang_token_ids = self.processor.tokenizer(lang_features).input_ids
318
  # # Replace language and task if they are in the beginning, otherwise add them
319
  # if (labels[:, 1] == self.task_id).all().cpu().item():
 
325
  # labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
326
 
327
  # Set language and task to pad token
328
+ # labels[:, 0] = torch.full_like(labels[:, 0], -100)
329
+ # labels[:, 1] = torch.full_like(labels[:, 1], -100)
330
 
331
  batch["labels"] = labels
332
 
 
458
  return dataset
459
 
460
 
461
+ def print_data_samples(dataset, processor, max_samples=5):
462
+ shown_samples = 0
463
+ for batch in dataset:
464
+ print("Target: ", tokenizer.batch_decode(batch["labels"]))
465
+ shown_samples += len(batch)
466
+ if shown_samples >= max_samples:
467
+ break
468
+
469
+
470
  def main():
471
  # 1. Parse input arguments
472
  # See all possible arguments in src/transformers/training_args.py
 
601
  use_auth_token=hf_token if model_args.use_auth_token else None
602
  )
603
 
604
+ # Forced decoder ids will be overwritten before evaluation
605
  config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
606
 
607
  if training_args.gradient_checkpointing:
 
768
  elif isinstance(train_dataloader.dataset, IterableDataset):
769
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
770
 
771
+
772
  # Initialize Trainer
773
  logger.info("*** Init trainer ***")
774
  trainer = Seq2SeqTrainer(
 
786
  # 12. Training
787
  if training_args.do_train:
788
  logger.info("*** Train ***")
789
+ print_data_samples(vectorized_datasets["train"], processor)
790
  checkpoint = None
791
  if training_args.resume_from_checkpoint is not None:
792
  checkpoint = training_args.resume_from_checkpoint
 
829
  pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
830
  logger.info("Before setting language and task")
831
  logger.info(f"{pred_labels}")
832
+ trainer.model.config.forced_decoder_ids = \
833
+ processor.get_decoder_prompt_ids(language=data_args.language_eval, task=data_args.task, no_timestamps=True)
834
  preds = processor.batch_decode(predictions.predictions)
835
  labels = processor.batch_decode(predictions.label_ids)
836
  pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
 
841
  results = {}
842
  if training_args.do_eval:
843
  logger.info("*** Evaluate ***")
844
+ print_data_samples(vectorized_datasets["eval"], processor)
845
  metrics = trainer.evaluate(
846
  metric_key_prefix="eval",
847
  max_length=training_args.generation_max_length,