marinone94
commited on
Commit
•
ed53c37
1
Parent(s):
2786d7b
add logs
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -337,6 +337,7 @@ def main():
|
|
337 |
# See all possible arguments in src/transformers/training_args.py
|
338 |
# or by passing the --help flag to this script.
|
339 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
|
340 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
341 |
|
342 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
@@ -351,6 +352,7 @@ def main():
|
|
351 |
send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
|
352 |
|
353 |
# 2. Setup logging
|
|
|
354 |
logging.basicConfig(
|
355 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
356 |
datefmt="%m/%d/%Y %H:%M:%S",
|
@@ -396,6 +398,7 @@ def main():
|
|
396 |
set_seed(training_args.seed)
|
397 |
|
398 |
# 4. Load dataset
|
|
|
399 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
400 |
|
401 |
if training_args.do_train:
|
@@ -433,7 +436,7 @@ def main():
|
|
433 |
)
|
434 |
|
435 |
# 5. Load pretrained model, tokenizer, and feature extractor
|
436 |
-
|
437 |
# Distributed training:
|
438 |
# The .from_pretrained methods guarantee that only one local process can concurrently
|
439 |
config = AutoConfig.from_pretrained(
|
@@ -483,6 +486,7 @@ def main():
|
|
483 |
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
484 |
|
485 |
# 6. Resample speech dataset if necessary
|
|
|
486 |
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
487 |
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
488 |
raw_datasets = raw_datasets.cast_column(
|
@@ -491,6 +495,7 @@ def main():
|
|
491 |
|
492 |
# 7. Preprocessing the datasets.
|
493 |
# We need to read the audio files as arrays and tokenize the targets.
|
|
|
494 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
495 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
496 |
audio_column_name = data_args.audio_column_name
|
@@ -554,6 +559,7 @@ def main():
|
|
554 |
)
|
555 |
|
556 |
# 8. Load Metric
|
|
|
557 |
metric = evaluate.load("wer")
|
558 |
do_normalize_eval = data_args.do_normalize_eval
|
559 |
|
@@ -578,6 +584,7 @@ def main():
|
|
578 |
return {"wer": wer}
|
579 |
|
580 |
# 9. Create a single speech processor
|
|
|
581 |
if is_main_process(training_args.local_rank):
|
582 |
# save feature extractor, tokenizer and config
|
583 |
feature_extractor.save_pretrained(training_args.output_dir)
|
@@ -595,6 +602,7 @@ def main():
|
|
595 |
# 11. Configure Trainer
|
596 |
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
597 |
# Only required for streaming: Trainer automatically shuffles non-streaming datasets
|
|
|
598 |
class ShuffleCallback(TrainerCallback):
|
599 |
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
|
600 |
if isinstance(train_dataloader.dataset, IterableDatasetShard):
|
@@ -603,6 +611,7 @@ def main():
|
|
603 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
604 |
|
605 |
# Initialize Trainer
|
|
|
606 |
trainer = Seq2SeqTrainer(
|
607 |
model=model,
|
608 |
args=training_args,
|
@@ -616,6 +625,7 @@ def main():
|
|
616 |
|
617 |
# 12. Training
|
618 |
if training_args.do_train:
|
|
|
619 |
checkpoint = None
|
620 |
if training_args.resume_from_checkpoint is not None:
|
621 |
checkpoint = training_args.resume_from_checkpoint
|
@@ -663,14 +673,18 @@ def main():
|
|
663 |
if model_args.model_index_name is not None:
|
664 |
kwargs["model_name"] = model_args.model_index_name
|
665 |
|
|
|
666 |
if training_args.push_to_hub:
|
667 |
trainer.push_to_hub(**kwargs)
|
668 |
else:
|
669 |
trainer.create_model_card(**kwargs)
|
670 |
|
671 |
# Training complete notification
|
|
|
672 |
notify_me(recipient="marinone94@gmail.com", message=json.dumps(kwargs, indent=4))
|
673 |
|
|
|
|
|
674 |
return results
|
675 |
|
676 |
|
|
|
337 |
# See all possible arguments in src/transformers/training_args.py
|
338 |
# or by passing the --help flag to this script.
|
339 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
340 |
+
logger.info("*** Parse args ***")
|
341 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
342 |
|
343 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
|
352 |
send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
|
353 |
|
354 |
# 2. Setup logging
|
355 |
+
logger.info("*** Setup logging ***")
|
356 |
logging.basicConfig(
|
357 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
358 |
datefmt="%m/%d/%Y %H:%M:%S",
|
|
|
398 |
set_seed(training_args.seed)
|
399 |
|
400 |
# 4. Load dataset
|
401 |
+
logger.info("*** Load dataset ***")
|
402 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
403 |
|
404 |
if training_args.do_train:
|
|
|
436 |
)
|
437 |
|
438 |
# 5. Load pretrained model, tokenizer, and feature extractor
|
439 |
+
logger.info("*** Load pretrained model, tokenizer, and feature extractor ***")
|
440 |
# Distributed training:
|
441 |
# The .from_pretrained methods guarantee that only one local process can concurrently
|
442 |
config = AutoConfig.from_pretrained(
|
|
|
486 |
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
487 |
|
488 |
# 6. Resample speech dataset if necessary
|
489 |
+
logger.info("*** Resample dataset ***")
|
490 |
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
491 |
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
492 |
raw_datasets = raw_datasets.cast_column(
|
|
|
495 |
|
496 |
# 7. Preprocessing the datasets.
|
497 |
# We need to read the audio files as arrays and tokenize the targets.
|
498 |
+
logger.info("*** Preprocess dataset ***")
|
499 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
500 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
501 |
audio_column_name = data_args.audio_column_name
|
|
|
559 |
)
|
560 |
|
561 |
# 8. Load Metric
|
562 |
+
logger.info("*** Load metric ***")
|
563 |
metric = evaluate.load("wer")
|
564 |
do_normalize_eval = data_args.do_normalize_eval
|
565 |
|
|
|
584 |
return {"wer": wer}
|
585 |
|
586 |
# 9. Create a single speech processor
|
587 |
+
logger.info("*** Init processor ***")
|
588 |
if is_main_process(training_args.local_rank):
|
589 |
# save feature extractor, tokenizer and config
|
590 |
feature_extractor.save_pretrained(training_args.output_dir)
|
|
|
602 |
# 11. Configure Trainer
|
603 |
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
604 |
# Only required for streaming: Trainer automatically shuffles non-streaming datasets
|
605 |
+
logger.info("*** Set shuffle callback ***")
|
606 |
class ShuffleCallback(TrainerCallback):
|
607 |
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
|
608 |
if isinstance(train_dataloader.dataset, IterableDatasetShard):
|
|
|
611 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
612 |
|
613 |
# Initialize Trainer
|
614 |
+
logger.info("*** Init trainer ***")
|
615 |
trainer = Seq2SeqTrainer(
|
616 |
model=model,
|
617 |
args=training_args,
|
|
|
625 |
|
626 |
# 12. Training
|
627 |
if training_args.do_train:
|
628 |
+
logger.info("*** Train ***")
|
629 |
checkpoint = None
|
630 |
if training_args.resume_from_checkpoint is not None:
|
631 |
checkpoint = training_args.resume_from_checkpoint
|
|
|
673 |
if model_args.model_index_name is not None:
|
674 |
kwargs["model_name"] = model_args.model_index_name
|
675 |
|
676 |
+
logger.info("*** Pushing to hub ***")
|
677 |
if training_args.push_to_hub:
|
678 |
trainer.push_to_hub(**kwargs)
|
679 |
else:
|
680 |
trainer.create_model_card(**kwargs)
|
681 |
|
682 |
# Training complete notification
|
683 |
+
logger.info("*** Sending notification ***")
|
684 |
notify_me(recipient="marinone94@gmail.com", message=json.dumps(kwargs, indent=4))
|
685 |
|
686 |
+
logger.info("*** Training complete!!! ***")
|
687 |
+
|
688 |
return results
|
689 |
|
690 |
|