marinone94 commited on
Commit
21f22fe
·
1 Parent(s): 8f1a9b5

final swedish training

Browse files
Files changed (2) hide show
  1. run.sh +9 -8
  2. run_speech_recognition_seq2seq_streaming.py +197 -23
run.sh CHANGED
@@ -1,12 +1,14 @@
1
  python run_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path="marinone94/whisper-medium-nordic" \
3
- --dataset_name="mozilla-foundation/common_voice_11_0" \
4
- --dataset_config_name="sv-SE" \
5
  --language="swedish" \
6
- --train_split_name="train+validation" \
 
 
7
  --eval_split_name="test" \
8
  --model_index_name="Whisper Medium Swedish" \
9
- --max_steps="2500" \
10
  --output_dir="./" \
11
  --per_device_train_batch_size="32" \
12
  --per_device_eval_batch_size="16" \
@@ -20,9 +22,9 @@ python run_speech_recognition_seq2seq_streaming.py \
20
  --generation_max_length="225" \
21
  --length_column_name="input_length" \
22
  --max_duration_in_seconds="30" \
23
- --text_column_name="sentence" \
24
  --freeze_feature_encoder="False" \
25
- --report_to="tensorboard" \
26
  --metric_for_best_model="wer" \
27
  --greater_is_better="False" \
28
  --load_best_model_at_end \
@@ -34,5 +36,4 @@ python run_speech_recognition_seq2seq_streaming.py \
34
  --predict_with_generate \
35
  --do_normalize_eval \
36
  --streaming \
37
- --use_auth_token \
38
- --push_to_hub
 
1
  python run_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path="marinone94/whisper-medium-nordic" \
3
+ --dataset_train_name="mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,google/fleurs" \
4
+ --dataset_train_config_name="sv-SE,nst,sv_se" \
5
  --language="swedish" \
6
+ --train_split_name="train+validation,train,train+validation+test" \
7
+ --dataset_eval_name="mozilla-foundation/common_voice_11_0" \
8
+ --dataset_eval_config_name="sv-SE" \
9
  --eval_split_name="test" \
10
  --model_index_name="Whisper Medium Swedish" \
11
+ --max_steps="5000" \
12
  --output_dir="./" \
13
  --per_device_train_batch_size="32" \
14
  --per_device_eval_batch_size="16" \
 
22
  --generation_max_length="225" \
23
  --length_column_name="input_length" \
24
  --max_duration_in_seconds="30" \
25
+ --text_column_name="sentence,raw_transcription" \
26
  --freeze_feature_encoder="False" \
27
+ --report_to="wandb" \
28
  --metric_for_best_model="wer" \
29
  --greater_is_better="False" \
30
  --load_best_model_at_end \
 
36
  --predict_with_generate \
37
  --do_normalize_eval \
38
  --streaming \
39
+ --use_auth_token
 
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -20,6 +20,7 @@ with 🤗 Datasets' streaming mode.
20
  # You can also adapt this script for your own sequence to sequence speech
21
  # recognition task. Pointers for this are left as comments.
22
 
 
23
  import logging
24
  import os
25
  import sys
@@ -28,6 +29,7 @@ from typing import Any, Dict, List, Optional, Union
28
 
29
  import datasets
30
  import torch
 
31
  from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
32
  from torch.utils.data import IterableDataset
33
 
@@ -60,6 +62,42 @@ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/spe
60
  logger = logging.getLogger(__name__)
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @dataclass
64
  class ModelArguments:
65
  """
@@ -265,27 +303,131 @@ class DataCollatorSpeechSeq2SeqWithPadding:
265
  return batch
266
 
267
 
268
- def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  """
270
  Utility function to load a dataset in streaming mode. For datasets with multiple splits,
271
  each split is loaded individually and then splits combined by taking alternating examples from
272
  each (interleaving).
273
  """
274
- if "+" in split:
 
 
 
 
 
275
  # load multiple splits separated by the `+` symbol with streaming mode
276
- dataset_splits = [
277
- load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
278
- for split_name in split.split("+")
279
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # interleave multiple splits to form one dataset
281
- interleaved_dataset = interleave_datasets(dataset_splits)
282
  return interleaved_dataset
283
  else:
284
  # load a single split *with* streaming mode
285
- dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
 
 
 
 
 
 
 
 
 
286
  return dataset
287
 
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  def main():
290
  # 1. Parse input arguments
291
  # See all possible arguments in src/transformers/training_args.py
@@ -349,25 +491,41 @@ def main():
349
  # Set seed before initializing model.
350
  set_seed(training_args.seed)
351
 
 
 
 
 
 
 
 
 
352
  # 4. Load dataset
353
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
354
 
355
  if training_args.do_train:
356
  raw_datasets["train"] = load_maybe_streaming_dataset(
357
- data_args.dataset_name,
358
- data_args.dataset_config_name,
359
  split=data_args.train_split_name,
360
- use_auth_token=True if model_args.use_auth_token else None,
361
  streaming=data_args.streaming,
 
 
 
 
362
  )
363
 
364
  if training_args.do_eval:
365
  raw_datasets["eval"] = load_maybe_streaming_dataset(
366
- data_args.dataset_name,
367
- data_args.dataset_config_name,
368
  split=data_args.eval_split_name,
369
- use_auth_token=True if model_args.use_auth_token else None,
370
  streaming=data_args.streaming,
 
 
 
 
371
  )
372
 
373
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
@@ -394,7 +552,7 @@ def main():
394
  model_args.config_name if model_args.config_name else model_args.model_name_or_path,
395
  cache_dir=model_args.cache_dir,
396
  revision=model_args.model_revision,
397
- use_auth_token=True if model_args.use_auth_token else None,
398
  )
399
 
400
  config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
@@ -402,25 +560,19 @@ def main():
402
  if training_args.gradient_checkpointing:
403
  config.update({"use_cache": False})
404
 
405
- feature_extractor = AutoFeatureExtractor.from_pretrained(
406
- model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
407
- cache_dir=model_args.cache_dir,
408
- revision=model_args.model_revision,
409
- use_auth_token=True if model_args.use_auth_token else None,
410
- )
411
  tokenizer = AutoTokenizer.from_pretrained(
412
  model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
413
  cache_dir=model_args.cache_dir,
414
  use_fast=model_args.use_fast_tokenizer,
415
  revision=model_args.model_revision,
416
- use_auth_token=True if model_args.use_auth_token else None,
417
  )
418
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
419
  model_args.model_name_or_path,
420
  config=config,
421
  cache_dir=model_args.cache_dir,
422
  revision=model_args.model_revision,
423
- use_auth_token=True if model_args.use_auth_token else None,
424
  )
425
 
426
  if model.config.decoder_start_token_id is None:
@@ -568,6 +720,9 @@ def main():
568
  callbacks=[ShuffleCallback()] if data_args.streaming else None,
569
  )
570
 
 
 
 
571
  # 12. Training
572
  if training_args.do_train:
573
  checkpoint = None
@@ -617,10 +772,29 @@ def main():
617
  if model_args.model_index_name is not None:
618
  kwargs["model_name"] = model_args.model_index_name
619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  if training_args.push_to_hub:
621
  trainer.push_to_hub(**kwargs)
622
  else:
623
  trainer.create_model_card(**kwargs)
 
 
 
 
 
624
 
625
  return results
626
 
 
20
  # You can also adapt this script for your own sequence to sequence speech
21
  # recognition task. Pointers for this are left as comments.
22
 
23
+ import json
24
  import logging
25
  import os
26
  import sys
 
29
 
30
  import datasets
31
  import torch
32
+ import wandb
33
  from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
34
  from torch.utils.data import IterableDataset
35
 
 
62
  logger = logging.getLogger(__name__)
63
 
64
 
65
+ SENDING_NOTIFICATION = "*** Sending notification to email ***"
66
+ RECIPIENT_ADDRESS = "marinone94@gmail.com"
67
+
68
+ wandb_token = os.environ.get("WANDB_TOKEN", "None")
69
+ hf_token = os.environ.get("HF_TOKEN", None)
70
+ if (hf_token is None or wandb_token == "None") and os.path.exists("./creds.txt"):
71
+ with open("./creds.txt", "r") as f:
72
+ lines = f.readlines()
73
+ for line in lines:
74
+ key, value = line.split("=")
75
+ if key == "HF_TOKEN":
76
+ hf_token = value.strip()
77
+ if key == "WANDB_TOKEN":
78
+ wandb_token = value.strip()
79
+ if key == "EMAIL_ADDRESS":
80
+ os.environ["EMAIL_ADDRESS"] = value.strip()
81
+ if key == "EMAIL_PASSWORD":
82
+ os.environ["EMAIL_PASSWORD"] = value.strip()
83
+
84
+ if hf_token is not None:
85
+ try:
86
+ os.makedirs("/root/.huggingface", exist_ok=True)
87
+ with open("/root/.huggingface/token", "w") as f:
88
+ f.write(hf_token)
89
+ logger.info("Huggingface API key set")
90
+ except (PermissionError, OSError):
91
+ logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
92
+ else:
93
+ logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
94
+
95
+ wandb.login(key=wandb_token, relogin=True, timeout=5)
96
+ wandb.init(project="whisper", entity="pn-aa")
97
+
98
+ logger.info("Wandb API key set, logging to wandb")
99
+
100
+
101
  @dataclass
102
  class ModelArguments:
103
  """
 
303
  return batch
304
 
305
 
306
+ def rename_col_and_resample(dataset, dataset_name, text_column_names, text_col_name_ref, audio_column_name, sampling_rate):
307
+ raw_datasets_features = list(dataset.features.keys())
308
+ logger.info(f"Dataset {dataset_name} - Features: {raw_datasets_features}")
309
+
310
+ if text_col_name_ref not in raw_datasets_features:
311
+ if len(text_column_names) == 1:
312
+ raise ValueError("None of the text column names provided found in dataset."
313
+ f"Text columns: {text_column_names}"
314
+ f"Dataset columns: {raw_datasets_features}")
315
+ flag = False
316
+ for text_column_name in text_column_names:
317
+ if text_column_name in raw_datasets_features:
318
+ logger.info(f"Renaming text column {text_column_name} to {text_col_name_ref}")
319
+ dataset = dataset.rename_column(text_column_name, text_col_name_ref)
320
+ flag = True
321
+ break
322
+ if flag is False:
323
+ raise ValueError("None of the text column names provided found in dataset."
324
+ f"Text columns: {text_column_names}"
325
+ f"Dataset columns: {raw_datasets_features}")
326
+ if audio_column_name is not None and sampling_rate is not None:
327
+ ds_sr = int(dataset.features[audio_column_name].sampling_rate)
328
+ if ds_sr != sampling_rate:
329
+ dataset = dataset.cast_column(
330
+ audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
331
+ )
332
+
333
+ raw_datasets_features = list(dataset.features.keys())
334
+ raw_datasets_features.remove(audio_column_name)
335
+ raw_datasets_features.remove(text_col_name_ref)
336
+ # Keep only audio and sentence
337
+ dataset = dataset.remove_columns(column_names=raw_datasets_features)
338
+ return dataset
339
+
340
+
341
+ def load_maybe_streaming_dataset(
342
+ dataset_names,
343
+ dataset_config_names,
344
+ split="train",
345
+ streaming=True,
346
+ audio_column_name=None,
347
+ sampling_rate=None,
348
+ **kwargs
349
+ ):
350
  """
351
  Utility function to load a dataset in streaming mode. For datasets with multiple splits,
352
  each split is loaded individually and then splits combined by taking alternating examples from
353
  each (interleaving).
354
  """
355
+ text_column_names = None
356
+ if "text_column_name" in kwargs:
357
+ text_column_names = kwargs.pop("text_column_name").split(",")
358
+ text_col_name_ref = text_column_names[0]
359
+
360
+ if "," in dataset_names or "+" in split:
361
  # load multiple splits separated by the `+` symbol with streaming mode
362
+ dataset_splits = []
363
+ for dataset_name, dataset_config_name, split_names in zip(
364
+ dataset_names.split(","), dataset_config_names.split(","), split.split(",")
365
+ ):
366
+ for split_name in split_names.split("+"):
367
+ if dataset_config_name:
368
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
369
+ else:
370
+ dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
371
+
372
+ dataset = rename_col_and_resample(
373
+ dataset,
374
+ dataset_name,
375
+ text_column_names,
376
+ text_col_name_ref,
377
+ audio_column_name,
378
+ sampling_rate
379
+ )
380
+
381
+ dataset_splits.append(dataset)
382
+
383
  # interleave multiple splits to form one dataset
384
+ interleaved_dataset = interleave_datasets(dataset_splits, stopping_strategy="all_exhausted")
385
  return interleaved_dataset
386
  else:
387
  # load a single split *with* streaming mode
388
+
389
+ dataset = load_dataset(dataset_names, dataset_config_names, split=split, streaming=streaming, **kwargs)
390
+ dataset = rename_col_and_resample(
391
+ dataset,
392
+ dataset_names,
393
+ text_column_names,
394
+ text_col_name_ref,
395
+ audio_column_name,
396
+ sampling_rate
397
+ )
398
  return dataset
399
 
400
 
401
+ def notify_me(recipient, message=None):
402
+ """
403
+ Send an email to the specified address with the specified message
404
+ """
405
+ sender = os.environ.get("EMAIL_ADDRESS", None)
406
+ password = os.environ.get("EMAIL_PASSWORD", None)
407
+ if sender is None:
408
+ logging.warning("No email address specified, not sending notification")
409
+ if password is None:
410
+ logging.warning("No email password specified, not sending notification")
411
+ if message is None:
412
+ message = "Training is finished!"
413
+
414
+ if sender is not None:
415
+ import smtplib
416
+ from email.mime.text import MIMEText
417
+
418
+ msg = MIMEText(message)
419
+ msg["Subject"] = "Training updates..."
420
+ msg["From"] = "marinone.auto@gmail.com"
421
+ msg["To"] = recipient
422
+
423
+ # send the email
424
+ smtp_obj = smtplib.SMTP("smtp.gmail.com", 587)
425
+ smtp_obj.starttls()
426
+ smtp_obj.login(sender, password)
427
+ smtp_obj.sendmail(sender, recipient, msg.as_string())
428
+ smtp_obj.quit()
429
+
430
+
431
  def main():
432
  # 1. Parse input arguments
433
  # See all possible arguments in src/transformers/training_args.py
 
491
  # Set seed before initializing model.
492
  set_seed(training_args.seed)
493
 
494
+ # Load feature extractor
495
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
496
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
497
+ cache_dir=model_args.cache_dir,
498
+ revision=model_args.model_revision,
499
+ use_auth_token=hf_token if model_args.use_auth_token else None,
500
+ )
501
+
502
  # 4. Load dataset
503
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
504
 
505
  if training_args.do_train:
506
  raw_datasets["train"] = load_maybe_streaming_dataset(
507
+ data_args.dataset_train_name,
508
+ data_args.dataset_train_config_name,
509
  split=data_args.train_split_name,
510
+ use_auth_token=hf_token if model_args.use_auth_token else None,
511
  streaming=data_args.streaming,
512
+ text_column_name=data_args.text_column_name,
513
+ audio_column_name=data_args.audio_column_name,
514
+ sampling_rate=int(feature_extractor.sampling_rate),
515
+ # language=data_args.language_train
516
  )
517
 
518
  if training_args.do_eval:
519
  raw_datasets["eval"] = load_maybe_streaming_dataset(
520
+ data_args.dataset_eval_name,
521
+ data_args.dataset_eval_config_name,
522
  split=data_args.eval_split_name,
523
+ use_auth_token=hf_token if model_args.use_auth_token else None,
524
  streaming=data_args.streaming,
525
+ text_column_name=data_args.text_column_name,
526
+ audio_column_name=data_args.audio_column_name,
527
+ sampling_rate=int(feature_extractor.sampling_rate),
528
+ # language=data_args.language_eval
529
  )
530
 
531
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
 
552
  model_args.config_name if model_args.config_name else model_args.model_name_or_path,
553
  cache_dir=model_args.cache_dir,
554
  revision=model_args.model_revision,
555
+ use_auth_token=hf_token if model_args.use_auth_token else None,
556
  )
557
 
558
  config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
 
560
  if training_args.gradient_checkpointing:
561
  config.update({"use_cache": False})
562
 
 
 
 
 
 
 
563
  tokenizer = AutoTokenizer.from_pretrained(
564
  model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
565
  cache_dir=model_args.cache_dir,
566
  use_fast=model_args.use_fast_tokenizer,
567
  revision=model_args.model_revision,
568
+ use_auth_token=hf_token if model_args.use_auth_token else None,
569
  )
570
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
571
  model_args.model_name_or_path,
572
  config=config,
573
  cache_dir=model_args.cache_dir,
574
  revision=model_args.model_revision,
575
+ use_auth_token=hf_token if model_args.use_auth_token else None,
576
  )
577
 
578
  if model.config.decoder_start_token_id is None:
 
720
  callbacks=[ShuffleCallback()] if data_args.streaming else None,
721
  )
722
 
723
+ orig_push_to_hub = trainer.args.push_to_hub
724
+ trainer.args.push_to_hub = False
725
+
726
  # 12. Training
727
  if training_args.do_train:
728
  checkpoint = None
 
772
  if model_args.model_index_name is not None:
773
  kwargs["model_name"] = model_args.model_index_name
774
 
775
+ logger.info("*** Training stats written ***")
776
+ logger.info(json.dumps(kwargs, indent=4))
777
+
778
+ # Training complete notification
779
+ logger.info("*** Training and eval complete ***")
780
+ logger.info(SENDING_NOTIFICATION)
781
+ with open(os.path.join(training_args.output_dir, "train_results.json"), "r") as f:
782
+ train_results = json.load(f)
783
+ with open(os.path.join(training_args.output_dir, "eval_results.json"), "r") as f:
784
+ eval_results = json.load(f)
785
+ notify_me(recipient=RECIPIENT_ADDRESS,
786
+ message=f"Training complete! {train_results = } {eval_results = }")
787
+
788
+ trainer.args.push_to_hub = orig_push_to_hub
789
  if training_args.push_to_hub:
790
  trainer.push_to_hub(**kwargs)
791
  else:
792
  trainer.create_model_card(**kwargs)
793
+
794
+ with open(os.path.join(training_args.output_dir, "README.md"), "r") as f:
795
+ readme = f.read()
796
+ notify_me(recipient=RECIPIENT_ADDRESS,
797
+ message=f"Model pushed to hub! {readme = }")
798
 
799
  return results
800