marinone94 commited on
Commit
a8eb6d3
1 Parent(s): ed0b2a7
Files changed (1) hide show
  1. run_speech_recognition_ctc.py +11 -0
run_speech_recognition_ctc.py CHANGED
@@ -15,6 +15,7 @@
15
 
16
  """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition"""
17
 
 
18
  import functools
19
  import json
20
  import logging
@@ -28,6 +29,7 @@ from typing import Dict, List, Optional, Union
28
  import datasets
29
  import numpy as np
30
  import torch
 
31
  from datasets import DatasetDict, load_dataset, load_metric
32
 
33
  import transformers
@@ -355,6 +357,15 @@ def main():
355
  else:
356
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
357
 
 
 
 
 
 
 
 
 
 
358
  # Detecting last checkpoint.
359
  last_checkpoint = None
360
  if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
 
15
 
16
  """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition"""
17
 
18
+ import datetime
19
  import functools
20
  import json
21
  import logging
 
29
  import datasets
30
  import numpy as np
31
  import torch
32
+ import wandb
33
  from datasets import DatasetDict, load_dataset, load_metric
34
 
35
  import transformers
 
357
  else:
358
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
359
 
360
+ # TODO: Replace with check of wandb env vars
361
+ try:
362
+ os.environ["WANDB_PROJECT"] = os.getcwd().split("/")[-1]
363
+ wandb.login()
364
+ training_args.report_to = ["wandb"]
365
+ training_args.run_name = f"{datetime.datetime.utcnow()}".replace(" ", "T")
366
+ except:
367
+ pass
368
+
369
  # Detecting last checkpoint.
370
  last_checkpoint = None
371
  if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: