ihanif commited on
Commit
8d73d0c
·
1 Parent(s): 5b93cd5

feat: add requirements and training script

Browse files
Files changed (2) hide show
  1. requirements.txt +14 -0
  2. whisper_small_ps_augmented.py +308 -0
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.7
2
+ torchaudio
3
+ git+https://github.com/huggingface/transformers
4
+ git+https://github.com/huggingface/datasets
5
+ librosa
6
+ jiwer
7
+ evaluate>=0.3.0
8
+ more-itertools
9
+ tensorboard
10
+ audiomentations
11
+ soundfile
12
+ gradio
13
+ wandb
14
+ "holoviews[recommended]"
whisper_small_ps_augmented.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
2
+ from audiomentations import Compose, TimeStretch, PitchShift
3
+ from datasets import Audio
4
+ from datasets import load_dataset, DatasetDict
5
+ import jiwer
6
+ import warnings
7
+ import pandas as pd
8
+ from io import StringIO
9
+ from datasets import Dataset, IterableDatasetDict, load_dataset, interleave_datasets, Audio
10
+ import evaluate
11
+
12
+ import torch
13
+ import string
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Union
16
+
17
+ from transformers import WhisperForConditionalGeneration
18
+ from transformers import WhisperProcessor
19
+ from transformers import Seq2SeqTrainingArguments
20
+ from transformers import Seq2SeqTrainer
21
+ from transformers import WhisperTokenizer
22
+ from transformers import WhisperFeatureExtractor
23
+ import wandb
24
+ from IPython.display import clear_output
25
+ from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
26
+ import numpy as np
27
+ from huggingface_hub import notebook_login
28
+ from transformers import TrainerCallback
29
+ from transformers.integrations import WandbCallback
30
+ from transformers.trainer_pt_utils import IterableDatasetShard
31
+ from torch.utils.data import IterableDataset
32
+ from datasets import load_dataset, Audio
33
+ from pathlib import Path
34
+ import numpy as np
35
+ import holoviews as hv
36
+ import panel as pn
37
+ import tempfile
38
+ from bokeh.resources import INLINE
39
+ hv.extension("bokeh", logo=False)
40
+
41
+ warnings.filterwarnings('ignore')
42
+
43
+ clear_output()
44
+ torch.cuda.is_available()
45
+
46
+ """## Load Dataset
47
+ Loading MS-MY Dataset from FLEURS.
48
+ Combine train and validation set.
49
+ """
50
+
51
+ # notebook_login()
52
+
53
+
54
+ fleurs = DatasetDict()
55
+ fleurs["train"] = load_dataset(
56
+ "google/fleurs", "ps_af", split="train+validation", use_auth_token=True)
57
+ fleurs["test"] = load_dataset(
58
+ "google/fleurs", "ps_af", split="test", use_auth_token=True)
59
+
60
+ fleurs = fleurs.remove_columns(
61
+ ["id", "num_samples", "path", "raw_transcription", "gender", "lang_id", "language", "lang_group_id"])
62
+
63
+ print(fleurs)
64
+
65
+
66
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
67
+ "openai/whisper-small")
68
+
69
+
70
+ tokenizer = WhisperTokenizer.from_pretrained(
71
+ "openai/whisper-small", language="Pashto", task="transcribe")
72
+
73
+ """### Combine To Create A WhisperProcessor"""
74
+
75
+
76
+ processor = WhisperProcessor.from_pretrained(
77
+ "openai/whisper-small", language="Pashto", task="transcribe")
78
+
79
+ """### Prepare Data"""
80
+
81
+ fleurs = fleurs.cast_column("audio", Audio(sampling_rate=16000))
82
+
83
+
84
+ augment_waveform = Compose([
85
+ TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3,
86
+ leave_length_unchanged=False),
87
+ PitchShift(min_semitones=-4, max_semitones=4, p=0.3),
88
+ ])
89
+
90
+
91
+ def augment_dataset(batch):
92
+
93
+ audio = batch["audio"]["array"]
94
+ # apply augmentation
95
+ augmented_audio = augment_waveform(samples=audio, sample_rate=16000)
96
+
97
+ batch["audio"]["array"] = augmented_audio
98
+
99
+ return batch
100
+
101
+
102
+ print('Augment train set:')
103
+ fleurs['train'] = fleurs['train'].map(augment_dataset, num_proc=1)
104
+
105
+ """We can apply the data preparation function to all of our training examples using dataset's `.map` method. The argument `num_proc` specifies how many CPU cores to use. Setting `num_proc` > 1 will enable multiprocessing. If the `.map` method hangs with multiprocessing, set `num_proc=1` and process the dataset sequentially."""
106
+
107
+
108
+ do_lower_case = True
109
+ do_remove_punctuation = True
110
+
111
+ normalizer = BasicTextNormalizer()
112
+
113
+
114
+ def prepare_dataset(batch):
115
+ # load and (possibly) resample audio data to 16kHz
116
+ audio = batch["audio"]
117
+
118
+ # compute log-Mel input features from input audio array
119
+ batch["input_features"] = processor.feature_extractor(
120
+ audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
121
+ # compute input length of audio sample in seconds
122
+ batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
123
+
124
+ # optional pre-processing steps
125
+ transcription = batch["transcription"]
126
+ if do_lower_case:
127
+ transcription = transcription.lower()
128
+ if do_remove_punctuation:
129
+ transcription = normalizer(transcription).strip()
130
+
131
+ # encode target text to label ids
132
+ batch["labels"] = processor.tokenizer(transcription).input_ids
133
+ return batch
134
+
135
+
136
+ print('Extract features and normalize data:')
137
+ fleurs = fleurs.map(
138
+ prepare_dataset, remove_columns=fleurs.column_names['train'], num_proc=1).with_format('torch')
139
+
140
+ """Finally, we filter any training data with audio samples longer than 30s. These samples would otherwise be truncated by the Whisper feature-extractor which could affect the stability of training. We define a function that returns `True` for samples that are less than 30s, and `False` for those that are longer:"""
141
+
142
+ max_input_length = 30.0
143
+
144
+
145
+ def is_audio_in_length_range(length):
146
+ return length < max_input_length
147
+
148
+
149
+ """We apply our filter function to all samples of our training dataset through 🤗 Datasets' `.filter` method:"""
150
+
151
+ fleurs['train'] = fleurs['train'].filter(
152
+ is_audio_in_length_range,
153
+ input_columns=["input_length"],
154
+ )
155
+
156
+ fleurs["train"] = fleurs["train"].shuffle(seed=42, writer_batch_size=100)
157
+
158
+
159
+ @dataclass
160
+ class DataCollatorSpeechSeq2SeqWithPadding:
161
+ processor: Any
162
+
163
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
164
+ # split inputs and labels since they have to be of different lengths and need different padding methods
165
+ # first treat the audio inputs by simply returning torch tensors
166
+ input_features = [{"input_features": feature["input_features"]}
167
+ for feature in features]
168
+ batch = self.processor.feature_extractor.pad(
169
+ input_features, return_tensors="pt")
170
+
171
+ # get the tokenized label sequences
172
+ label_features = [{"input_ids": feature["labels"]}
173
+ for feature in features]
174
+ # pad the labels to max length
175
+ labels_batch = self.processor.tokenizer.pad(
176
+ label_features, return_tensors="pt")
177
+
178
+ # replace padding with -100 to ignore loss correctly
179
+ labels = labels_batch["input_ids"].masked_fill(
180
+ labels_batch.attention_mask.ne(1), -100)
181
+
182
+ # if bos token is appended in previous tokenization step,
183
+ # cut bos token here as it's append later anyways
184
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
185
+ labels = labels[:, 1:]
186
+
187
+ batch["labels"] = labels
188
+
189
+ return batch
190
+
191
+
192
+ """Let's initialise the data collator we've just defined:"""
193
+
194
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
195
+
196
+ """### Evaluation Metrics
197
+
198
+ We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing
199
+ ASR systems. For more information, refer to the WER [docs](https://huggingface.co/metrics/wer). We'll load the WER metric from 🤗 Evaluate:
200
+ """
201
+
202
+
203
+ wer_metric = evaluate.load("wer")
204
+ cer_metric = evaluate.load("cer")
205
+
206
+ #  evaluate with the 'normalised' WER
207
+ do_normalize_eval = True
208
+
209
+
210
+ def compute_metrics(pred):
211
+ pred_ids = pred.predictions
212
+ label_ids = pred.label_ids
213
+
214
+ # replace -100 with the pad_token_id
215
+ label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
216
+
217
+ # we do not want to group tokens when computing the metrics
218
+ pred_str = processor.tokenizer.batch_decode(
219
+ pred_ids, skip_special_tokens=True)
220
+ label_str = processor.tokenizer.batch_decode(
221
+ label_ids, skip_special_tokens=True)
222
+
223
+ if do_normalize_eval:
224
+ pred_str = [normalizer(pred) for pred in pred_str]
225
+ label_str = [normalizer(label) for label in label_str]
226
+
227
+ wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
228
+ cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)
229
+
230
+ return {"wer": wer, "cer": cer}
231
+
232
+
233
+ """### Load a Pre-Trained Checkpoint """
234
+
235
+
236
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
237
+
238
+ """Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)). Set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible:"""
239
+
240
+ model.config.forced_decoder_ids = None
241
+ model.config.suppress_tokens = []
242
+ model.config.use_cache = False
243
+
244
+ """### Define the Training Configuration
245
+
246
+ In the final step, we define all the parameters related to training. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).
247
+ """
248
+
249
+
250
+ training_args = Seq2SeqTrainingArguments(
251
+ output_dir="./",
252
+ per_device_train_batch_size=2,
253
+ # increase by 2x for every 2x decrease in batch size
254
+ gradient_accumulation_steps=16,
255
+ learning_rate=1e-5,
256
+ warmup_steps=30,
257
+ max_steps=300,
258
+ gradient_checkpointing=True,
259
+ fp16=True,
260
+ evaluation_strategy="steps",
261
+ per_device_eval_batch_size=2,
262
+ predict_with_generate=True,
263
+ generation_max_length=225,
264
+ save_steps=100,
265
+ eval_steps=100,
266
+ logging_steps=10,
267
+ report_to=["tensorboard"],
268
+ load_best_model_at_end=True,
269
+ metric_for_best_model="wer",
270
+ greater_is_better=False,
271
+ push_to_hub=True,
272
+ optim='adamw_bnb_8bit', # 'adamw_bnb_8bit',
273
+ overwrite_output_dir="True"
274
+ )
275
+
276
+
277
+ trainer = Seq2SeqTrainer(
278
+ args=training_args,
279
+ model=model,
280
+ train_dataset=fleurs['train'],
281
+ eval_dataset=fleurs['test'],
282
+ data_collator=data_collator,
283
+ compute_metrics=compute_metrics,
284
+ tokenizer=processor.feature_extractor
285
+
286
+ )
287
+
288
+ """We'll save the processor object once before starting training. Since the processor is not trainable, it won't change over the course of training:"""
289
+
290
+ processor.save_pretrained(training_args.output_dir)
291
+
292
+ trainer.train()
293
+
294
+ """We can label our checkpoint with the `whisper-event` tag on push by setting the appropriate key-word arguments (kwargs):"""
295
+
296
+ kwargs = {
297
+ "dataset_tags": "google/fleurs",
298
+ "dataset": "google/fleurs", # a 'pretty' name for the training dataset
299
+ "language": "ps_af",
300
+ "model_name": "Whisper Small Pashto - Augmented", # a 'pretty' name for your model
301
+ "finetuned_from": "openai/whisper-small",
302
+ "tasks": "automatic-speech-recognition",
303
+ "tags": "whisper-event",
304
+ }
305
+
306
+ """The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` command and save the preprocessor object we created:"""
307
+
308
+ trainer.push_to_hub(**kwargs)