jsnfly commited on
Commit
004e907
1 Parent(s): 3cb1ad3

add training notebooks

Browse files
README.md CHANGED
@@ -37,7 +37,16 @@ the decoder from [dbmdz/german-gpt2](https://huggingface.co/dbmdz/german-gpt2).
37
 
38
  It was trained using a two step process:
39
  * fine-tuning only the cross-attention weights and the decoder using the pre-computed outputs of the Wav2Vec-Modell
 
 
 
 
40
  * fine-tuning the model end-to-end
 
 
41
 
42
  There is also one trick, which seemed to improve performance significantly: adding position embeddings to the
43
- encoder outputs and initializing them with the pre-trained position embeddings of the GPT2 model (See `eval.py`).
 
 
 
37
 
38
  It was trained using a two step process:
39
  * fine-tuning only the cross-attention weights and the decoder using the pre-computed outputs of the Wav2Vec-Modell
40
+ * relatively fast training
41
+ * also works on small GPU (eg. 8 GB)
42
+ * but may take a lot of disk space
43
+ * should already yield decent results
44
  * fine-tuning the model end-to-end
45
+ * much slower
46
+ * needs a bigger GPU
47
 
48
  There is also one trick, which seemed to improve performance significantly: adding position embeddings to the
49
+ encoder outputs and initializing them with the pre-trained position embeddings of the GPT2 model (See `eval.py`).
50
+
51
+ The training notebooks are still early drafts. Also results can probably improved a lot by using for example a learning
52
+ rate schedule.
training/data_loading.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from pathlib import Path
4
+
5
+
6
+ class S2TDataset(Dataset):
7
+ def __init__(self, data_path):
8
+ self.path = Path(data_path)
9
+ self.files = list(self.path.iterdir())
10
+
11
+ def __len__(self):
12
+ return len(self.files)
13
+
14
+ def __getitem__(self, idx):
15
+ file_path = self.files[idx]
16
+ eg = torch.load(file_path)
17
+ eg['file_path'] = file_path
18
+ return eg
19
+
20
+
21
+ # TODO: Somehow masks do not work yet (bad performace), but Training also works w/o using the mask.
22
+ def make_collate_fn(tokenizer):
23
+ def collate_fn(examples):
24
+ wav2vec_feats = [eg['wave2vec_features'] for eg in examples]
25
+ max_len = len(max(wav2vec_feats, key=len))
26
+ padded_feats, attention_masks = [], []
27
+ for feats in wav2vec_feats:
28
+ num_pads = max_len - len(feats)
29
+ padded_feats.append(torch.cat([feats, torch.zeros((num_pads, feats.shape[-1]), device=feats.device)]))
30
+ if num_pads > 0:
31
+ mask = torch.zeros((max_len,), device=feats.device).long()
32
+ mask[:-num_pads] = 1
33
+ else:
34
+ mask = torch.ones((max_len,), device=feats.device).long()
35
+ attention_masks.append(mask)
36
+
37
+ encoder_hidden_states = torch.stack(padded_feats, dim=0)
38
+ encoder_attention_masks = torch.stack(attention_masks, dim=0).bool()
39
+ input_ids = tokenizer([eg['sentence'] for eg in examples], return_tensors='pt', padding=True).input_ids
40
+ return encoder_hidden_states, encoder_attention_masks, input_ids
41
+ return collate_fn
training/decoder_only_training.ipynb ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "521e21ab",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# This notebook is currently designed for a GPU using fp16. Hyperparameters however are barely tuned."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "1732f970",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import random\n",
21
+ "import torch\n",
22
+ "from pathlib import Path"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "f55f4047",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "EXPERIMENT_NAME = '00'\n",
33
+ "DATA_PATH = Path('../data/common_voice/de')\n",
34
+ "\n",
35
+ "model_dir = Path('decoder_only/de') / EXPERIMENT_NAME\n",
36
+ "log_dir = model_dir / 'logs'\n",
37
+ "log_dir.mkdir(exist_ok=True, parents=True)\n",
38
+ "\n",
39
+ "config = {\n",
40
+ " 'use_train_frac': 1.0, # When using all samples the wav2vec-outputs take up ~275GB disk space!!(~360,000 samples)\n",
41
+ " 'use_val_frac': 0.2,\n",
42
+ " 'encoder_id': 'jonatasgrosman/wav2vec2-large-xlsr-53-german',\n",
43
+ " 'decoder_id': 'dbmdz/german-gpt2',\n",
44
+ " 'decoder_pad_token': '_',\n",
45
+ " 'decoder_bos_token': '~',\n",
46
+ " 'num_beams': 1,\n",
47
+ " 'batch_size': 16,\n",
48
+ " 'weight_decay': 0.,\n",
49
+ " 'accumulate_grad': 2,\n",
50
+ " 'max_epochs': 10,\n",
51
+ " 'max_len': 36 # len(max(tokenizer(common_voice['validation']['sentence'] + common_voice['test']['sentence']).input_ids, key=len))\n",
52
+ "}"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "id": "eb3de6a4",
58
+ "metadata": {},
59
+ "source": [
60
+ "# Feature Extraction"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "id": "b176328e",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "from huggingface_hub import notebook_login\n",
71
+ "from datasets import load_dataset\n",
72
+ "from datasets.features import Audio\n",
73
+ "from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "54e70696",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "notebook_login()"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "f0d22752",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "def extract_features_to_files(model, feature_extractor, dataset_split, batch_size, output_path):\n",
94
+ " output_path = Path(output_path)\n",
95
+ " output_path.mkdir(parents=True, exist_ok=True)\n",
96
+ "\n",
97
+ " model.eval().cuda()\n",
98
+ " for i in range(0, len(dataset_split), batch_size):\n",
99
+ " batch = dataset_split[i:i+batch_size]\n",
100
+ " sent_batch = batch['sentence']\n",
101
+ " audio_batch = batch['audio']\n",
102
+ " for i, eg in enumerate(audio_batch):\n",
103
+ " # Remove the longest examples, should be only three and these may lead to OOM- or Index-Errors.\n",
104
+ " if len(eg['array']) > 300_000:\n",
105
+ " print('Too Long.')\n",
106
+ " audio_batch.pop(i)\n",
107
+ " sent_batch.pop(i)\n",
108
+ " features = feature_extractor([eg['array'] for eg in audio_batch],\n",
109
+ " sampling_rate=16_000,\n",
110
+ " return_tensors='pt',\n",
111
+ " padding='longest')\n",
112
+ "\n",
113
+ " with torch.no_grad():\n",
114
+ " out = model(features.input_values.cuda(), attention_mask=features.attention_mask.cuda())\n",
115
+ "\n",
116
+ " assert len(sent_batch) == len(audio_batch) == len(out.last_hidden_state)\n",
117
+ " for sent, audio, hs in zip(sent_batch, audio_batch, out.last_hidden_state.bfloat16().cpu()):\n",
118
+ " file_name = audio['path'].split('/')[-1]\n",
119
+ " torch.save(\n",
120
+ " # .clone() is necessary: https://github.com/pytorch/pytorch/issues/1995\n",
121
+ " {'sentence': sent, 'wave2vec_features': hs.clone()},\n",
122
+ " output_path / file_name\n",
123
+ " )"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "id": "06324b6f",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "if not DATA_PATH.exists():\n",
134
+ " \n",
135
+ " common_voice = load_dataset('mozilla-foundation/common_voice_7_0', 'de', use_auth_token=True)\n",
136
+ " \n",
137
+ " random.seed(419)\n",
138
+ " train_inds = list(range(len(common_voice['train'])))\n",
139
+ " random.shuffle(train_inds)\n",
140
+ " val_inds = list(range(len(common_voice['validation'])))\n",
141
+ " random.shuffle(val_inds)\n",
142
+ " \n",
143
+ " train_inds = train_inds[:int(config['use_train_frac'] * len(train_inds))]\n",
144
+ " train = common_voice['train'].select(train_inds)\n",
145
+ " train = train.cast_column('audio', Audio(sampling_rate=16_000))\n",
146
+ " \n",
147
+ " val_inds = val_inds[:int(config['use_val_frac'] * len(val_inds))]\n",
148
+ " val = common_voice['validation'].select(val_inds)\n",
149
+ " val = val.cast_column('audio', Audio(sampling_rate=16_000))\n",
150
+ " \n",
151
+ " # Load Model for feature extraction.\n",
152
+ " wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['encoder_id'])\n",
153
+ " wave2vec = Wav2Vec2Model.from_pretrained(config['encoder_id'])\n",
154
+ " wave2vec.eval().cuda()\n",
155
+ " \n",
156
+ " extract_features_to_files(wave2vec, wave2vec_extractor, train, batch_size=8, output_path=DATA_PATH / 'train')\n",
157
+ " extract_features_to_files(wave2vec, wave2vec_extractor, val, batch_size=8, output_path=DATA_PATH / 'val')\n",
158
+ " \n",
159
+ " wave2vec.cpu()\n",
160
+ " torch.cuda.empty_cache()"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "b2ae2a47",
166
+ "metadata": {},
167
+ "source": [
168
+ "# Training"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "id": "188ef54f",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "import json\n",
179
+ "from accelerate import Accelerator\n",
180
+ "from torch.utils.data import DataLoader\n",
181
+ "from torch.optim import AdamW\n",
182
+ "from torch.utils.tensorboard import SummaryWriter\n",
183
+ "from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor\n",
184
+ "from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2BaseModelOutput\n",
185
+ "from data_loading import make_collate_fn, S2TDataset\n",
186
+ "from wer import calculate_wer # Not what's used in eval.py.\n",
187
+ "from model import Wav2VecGPT2Model"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "id": "41518c81",
194
+ "metadata": {
195
+ "scrolled": false
196
+ },
197
+ "outputs": [],
198
+ "source": [
199
+ "tokenizer = AutoTokenizer.from_pretrained(config['decoder_id'])\n",
200
+ "tokenizer.add_special_tokens({'pad_token': config['decoder_pad_token'], 'bos_token': config['decoder_bos_token']})\n",
201
+ "\n",
202
+ "model = Wav2VecGPT2Model.from_encoder_decoder_pretrained(\n",
203
+ " config['encoder_id'], config['decoder_id'], max_length=config['max_len'], num_beams=config['num_beams']\n",
204
+ ")\n",
205
+ "\n",
206
+ "model.config.decoder_start_token_id = tokenizer.bos_token_id\n",
207
+ "model.config.pad_token_id = tokenizer.pad_token_id"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "id": "a95ec028",
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "collate_fn = make_collate_fn(tokenizer)\n",
218
+ "\n",
219
+ "train_ds = S2TDataset(DATA_PATH / 'train')\n",
220
+ "train_dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=4)\n",
221
+ "\n",
222
+ "val_ds = S2TDataset(DATA_PATH / 'val')\n",
223
+ "val_dl = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=4)"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "id": "0aaeeced",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "high_lr_modules = ['cross_attn', 'crossattention', 'enc_to_dec_proj', 'encoder_outputs_pos_emb']\n",
234
+ "high_lr_params = [p for n, p in model.named_parameters() if any(m in n for m in high_lr_modules)]\n",
235
+ "\n",
236
+ "optimizer_grouped_parameters = [\n",
237
+ " {\n",
238
+ " \"params\": high_lr_params,\n",
239
+ " \"lr\": 5e-4,\n",
240
+ " },\n",
241
+ " {\n",
242
+ " \"params\": [p for n, p in model.decoder.named_parameters() if not any(m in n for m in high_lr_modules)],\n",
243
+ " \"lr\": 1e-6,\n",
244
+ " },\n",
245
+ "]\n",
246
+ "optimizer = AdamW(optimizer_grouped_parameters, weight_decay=0.)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "id": "cf98d090",
253
+ "metadata": {},
254
+ "outputs": [],
255
+ "source": [
256
+ "accelerator = Accelerator(fp16=True)\n",
257
+ "print(f'Using {accelerator.device}.')"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "da9e928e",
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "id": "f191f256",
274
+ "metadata": {
275
+ "scrolled": false
276
+ },
277
+ "outputs": [],
278
+ "source": [
279
+ "with open(log_dir / 'config.json', 'w') as config_file:\n",
280
+ " json.dump(config, config_file, indent=4)\n",
281
+ " \n",
282
+ "writer = SummaryWriter(log_dir)\n",
283
+ "val_golds = [eg['sentence'] for eg in val_ds]\n",
284
+ "best_val_wer = 10.\n",
285
+ "global_train_step = 0\n",
286
+ "\n",
287
+ "for epoch in range(config['max_epochs']):\n",
288
+ " \n",
289
+ " model.train()\n",
290
+ " model.encoder.cpu() # Model gets moved to gpu for evaluation (see below).\n",
291
+ " torch.cuda.empty_cache()\n",
292
+ " for batch_step, (encoder_hidden_states, att_mask, input_ids) in enumerate(train_dl):\n",
293
+ " if encoder_hidden_states.shape[1] > 1024:\n",
294
+ " # That's too long for the position embeddings. \n",
295
+ " # TODO: handle this in model code.\n",
296
+ " print(f'SKIPPED: {encoder_hidden_states.shape}')\n",
297
+ " continue\n",
298
+ " global_train_step += 1\n",
299
+ " \n",
300
+ " out = model(labels=input_ids, encoder_outputs=Wav2Vec2BaseModelOutput(encoder_hidden_states))\n",
301
+ " accelerator.backward(out.loss)\n",
302
+ " writer.add_scalar('train_loss', out.loss.item(), global_train_step)\n",
303
+ " \n",
304
+ " if (batch_step + 1) % config['accumulate_grad'] == 0:\n",
305
+ " optimizer.step()\n",
306
+ " optimizer.zero_grad()\n",
307
+ " \n",
308
+ " if batch_step % 300 == 0:\n",
309
+ " print(out.loss.item())\n",
310
+ " \n",
311
+ " model.eval()\n",
312
+ " model.cuda() # Necessary for input_ids to be initialized on the correct device.\n",
313
+ " val_preds = []\n",
314
+ " for encoder_hidden_states, att_mask, _ in val_dl:\n",
315
+ " with torch.no_grad():\n",
316
+ " generated = model.generate(\n",
317
+ " encoder_outputs=Wav2Vec2BaseModelOutput(last_hidden_state=encoder_hidden_states)\n",
318
+ " )\n",
319
+ " val_preds += tokenizer.batch_decode(generated)\n",
320
+ " val_preds = [pred.lstrip('~').rstrip('_') for pred in val_preds]\n",
321
+ " wer = calculate_wer(val_preds, val_golds)\n",
322
+ " writer.add_scalar('val_wer', wer, epoch)\n",
323
+ " print('WER: ', wer)\n",
324
+ " \n",
325
+ " if wer < best_val_wer:\n",
326
+ " torch.save(model.state_dict(), model_dir / 'model.pt')\n",
327
+ " print('Saved Model.')\n",
328
+ " best_val_wer = wer"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": null,
334
+ "id": "d84a7e5c",
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "# # Load saved pytorch model and save with all necessary model files.\n",
339
+ "# output_path = model_dir /'full_model'\n",
340
+ "# \n",
341
+ "# model.load_state_dict(torch.load(model_dir / 'model.pt'))\n",
342
+ "# \n",
343
+ "# tokenizer.save_pretrained(output_path)\n",
344
+ "# wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['encoder_id'])\n",
345
+ "# wave2vec_extractor.save_pretrained(output_path)\n",
346
+ "# model.save_pretrained(output_path)"
347
+ ]
348
+ }
349
+ ],
350
+ "metadata": {
351
+ "kernelspec": {
352
+ "display_name": "Python 3 (ipykernel)",
353
+ "language": "python",
354
+ "name": "python3"
355
+ },
356
+ "language_info": {
357
+ "codemirror_mode": {
358
+ "name": "ipython",
359
+ "version": 3
360
+ },
361
+ "file_extension": ".py",
362
+ "mimetype": "text/x-python",
363
+ "name": "python",
364
+ "nbconvert_exporter": "python",
365
+ "pygments_lexer": "ipython3",
366
+ "version": "3.9.7"
367
+ }
368
+ },
369
+ "nbformat": 4,
370
+ "nbformat_minor": 5
371
+ }
training/end2end_training.ipynb ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "9e852db9",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# This notebook is currently designed for a GPU using fp16. Hyperparameters however are barely tuned."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "e730080b",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import json\n",
21
+ "import random\n",
22
+ "import torch\n",
23
+ "from pathlib import Path\n",
24
+ "from accelerate import Accelerator\n",
25
+ "from datasets import load_dataset, concatenate_datasets\n",
26
+ "from datasets.features import Audio\n",
27
+ "from torch.utils.data import Dataset, DataLoader\n",
28
+ "from torch.optim import AdamW\n",
29
+ "from torch.utils.tensorboard import SummaryWriter\n",
30
+ "from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor\n",
31
+ "from wer import calculate_wer # Not what's used in eval.py.\n",
32
+ "from model import Wav2VecGPT2Model"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "72af6337",
39
+ "metadata": {
40
+ "scrolled": true
41
+ },
42
+ "outputs": [],
43
+ "source": [
44
+ "common_voice = load_dataset('mozilla-foundation/common_voice_7_0', 'de', use_auth_token=True)"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "6396e61d",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "EXPERIMENT_NAME = '00'\n",
55
+ "\n",
56
+ "model_dir = Path('end2end/de') / EXPERIMENT_NAME\n",
57
+ "log_dir = model_dir / 'logs'\n",
58
+ "log_dir.mkdir(exist_ok=True, parents=True)\n",
59
+ "\n",
60
+ "config = {\n",
61
+ " 'encoder_id': 'jonatasgrosman/wav2vec2-large-xlsr-53-german',\n",
62
+ " 'decoder_id': 'dbmdz/german-gpt2',\n",
63
+ " 'decoder_pad_token': '_',\n",
64
+ " 'decoder_bos_token': '~',\n",
65
+ " 'num_beams': 1,\n",
66
+ " 'num_val_examples': 1500,\n",
67
+ " 'batch_size': 8,\n",
68
+ " 'base_lr': 3e-4,\n",
69
+ " 'weight_decay': 0.,\n",
70
+ " 'accumulate_grad': 4,\n",
71
+ " 'max_epochs': 10,\n",
72
+ " 'max_len': 36 # len(max(tokenizer(common_voice['validation']['sentence'] + common_voice['test']['sentence']).input_ids, key=len))"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "6c632a61",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "tokenizer = AutoTokenizer.from_pretrained(config['decoder_id'])\n",
83
+ "tokenizer.add_special_tokens({'pad_token': config['decoder_pad_token'], 'bos_token': config['decoder_bos_token']})\n",
84
+ "\n",
85
+ "wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['encoder_id'])\n",
86
+ "\n",
87
+ "model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(\n",
88
+ " config['encoder_id'], config['decoder_id'], max_length=config['max_len'], num_beams=config['num_beams']\n",
89
+ ")\n",
90
+ "\n",
91
+ "model.config.decoder_start_token_id = tokenizer.bos_token_id\n",
92
+ "model.config.pad_token_id = tokenizer.pad_token_id"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "id": "30e5b73c",
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "# Load model from decoder-only training.\n",
103
+ "model.load_state_dict(torch.load('decoder_only/de/00/model.pt'))"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "5466e908",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "class AudioDataset(Dataset):\n",
114
+ " \n",
115
+ " def __init__(self, ds):\n",
116
+ " self.ds = ds\n",
117
+ " \n",
118
+ " def __len__(self):\n",
119
+ " return len(self.ds)\n",
120
+ " \n",
121
+ " def __getitem__(self, idx):\n",
122
+ " eg = self.ds[idx]\n",
123
+ " return eg['audio']['array'], eg['sentence']\n",
124
+ " \n",
125
+ "def collate_fn(examples):\n",
126
+ " # Remove the longest examples, should be only three and these may lead to OOM- or Index-Errors.\n",
127
+ " examples = [eg for eg in examples if len(eg[0]) < 300_000]\n",
128
+ " \n",
129
+ " audio_features = wave2vec_extractor(\n",
130
+ " [eg[0] for eg in examples], sampling_rate=16_000, return_tensors='pt', padding='longest'\n",
131
+ " ).input_values\n",
132
+ " \n",
133
+ " input_ids = tokenizer(\n",
134
+ " [eg[1] for eg in examples], return_tensors='pt', padding=True\n",
135
+ " ).input_ids\n",
136
+ " \n",
137
+ " return audio_features, input_ids"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "id": "0453ccc1",
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "train = common_voice['train'].cast_column('audio', Audio(sampling_rate=16_000))\n",
148
+ "val = common_voice['validation'].cast_column('audio', Audio(sampling_rate=16_000))"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "ad81c9ab",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "random.seed(419)\n",
159
+ "val_inds = list(range(len(common_voice['validation'])))\n",
160
+ "random.shuffle(val_inds)\n",
161
+ "\n",
162
+ "train_ds = AudioDataset(concatenate_datasets([train, val.select(val_inds[config['num_val_examples']:])]))\n",
163
+ "val_ds = AudioDataset(val.select(val_inds[:config['num_val_examples']]))\n",
164
+ "\n",
165
+ "train_dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=4)\n",
166
+ "val_dl = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=4)"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "f0d1c290",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "accelerator = Accelerator(fp16=True)\n",
177
+ "print(f'Using {accelerator.device}.')"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "2af1f2f1",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "optimizer = AdamW(model.parameters(), lr=config['base_lr'], weight_decay=config['weight_decay'])"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "id": "6921d32c",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "id": "d699c404",
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "with open(log_dir / 'config.json', 'w') as config_file:\n",
208
+ " json.dump(config, config_file, indent=4)\n",
209
+ " \n",
210
+ "writer = SummaryWriter(log_dir)\n",
211
+ "val_golds = common_voice['validation'].select(val_inds[:config['num_val_examples']])['sentence']\n",
212
+ "best_val_wer = 10.\n",
213
+ "global_train_step = 0\n",
214
+ "\n",
215
+ "for epoch in range(config['max_epochs']):\n",
216
+ " model.train()\n",
217
+ " for batch_step, (audio_features, input_ids) in enumerate(train_dl):\n",
218
+ " global_train_step += 1\n",
219
+ " \n",
220
+ " out = model(labels=input_ids, input_values=audio_features)\n",
221
+ " accelerator.backward(out.loss)\n",
222
+ " writer.add_scalar('train_loss', out.loss.item(), global_train_step)\n",
223
+ " \n",
224
+ " if (batch_step + 1) % config['accumulate_grad'] == 0:\n",
225
+ " optimizer.step()\n",
226
+ " optimizer.zero_grad()\n",
227
+ " if batch_step % 300 == 0:\n",
228
+ " print(out.loss.item())\n",
229
+ " \n",
230
+ " model.eval()\n",
231
+ " val_preds = []\n",
232
+ " for audio_features, input_ids in val_dl:\n",
233
+ " with torch.no_grad():\n",
234
+ " generated = model.generate(audio_features)\n",
235
+ " val_preds += tokenizer.batch_decode(generated)\n",
236
+ " val_preds = [pred.lstrip('~').rstrip('_') for pred in val_preds]\n",
237
+ " wer = calculate_wer(val_preds, val_golds)\n",
238
+ " writer.add_scalar('val_wer', wer, epoch)\n",
239
+ " print('WER: ', wer)\n",
240
+ " \n",
241
+ " if wer < best_val_wer:\n",
242
+ " torch.save(model.state_dict(), model_dir / 'model.pt')\n",
243
+ " print('Saved Model.')\n",
244
+ " best_val_wer = wer"
245
+ ]
246
+ }
247
+ ],
248
+ "metadata": {
249
+ "kernelspec": {
250
+ "display_name": "Python 3 (ipykernel)",
251
+ "language": "python",
252
+ "name": "python3"
253
+ },
254
+ "language_info": {
255
+ "codemirror_mode": {
256
+ "name": "ipython",
257
+ "version": 3
258
+ },
259
+ "file_extension": ".py",
260
+ "mimetype": "text/x-python",
261
+ "name": "python",
262
+ "nbconvert_exporter": "python",
263
+ "pygments_lexer": "ipython3",
264
+ "version": "3.9.7"
265
+ }
266
+ },
267
+ "nbformat": 4,
268
+ "nbformat_minor": 5
269
+ }
training/model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import SpeechEncoderDecoderModel
3
+ from torch import nn
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers.models.encoder_decoder.modeling_encoder_decoder import shift_tokens_right
6
+ from transformers.modeling_outputs import Seq2SeqLMOutput
7
+
8
+ class Wav2VecGPT2Model(SpeechEncoderDecoderModel):
9
+ """
10
+ Basically the same as `SpeechEncoderDecoderModel` but position embeddings (initialized with GPT2's position
11
+ embeddings) are added to encoder output
12
+ """
13
+ def __init__(self, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ self.encoder_outputs_pos_emb = nn.Embedding(1024, self.decoder.config.hidden_size)
16
+ with torch.no_grad():
17
+ self.encoder_outputs_pos_emb.weight.copy_(self.decoder.transformer.wpe.weight)
18
+ self.enc_to_dec_proj_ln = nn.LayerNorm(self.decoder.config.hidden_size,
19
+ eps=self.decoder.config.layer_norm_epsilon)
20
+
21
+ def __getattribute__(self, name):
22
+ # Fake class so it is recognized as seq2seq model.
23
+ if name == '__class__':
24
+ return SpeechEncoderDecoderModel
25
+ return SpeechEncoderDecoderModel.__getattribute__(self, name)
26
+
27
+ def forward(
28
+ self,
29
+ inputs=None,
30
+ attention_mask=None,
31
+ decoder_input_ids=None,
32
+ decoder_attention_mask=None,
33
+ encoder_outputs=None,
34
+ past_key_values=None,
35
+ decoder_inputs_embeds=None,
36
+ labels=None,
37
+ use_cache=None,
38
+ output_attentions=None,
39
+ output_hidden_states=None,
40
+ input_values=None,
41
+ input_features=None,
42
+ return_dict=None,
43
+ **kwargs,
44
+ ):
45
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
46
+
47
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
48
+
49
+ kwargs_decoder = {
50
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
51
+ }
52
+
53
+ if encoder_outputs is None and inputs is None:
54
+ if input_values is not None and input_features is not None:
55
+ raise ValueError("You cannot specify both input_values and input_features at the same time")
56
+ elif input_values is not None:
57
+ inputs = input_values
58
+ elif input_features is not None:
59
+ inputs = input_features
60
+ else:
61
+ raise ValueError("You have to specify either input_values or input_features")
62
+
63
+ encoder_outputs = self.encoder(
64
+ inputs,
65
+ attention_mask=attention_mask,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
+ **kwargs_encoder,
70
+ )
71
+
72
+ encoder_hidden_states = encoder_outputs[0]
73
+
74
+ # optionally project encoder_hidden_states
75
+ if (
76
+ self.encoder_output_dim != self.decoder.config.hidden_size
77
+ and self.decoder.config.cross_attention_hidden_size is None
78
+ ):
79
+ # TODO: Truncate and warn if the sequence length is greater than 1024!
80
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
81
+ encoder_hidden_states += self.encoder_outputs_pos_emb(
82
+ torch.arange(0, encoder_hidden_states.shape[1], device=encoder_hidden_states.device)
83
+ )
84
+ encoder_hidden_states = self.enc_to_dec_proj_ln(encoder_hidden_states)
85
+
86
+ # compute correct encoder attention mask
87
+ if attention_mask is not None:
88
+ encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
89
+ encoder_hidden_states.shape[1], attention_mask
90
+ )
91
+ else:
92
+ encoder_attention_mask = None
93
+
94
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
95
+ decoder_input_ids = shift_tokens_right(
96
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
97
+ )
98
+
99
+ # Decode
100
+ decoder_outputs = self.decoder(
101
+ input_ids=decoder_input_ids,
102
+ attention_mask=decoder_attention_mask,
103
+ encoder_hidden_states=encoder_hidden_states,
104
+ encoder_attention_mask=encoder_attention_mask,
105
+ inputs_embeds=decoder_inputs_embeds,
106
+ output_attentions=output_attentions,
107
+ output_hidden_states=output_hidden_states,
108
+ use_cache=use_cache,
109
+ past_key_values=past_key_values,
110
+ return_dict=return_dict,
111
+ **kwargs_decoder,
112
+ )
113
+
114
+ # Compute loss independent from decoder (as some shift the logits inside them)
115
+ loss = None
116
+ if labels is not None:
117
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
118
+ loss_fct = CrossEntropyLoss()
119
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
120
+
121
+ if not return_dict:
122
+ if loss is not None:
123
+ return (loss,) + decoder_outputs + encoder_outputs
124
+ else:
125
+ return decoder_outputs + encoder_outputs
126
+
127
+ return Seq2SeqLMOutput(
128
+ loss=loss,
129
+ logits=decoder_outputs.logits,
130
+ past_key_values=decoder_outputs.past_key_values,
131
+ decoder_hidden_states=decoder_outputs.hidden_states,
132
+ decoder_attentions=decoder_outputs.attentions,
133
+ cross_attentions=decoder_outputs.cross_attentions,
134
+ encoder_last_hidden_state=encoder_outputs[0],
135
+ encoder_hidden_states=getattr(encoder_outputs, 'hidden_states', None), # TODO: only temporary (inconsistant)
136
+ encoder_attentions=getattr(encoder_outputs, 'attentions', None),
137
+ )
training/wer.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jiwer
2
+
3
+ def calculate_wer(predictions, golds):
4
+
5
+ transformation = jiwer.Compose([
6
+ jiwer.ToLowerCase(),
7
+ jiwer.RemoveWhiteSpace(replace_by_space=True),
8
+ jiwer.RemoveMultipleSpaces(),
9
+ jiwer.Strip(),
10
+ jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
11
+ ])
12
+ return jiwer.wer(golds, predictions, truth_transform=transformation, hypothesis_transform=transformation)