first
Browse files- requirements.txt +2 -2
- run.sh +4 -0
- run_whisper.py +12 -13
- xla_spawn.py +83 -0
requirements.txt
CHANGED
@@ -101,8 +101,8 @@ tensorboard-plugin-wit==1.8.1
|
|
101 |
threadpoolctl==3.1.0
|
102 |
tokenizers==0.13.1
|
103 |
tomli==2.0.1
|
104 |
-
torch
|
105 |
-
torchaudio
|
106 |
tqdm==4.64.1
|
107 |
transformers @ git+https://github.com/huggingface/transformers@504db92e7da010070c36e185332420a1d52c12b2
|
108 |
typing_extensions==4.4.0
|
|
|
101 |
threadpoolctl==3.1.0
|
102 |
tokenizers==0.13.1
|
103 |
tomli==2.0.1
|
104 |
+
torch>=1.12.1
|
105 |
+
torchaudio>=0.12.1
|
106 |
tqdm==4.64.1
|
107 |
transformers @ git+https://github.com/huggingface/transformers@504db92e7da010070c36e185332420a1d52c12b2
|
108 |
typing_extensions==4.4.0
|
run.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
python xla_spawn.py --num_cores=4 run_whisper.py
|
3 |
+
|
4 |
+
|
run_whisper.py
CHANGED
@@ -88,23 +88,23 @@ def main():
|
|
88 |
# Map the source and target columns
|
89 |
# Whisper expects these to be "audio" and "sentence". Change if anything else in the dataset
|
90 |
source = "audio"
|
91 |
-
target = "
|
92 |
|
93 |
|
94 |
# Load a sample dataset
|
95 |
speech_data = DatasetDict()
|
96 |
|
97 |
# Examples
|
98 |
-
|
99 |
-
|
100 |
# speech_data["train"] = load_dataset("NbAiLab/LIA_speech", split="train", use_auth_token=True)
|
101 |
#speech_data["test"] = load_dataset("NbAiLab/LIA_speech", split="test", use_auth_token=True)
|
102 |
|
103 |
# The smallest dataset I found
|
104 |
-
speech_data["train"] = load_dataset(
|
105 |
-
|
106 |
-
speech_data["test"] = load_dataset(
|
107 |
-
|
108 |
|
109 |
|
110 |
# Rename columns
|
@@ -148,15 +148,13 @@ def main():
|
|
148 |
|
149 |
# Training arguments
|
150 |
training_args = Seq2SeqTrainingArguments(
|
151 |
-
output_dir="
|
152 |
-
|
153 |
-
per_device_train_batch_size=4,
|
154 |
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
155 |
-
learning_rate=
|
156 |
warmup_steps=500,
|
157 |
-
max_steps=
|
158 |
gradient_checkpointing=True,
|
159 |
-
fp16=True,
|
160 |
group_by_length=True,
|
161 |
evaluation_strategy="steps",
|
162 |
per_device_eval_batch_size=8,
|
@@ -189,6 +187,7 @@ def main():
|
|
189 |
|
190 |
def _mp_fn(index):
|
191 |
# For xla_spawn (TPUs)
|
|
|
192 |
main()
|
193 |
|
194 |
|
|
|
88 |
# Map the source and target columns
|
89 |
# Whisper expects these to be "audio" and "sentence". Change if anything else in the dataset
|
90 |
source = "audio"
|
91 |
+
target = "sentence_text"
|
92 |
|
93 |
|
94 |
# Load a sample dataset
|
95 |
speech_data = DatasetDict()
|
96 |
|
97 |
# Examples
|
98 |
+
speech_data["train"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="train", use_auth_token=True)
|
99 |
+
speech_data["test"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="test", use_auth_token=True)
|
100 |
# speech_data["train"] = load_dataset("NbAiLab/LIA_speech", split="train", use_auth_token=True)
|
101 |
#speech_data["test"] = load_dataset("NbAiLab/LIA_speech", split="test", use_auth_token=True)
|
102 |
|
103 |
# The smallest dataset I found
|
104 |
+
#speech_data["train"] = load_dataset(
|
105 |
+
# "mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True)
|
106 |
+
#speech_data["test"] = load_dataset(
|
107 |
+
# "mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True)
|
108 |
|
109 |
|
110 |
# Rename columns
|
|
|
148 |
|
149 |
# Training arguments
|
150 |
training_args = Seq2SeqTrainingArguments(
|
151 |
+
output_dir="./first-whisper-test2", # change to a repo name of your choice
|
152 |
+
per_device_train_batch_size=64,
|
|
|
153 |
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
154 |
+
learning_rate=2e-5,
|
155 |
warmup_steps=500,
|
156 |
+
max_steps=5000, # Changed from 4000
|
157 |
gradient_checkpointing=True,
|
|
|
158 |
group_by_length=True,
|
159 |
evaluation_strategy="steps",
|
160 |
per_device_eval_batch_size=8,
|
|
|
187 |
|
188 |
def _mp_fn(index):
|
189 |
# For xla_spawn (TPUs)
|
190 |
+
print("The XLA is initiated")
|
191 |
main()
|
192 |
|
193 |
|
xla_spawn.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
A simple launcher script for TPU training
|
16 |
+
|
17 |
+
Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
|
18 |
+
|
19 |
+
::
|
20 |
+
>>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE
|
21 |
+
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
|
22 |
+
arguments of your training script)
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
import importlib
|
28 |
+
import sys
|
29 |
+
from argparse import REMAINDER, ArgumentParser
|
30 |
+
from pathlib import Path
|
31 |
+
|
32 |
+
import torch_xla.distributed.xla_multiprocessing as xmp
|
33 |
+
|
34 |
+
|
35 |
+
def parse_args():
|
36 |
+
"""
|
37 |
+
Helper function parsing the command line options
|
38 |
+
@retval ArgumentParser
|
39 |
+
"""
|
40 |
+
parser = ArgumentParser(
|
41 |
+
description=(
|
42 |
+
"PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
|
43 |
+
)
|
44 |
+
)
|
45 |
+
|
46 |
+
# Optional arguments for the launch helper
|
47 |
+
parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use. 1 or 8 on v3-8. 1 or 4 on v4-8")
|
48 |
+
|
49 |
+
# positional
|
50 |
+
parser.add_argument(
|
51 |
+
"training_script",
|
52 |
+
type=str,
|
53 |
+
help=(
|
54 |
+
"The full path to the single TPU training "
|
55 |
+
"program/script to be launched in parallel, "
|
56 |
+
"followed by all the arguments for the "
|
57 |
+
"training script"
|
58 |
+
),
|
59 |
+
)
|
60 |
+
|
61 |
+
# rest from the training program
|
62 |
+
parser.add_argument("training_script_args", nargs=REMAINDER)
|
63 |
+
|
64 |
+
return parser.parse_args()
|
65 |
+
|
66 |
+
|
67 |
+
def main():
|
68 |
+
args = parse_args()
|
69 |
+
|
70 |
+
# Import training_script as a module.
|
71 |
+
script_fpath = Path(args.training_script)
|
72 |
+
sys.path.append(str(script_fpath.parent.resolve()))
|
73 |
+
mod_name = script_fpath.stem
|
74 |
+
mod = importlib.import_module(mod_name)
|
75 |
+
|
76 |
+
# Patch sys.argv
|
77 |
+
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
|
78 |
+
|
79 |
+
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|