marinone94
commited on
Commit
•
e023bb4
1
Parent(s):
b9bdaf4
debug dataset
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -312,8 +312,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
312 |
|
313 |
# if bos token is appended in previous tokenization step,
|
314 |
# cut bos token here as it's append later anyways
|
315 |
-
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
316 |
-
|
317 |
# lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
318 |
# # Replace language and task if they are in the beginning, otherwise add them
|
319 |
# if (labels[:, 1] == self.task_id).all().cpu().item():
|
@@ -325,8 +325,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
325 |
# labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
|
326 |
|
327 |
# Set language and task to pad token
|
328 |
-
labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
329 |
-
labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
330 |
|
331 |
batch["labels"] = labels
|
332 |
|
@@ -458,6 +458,15 @@ def load_maybe_streaming_dataset(
|
|
458 |
return dataset
|
459 |
|
460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
def main():
|
462 |
# 1. Parse input arguments
|
463 |
# See all possible arguments in src/transformers/training_args.py
|
@@ -592,6 +601,7 @@ def main():
|
|
592 |
use_auth_token=hf_token if model_args.use_auth_token else None
|
593 |
)
|
594 |
|
|
|
595 |
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
596 |
|
597 |
if training_args.gradient_checkpointing:
|
@@ -758,6 +768,7 @@ def main():
|
|
758 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
759 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
760 |
|
|
|
761 |
# Initialize Trainer
|
762 |
logger.info("*** Init trainer ***")
|
763 |
trainer = Seq2SeqTrainer(
|
@@ -775,6 +786,7 @@ def main():
|
|
775 |
# 12. Training
|
776 |
if training_args.do_train:
|
777 |
logger.info("*** Train ***")
|
|
|
778 |
checkpoint = None
|
779 |
if training_args.resume_from_checkpoint is not None:
|
780 |
checkpoint = training_args.resume_from_checkpoint
|
@@ -817,7 +829,8 @@ def main():
|
|
817 |
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
818 |
logger.info("Before setting language and task")
|
819 |
logger.info(f"{pred_labels}")
|
820 |
-
trainer.
|
|
|
821 |
preds = processor.batch_decode(predictions.predictions)
|
822 |
labels = processor.batch_decode(predictions.label_ids)
|
823 |
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
@@ -828,7 +841,7 @@ def main():
|
|
828 |
results = {}
|
829 |
if training_args.do_eval:
|
830 |
logger.info("*** Evaluate ***")
|
831 |
-
|
832 |
metrics = trainer.evaluate(
|
833 |
metric_key_prefix="eval",
|
834 |
max_length=training_args.generation_max_length,
|
|
|
312 |
|
313 |
# if bos token is appended in previous tokenization step,
|
314 |
# cut bos token here as it's append later anyways
|
315 |
+
# if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
316 |
+
# labels = labels[:, 1:]
|
317 |
# lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
318 |
# # Replace language and task if they are in the beginning, otherwise add them
|
319 |
# if (labels[:, 1] == self.task_id).all().cpu().item():
|
|
|
325 |
# labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
|
326 |
|
327 |
# Set language and task to pad token
|
328 |
+
# labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
329 |
+
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
330 |
|
331 |
batch["labels"] = labels
|
332 |
|
|
|
458 |
return dataset
|
459 |
|
460 |
|
461 |
+
def print_data_samples(dataset, processor, max_samples=5):
|
462 |
+
shown_samples = 0
|
463 |
+
for batch in dataset:
|
464 |
+
print("Target: ", tokenizer.batch_decode(batch["labels"]))
|
465 |
+
shown_samples += len(batch)
|
466 |
+
if shown_samples >= max_samples:
|
467 |
+
break
|
468 |
+
|
469 |
+
|
470 |
def main():
|
471 |
# 1. Parse input arguments
|
472 |
# See all possible arguments in src/transformers/training_args.py
|
|
|
601 |
use_auth_token=hf_token if model_args.use_auth_token else None
|
602 |
)
|
603 |
|
604 |
+
# Forced decoder ids will be overwritten before evaluation
|
605 |
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
606 |
|
607 |
if training_args.gradient_checkpointing:
|
|
|
768 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
769 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
770 |
|
771 |
+
|
772 |
# Initialize Trainer
|
773 |
logger.info("*** Init trainer ***")
|
774 |
trainer = Seq2SeqTrainer(
|
|
|
786 |
# 12. Training
|
787 |
if training_args.do_train:
|
788 |
logger.info("*** Train ***")
|
789 |
+
print_data_samples(vectorized_datasets["train"], processor)
|
790 |
checkpoint = None
|
791 |
if training_args.resume_from_checkpoint is not None:
|
792 |
checkpoint = training_args.resume_from_checkpoint
|
|
|
829 |
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
830 |
logger.info("Before setting language and task")
|
831 |
logger.info(f"{pred_labels}")
|
832 |
+
trainer.model.config.forced_decoder_ids = \
|
833 |
+
processor.get_decoder_prompt_ids(language=data_args.language_eval, task=data_args.task, no_timestamps=True)
|
834 |
preds = processor.batch_decode(predictions.predictions)
|
835 |
labels = processor.batch_decode(predictions.label_ids)
|
836 |
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
|
|
841 |
results = {}
|
842 |
if training_args.do_eval:
|
843 |
logger.info("*** Evaluate ***")
|
844 |
+
print_data_samples(vectorized_datasets["eval"], processor)
|
845 |
metrics = trainer.evaluate(
|
846 |
metric_key_prefix="eval",
|
847 |
max_length=training_args.generation_max_length,
|