pere commited on
Commit
8168661
1 Parent(s): f5c74a6
Files changed (2) hide show
  1. run_whisper_finetuning.py +4 -0
  2. run_xla_test.sh +46 -0
run_whisper_finetuning.py CHANGED
@@ -340,6 +340,10 @@ def main():
340
  parser = HfArgumentParser(
341
  (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
342
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
 
 
 
 
343
 
344
  # Metrics
345
 
 
340
  parser = HfArgumentParser(
341
  (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
342
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
343
+
344
+ #Debug
345
+ import torch_xla.debug.metrics as met
346
+ print(met.metrics_report())
347
 
348
  # Metrics
349
 
run_xla_test.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Whisper Finetuning script for the NST dataset
2
+ # This is a test script for XLA on TPU
3
+
4
+ python xla_spawn.py --num_cores=4 run_whisper_finetuning.py\
5
+ --model_name_or_path="openai/whisper-small" \
6
+ --output_dir="../whisper-NST-TPU-test2" \
7
+ --overwrite_output_dir=True \
8
+ --language="Norwegian" \
9
+ --task="transcribe" \
10
+ --dataset_name="NbAiLab/NST" \
11
+ --dataset_config="no-close" \
12
+ --do_train=True \
13
+ --do_eval=True \
14
+ --audio_column_name="audio" \
15
+ --text_column_name="text" \
16
+ --per_device_train_batch_size=16 \
17
+ --per_device_train_batch_size=16 \
18
+ --learning_rate=2e-5 \
19
+ --warmup_steps=0 \
20
+ --max_steps=10 \
21
+ --gradient_checkpointing=True \
22
+ --gradient_accumulation_steps=1 \
23
+ --group_by_length=False \
24
+ --evaluation_strategy="steps" \
25
+ --save_steps=10 \
26
+ --eval_steps=10 \
27
+ --max_eval_samples=2 \
28
+ --logging_steps=10 \
29
+ --load_best_model_at_end=True \
30
+ --metric_for_best_model="wer" \
31
+ --greater_is_better=False \
32
+ --report_to="tensorboard" \
33
+ --predict_with_generate=True \
34
+ --generation_max_length=225 \
35
+ --print_training_arguments=True \
36
+ --push_to_hub=True
37
+
38
+
39
+ # Very likely that some of this parameters needs to be added
40
+ # tpu_name (:obj:`str`, `optional`):
41
+ # The name of the TPU the process is running on.
42
+ # tpu_zone (:obj:`str`, `optional`):
43
+ # The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect
44
+ # from metadata.
45
+ # xla (:obj:`bool`, `optional`):
46
+ # Whether to activate the XLA compilation or not.