Sin2pi commited on
Commit
effaa91
1 Parent(s): 99b9c53

Delete model.ipynb

Browse files
Files changed (1) hide show
  1. model.ipynb +0 -1183
model.ipynb DELETED
@@ -1,1183 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import base64, gzip, evaluate, math, os, sys, time\n",
10
- "import gzip, neologdn\n",
11
- "import collections\n",
12
- "import copy\n",
13
- "import functools\n",
14
- "from functools import partial, wraps\n",
15
- "from threading import Thread\n",
16
- "import gc\n",
17
- "import importlib.metadata\n",
18
- "import inspect\n",
19
- "import itertools\n",
20
- "import torch\n",
21
- "from torch import amp, Tensor, optim\n",
22
- "from torch.utils.checkpoint import checkpoint\n",
23
- "from contextlib import contextmanager\n",
24
- "from dataclasses import dataclass\n",
25
- "from transformers.models.whisper.modeling_whisper import WhisperPreTrainedModel\n",
26
- "from transformers.models.whisper.generation_whisper import WhisperGenerationMixin\n",
27
- "from transformers.optimization import Adafactor, AdafactorSchedule\n",
28
- "from huggingface_hub import PyTorchModelHubMixin\n",
29
- "from datasets import IterableDatasetDict, Audio, load_dataset, load_from_disk\n",
30
- "import numpy as np\n",
31
- "import torch, transformers, warnings\n",
32
- "from typing import Dict, Iterable, Optional, Tuple, Union, List, Any, Type\n",
33
- "import torch.nn.functional as F\n",
34
- "from torch import Tensor, nn\n",
35
- "import torchaudio, torchaudio.transforms as T\n",
36
- "from transformers import Seq2SeqTrainer, TrainerCallback, Seq2SeqTrainingArguments, WhisperTokenizer, WhisperForConditionalGeneration, WhisperConfig, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration\n",
37
- "from whisper.decoding import decode as decode_function\n",
38
- "from whisper.decoding import detect_language as detect_language_function\n",
39
- "from whisper.transcribe import transcribe as transcribe_function\n",
40
- "\n",
41
- "try:\n",
42
- " from torch.nn.functional import scaled_dot_product_attention\n",
43
- "\n",
44
- " SDPA_AVAILABLE = True\n",
45
- "except (ImportError, RuntimeError, OSError):\n",
46
- " scaled_dot_product_attention = None\n",
47
- " SDPA_AVAILABLE = False\n",
48
- "\n",
49
- "transformers.utils.logging.set_verbosity_error()\n",
50
- "warnings.filterwarnings(action=\"ignore\")\n",
51
- "warnings.warn = lambda *args,**kwargs: None\n",
52
- "\n",
53
- "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
54
- "\n",
55
- "dtype = torch.float32\n"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": null,
61
- "metadata": {},
62
- "outputs": [],
63
- "source": [
64
- "### Model ###\n",
65
- "\n",
66
- "\n",
67
- "\n",
68
- "class LayerNorm(nn.Module):\n",
69
- " def __init__(self, num_features, eps=1e-6):\n",
70
- " super(LayerNorm, self).__init__()\n",
71
- " self.gamma = nn.Parameter(torch.ones(num_features))\n",
72
- " self.beta = nn.Parameter(torch.zeros(num_features))\n",
73
- " self.eps = eps\n",
74
- "\n",
75
- " def forward(self, x):\n",
76
- " mean = x.mean(dim=-1, keepdim=True)\n",
77
- " std = x.std(dim=-1, keepdim=True)\n",
78
- " x = (x - mean) / (std + self.eps)\n",
79
- " return self.gamma * x + self.beta\n",
80
- "\n",
81
- "class Linear(nn.Module):\n",
82
- " def __init__(self, in_features: int, out_features: int, dropout_rate = 0.001, use_batchnorm: bool = True, activation: str = 'relu'):\n",
83
- " super(Linear, self).__init__()\n",
84
- " self.linear = nn.Linear(in_features, out_features)\n",
85
- " self.dropout = nn.Dropout(dropout_rate)\n",
86
- " self.use_batchnorm = use_batchnorm\n",
87
- " self.activation = activation\n",
88
- "\n",
89
- " if self.use_batchnorm:\n",
90
- " self.batchnorm = nn.BatchNorm1d(out_features)\n",
91
- " self.reset_parameters()\n",
92
- "\n",
93
- " def reset_parameters(self):\n",
94
- " nn.init.kaiming_uniform_(self.linear.weight, nonlinearity=self.activation)\n",
95
- " if self.linear.bias is not None:\n",
96
- " nn.init.zeros_(self.linear.bias)\n",
97
- "\n",
98
- " def forward(self, x):\n",
99
- " batch_size, seq_len, _ = x.size()\n",
100
- " x = x.view(-1, x.size(-1)) \n",
101
- " x = self.linear(x)\n",
102
- "\n",
103
- " if self.use_batchnorm:\n",
104
- " x = self.batchnorm(x)\n",
105
- "\n",
106
- " x = self.apply_activation(x)\n",
107
- " x = self.dropout(x)\n",
108
- " x = x.view(batch_size, seq_len, -1) \n",
109
- " \n",
110
- " return x\n",
111
- "\n",
112
- " def apply_activation(self, x):\n",
113
- " if self.activation == 'relu':\n",
114
- " return F.relu(x)\n",
115
- " elif self.activation == 'tanh':\n",
116
- " return torch.tanh(x)\n",
117
- " elif self.activation == 'sigmoid':\n",
118
- " return torch.sigmoid(x)\n",
119
- " else:\n",
120
- " raise ValueError(f'Unsupported activation function: {self.activation}')\n",
121
- "\n",
122
- "class Conv1d(nn.Conv1d):\n",
123
- " def __init__(self, *args, **kwargs):\n",
124
- " super().__init__(*args, **kwargs)\n",
125
- " self.reset_parameters()\n",
126
- "\n",
127
- " def reset_parameters(self):\n",
128
- " nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')\n",
129
- " if self.bias is not None:\n",
130
- " nn.init.zeros_(self.bias)\n",
131
- "\n",
132
- " def _conv_forward(self, x, weight, bias) -> Tensor:\n",
133
- " weight = self.weight.to(x.dtype)\n",
134
- " bias = None if self.bias is None else self.bias.to(x.dtype)\n",
135
- " return super()._conv_forward(x, weight, bias)\n",
136
- "\n",
137
- "class BiasedCrossAttention(nn.Module):\n",
138
- " def __init__(self, n_state, n_head, dropout_rate=0.001):\n",
139
- " super().__init__()\n",
140
- " self.n_head = n_head\n",
141
- " self.n_state = n_state\n",
142
- " self.head_dim = n_state // n_head\n",
143
- "\n",
144
- " self.query = nn.Linear(n_state, n_state)\n",
145
- " self.key = nn.Linear(n_state, n_state, bias=False)\n",
146
- " self.value = nn.Linear(n_state, n_state)\n",
147
- " self.out = nn.Linear(n_state, n_state)\n",
148
- "\n",
149
- " self.bias = nn.Parameter(torch.zeros(n_head, 1, self.head_dim))\n",
150
- " self.dropout = nn.Dropout(dropout_rate)\n",
151
- " self.norm = LayerNorm(n_state)\n",
152
- " \n",
153
- " def forward(self, q, k, v, mask=None):\n",
154
- " batch_size, seq_length, _ = q.size()\n",
155
- "\n",
156
- " q = self.query(q).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
157
- " k = self.key(k).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
158
- " v = self.value(v).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
159
- "\n",
160
- " qk = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.bias\n",
161
- " if mask is not None:\n",
162
- " qk = qk.masked_fill(mask == 0, float('-inf'))\n",
163
- "\n",
164
- " w = F.softmax(qk, dim=-1)\n",
165
- " w = self.dropout(w)\n",
166
- "\n",
167
- " out = (w @ v).transpose(1, 2).contiguous().view(batch_size, seq_length, -1)\n",
168
- " out = self.norm(self.out(out) + q.view(batch_size, seq_length, -1))\n",
169
- " return out\n",
170
- "\n",
171
- "class DynamicConvAttention(nn.Module):\n",
172
- " def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.001):\n",
173
- " super().__init__()\n",
174
- " self.n_state = n_state\n",
175
- " self.n_head = n_head\n",
176
- " self.kernel_size = kernel_size\n",
177
- "\n",
178
- " self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size // 2, groups=n_head)\n",
179
- " self.dropout = nn.Dropout(dropout_rate)\n",
180
- "\n",
181
- " self.query = nn.Linear(n_state, n_state)\n",
182
- " self.key = nn.Linear(n_state, n_state, bias=False)\n",
183
- " self.value = nn.Linear(n_state, n_state)\n",
184
- " self.out_proj = nn.Linear(n_state, n_state)\n",
185
- "\n",
186
- " self.norm = LayerNorm(n_state)\n",
187
- "\n",
188
- " def forward(self, x):\n",
189
- " batch_size, seq_len, embed_dim = x.size()\n",
190
- " if embed_dim != self.n_state:\n",
191
- " raise ValueError(f\"Expected embed_dim of {self.n_state}, but got {embed_dim}\")\n",
192
- "\n",
193
- " q = self.query(x)\n",
194
- " k = self.key(x)\n",
195
- " v = self.value(x)\n",
196
- "\n",
197
- " x = x.permute(0, 2, 1)\n",
198
- " conv_out = self.conv(x)\n",
199
- " conv_out = conv_out.permute(0, 2, 1)\n",
200
- " conv_out = self.norm(conv_out)\n",
201
- " conv_out = self.dropout(conv_out)\n",
202
- "\n",
203
- " attention_out = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (self.n_state ** 0.5), dim=-1)\n",
204
- " attention_out = torch.matmul(attention_out, v)\n",
205
- " \n",
206
- " combined_out = conv_out + attention_out\n",
207
- " combined_out = self.norm(combined_out)\n",
208
- " \n",
209
- " return self.out_proj(self.dropout(combined_out)) + x.permute(0, 2, 1)\n",
210
- "\n",
211
- "class HybridAttention(nn.Module):\n",
212
- " def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.001):\n",
213
- " super().__init__()\n",
214
- " self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)\n",
215
- " self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)\n",
216
- " self.ln_local = LayerNorm(n_state)\n",
217
- " self.ln_global = LayerNorm(n_state)\n",
218
- "\n",
219
- " self.dropout = nn.Dropout(dropout_rate)\n",
220
- " self.window_size = window_size\n",
221
- "\n",
222
- " def forward(self, x):\n",
223
- " x_local = self.ln_local(x)\n",
224
- " x_global = self.ln_global(x)\n",
225
- " x_local = x_local.permute(1, 0, 2)\n",
226
- " x_global = x_global.permute(1, 0, 2)\n",
227
- " local_out = self.sliding_window_attention(x_local)\n",
228
- " global_out, _ = self.global_attn(x_global, x_global, x_global)\n",
229
- " combined_out = local_out + global_out\n",
230
- " combined_out = combined_out.permute(1, 0, 2)\n",
231
- " return self.dropout(combined_out)\n",
232
- "\n",
233
- " def sliding_window_attention(self, x):\n",
234
- " batch_size, seq_len, n_state = x.size()\n",
235
- " window_size = min(self.window_size, max(1, seq_len // 4))\n",
236
- " output = torch.zeros_like(x, device=x.device, dtype=x.dtype)\n",
237
- "\n",
238
- " for i in range(0, seq_len, window_size):\n",
239
- " end = min(i + window_size, seq_len)\n",
240
- " query = x[i:end, :, :]\n",
241
- " start = max(0, i - window_size)\n",
242
- " key = x[start:end, :, :]\n",
243
- " value = x[start:end, :, :]\n",
244
- " attn_output, _ = self.local_attn(query, key, value)\n",
245
- " output[i:end, :, :] = attn_output[:end - i, :, :]\n",
246
- "\n",
247
- " return output\n",
248
- "\n",
249
- "# def givens_rotation_matrix(n_state, i, j, theta):\n",
250
- "# G = torch.eye(n_state)\n",
251
- "# G[i, i] = math.cos(theta)\n",
252
- "# G[i, j] = -math.sin(theta)\n",
253
- "# G[j, i] = math.sin(theta)\n",
254
- "# G[j, j] = math.cos(theta)\n",
255
- "# return G\n",
256
- "\n",
257
- "# class GivensRotations(nn.Module):\n",
258
- "# def __init__(self, h_dim, num_rotations):\n",
259
- "# super().__init__()\n",
260
- "# self.h_dim = h_dim\n",
261
- "# self.num_rotations = num_rotations\n",
262
- "# self.thetas = nn.Parameter(torch.zeros(num_rotations))\n",
263
- "\n",
264
- "# def forward(self, x):\n",
265
- "# if x.dim() != 4:\n",
266
- "# raise ValueError(f\"Expected input tensor to be 4D, but got {x.dim()}D\")\n",
267
- " \n",
268
- "# batch_size, seq_len, n_head, h_dim = x.size()\n",
269
- " \n",
270
- "# if h_dim != self.h_dim:\n",
271
- "# raise ValueError(f\"Expected h_dim of {self.h_dim}, but got {h_dim}\")\n",
272
- " \n",
273
- "# x = x.view(-1, h_dim) \n",
274
- "# for k in range(self.num_rotations):\n",
275
- "# i, j = k % self.h_dim, (k + 1) % self.h_dim\n",
276
- "# G = givens_rotation_matrix(self.h_dim, i, j, self.thetas[k])\n",
277
- "# x = torch.matmul(x, G.to(x.device))\n",
278
- " \n",
279
- "# x = x.view(batch_size, seq_len, n_head, h_dim) \n",
280
- "# return x\n",
281
- "## old\n",
282
- "# class RotaryEmbeddingWithRotation(nn.Module):\n",
283
- "# def __init__(self, n_state, n_head, checkpointing=False, base=10000):\n",
284
- "# super().__init__()\n",
285
- "# self.n_state = n_state\n",
286
- "# self.n_head = n_head\n",
287
- "# self.h_dim = n_state // n_head\n",
288
- "# self.base = base\n",
289
- "# self.checkpointing = checkpointing\n",
290
- "\n",
291
- "# self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))\n",
292
- "# inv_freq = 1.0 / (base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n",
293
- "# self.register_buffer('inv_freq', inv_freq)\n",
294
- "\n",
295
- "# def update_base(self, new_base):\n",
296
- "# self.base = new_base\n",
297
- "# inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n",
298
- "# self.register_buffer('inv_freq', inv_freq)\n",
299
- "\n",
300
- "# def reset_parameters(self):\n",
301
- "# nn.init.orthogonal_(self.rotation_matrix)\n",
302
- "\n",
303
- "# def forward(self, x):\n",
304
- "# if self.checkpointing:\n",
305
- "# return checkpoint(self._forward, x)\n",
306
- "# else:\n",
307
- "# return self._forward(x)\n",
308
- "\n",
309
- "# def _forward(self, x):\n",
310
- "# if x.dim() == 3:\n",
311
- "# batch_size, seq_len, n_state = x.size()\n",
312
- "# elif x.dim() == 4:\n",
313
- "# batch_size, seq_len, n_head, h_dim = x.size()\n",
314
- "# n_state = n_head * h_dim\n",
315
- "# x = x.view(batch_size, seq_len, n_state)\n",
316
- "# else:\n",
317
- "# raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n",
318
- "\n",
319
- "# if n_state != self.n_state:\n",
320
- "# raise ValueError(f\"Expected n_state of {self.n_state}, but got {n_state}\")\n",
321
- "\n",
322
- "# x = x.reshape(batch_size, seq_len, self.n_head, self.h_dim)\n",
323
- "# x = x.reshape(-1, self.h_dim)\n",
324
- "# rotated_x = torch.matmul(x, self.rotation_matrix)\n",
325
- "# rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_head, self.h_dim)\n",
326
- "\n",
327
- "# sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))\n",
328
- "# sin = sinusoid_inp.sin()[None, :, None, :]\n",
329
- "# cos = sinusoid_inp.cos()[None, :, None, :]\n",
330
- "# x1, x2 = rotated_x[..., ::2], rotated_x[..., 1::2]\n",
331
- "# rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n",
332
- " \n",
333
- "# rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_state)\n",
334
- "# return rotated_x\n",
335
- "\n",
336
- "# ## new\n",
337
- "# class CombinedRotaryEmbedding(nn.Module):\n",
338
- "# def __init__(self, n_state, n_head, num_rotations, base=10000, checkpointing=False):\n",
339
- "# super().__init__()\n",
340
- "# self.n_state = n_state\n",
341
- "# self.n_head = n_head\n",
342
- "# self.h_dim = n_state // n_head\n",
343
- "# self.num_rotations = num_rotations\n",
344
- "# self.base = base\n",
345
- "# self.checkpointing = checkpointing\n",
346
- " \n",
347
- "# self.thetas = nn.Parameter(torch.zeros(num_rotations))\n",
348
- "# self.rotation_pairs = nn.Parameter(torch.rand(num_rotations, 2) * self.h_dim)\n",
349
- "\n",
350
- "# self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))\n",
351
- " \n",
352
- "# self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
353
- " \n",
354
- "# def givens_rotation_matrix(self, n_state, i, j, theta):\n",
355
- "# G = torch.eye(n_state, device=theta.device)\n",
356
- "# G[i, i] = math.cos(theta)\n",
357
- "# G[i, j] = -math.sin(theta)\n",
358
- "# G[j, i] = math.sin(theta)\n",
359
- "# G[j, j] = math.cos(theta)\n",
360
- "# return G\n",
361
- " \n",
362
- "# def update_base(self, new_base):\n",
363
- "# self.base = new_base\n",
364
- "# self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
365
- " \n",
366
- "# def reset_parameters(self):\n",
367
- "# nn.init.orthogonal_(self.rotation_matrix)\n",
368
- "# nn.init.zeros_(self.thetas)\n",
369
- " \n",
370
- "# def forward(self, x):\n",
371
- "# if self.checkpointing:\n",
372
- "# return checkpoint(self._forward, x)\n",
373
- "# else:\n",
374
- "# return self._forward(x)\n",
375
- " \n",
376
- "# def _forward(self, x):\n",
377
- "# if x.dim() not in [3, 4]:\n",
378
- "# raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n",
379
- " \n",
380
- "# if x.dim() == 3:\n",
381
- "# batch_size, seq_len, n_state = x.size()\n",
382
- "# x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
383
- "# else:\n",
384
- "# batch_size, seq_len, n_head, h_dim = x.size()\n",
385
- "# if n_head != self.n_head or h_dim != self.h_dim:\n",
386
- "# raise ValueError(f\"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}\")\n",
387
- " \n",
388
- "# x = x.reshape(-1, self.h_dim)\n",
389
- " \n",
390
- "# for k in range(self.num_rotations):\n",
391
- "# i, j = self.rotation_pairs[k].long()\n",
392
- "# theta = self.thetas[k]\n",
393
- "# G = self.givens_rotation_matrix(self.h_dim, i, j, theta)\n",
394
- "# x = torch.matmul(x, G)\n",
395
- " \n",
396
- "# x = torch.matmul(x, self.rotation_matrix)\n",
397
- " \n",
398
- "# x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
399
- " \n",
400
- "# sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq)\n",
401
- "# sin = sinusoid_inp.sin()[None, :, None, :]\n",
402
- "# cos = sinusoid_inp.cos()[None, :, None, :]\n",
403
- " \n",
404
- "# x1, x2 = x[..., ::2], x[..., 1::2]\n",
405
- "# x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n",
406
- " \n",
407
- "# x = x.view(batch_size, seq_len, self.n_state)\n",
408
- " \n",
409
- "# return x\n",
410
- "\n",
411
- "class CombinedRotaryEmbedding(nn.Module):\n",
412
- " def __init__(self, n_state, n_head, num_rotations, base=10000, checkpointing=False):\n",
413
- " super().__init__()\n",
414
- " self.n_state = n_state\n",
415
- " self.n_head = n_head\n",
416
- " self.h_dim = n_state // n_head\n",
417
- " self.num_rotations = num_rotations\n",
418
- " self.base = base\n",
419
- " self.checkpointing = checkpointing\n",
420
- " \n",
421
- " self.thetas = nn.Parameter(torch.zeros(num_rotations))\n",
422
- " self.rotation_pairs = nn.Parameter(torch.rand(num_rotations, 2) * self.h_dim)\n",
423
- "\n",
424
- " # Add a scaling factor for thetas\n",
425
- " self.theta_scale = nn.Parameter(torch.ones(1)) \n",
426
- "\n",
427
- " self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))\n",
428
- " \n",
429
- " self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
430
- " \n",
431
- " def givens_rotation_matrix(self, n_state, i, j, theta):\n",
432
- " G = torch.eye(n_state, device=theta.device)\n",
433
- " G[i, i] = math.cos(theta)\n",
434
- " G[i, j] = -math.sin(theta)\n",
435
- " G[j, i] = math.sin(theta)\n",
436
- " G[j, j] = math.cos(theta)\n",
437
- " return G\n",
438
- " \n",
439
- " def update_base(self, new_base):\n",
440
- " self.base = new_base\n",
441
- " self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
442
- " \n",
443
- " def reset_parameters(self):\n",
444
- " nn.init.orthogonal_(self.rotation_matrix)\n",
445
- " nn.init.zeros_(self.thetas)\n",
446
- " \n",
447
- " def forward(self, x):\n",
448
- " if self.checkpointing:\n",
449
- " return checkpoint(self._forward, x)\n",
450
- " else:\n",
451
- " return self._forward(x)\n",
452
- " \n",
453
- " def _forward(self, x):\n",
454
- " if x.dim() not in [3, 4]:\n",
455
- " raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n",
456
- " \n",
457
- " if x.dim() == 3:\n",
458
- " batch_size, seq_len, n_state = x.size()\n",
459
- " x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
460
- " else:\n",
461
- " batch_size, seq_len, n_head, h_dim = x.size()\n",
462
- " if n_head != self.n_head or h_dim != self.h_dim:\n",
463
- " raise ValueError(f\"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}\")\n",
464
- " \n",
465
- " x = x.reshape(-1, self.h_dim)\n",
466
- " \n",
467
- " for k in range(self.num_rotations):\n",
468
- " i, j = self.rotation_pairs[k].long()\n",
469
- " \n",
470
- " # Apply the scaling factor to theta\n",
471
- " theta = self.thetas[k] * self.theta_scale \n",
472
- " \n",
473
- " G = self.givens_rotation_matrix(self.h_dim, i, j, theta)\n",
474
- " x = torch.matmul(x, G)\n",
475
- " \n",
476
- " x = torch.matmul(x, self.rotation_matrix)\n",
477
- " \n",
478
- " x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
479
- " \n",
480
- " sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))\n",
481
- " sin = sinusoid_inp.sin()[None, :, None, :]\n",
482
- " cos = sinusoid_inp.cos()[None, :, None, :]\n",
483
- " \n",
484
- " x1, x2 = x[..., ::2], x[..., 1::2]\n",
485
- " x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n",
486
- " \n",
487
- " x = x.view(batch_size, seq_len, self.n_state)\n",
488
- " \n",
489
- " return x\n",
490
- "\n",
491
- "\n",
492
- "class LearnedSinusoidalEmbeddings(nn.Module):\n",
493
- " def __init__(self, n_ctx, n_state, checkpointing=False):\n",
494
- " super().__init__()\n",
495
- " self.n_ctx = n_ctx\n",
496
- " self.n_state = n_state\n",
497
- " self.checkpointing = checkpointing\n",
498
- "\n",
499
- " position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)\n",
500
- " div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))\n",
501
- " features = torch.zeros(n_ctx, n_state)\n",
502
- " features[:, 0::2] = torch.sin(position * div_term)\n",
503
- " features[:, 1::2] = torch.cos(position * div_term)\n",
504
- " self.register_buffer('sinusoidal_features', features)\n",
505
- "\n",
506
- " self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())\n",
507
- "\n",
508
- " def forward(self, positions):\n",
509
- " if self.checkpointing:\n",
510
- " position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)\n",
511
- " else:\n",
512
- " position_embeddings = self.positional_embeddings[positions]\n",
513
- "\n",
514
- " position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)\n",
515
- " return position_embeddings\n",
516
- "\n",
517
- "class MultiHeadAttention(nn.Module):\n",
518
- " use_sdpa = True\n",
519
- "\n",
520
- " def __init__(self, n_state: int, n_head: int, max_rel_dist: int = 1, base: int = 10000):\n",
521
- " super().__init__()\n",
522
- " assert n_state % n_head == 0, \"n_state must be divisible by n_head\"\n",
523
- " self.n_head = n_head\n",
524
- " self.h_dim = n_state // n_head\n",
525
- " assert self.h_dim % 2 == 0, \"Head dimension must be even for rotary embeddings\"\n",
526
- "\n",
527
- " self.positional_scaling = nn.Parameter(torch.ones(1))\n",
528
- "\n",
529
- " self.query = nn.Linear(n_state, n_state)\n",
530
- " self.key = nn.Linear(n_state, n_state, bias=False)\n",
531
- " self.value = nn.Linear(n_state, n_state)\n",
532
- " self.out = nn.Linear(n_state, n_state)\n",
533
- "\n",
534
- " self.max_rel_dist = max_rel_dist\n",
535
- " self.base = base\n",
536
- " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n",
537
- " self.register_buffer('inv_freq', inv_freq)\n",
538
- " self.rel_pos_bias = nn.Embedding(2 * self.max_rel_dist - 1, self.n_head)\n",
539
- " self.rel_pos_bias.weight.data.fill_(0)\n",
540
- "\n",
541
- " self.combined_rotary = CombinedRotaryEmbedding(\n",
542
- " n_state=n_state,\n",
543
- " n_head=n_head,\n",
544
- " num_rotations=self.h_dim // 2,\n",
545
- " base=base,\n",
546
- " checkpointing=False \n",
547
- " )\n",
548
- "\n",
549
- " if device:\n",
550
- " self.to(device)\n",
551
- "\n",
552
- " def update_base(self, new_base): \n",
553
- " self.base = new_base \n",
554
- " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)) \n",
555
- " self.register_buffer('inv_freq', inv_freq) \n",
556
- " self.combined_rotary.update_base(new_base)\n",
557
- "\n",
558
- " def forward(self, x, xa = None, mask = None, kv_cache = None):\n",
559
- " q = self.query(x)\n",
560
- "\n",
561
- " if kv_cache is None or xa is None or 'k' not in kv_cache:\n",
562
- " k_input = x if xa is None else xa\n",
563
- " k = self.key(k_input)\n",
564
- " v = self.value(k_input)\n",
565
- " if kv_cache is not None:\n",
566
- " kv_cache['k'] = k\n",
567
- " kv_cache['v'] = v\n",
568
- " else:\n",
569
- " k = kv_cache['k']\n",
570
- " v = kv_cache['v']\n",
571
- "\n",
572
- " q = q.view(q.shape[0], q.shape[1], self.n_head, -1)\n",
573
- " k = k.view(k.shape[0], k.shape[1], self.n_head, -1)\n",
574
- " v = v.view(v.shape[0], v.shape[1], self.n_head, -1)\n",
575
- "\n",
576
- " q = self.combined_rotary(q) \n",
577
- " k = self.combined_rotary(k)\n",
578
- "\n",
579
- " q = q.view(q.shape[0], q.shape[1], -1)\n",
580
- " k = k.view(k.shape[0], k.shape[1], -1)\n",
581
- "\n",
582
- " wv, qk = self.qkv_attention(q, k, v, mask)\n",
583
- " return self.out(wv), qk\n",
584
- " \n",
585
- " def qkv_attention(self, q, k, v, mask = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n",
586
- " n_batch, n_ctx, n_state = q.shape\n",
587
- "\n",
588
- " scale = (n_state // self.n_head) ** -0.25\n",
589
- " q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
590
- " k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
591
- " v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
592
- "\n",
593
- " qk = (q * scale) @ (k * scale).transpose(-1, -2)\n",
594
- "\n",
595
- " seq_len_q = q.size(2)\n",
596
- " seq_len_k = k.size(2)\n",
597
- "\n",
598
- " positions = torch.arange(seq_len_q, device=q.device).unsqueeze(1) - torch.arange(seq_len_k, device=q.device).unsqueeze(0)\n",
599
- " positions = positions.clamp(-self.max_rel_dist + 1, self.max_rel_dist - 1) + self.max_rel_dist - 1\n",
600
- " rel_bias = self.rel_pos_bias(positions) \n",
601
- " rel_bias = rel_bias.permute(2, 0, 1).unsqueeze(0) \n",
602
- "\n",
603
- " qk = qk + rel_bias\n",
604
- "\n",
605
- " if mask is not None:\n",
606
- " qk = qk + mask[:n_ctx, :n_ctx]\n",
607
- " qk = qk.float()\n",
608
- "\n",
609
- " w = F.softmax(qk, dim=-1).to(q.dtype)\n",
610
- " out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)\n",
611
- " qk = qk.detach()\n",
612
- "\n",
613
- " return out, qk\n",
614
- "\n",
615
- "class ResidualAttentionBlock(nn.Module):\n",
616
- " def __init__(self, n_state, n_head, cross_attention = False, max_rel_dist = 1, checkpointing=False):\n",
617
- " super().__init__()\n",
618
- "\n",
619
- " self.attn = MultiHeadAttention(n_state, n_head)\n",
620
- " self.attn_ln = LayerNorm(n_state)\n",
621
- " self.checkpointing = checkpointing\n",
622
- " self.max_rel_dist = max_rel_dist\n",
623
- "\n",
624
- " self.cross_attn = (\n",
625
- " MultiHeadAttention(n_state, n_head) if cross_attention else None\n",
626
- " )\n",
627
- " self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None\n",
628
- "\n",
629
- " n_mlp = n_state * 4\n",
630
- " self.mlp = nn.Sequential(\n",
631
- " Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)\n",
632
- " )\n",
633
- " self.mlp_ln = LayerNorm(n_state)\n",
634
- "\n",
635
- " def forward(self, x, xa = None, mask = None, kv_cache = None):\n",
636
- " if self.checkpointing:\n",
637
- " x = checkpoint(self._attn_forward, x, mask, kv_cache)\n",
638
- " else:\n",
639
- " x = self._attn_forward(x, mask, kv_cache)\n",
640
- "\n",
641
- " if self.cross_attn:\n",
642
- " if self.checkpointing:\n",
643
- " x = checkpoint(self._cross_attn_forward, x, xa, kv_cache)\n",
644
- " else:\n",
645
- " x = self._cross_attn_forward(x, xa, kv_cache)\n",
646
- "\n",
647
- " if self.checkpointing:\n",
648
- " x = checkpoint(self._mlp_forward, x)\n",
649
- " else:\n",
650
- " x = self._mlp_forward(x)\n",
651
- "\n",
652
- " return x\n",
653
- "\n",
654
- " def _attn_forward(self, x, mask, kv_cache):\n",
655
- " residual = x\n",
656
- " x = self.attn_ln(x)\n",
657
- " x = residual + self.attn(x, mask=mask, kv_cache=kv_cache)[0]\n",
658
- " return x\n",
659
- "\n",
660
- " def _cross_attn_forward(self, x, xa, kv_cache):\n",
661
- " residual = x\n",
662
- " x = self.cross_attn_ln(x)\n",
663
- " x = residual + self.cross_attn(x, xa, kv_cache=kv_cache)[0]\n",
664
- " return x\n",
665
- "\n",
666
- " def _mlp_forward(self, x):\n",
667
- " residual = x\n",
668
- " x = self.mlp_ln(x)\n",
669
- " x = residual + self.mlp(x)\n",
670
- " return x\n",
671
- "\n",
672
- "class AudioEncoder(nn.Module):\n",
673
- " def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, max_rel_dist = 1, cross_attention=True, checkpointing=False, base=10000):\n",
674
- " super().__init__()\n",
675
- " self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)\n",
676
- " self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)\n",
677
- " self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)\n",
678
- " self.checkpointing = checkpointing\n",
679
- " self.h_dim = n_state // n_head\n",
680
- "\n",
681
- " self.combined_rotary = CombinedRotaryEmbedding(\n",
682
- " n_state=n_state,\n",
683
- " n_head=n_head,\n",
684
- " num_rotations=self.h_dim // 2,\n",
685
- " base=base,\n",
686
- " checkpointing=False \n",
687
- " )\n",
688
- "\n",
689
- " self.blocks = nn.ModuleList(\n",
690
- " [ResidualAttentionBlock(n_state, n_head, max_rel_dist, checkpointing=checkpointing) for _ in range(n_layer)]\n",
691
- " )\n",
692
- " self.ln_post = LayerNorm(n_state)\n",
693
- "\n",
694
- " def update_base(self, new_base):\n",
695
- " self.combined_rotary.update_base(new_base)\n",
696
- " for block in self.blocks:\n",
697
- " if isinstance(block.attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
698
- " block.attn.update_base(new_base)\n",
699
- " if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
700
- " block.cross_attn.update_base(new_base)\n",
701
- "\n",
702
- " def forward(self, x):\n",
703
- " if self.checkpointing:\n",
704
- " x = checkpoint(self._conv_forward, x)\n",
705
- " else:\n",
706
- " x = self._conv_forward(x)\n",
707
- "\n",
708
- " for block in self.blocks:\n",
709
- " if self.checkpointing:\n",
710
- " x = checkpoint(block, x)\n",
711
- " else:\n",
712
- " x = block(x)\n",
713
- "\n",
714
- " x = self.ln_post(x)\n",
715
- " return x\n",
716
- "\n",
717
- " def _conv_forward(self, x):\n",
718
- " x = F.gelu(self.conv1(x))\n",
719
- " x = F.gelu(self.conv2(x))\n",
720
- " x = x.permute(0, 2, 1)\n",
721
- "\n",
722
- " x = self.combined_rotary(x)\n",
723
- "\n",
724
- " pos_emb = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)\n",
725
- " x = x + pos_emb\n",
726
- " return x\n",
727
- "\n",
728
- "class TextDecoder(nn.Module):\n",
729
- " def __init__(self, vocab_size, n_ctx, n_state, n_head, n_layer, max_rel_dist = 1, cross_attention=True, checkpointing=False, base=10000):\n",
730
- " super().__init__()\n",
731
- " self.token_embedding = nn.Embedding(vocab_size, n_state)\n",
732
- " self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)\n",
733
- " self.checkpointing = checkpointing\n",
734
- " self.n_head = n_head\n",
735
- " self.h_dim = n_state // n_head\n",
736
- " \n",
737
- " self.combined_rotary = CombinedRotaryEmbedding(\n",
738
- " n_state=n_state,\n",
739
- " n_head=n_head,\n",
740
- " num_rotations=self.h_dim // 2, \n",
741
- " base=base,\n",
742
- " checkpointing=False \n",
743
- " )\n",
744
- "\n",
745
- " self.blocks = nn.ModuleList([\n",
746
- " ResidualAttentionBlock(n_state, n_head, max_rel_dist, cross_attention, checkpointing=checkpointing)\n",
747
- " for _ in range(n_layer)\n",
748
- " ])\n",
749
- " self.ln = LayerNorm(n_state)\n",
750
- " mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)\n",
751
- " self.register_buffer(\"mask\", mask, persistent=False)\n",
752
- "\n",
753
- " def update_base(self, new_base):\n",
754
- " self.combined_rotary.update_base(new_base)\n",
755
- " for block in self.blocks:\n",
756
- " if isinstance(block.attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
757
- " block.attn.update_base(new_base)\n",
758
- " if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
759
- " block.cross_attn.update_base(new_base)\n",
760
- "\n",
761
- " def forward(self, x, xa, kv_cache = None):\n",
762
- " if self.checkpointing:\n",
763
- " x = checkpoint(self._embedding_forward, x, xa, kv_cache)\n",
764
- " else:\n",
765
- " x = self._embedding_forward(x, xa, kv_cache)\n",
766
- "\n",
767
- " for block in self.blocks:\n",
768
- " if self.checkpointing:\n",
769
- " x = checkpoint(block, x, xa, self.mask, kv_cache)\n",
770
- " else:\n",
771
- " x = block(x, xa, self.mask, kv_cache)\n",
772
- "\n",
773
- " x = self.ln(x)\n",
774
- " logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()\n",
775
- "\n",
776
- " return logits\n",
777
- "\n",
778
- " def _embedding_forward(self, x, xa, kv_cache):\n",
779
- " offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0\n",
780
- " positions = torch.arange(x.shape[1], device=x.device) + offset\n",
781
- " pos_emb = self.positional_embedding(positions).unsqueeze(0)\n",
782
- "\n",
783
- " x = self.token_embedding(x) + pos_emb\n",
784
- " x = x.to(xa.dtype)\n",
785
- "\n",
786
- " batch_size, seq_length, embedding_dim = x.shape\n",
787
- " num_heads = self.n_head\n",
788
- " head_dim = embedding_dim // num_heads\n",
789
- " x = x.view(batch_size, seq_length, num_heads, head_dim)\n",
790
- "\n",
791
- " x = self.combined_rotary(x)\n",
792
- "\n",
793
- " x = x.view(batch_size, seq_length, embedding_dim)\n",
794
- " return x\n",
795
- " \n",
796
- "class Echo(WhisperPreTrainedModel, PyTorchModelHubMixin):\n",
797
- " config_class = WhisperConfig\n",
798
- "\n",
799
- " def __init__(self, config: WhisperConfig):\n",
800
- " super().__init__(config)\n",
801
- " self.config = config\n",
802
- "\n",
803
- " self.n_mels = self.config.num_mel_bins\n",
804
- " self.n_audio_ctx = self.config.max_source_positions\n",
805
- " self.n_audio_state = self.config.d_model\n",
806
- " self.n_audio_head = self.config.encoder_attention_heads\n",
807
- " self.n_audio_layer = self.config.encoder_layers\n",
808
- " self.vocab_size = self.config.vocab_size\n",
809
- " self.n_text_ctx = self.config.max_target_positions\n",
810
- " self.n_text_state = self.config.d_model\n",
811
- " self.n_text_head = self.config.decoder_attention_heads\n",
812
- " self.n_text_layer = self.config.decoder_layers\n",
813
- " self.checkpointing = self.config.checkpointing\n",
814
- " self.max_rel_dist = self.config.max_rel_dist\n",
815
- " self.cross_attention = self.config.cross_attention\n",
816
- " self.base = self.config.base\n",
817
- "\n",
818
- " self.encoder = AudioEncoder(\n",
819
- " self.config.n_mels,\n",
820
- " self.config.n_audio_ctx,\n",
821
- " self.config.n_audio_state,\n",
822
- " self.config.n_audio_head,\n",
823
- " self.config.n_audio_layer,\n",
824
- " self.config.checkpointing,\n",
825
- " self.config.max_rel_dist,\n",
826
- " self.config.cross_attention,\n",
827
- " self.config.base,\n",
828
- " )\n",
829
- " self.decoder = TextDecoder(\n",
830
- " self.config.vocab_size,\n",
831
- " self.config.n_text_ctx,\n",
832
- " self.config.n_text_state,\n",
833
- " self.config.n_text_head,\n",
834
- " self.config.n_text_layer,\n",
835
- " self.config.checkpointing,\n",
836
- " self.config.max_rel_dist,\n",
837
- " self.config.cross_attention,\n",
838
- " self.config.base,\n",
839
- " )\n",
840
- "\n",
841
- " all_heads = torch.zeros(self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool)\n",
842
- " all_heads[self.config.n_text_layer // 2:] = True\n",
843
- " self.register_buffer(\"alignment_heads\", all_heads.to_sparse(), persistent=False)\n",
844
- "\n",
845
- " self.best_loss = float('inf')\n",
846
- " self.base = 10000 \n",
847
- "\n",
848
- " def update_base(self, new_base):\n",
849
- " self.encoder.combined_rotary.update_base(new_base)\n",
850
- " self.decoder.combined_rotary.update_base(new_base)\n",
851
- "\n",
852
- " for name, module in self.encoder.named_modules():\n",
853
- " if isinstance(module, (MultiHeadAttention, CombinedRotaryEmbedding)):\n",
854
- " module.update_base(new_base)\n",
855
- "\n",
856
- " for name, module in self.decoder.named_modules():\n",
857
- " if isinstance(module, (MultiHeadAttention, CombinedRotaryEmbedding)):\n",
858
- " module.update_base(new_base)\n",
859
- "\n",
860
- " def adjust_base(self, loss, factor=1.05):\n",
861
- " if loss < self.best_loss:\n",
862
- " new_base = self.base * factor\n",
863
- " else:\n",
864
- " new_base = self.base / factor\n",
865
- "\n",
866
- " self.update_base(new_base)\n",
867
- " self.best_loss = loss\n",
868
- " # print(f\"Adjusted base: {new_base}\")\n",
869
- "\n",
870
- " @staticmethod\n",
871
- " def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) -> torch.Tensor:\n",
872
- " shifted_input_ids = input_ids.new_zeros(input_ids.shape)\n",
873
- " shifted_input_ids[:, 1:] = input_ids[:, :-1]\n",
874
- " shifted_input_ids[:, 0] = decoder_start_token_id\n",
875
- " shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n",
876
- " return shifted_input_ids\n",
877
- "\n",
878
- " def forward(self, input_features, labels=None, dec_input_ids=None):\n",
879
- " if labels is not None:\n",
880
- " if dec_input_ids is None:\n",
881
- " dec_input_ids = self.shift_tokens_right(\n",
882
- " labels, self.config.pad_token_id, self.config.decoder_start_token_id\n",
883
- " )\n",
884
- "\n",
885
- " encoded_features = self.encoder(input_features).to(device)\n",
886
- " logits = self.decoder(dec_input_ids, encoded_features)\n",
887
- "\n",
888
- " loss = None\n",
889
- " if labels is not None:\n",
890
- " loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) \n",
891
- " labels = labels.to(logits.device).long()\n",
892
- " loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))\n",
893
- "\n",
894
- " self.adjust_base(loss.item())\n",
895
- "\n",
896
- " return {\n",
897
- " \"loss\": loss,\n",
898
- " \"logits\": logits,\n",
899
- " }\n",
900
- "\n",
901
- " def _initialize_weights(self):\n",
902
- " nn.init.normal_(self.decoder.token_embedding.weight, mean=0.0, std=self.config.init_std)\n",
903
- " if hasattr(self.decoder.positional_embedding, 'weight'):\n",
904
- " nn.init.normal_(self.decoder.positional_embedding.weight, mean=0.0, std=self.config.init_std)\n",
905
- " for block in self.decoder.blocks:\n",
906
- " for layer in block.children():\n",
907
- " if isinstance(layer, nn.Linear):\n",
908
- " nn.init.xavier_normal_(layer.weight)\n",
909
- " if layer.bias is not None:\n",
910
- " nn.init.zeros_(layer.bias)\n",
911
- "\n",
912
- " nn.init.constant_(self.decoder.ln.gamma, 1)\n",
913
- " if self.decoder.ln.beta is not None:\n",
914
- " nn.init.constant_(self.decoder.ln.beta, 0)\n",
915
- "\n",
916
- " nn.init.xavier_normal_(self.encoder.conv1.weight)\n",
917
- " if self.encoder.conv1.bias is not None:\n",
918
- " nn.init.zeros_(self.encoder.conv1.bias)\n",
919
- "\n",
920
- " nn.init.kaiming_normal_(self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')\n",
921
- " if self.encoder.conv2.bias is not None:\n",
922
- " nn.init.zeros_(self.encoder.conv2.bias)\n",
923
- "\n",
924
- " nn.init.constant_(self.encoder.ln_post.gamma, 1)\n",
925
- " if self.encoder.ln_post.beta is not None:\n",
926
- " nn.init.constant_(self.encoder.ln_post.beta, 0)\n",
927
- " \n",
928
- " def apply_initialization(self):\n",
929
- " self._initialize_weights()\n",
930
- "\n",
931
- " def set_alignment_heads(self, dump: bytes):\n",
932
- " array = np.frombuffer(\n",
933
- " gzip.decompress(base64.b85decode(dump)), dtype=bool\n",
934
- " ).copy()\n",
935
- " mask = torch.from_numpy(array).reshape(\n",
936
- " self.config.n_text_layer, self.config.n_text_head\n",
937
- " )\n",
938
- " self.register_buffer(\"alignment_heads\", mask.to_sparse(), persistent=False)\n",
939
- "\n",
940
- " def embed_audio(self, mel):\n",
941
- " return self.encoder(mel)\n",
942
- "\n",
943
- " def logits(self, labels, input_features):\n",
944
- " return self.decoder(labels, input_features)\n",
945
- "\n",
946
- " @property\n",
947
- " def device(self):\n",
948
- " return next(self.parameters()).device\n",
949
- "\n",
950
- " @property\n",
951
- " def is_multilingual(self):\n",
952
- " return self.config.vocab_size >= len(tokenizer)\n",
953
- "\n",
954
- " @property\n",
955
- " def num_languages(self):\n",
956
- " return self.config.vocab_size - (len(tokenizer)-100) - int(self.is_multilingual)\n",
957
- "\n",
958
- " def install_kv_cache_hooks(self, cache = None):\n",
959
- " cache = {**cache} if cache is not None else {}\n",
960
- " hooks = []\n",
961
- "\n",
962
- " def save_to_cache(module, _, output):\n",
963
- " if module not in cache or output.shape[1] > self.config.n_text_ctx:\n",
964
- " cache[module] = output\n",
965
- " else:\n",
966
- " cache[module] = torch.cat([cache[module], output], dim=1).detach()\n",
967
- " return cache[module]\n",
968
- "\n",
969
- " def install_hooks(layer: nn.Module):\n",
970
- " if isinstance(layer, MultiHeadAttention):\n",
971
- " hooks.append(layer.key.register_forward_hook(save_to_cache))\n",
972
- " hooks.append(layer.value.register_forward_hook(save_to_cache))\n",
973
- "\n",
974
- " self.decoder.apply(install_hooks)\n",
975
- " return cache, hooks\n",
976
- "\n",
977
- " detect_language = detect_language_function\n",
978
- " transcribe = transcribe_function\n",
979
- " decode = decode_function\n",
980
- "\n",
981
- " def get_encoder(self):\n",
982
- " return self.encoder\n",
983
- "\n",
984
- " def prepare_inputs_for_generation(self, input_ids, **kwargs):\n",
985
- " return {'input_features': input_ids}\n",
986
- "\n",
987
- " def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):\n",
988
- " return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id\n",
989
- "\n",
990
- " def can_generate(self):\n",
991
- " return True\n",
992
- " \n",
993
- " def generate(self, inputs, **kwargs):\n",
994
- " encoder_outputs = self.encoder(inputs)\n",
995
- " decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)\n",
996
- " outputs = self.decoder(decoder_input_ids, encoder_outputs)\n",
997
- " return outputs.argmax(dim=-1)\n",
998
- "\n",
999
- " def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):\n",
1000
- " if not self.supports_gradient_checkpointing:\n",
1001
- " raise ValueError(f\"{self.__class__.__name__} does not support gradient checkpointing.\")\n",
1002
- "\n",
1003
- " if gradient_checkpointing_kwargs is None:\n",
1004
- " gradient_checkpointing_kwargs = {\"use_reentrant\": True}\n",
1005
- "\n",
1006
- " gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)\n",
1007
- "\n",
1008
- " _is_using_old_format = \"value\" in inspect.signature(self._set_gradient_checkpointing).parameters\n",
1009
- "\n",
1010
- " if not _is_using_old_format:\n",
1011
- " self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)\n",
1012
- " else:\n",
1013
- " self.apply(partial(self._set_gradient_checkpointing, value=True))\n",
1014
- "\n",
1015
- " if getattr(self, \"_hf_peft_config_loaded\", False):\n",
1016
- " self.enable_input_require_grads()\n",
1017
- "\n",
1018
- "feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-small\", feature_size=128, n_fft=1024, hop_length=256, sampling_rate=16000)#, win_length=2048, sampling_rate=16000, mel_fmin=0.0, mel_fmax=8000.0)\n",
1019
- "tokenizer = WhisperTokenizer.from_pretrained(\"openai/whisper-small\", language=\"japanese\", task=\"transcribe\")\n",
1020
- "processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", language=\"japanese\", task=\"transcribe\")\n",
1021
- "\n",
1022
- "config = WhisperConfig(\n",
1023
- " n_mels=128,\n",
1024
- " n_audio_ctx=1500,\n",
1025
- " n_audio_state=1024,\n",
1026
- " n_audio_head=16,\n",
1027
- " n_audio_layer=24,\n",
1028
- " vocab_size=51865,\n",
1029
- " n_text_ctx=448,\n",
1030
- " n_text_state=1024,\n",
1031
- " n_text_head=16,\n",
1032
- " n_text_layer=20,\n",
1033
- " max_rel_dist=200,\n",
1034
- " cross_attention=True,\n",
1035
- " checkpointing=True,\n",
1036
- " base=10000,\n",
1037
- " bos_token_id = 50257,\n",
1038
- " eos_token_id = 50257,\n",
1039
- " pad_token_id = 50257,\n",
1040
- " decoder_start_token_id = 50258,\n",
1041
- " is_encoder_decoder = True,\n",
1042
- " init_std=0.02,\n",
1043
- " )\n",
1044
- "\n",
1045
- "model = Echo(config).to(device)\n",
1046
- "model.apply_initialization()\n",
1047
- "model.save_pretrained(\"./models/echo\")\n",
1048
- "# model = Echo.from_pretrained(\"./models/echo2\")\n",
1049
- "\n"
1050
- ]
1051
- },
1052
- {
1053
- "cell_type": "code",
1054
- "execution_count": null,
1055
- "metadata": {},
1056
- "outputs": [],
1057
- "source": [
1058
- "raw_datasets = IterableDatasetDict()\n",
1059
- "\n",
1060
- "raw_datasets[\"train\"] = load_dataset(\"mozilla-foundation/common_voice_17_0\", \"ja\", split=\"train\", trust_remote_code=True, streaming=True) # set split=\"train+validation\" for low-resource\n",
1061
- "raw_datasets[\"test\"] = load_dataset(\"mozilla-foundation/common_voice_17_0\", \"ja\", split=\"test\", trust_remote_code=True, streaming=True).take(100)\n",
1062
- "\n",
1063
- "raw_datasets = raw_datasets.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
1064
- "\n",
1065
- "tokenizer = WhisperTokenizer.from_pretrained(\"openai/whisper-small\", language=\"japanese\", task=\"transcribe\")\n",
1066
- "processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", language=\"japanese\", task=\"transcribe\")\n",
1067
- "\n",
1068
- "def prepare_dataset(batch):\n",
1069
- " audio = batch[\"audio\"]\n",
1070
- " batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
1071
- " transcription = batch[\"sentence\"]\n",
1072
- " batch[\"labels\"] = tokenizer(transcription).input_ids\n",
1073
- " return batch\n",
1074
- "\n",
1075
- "vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format(\"torch\")\n",
1076
- "\n",
1077
- "\n",
1078
- "@dataclass\n",
1079
- "class DataCollatorSpeechSeq2SeqWithPadding:\n",
1080
- " processor: Any\n",
1081
- "\n",
1082
- " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
1083
- " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
1084
- " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
1085
- " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
1086
- " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
1087
- " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
1088
- " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
1089
- " labels = labels[:, 1:]\n",
1090
- " batch[\"labels\"] = labels\n",
1091
- " return batch\n",
1092
- " \n",
1093
- "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\n",
1094
- "\n",
1095
- "metric = evaluate.load(\"cer\")\n",
1096
- "\n",
1097
- "def compute_metrics(pred):\n",
1098
- " pred_logits = pred.predictions\n",
1099
- " label_ids = pred.label_ids\n",
1100
- "\n",
1101
- " if isinstance(pred_logits, tuple):\n",
1102
- " pred_ids = pred_logits[0]\n",
1103
- " else:\n",
1104
- " pred_ids = pred_logits\n",
1105
- " if pred_ids.ndim == 3:\n",
1106
- " pred_ids = np.argmax(pred_ids, axis=-1)\n",
1107
- "\n",
1108
- " label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
1109
- " pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
1110
- " label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
1111
- " cer = 100 * metric.compute(predictions=pred_str, references=label_str)\n",
1112
- " return {\"cer\": cer}\n",
1113
- "\n",
1114
- "training_args = Seq2SeqTrainingArguments(\n",
1115
- " output_dir=\"./test11\", \n",
1116
- " per_device_train_batch_size=1,\n",
1117
- " per_device_eval_batch_size=1,\n",
1118
- " gradient_accumulation_steps=1,\n",
1119
- " eval_accumulation_steps=1,\n",
1120
- " # num_train_epochs=1,\n",
1121
- " tf32=True,\n",
1122
- " bf16=True,\n",
1123
- " learning_rate=1e-5,\n",
1124
- " warmup_steps=500,\n",
1125
- " evaluation_strategy=\"steps\",\n",
1126
- " # predict_with_generate=True,\n",
1127
- " # generation_max_length=225,\n",
1128
- " max_steps=40000,\n",
1129
- " save_steps=1000,\n",
1130
- " eval_steps=100,\n",
1131
- " logging_steps=5,\n",
1132
- " report_to=[\"tensorboard\"],\n",
1133
- " load_best_model_at_end=True,\n",
1134
- " metric_for_best_model=\"loss\",\n",
1135
- " greater_is_better=False,\n",
1136
- " push_to_hub=False,\n",
1137
- " optim=\"adafactor\",\n",
1138
- " weight_decay=0.0025,\n",
1139
- " disable_tqdm=False,\n",
1140
- " save_total_limit=2,\n",
1141
- " torch_empty_cache_steps=10,\n",
1142
- ")\n",
1143
- "\n",
1144
- "trainer = Seq2SeqTrainer(\n",
1145
- " args=training_args,\n",
1146
- " model=model,\n",
1147
- " train_dataset=vectorized_datasets[\"train\"],\n",
1148
- " eval_dataset=vectorized_datasets[\"test\"],\n",
1149
- " data_collator=data_collator,\n",
1150
- " compute_metrics=compute_metrics,\n",
1151
- " tokenizer=processor.feature_extractor,\n",
1152
- ")\n",
1153
- "\n",
1154
- "# trainer.add_callback(CustomCallback)\n",
1155
- "\n",
1156
- "trainer.train()\n",
1157
- "\n",
1158
- "import tensorboard"
1159
- ]
1160
- }
1161
- ],
1162
- "metadata": {
1163
- "kernelspec": {
1164
- "display_name": "Python 3",
1165
- "language": "python",
1166
- "name": "python3"
1167
- },
1168
- "language_info": {
1169
- "codemirror_mode": {
1170
- "name": "ipython",
1171
- "version": 3
1172
- },
1173
- "file_extension": ".py",
1174
- "mimetype": "text/x-python",
1175
- "name": "python",
1176
- "nbconvert_exporter": "python",
1177
- "pygments_lexer": "ipython3",
1178
- "version": "3.10.0"
1179
- }
1180
- },
1181
- "nbformat": 4,
1182
- "nbformat_minor": 2
1183
- }