pere commited on
Commit
c531109
1 Parent(s): 6799df5

Fist submit

Browse files
Files changed (3) hide show
  1. README.md +4 -0
  2. requirements.txt +116 -0
  3. run_whisper.py +187 -0
README.md CHANGED
@@ -1,3 +1,7 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+ # Whisper Finetuning
5
+ Whisper finetuning example script.
6
+
7
+
requirements.txt ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.3.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.2.0
4
+ anyio==3.6.2
5
+ appdirs==1.4.4
6
+ async-timeout==4.0.2
7
+ attrs==22.1.0
8
+ audioread==3.0.0
9
+ autopep8==2.0.0
10
+ bcrypt==4.0.1
11
+ cachetools==5.2.0
12
+ certifi==2022.9.24
13
+ cffi==1.15.1
14
+ charset-normalizer==2.1.1
15
+ click==8.1.3
16
+ contourpy==1.0.6
17
+ cryptography==38.0.3
18
+ cycler==0.11.0
19
+ datasets==2.6.1
20
+ decorator==5.1.1
21
+ dill==0.3.5.1
22
+ evaluate==0.3.0
23
+ fastapi==0.86.0
24
+ ffmpy==0.3.0
25
+ filelock==3.8.0
26
+ fonttools==4.38.0
27
+ frozenlist==1.3.1
28
+ fsspec==2022.10.0
29
+ google-auth==2.14.0
30
+ google-auth-oauthlib==0.4.6
31
+ gradio==3.9
32
+ grpcio==1.50.0
33
+ h11==0.12.0
34
+ httpcore==0.15.0
35
+ httpx==0.23.0
36
+ huggingface-hub==0.10.1
37
+ idna==3.4
38
+ importlib-metadata==5.0.0
39
+ Jinja2==3.1.2
40
+ jiwer==2.5.1
41
+ joblib==1.2.0
42
+ kiwisolver==1.4.4
43
+ Levenshtein==0.20.2
44
+ librosa==0.9.2
45
+ linkify-it-py==1.0.3
46
+ llvmlite==0.39.1
47
+ Markdown==3.4.1
48
+ markdown-it-py==2.1.0
49
+ MarkupSafe==2.1.1
50
+ matplotlib==3.6.2
51
+ mdit-py-plugins==0.3.1
52
+ mdurl==0.1.2
53
+ multidict==6.0.2
54
+ multiprocess==0.70.13
55
+ numba==0.56.4
56
+ numpy==1.23.4
57
+ nvidia-cublas-cu11==11.10.3.66
58
+ nvidia-cuda-nvrtc-cu11==11.7.99
59
+ nvidia-cuda-runtime-cu11==11.7.99
60
+ nvidia-cudnn-cu11==8.5.0.96
61
+ oauthlib==3.2.2
62
+ orjson==3.8.1
63
+ packaging==21.3
64
+ pandas==1.5.1
65
+ paramiko==2.12.0
66
+ Pillow==9.3.0
67
+ pooch==1.6.0
68
+ protobuf==3.19.6
69
+ pyarrow==10.0.0
70
+ pyasn1==0.4.8
71
+ pyasn1-modules==0.2.8
72
+ pycodestyle==2.9.1
73
+ pycparser==2.21
74
+ pycryptodome==3.15.0
75
+ pydantic==1.10.2
76
+ pydub==0.25.1
77
+ PyNaCl==1.5.0
78
+ pyparsing==3.0.9
79
+ python-dateutil==2.8.2
80
+ python-multipart==0.0.5
81
+ pytz==2022.6
82
+ PyYAML==6.0
83
+ rapidfuzz==2.13.2
84
+ regex==2022.10.31
85
+ requests==2.28.1
86
+ requests-oauthlib==1.3.1
87
+ resampy==0.4.2
88
+ responses==0.18.0
89
+ rfc3986==1.5.0
90
+ rsa==4.9
91
+ scikit-learn==1.1.3
92
+ scipy==1.9.3
93
+ sentencepiece==0.1.97
94
+ six==1.16.0
95
+ sniffio==1.3.0
96
+ soundfile==0.11.0
97
+ starlette==0.20.4
98
+ tensorboard==2.10.1
99
+ tensorboard-data-server==0.6.1
100
+ tensorboard-plugin-wit==1.8.1
101
+ threadpoolctl==3.1.0
102
+ tokenizers==0.13.1
103
+ tomli==2.0.1
104
+ torch==1.12.1
105
+ torchaudio==0.12.1
106
+ tqdm==4.64.1
107
+ transformers @ git+https://github.com/huggingface/transformers@504db92e7da010070c36e185332420a1d52c12b2
108
+ typing_extensions==4.4.0
109
+ uc-micro-py==1.0.1
110
+ urllib3==1.26.12
111
+ uvicorn==0.19.0
112
+ websockets==10.4
113
+ Werkzeug==2.2.2
114
+ xxhash==3.1.0
115
+ yarl==1.8.1
116
+ zipp==3.10.0
run_whisper.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from datasets import load_dataset, DatasetDict
4
+ from datasets import Audio
5
+
6
+ from transformers import WhisperFeatureExtractor
7
+ from transformers import WhisperTokenizer
8
+ from transformers import WhisperProcessor
9
+ from transformers import WhisperForConditionalGeneration
10
+
11
+ from transformers import Seq2SeqTrainingArguments
12
+ from transformers import Seq2SeqTrainer
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Union
16
+ import evaluate
17
+
18
+
19
+ # Functions
20
+ # Define a Data Collator
21
+ @dataclass
22
+ class DataCollatorSpeechSeq2SeqWithPadding:
23
+ processor: Any
24
+
25
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
26
+ # split inputs and labels since they have to be of different lengths and need different padding methods
27
+ # first treat the audio inputs by simply returning torch tensors
28
+ input_features = [{"input_features": feature["input_features"]}
29
+ for feature in features]
30
+ batch = self.processor.feature_extractor.pad(
31
+ input_features, return_tensors="pt")
32
+
33
+ # get the tokenized label sequences
34
+ label_features = [{"input_ids": feature["labels"]}
35
+ for feature in features]
36
+ # pad the labels to max length
37
+ labels_batch = self.processor.tokenizer.pad(
38
+ label_features, return_tensors="pt")
39
+
40
+ # replace padding with -100 to ignore loss correctly
41
+ labels = labels_batch["input_ids"].masked_fill(
42
+ labels_batch.attention_mask.ne(1), -100)
43
+
44
+ # if bos token is appended in previous tokenization step,
45
+ # cut bos token here as it's append later anyways
46
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
47
+ labels = labels[:, 1:]
48
+
49
+ batch["labels"] = labels
50
+
51
+ return batch
52
+
53
+
54
+ # Metrics
55
+ def compute_metrics(pred):
56
+ pred_ids = pred.predictions
57
+ label_ids = pred.label_ids
58
+
59
+ # replace -100 with the pad_token_id
60
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
61
+
62
+ # we do not want to group tokens when computing the metrics
63
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
64
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
65
+
66
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
67
+
68
+ return {"wer": wer}
69
+
70
+ # Prepare dataset
71
+
72
+
73
+ def prepare_dataset(batch):
74
+ # load and resample audio data from 48 to 16kHz
75
+ audio = batch["audio"]
76
+
77
+ # compute log-Mel input features from input audio array
78
+ batch["input_features"] = feature_extractor(
79
+ audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
80
+
81
+ # encode target text to label ids
82
+ batch["labels"] = tokenizer(batch["sentence"]).input_ids
83
+ return batch
84
+
85
+
86
+ # Whisper Trainin Script
87
+
88
+ # Map the source and target columns
89
+ # Whisper expects these to be "audio" and "sentence". Change if anything else in the dataset
90
+ source = "audio"
91
+ target = "sentence"
92
+
93
+
94
+ # Load a sample dataset
95
+ speech_data = DatasetDict()
96
+
97
+ # Examples
98
+ # speech_data["train"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="train", use_auth_token=True)
99
+ # speech_data["test"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="test", use_auth_token=True)
100
+ # speech_data["train"] = load_dataset("NbAiLab/LIA_speech", split="train", use_auth_token=True)
101
+ #speech_data["test"] = load_dataset("NbAiLab/LIA_speech", split="test", use_auth_token=True)
102
+
103
+ # The smallest dataset I found
104
+ speech_data["train"] = load_dataset(
105
+ "mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True)
106
+ speech_data["test"] = load_dataset(
107
+ "mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True)
108
+
109
+
110
+ # Rename columns
111
+ if "audio" not in speech_data.column_names["train"]:
112
+ speech_data = speech_data.rename_column(source, "audio")
113
+
114
+ if "sentence" not in speech_data.column_names["train"]:
115
+ speech_data = speech_data.rename_column(target, "sentence")
116
+
117
+ # Remove not needed columns - Not really sure if this is necessary
118
+ remove_list = [i for i in speech_data.column_names["train"]
119
+ if i not in ["audio", "sentence"]]
120
+
121
+ speech_data = speech_data.remove_columns(remove_list)
122
+
123
+ # Initialise
124
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
125
+ "openai/whisper-small")
126
+ tokenizer = WhisperTokenizer.from_pretrained(
127
+ "openai/whisper-small", language="Norwegian", task="transcribe")
128
+ processor = WhisperProcessor.from_pretrained(
129
+ "openai/whisper-small", language="Norwegian", task="transcribe")
130
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
131
+
132
+ # Prepare data
133
+ speech_data = speech_data.cast_column("audio", Audio(sampling_rate=16000))
134
+ speech_data = speech_data.map(
135
+ prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1)
136
+
137
+ # Metrics
138
+ metric = evaluate.load("wer")
139
+
140
+ # Initialise a Pretrained model
141
+ # We need to set use_cache=False here if we want to use gradient accumulation
142
+ model = WhisperForConditionalGeneration.from_pretrained(
143
+ "openai/whisper-small", use_cache=False)
144
+
145
+ # Overriding generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):
146
+ model.config.forced_decoder_ids = None
147
+ model.config.suppress_tokens = []
148
+
149
+ # Training arguments
150
+ training_args = Seq2SeqTrainingArguments(
151
+ output_dir="./whisper-small-no-test", # change to a repo name of your choice
152
+ # Use at least 16 is reasonable. This is just for the test on Ficino
153
+ per_device_train_batch_size=4,
154
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
155
+ learning_rate=1e-5,
156
+ warmup_steps=500,
157
+ max_steps=1000, # Changed from 4000
158
+ gradient_checkpointing=True,
159
+ fp16=True,
160
+ group_by_length=True,
161
+ evaluation_strategy="steps",
162
+ per_device_eval_batch_size=8,
163
+ predict_with_generate=True,
164
+ generation_max_length=225,
165
+ save_steps=500,
166
+ eval_steps=500,
167
+ logging_steps=25,
168
+ report_to=["tensorboard"],
169
+ load_best_model_at_end=True,
170
+ metric_for_best_model="wer",
171
+ greater_is_better=False,
172
+ push_to_hub=True,
173
+ )
174
+
175
+ trainer = Seq2SeqTrainer(
176
+ args=training_args,
177
+ model=model,
178
+ train_dataset=speech_data["train"],
179
+ eval_dataset=speech_data["test"],
180
+ data_collator=data_collator,
181
+ compute_metrics=compute_metrics,
182
+ tokenizer=processor.feature_extractor,
183
+ )
184
+
185
+
186
+ # Start training
187
+ trainer.train()