test
Browse files- run_whisper_finetuning.py +4 -0
- run_xla_test.sh +46 -0
run_whisper_finetuning.py
CHANGED
@@ -340,6 +340,10 @@ def main():
|
|
340 |
parser = HfArgumentParser(
|
341 |
(ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
342 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
|
|
|
|
|
343 |
|
344 |
# Metrics
|
345 |
|
|
|
340 |
parser = HfArgumentParser(
|
341 |
(ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
342 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
343 |
+
|
344 |
+
#Debug
|
345 |
+
import torch_xla.debug.metrics as met
|
346 |
+
print(met.metrics_report())
|
347 |
|
348 |
# Metrics
|
349 |
|
run_xla_test.sh
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Whisper Finetuning script for the NST dataset
|
2 |
+
# This is a test script for XLA on TPU
|
3 |
+
|
4 |
+
python xla_spawn.py --num_cores=4 run_whisper_finetuning.py\
|
5 |
+
--model_name_or_path="openai/whisper-small" \
|
6 |
+
--output_dir="../whisper-NST-TPU-test2" \
|
7 |
+
--overwrite_output_dir=True \
|
8 |
+
--language="Norwegian" \
|
9 |
+
--task="transcribe" \
|
10 |
+
--dataset_name="NbAiLab/NST" \
|
11 |
+
--dataset_config="no-close" \
|
12 |
+
--do_train=True \
|
13 |
+
--do_eval=True \
|
14 |
+
--audio_column_name="audio" \
|
15 |
+
--text_column_name="text" \
|
16 |
+
--per_device_train_batch_size=16 \
|
17 |
+
--per_device_train_batch_size=16 \
|
18 |
+
--learning_rate=2e-5 \
|
19 |
+
--warmup_steps=0 \
|
20 |
+
--max_steps=10 \
|
21 |
+
--gradient_checkpointing=True \
|
22 |
+
--gradient_accumulation_steps=1 \
|
23 |
+
--group_by_length=False \
|
24 |
+
--evaluation_strategy="steps" \
|
25 |
+
--save_steps=10 \
|
26 |
+
--eval_steps=10 \
|
27 |
+
--max_eval_samples=2 \
|
28 |
+
--logging_steps=10 \
|
29 |
+
--load_best_model_at_end=True \
|
30 |
+
--metric_for_best_model="wer" \
|
31 |
+
--greater_is_better=False \
|
32 |
+
--report_to="tensorboard" \
|
33 |
+
--predict_with_generate=True \
|
34 |
+
--generation_max_length=225 \
|
35 |
+
--print_training_arguments=True \
|
36 |
+
--push_to_hub=True
|
37 |
+
|
38 |
+
|
39 |
+
# Very likely that some of this parameters needs to be added
|
40 |
+
# tpu_name (:obj:`str`, `optional`):
|
41 |
+
# The name of the TPU the process is running on.
|
42 |
+
# tpu_zone (:obj:`str`, `optional`):
|
43 |
+
# The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect
|
44 |
+
# from metadata.
|
45 |
+
# xla (:obj:`bool`, `optional`):
|
46 |
+
# Whether to activate the XLA compilation or not.
|