Delete model.ipynb
Browse files- model.ipynb +0 -1183
model.ipynb
DELETED
@@ -1,1183 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"import base64, gzip, evaluate, math, os, sys, time\n",
|
10 |
-
"import gzip, neologdn\n",
|
11 |
-
"import collections\n",
|
12 |
-
"import copy\n",
|
13 |
-
"import functools\n",
|
14 |
-
"from functools import partial, wraps\n",
|
15 |
-
"from threading import Thread\n",
|
16 |
-
"import gc\n",
|
17 |
-
"import importlib.metadata\n",
|
18 |
-
"import inspect\n",
|
19 |
-
"import itertools\n",
|
20 |
-
"import torch\n",
|
21 |
-
"from torch import amp, Tensor, optim\n",
|
22 |
-
"from torch.utils.checkpoint import checkpoint\n",
|
23 |
-
"from contextlib import contextmanager\n",
|
24 |
-
"from dataclasses import dataclass\n",
|
25 |
-
"from transformers.models.whisper.modeling_whisper import WhisperPreTrainedModel\n",
|
26 |
-
"from transformers.models.whisper.generation_whisper import WhisperGenerationMixin\n",
|
27 |
-
"from transformers.optimization import Adafactor, AdafactorSchedule\n",
|
28 |
-
"from huggingface_hub import PyTorchModelHubMixin\n",
|
29 |
-
"from datasets import IterableDatasetDict, Audio, load_dataset, load_from_disk\n",
|
30 |
-
"import numpy as np\n",
|
31 |
-
"import torch, transformers, warnings\n",
|
32 |
-
"from typing import Dict, Iterable, Optional, Tuple, Union, List, Any, Type\n",
|
33 |
-
"import torch.nn.functional as F\n",
|
34 |
-
"from torch import Tensor, nn\n",
|
35 |
-
"import torchaudio, torchaudio.transforms as T\n",
|
36 |
-
"from transformers import Seq2SeqTrainer, TrainerCallback, Seq2SeqTrainingArguments, WhisperTokenizer, WhisperForConditionalGeneration, WhisperConfig, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration\n",
|
37 |
-
"from whisper.decoding import decode as decode_function\n",
|
38 |
-
"from whisper.decoding import detect_language as detect_language_function\n",
|
39 |
-
"from whisper.transcribe import transcribe as transcribe_function\n",
|
40 |
-
"\n",
|
41 |
-
"try:\n",
|
42 |
-
" from torch.nn.functional import scaled_dot_product_attention\n",
|
43 |
-
"\n",
|
44 |
-
" SDPA_AVAILABLE = True\n",
|
45 |
-
"except (ImportError, RuntimeError, OSError):\n",
|
46 |
-
" scaled_dot_product_attention = None\n",
|
47 |
-
" SDPA_AVAILABLE = False\n",
|
48 |
-
"\n",
|
49 |
-
"transformers.utils.logging.set_verbosity_error()\n",
|
50 |
-
"warnings.filterwarnings(action=\"ignore\")\n",
|
51 |
-
"warnings.warn = lambda *args,**kwargs: None\n",
|
52 |
-
"\n",
|
53 |
-
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
54 |
-
"\n",
|
55 |
-
"dtype = torch.float32\n"
|
56 |
-
]
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"cell_type": "code",
|
60 |
-
"execution_count": null,
|
61 |
-
"metadata": {},
|
62 |
-
"outputs": [],
|
63 |
-
"source": [
|
64 |
-
"### Model ###\n",
|
65 |
-
"\n",
|
66 |
-
"\n",
|
67 |
-
"\n",
|
68 |
-
"class LayerNorm(nn.Module):\n",
|
69 |
-
" def __init__(self, num_features, eps=1e-6):\n",
|
70 |
-
" super(LayerNorm, self).__init__()\n",
|
71 |
-
" self.gamma = nn.Parameter(torch.ones(num_features))\n",
|
72 |
-
" self.beta = nn.Parameter(torch.zeros(num_features))\n",
|
73 |
-
" self.eps = eps\n",
|
74 |
-
"\n",
|
75 |
-
" def forward(self, x):\n",
|
76 |
-
" mean = x.mean(dim=-1, keepdim=True)\n",
|
77 |
-
" std = x.std(dim=-1, keepdim=True)\n",
|
78 |
-
" x = (x - mean) / (std + self.eps)\n",
|
79 |
-
" return self.gamma * x + self.beta\n",
|
80 |
-
"\n",
|
81 |
-
"class Linear(nn.Module):\n",
|
82 |
-
" def __init__(self, in_features: int, out_features: int, dropout_rate = 0.001, use_batchnorm: bool = True, activation: str = 'relu'):\n",
|
83 |
-
" super(Linear, self).__init__()\n",
|
84 |
-
" self.linear = nn.Linear(in_features, out_features)\n",
|
85 |
-
" self.dropout = nn.Dropout(dropout_rate)\n",
|
86 |
-
" self.use_batchnorm = use_batchnorm\n",
|
87 |
-
" self.activation = activation\n",
|
88 |
-
"\n",
|
89 |
-
" if self.use_batchnorm:\n",
|
90 |
-
" self.batchnorm = nn.BatchNorm1d(out_features)\n",
|
91 |
-
" self.reset_parameters()\n",
|
92 |
-
"\n",
|
93 |
-
" def reset_parameters(self):\n",
|
94 |
-
" nn.init.kaiming_uniform_(self.linear.weight, nonlinearity=self.activation)\n",
|
95 |
-
" if self.linear.bias is not None:\n",
|
96 |
-
" nn.init.zeros_(self.linear.bias)\n",
|
97 |
-
"\n",
|
98 |
-
" def forward(self, x):\n",
|
99 |
-
" batch_size, seq_len, _ = x.size()\n",
|
100 |
-
" x = x.view(-1, x.size(-1)) \n",
|
101 |
-
" x = self.linear(x)\n",
|
102 |
-
"\n",
|
103 |
-
" if self.use_batchnorm:\n",
|
104 |
-
" x = self.batchnorm(x)\n",
|
105 |
-
"\n",
|
106 |
-
" x = self.apply_activation(x)\n",
|
107 |
-
" x = self.dropout(x)\n",
|
108 |
-
" x = x.view(batch_size, seq_len, -1) \n",
|
109 |
-
" \n",
|
110 |
-
" return x\n",
|
111 |
-
"\n",
|
112 |
-
" def apply_activation(self, x):\n",
|
113 |
-
" if self.activation == 'relu':\n",
|
114 |
-
" return F.relu(x)\n",
|
115 |
-
" elif self.activation == 'tanh':\n",
|
116 |
-
" return torch.tanh(x)\n",
|
117 |
-
" elif self.activation == 'sigmoid':\n",
|
118 |
-
" return torch.sigmoid(x)\n",
|
119 |
-
" else:\n",
|
120 |
-
" raise ValueError(f'Unsupported activation function: {self.activation}')\n",
|
121 |
-
"\n",
|
122 |
-
"class Conv1d(nn.Conv1d):\n",
|
123 |
-
" def __init__(self, *args, **kwargs):\n",
|
124 |
-
" super().__init__(*args, **kwargs)\n",
|
125 |
-
" self.reset_parameters()\n",
|
126 |
-
"\n",
|
127 |
-
" def reset_parameters(self):\n",
|
128 |
-
" nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')\n",
|
129 |
-
" if self.bias is not None:\n",
|
130 |
-
" nn.init.zeros_(self.bias)\n",
|
131 |
-
"\n",
|
132 |
-
" def _conv_forward(self, x, weight, bias) -> Tensor:\n",
|
133 |
-
" weight = self.weight.to(x.dtype)\n",
|
134 |
-
" bias = None if self.bias is None else self.bias.to(x.dtype)\n",
|
135 |
-
" return super()._conv_forward(x, weight, bias)\n",
|
136 |
-
"\n",
|
137 |
-
"class BiasedCrossAttention(nn.Module):\n",
|
138 |
-
" def __init__(self, n_state, n_head, dropout_rate=0.001):\n",
|
139 |
-
" super().__init__()\n",
|
140 |
-
" self.n_head = n_head\n",
|
141 |
-
" self.n_state = n_state\n",
|
142 |
-
" self.head_dim = n_state // n_head\n",
|
143 |
-
"\n",
|
144 |
-
" self.query = nn.Linear(n_state, n_state)\n",
|
145 |
-
" self.key = nn.Linear(n_state, n_state, bias=False)\n",
|
146 |
-
" self.value = nn.Linear(n_state, n_state)\n",
|
147 |
-
" self.out = nn.Linear(n_state, n_state)\n",
|
148 |
-
"\n",
|
149 |
-
" self.bias = nn.Parameter(torch.zeros(n_head, 1, self.head_dim))\n",
|
150 |
-
" self.dropout = nn.Dropout(dropout_rate)\n",
|
151 |
-
" self.norm = LayerNorm(n_state)\n",
|
152 |
-
" \n",
|
153 |
-
" def forward(self, q, k, v, mask=None):\n",
|
154 |
-
" batch_size, seq_length, _ = q.size()\n",
|
155 |
-
"\n",
|
156 |
-
" q = self.query(q).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
|
157 |
-
" k = self.key(k).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
|
158 |
-
" v = self.value(v).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
|
159 |
-
"\n",
|
160 |
-
" qk = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.bias\n",
|
161 |
-
" if mask is not None:\n",
|
162 |
-
" qk = qk.masked_fill(mask == 0, float('-inf'))\n",
|
163 |
-
"\n",
|
164 |
-
" w = F.softmax(qk, dim=-1)\n",
|
165 |
-
" w = self.dropout(w)\n",
|
166 |
-
"\n",
|
167 |
-
" out = (w @ v).transpose(1, 2).contiguous().view(batch_size, seq_length, -1)\n",
|
168 |
-
" out = self.norm(self.out(out) + q.view(batch_size, seq_length, -1))\n",
|
169 |
-
" return out\n",
|
170 |
-
"\n",
|
171 |
-
"class DynamicConvAttention(nn.Module):\n",
|
172 |
-
" def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.001):\n",
|
173 |
-
" super().__init__()\n",
|
174 |
-
" self.n_state = n_state\n",
|
175 |
-
" self.n_head = n_head\n",
|
176 |
-
" self.kernel_size = kernel_size\n",
|
177 |
-
"\n",
|
178 |
-
" self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size // 2, groups=n_head)\n",
|
179 |
-
" self.dropout = nn.Dropout(dropout_rate)\n",
|
180 |
-
"\n",
|
181 |
-
" self.query = nn.Linear(n_state, n_state)\n",
|
182 |
-
" self.key = nn.Linear(n_state, n_state, bias=False)\n",
|
183 |
-
" self.value = nn.Linear(n_state, n_state)\n",
|
184 |
-
" self.out_proj = nn.Linear(n_state, n_state)\n",
|
185 |
-
"\n",
|
186 |
-
" self.norm = LayerNorm(n_state)\n",
|
187 |
-
"\n",
|
188 |
-
" def forward(self, x):\n",
|
189 |
-
" batch_size, seq_len, embed_dim = x.size()\n",
|
190 |
-
" if embed_dim != self.n_state:\n",
|
191 |
-
" raise ValueError(f\"Expected embed_dim of {self.n_state}, but got {embed_dim}\")\n",
|
192 |
-
"\n",
|
193 |
-
" q = self.query(x)\n",
|
194 |
-
" k = self.key(x)\n",
|
195 |
-
" v = self.value(x)\n",
|
196 |
-
"\n",
|
197 |
-
" x = x.permute(0, 2, 1)\n",
|
198 |
-
" conv_out = self.conv(x)\n",
|
199 |
-
" conv_out = conv_out.permute(0, 2, 1)\n",
|
200 |
-
" conv_out = self.norm(conv_out)\n",
|
201 |
-
" conv_out = self.dropout(conv_out)\n",
|
202 |
-
"\n",
|
203 |
-
" attention_out = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (self.n_state ** 0.5), dim=-1)\n",
|
204 |
-
" attention_out = torch.matmul(attention_out, v)\n",
|
205 |
-
" \n",
|
206 |
-
" combined_out = conv_out + attention_out\n",
|
207 |
-
" combined_out = self.norm(combined_out)\n",
|
208 |
-
" \n",
|
209 |
-
" return self.out_proj(self.dropout(combined_out)) + x.permute(0, 2, 1)\n",
|
210 |
-
"\n",
|
211 |
-
"class HybridAttention(nn.Module):\n",
|
212 |
-
" def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.001):\n",
|
213 |
-
" super().__init__()\n",
|
214 |
-
" self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)\n",
|
215 |
-
" self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)\n",
|
216 |
-
" self.ln_local = LayerNorm(n_state)\n",
|
217 |
-
" self.ln_global = LayerNorm(n_state)\n",
|
218 |
-
"\n",
|
219 |
-
" self.dropout = nn.Dropout(dropout_rate)\n",
|
220 |
-
" self.window_size = window_size\n",
|
221 |
-
"\n",
|
222 |
-
" def forward(self, x):\n",
|
223 |
-
" x_local = self.ln_local(x)\n",
|
224 |
-
" x_global = self.ln_global(x)\n",
|
225 |
-
" x_local = x_local.permute(1, 0, 2)\n",
|
226 |
-
" x_global = x_global.permute(1, 0, 2)\n",
|
227 |
-
" local_out = self.sliding_window_attention(x_local)\n",
|
228 |
-
" global_out, _ = self.global_attn(x_global, x_global, x_global)\n",
|
229 |
-
" combined_out = local_out + global_out\n",
|
230 |
-
" combined_out = combined_out.permute(1, 0, 2)\n",
|
231 |
-
" return self.dropout(combined_out)\n",
|
232 |
-
"\n",
|
233 |
-
" def sliding_window_attention(self, x):\n",
|
234 |
-
" batch_size, seq_len, n_state = x.size()\n",
|
235 |
-
" window_size = min(self.window_size, max(1, seq_len // 4))\n",
|
236 |
-
" output = torch.zeros_like(x, device=x.device, dtype=x.dtype)\n",
|
237 |
-
"\n",
|
238 |
-
" for i in range(0, seq_len, window_size):\n",
|
239 |
-
" end = min(i + window_size, seq_len)\n",
|
240 |
-
" query = x[i:end, :, :]\n",
|
241 |
-
" start = max(0, i - window_size)\n",
|
242 |
-
" key = x[start:end, :, :]\n",
|
243 |
-
" value = x[start:end, :, :]\n",
|
244 |
-
" attn_output, _ = self.local_attn(query, key, value)\n",
|
245 |
-
" output[i:end, :, :] = attn_output[:end - i, :, :]\n",
|
246 |
-
"\n",
|
247 |
-
" return output\n",
|
248 |
-
"\n",
|
249 |
-
"# def givens_rotation_matrix(n_state, i, j, theta):\n",
|
250 |
-
"# G = torch.eye(n_state)\n",
|
251 |
-
"# G[i, i] = math.cos(theta)\n",
|
252 |
-
"# G[i, j] = -math.sin(theta)\n",
|
253 |
-
"# G[j, i] = math.sin(theta)\n",
|
254 |
-
"# G[j, j] = math.cos(theta)\n",
|
255 |
-
"# return G\n",
|
256 |
-
"\n",
|
257 |
-
"# class GivensRotations(nn.Module):\n",
|
258 |
-
"# def __init__(self, h_dim, num_rotations):\n",
|
259 |
-
"# super().__init__()\n",
|
260 |
-
"# self.h_dim = h_dim\n",
|
261 |
-
"# self.num_rotations = num_rotations\n",
|
262 |
-
"# self.thetas = nn.Parameter(torch.zeros(num_rotations))\n",
|
263 |
-
"\n",
|
264 |
-
"# def forward(self, x):\n",
|
265 |
-
"# if x.dim() != 4:\n",
|
266 |
-
"# raise ValueError(f\"Expected input tensor to be 4D, but got {x.dim()}D\")\n",
|
267 |
-
" \n",
|
268 |
-
"# batch_size, seq_len, n_head, h_dim = x.size()\n",
|
269 |
-
" \n",
|
270 |
-
"# if h_dim != self.h_dim:\n",
|
271 |
-
"# raise ValueError(f\"Expected h_dim of {self.h_dim}, but got {h_dim}\")\n",
|
272 |
-
" \n",
|
273 |
-
"# x = x.view(-1, h_dim) \n",
|
274 |
-
"# for k in range(self.num_rotations):\n",
|
275 |
-
"# i, j = k % self.h_dim, (k + 1) % self.h_dim\n",
|
276 |
-
"# G = givens_rotation_matrix(self.h_dim, i, j, self.thetas[k])\n",
|
277 |
-
"# x = torch.matmul(x, G.to(x.device))\n",
|
278 |
-
" \n",
|
279 |
-
"# x = x.view(batch_size, seq_len, n_head, h_dim) \n",
|
280 |
-
"# return x\n",
|
281 |
-
"## old\n",
|
282 |
-
"# class RotaryEmbeddingWithRotation(nn.Module):\n",
|
283 |
-
"# def __init__(self, n_state, n_head, checkpointing=False, base=10000):\n",
|
284 |
-
"# super().__init__()\n",
|
285 |
-
"# self.n_state = n_state\n",
|
286 |
-
"# self.n_head = n_head\n",
|
287 |
-
"# self.h_dim = n_state // n_head\n",
|
288 |
-
"# self.base = base\n",
|
289 |
-
"# self.checkpointing = checkpointing\n",
|
290 |
-
"\n",
|
291 |
-
"# self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))\n",
|
292 |
-
"# inv_freq = 1.0 / (base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n",
|
293 |
-
"# self.register_buffer('inv_freq', inv_freq)\n",
|
294 |
-
"\n",
|
295 |
-
"# def update_base(self, new_base):\n",
|
296 |
-
"# self.base = new_base\n",
|
297 |
-
"# inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n",
|
298 |
-
"# self.register_buffer('inv_freq', inv_freq)\n",
|
299 |
-
"\n",
|
300 |
-
"# def reset_parameters(self):\n",
|
301 |
-
"# nn.init.orthogonal_(self.rotation_matrix)\n",
|
302 |
-
"\n",
|
303 |
-
"# def forward(self, x):\n",
|
304 |
-
"# if self.checkpointing:\n",
|
305 |
-
"# return checkpoint(self._forward, x)\n",
|
306 |
-
"# else:\n",
|
307 |
-
"# return self._forward(x)\n",
|
308 |
-
"\n",
|
309 |
-
"# def _forward(self, x):\n",
|
310 |
-
"# if x.dim() == 3:\n",
|
311 |
-
"# batch_size, seq_len, n_state = x.size()\n",
|
312 |
-
"# elif x.dim() == 4:\n",
|
313 |
-
"# batch_size, seq_len, n_head, h_dim = x.size()\n",
|
314 |
-
"# n_state = n_head * h_dim\n",
|
315 |
-
"# x = x.view(batch_size, seq_len, n_state)\n",
|
316 |
-
"# else:\n",
|
317 |
-
"# raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n",
|
318 |
-
"\n",
|
319 |
-
"# if n_state != self.n_state:\n",
|
320 |
-
"# raise ValueError(f\"Expected n_state of {self.n_state}, but got {n_state}\")\n",
|
321 |
-
"\n",
|
322 |
-
"# x = x.reshape(batch_size, seq_len, self.n_head, self.h_dim)\n",
|
323 |
-
"# x = x.reshape(-1, self.h_dim)\n",
|
324 |
-
"# rotated_x = torch.matmul(x, self.rotation_matrix)\n",
|
325 |
-
"# rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_head, self.h_dim)\n",
|
326 |
-
"\n",
|
327 |
-
"# sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))\n",
|
328 |
-
"# sin = sinusoid_inp.sin()[None, :, None, :]\n",
|
329 |
-
"# cos = sinusoid_inp.cos()[None, :, None, :]\n",
|
330 |
-
"# x1, x2 = rotated_x[..., ::2], rotated_x[..., 1::2]\n",
|
331 |
-
"# rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n",
|
332 |
-
" \n",
|
333 |
-
"# rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_state)\n",
|
334 |
-
"# return rotated_x\n",
|
335 |
-
"\n",
|
336 |
-
"# ## new\n",
|
337 |
-
"# class CombinedRotaryEmbedding(nn.Module):\n",
|
338 |
-
"# def __init__(self, n_state, n_head, num_rotations, base=10000, checkpointing=False):\n",
|
339 |
-
"# super().__init__()\n",
|
340 |
-
"# self.n_state = n_state\n",
|
341 |
-
"# self.n_head = n_head\n",
|
342 |
-
"# self.h_dim = n_state // n_head\n",
|
343 |
-
"# self.num_rotations = num_rotations\n",
|
344 |
-
"# self.base = base\n",
|
345 |
-
"# self.checkpointing = checkpointing\n",
|
346 |
-
" \n",
|
347 |
-
"# self.thetas = nn.Parameter(torch.zeros(num_rotations))\n",
|
348 |
-
"# self.rotation_pairs = nn.Parameter(torch.rand(num_rotations, 2) * self.h_dim)\n",
|
349 |
-
"\n",
|
350 |
-
"# self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))\n",
|
351 |
-
" \n",
|
352 |
-
"# self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
|
353 |
-
" \n",
|
354 |
-
"# def givens_rotation_matrix(self, n_state, i, j, theta):\n",
|
355 |
-
"# G = torch.eye(n_state, device=theta.device)\n",
|
356 |
-
"# G[i, i] = math.cos(theta)\n",
|
357 |
-
"# G[i, j] = -math.sin(theta)\n",
|
358 |
-
"# G[j, i] = math.sin(theta)\n",
|
359 |
-
"# G[j, j] = math.cos(theta)\n",
|
360 |
-
"# return G\n",
|
361 |
-
" \n",
|
362 |
-
"# def update_base(self, new_base):\n",
|
363 |
-
"# self.base = new_base\n",
|
364 |
-
"# self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
|
365 |
-
" \n",
|
366 |
-
"# def reset_parameters(self):\n",
|
367 |
-
"# nn.init.orthogonal_(self.rotation_matrix)\n",
|
368 |
-
"# nn.init.zeros_(self.thetas)\n",
|
369 |
-
" \n",
|
370 |
-
"# def forward(self, x):\n",
|
371 |
-
"# if self.checkpointing:\n",
|
372 |
-
"# return checkpoint(self._forward, x)\n",
|
373 |
-
"# else:\n",
|
374 |
-
"# return self._forward(x)\n",
|
375 |
-
" \n",
|
376 |
-
"# def _forward(self, x):\n",
|
377 |
-
"# if x.dim() not in [3, 4]:\n",
|
378 |
-
"# raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n",
|
379 |
-
" \n",
|
380 |
-
"# if x.dim() == 3:\n",
|
381 |
-
"# batch_size, seq_len, n_state = x.size()\n",
|
382 |
-
"# x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
|
383 |
-
"# else:\n",
|
384 |
-
"# batch_size, seq_len, n_head, h_dim = x.size()\n",
|
385 |
-
"# if n_head != self.n_head or h_dim != self.h_dim:\n",
|
386 |
-
"# raise ValueError(f\"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}\")\n",
|
387 |
-
" \n",
|
388 |
-
"# x = x.reshape(-1, self.h_dim)\n",
|
389 |
-
" \n",
|
390 |
-
"# for k in range(self.num_rotations):\n",
|
391 |
-
"# i, j = self.rotation_pairs[k].long()\n",
|
392 |
-
"# theta = self.thetas[k]\n",
|
393 |
-
"# G = self.givens_rotation_matrix(self.h_dim, i, j, theta)\n",
|
394 |
-
"# x = torch.matmul(x, G)\n",
|
395 |
-
" \n",
|
396 |
-
"# x = torch.matmul(x, self.rotation_matrix)\n",
|
397 |
-
" \n",
|
398 |
-
"# x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
|
399 |
-
" \n",
|
400 |
-
"# sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq)\n",
|
401 |
-
"# sin = sinusoid_inp.sin()[None, :, None, :]\n",
|
402 |
-
"# cos = sinusoid_inp.cos()[None, :, None, :]\n",
|
403 |
-
" \n",
|
404 |
-
"# x1, x2 = x[..., ::2], x[..., 1::2]\n",
|
405 |
-
"# x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n",
|
406 |
-
" \n",
|
407 |
-
"# x = x.view(batch_size, seq_len, self.n_state)\n",
|
408 |
-
" \n",
|
409 |
-
"# return x\n",
|
410 |
-
"\n",
|
411 |
-
"class CombinedRotaryEmbedding(nn.Module):\n",
|
412 |
-
" def __init__(self, n_state, n_head, num_rotations, base=10000, checkpointing=False):\n",
|
413 |
-
" super().__init__()\n",
|
414 |
-
" self.n_state = n_state\n",
|
415 |
-
" self.n_head = n_head\n",
|
416 |
-
" self.h_dim = n_state // n_head\n",
|
417 |
-
" self.num_rotations = num_rotations\n",
|
418 |
-
" self.base = base\n",
|
419 |
-
" self.checkpointing = checkpointing\n",
|
420 |
-
" \n",
|
421 |
-
" self.thetas = nn.Parameter(torch.zeros(num_rotations))\n",
|
422 |
-
" self.rotation_pairs = nn.Parameter(torch.rand(num_rotations, 2) * self.h_dim)\n",
|
423 |
-
"\n",
|
424 |
-
" # Add a scaling factor for thetas\n",
|
425 |
-
" self.theta_scale = nn.Parameter(torch.ones(1)) \n",
|
426 |
-
"\n",
|
427 |
-
" self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))\n",
|
428 |
-
" \n",
|
429 |
-
" self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
|
430 |
-
" \n",
|
431 |
-
" def givens_rotation_matrix(self, n_state, i, j, theta):\n",
|
432 |
-
" G = torch.eye(n_state, device=theta.device)\n",
|
433 |
-
" G[i, i] = math.cos(theta)\n",
|
434 |
-
" G[i, j] = -math.sin(theta)\n",
|
435 |
-
" G[j, i] = math.sin(theta)\n",
|
436 |
-
" G[j, j] = math.cos(theta)\n",
|
437 |
-
" return G\n",
|
438 |
-
" \n",
|
439 |
-
" def update_base(self, new_base):\n",
|
440 |
-
" self.base = new_base\n",
|
441 |
-
" self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
|
442 |
-
" \n",
|
443 |
-
" def reset_parameters(self):\n",
|
444 |
-
" nn.init.orthogonal_(self.rotation_matrix)\n",
|
445 |
-
" nn.init.zeros_(self.thetas)\n",
|
446 |
-
" \n",
|
447 |
-
" def forward(self, x):\n",
|
448 |
-
" if self.checkpointing:\n",
|
449 |
-
" return checkpoint(self._forward, x)\n",
|
450 |
-
" else:\n",
|
451 |
-
" return self._forward(x)\n",
|
452 |
-
" \n",
|
453 |
-
" def _forward(self, x):\n",
|
454 |
-
" if x.dim() not in [3, 4]:\n",
|
455 |
-
" raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n",
|
456 |
-
" \n",
|
457 |
-
" if x.dim() == 3:\n",
|
458 |
-
" batch_size, seq_len, n_state = x.size()\n",
|
459 |
-
" x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
|
460 |
-
" else:\n",
|
461 |
-
" batch_size, seq_len, n_head, h_dim = x.size()\n",
|
462 |
-
" if n_head != self.n_head or h_dim != self.h_dim:\n",
|
463 |
-
" raise ValueError(f\"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}\")\n",
|
464 |
-
" \n",
|
465 |
-
" x = x.reshape(-1, self.h_dim)\n",
|
466 |
-
" \n",
|
467 |
-
" for k in range(self.num_rotations):\n",
|
468 |
-
" i, j = self.rotation_pairs[k].long()\n",
|
469 |
-
" \n",
|
470 |
-
" # Apply the scaling factor to theta\n",
|
471 |
-
" theta = self.thetas[k] * self.theta_scale \n",
|
472 |
-
" \n",
|
473 |
-
" G = self.givens_rotation_matrix(self.h_dim, i, j, theta)\n",
|
474 |
-
" x = torch.matmul(x, G)\n",
|
475 |
-
" \n",
|
476 |
-
" x = torch.matmul(x, self.rotation_matrix)\n",
|
477 |
-
" \n",
|
478 |
-
" x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
|
479 |
-
" \n",
|
480 |
-
" sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))\n",
|
481 |
-
" sin = sinusoid_inp.sin()[None, :, None, :]\n",
|
482 |
-
" cos = sinusoid_inp.cos()[None, :, None, :]\n",
|
483 |
-
" \n",
|
484 |
-
" x1, x2 = x[..., ::2], x[..., 1::2]\n",
|
485 |
-
" x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n",
|
486 |
-
" \n",
|
487 |
-
" x = x.view(batch_size, seq_len, self.n_state)\n",
|
488 |
-
" \n",
|
489 |
-
" return x\n",
|
490 |
-
"\n",
|
491 |
-
"\n",
|
492 |
-
"class LearnedSinusoidalEmbeddings(nn.Module):\n",
|
493 |
-
" def __init__(self, n_ctx, n_state, checkpointing=False):\n",
|
494 |
-
" super().__init__()\n",
|
495 |
-
" self.n_ctx = n_ctx\n",
|
496 |
-
" self.n_state = n_state\n",
|
497 |
-
" self.checkpointing = checkpointing\n",
|
498 |
-
"\n",
|
499 |
-
" position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)\n",
|
500 |
-
" div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))\n",
|
501 |
-
" features = torch.zeros(n_ctx, n_state)\n",
|
502 |
-
" features[:, 0::2] = torch.sin(position * div_term)\n",
|
503 |
-
" features[:, 1::2] = torch.cos(position * div_term)\n",
|
504 |
-
" self.register_buffer('sinusoidal_features', features)\n",
|
505 |
-
"\n",
|
506 |
-
" self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())\n",
|
507 |
-
"\n",
|
508 |
-
" def forward(self, positions):\n",
|
509 |
-
" if self.checkpointing:\n",
|
510 |
-
" position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)\n",
|
511 |
-
" else:\n",
|
512 |
-
" position_embeddings = self.positional_embeddings[positions]\n",
|
513 |
-
"\n",
|
514 |
-
" position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)\n",
|
515 |
-
" return position_embeddings\n",
|
516 |
-
"\n",
|
517 |
-
"class MultiHeadAttention(nn.Module):\n",
|
518 |
-
" use_sdpa = True\n",
|
519 |
-
"\n",
|
520 |
-
" def __init__(self, n_state: int, n_head: int, max_rel_dist: int = 1, base: int = 10000):\n",
|
521 |
-
" super().__init__()\n",
|
522 |
-
" assert n_state % n_head == 0, \"n_state must be divisible by n_head\"\n",
|
523 |
-
" self.n_head = n_head\n",
|
524 |
-
" self.h_dim = n_state // n_head\n",
|
525 |
-
" assert self.h_dim % 2 == 0, \"Head dimension must be even for rotary embeddings\"\n",
|
526 |
-
"\n",
|
527 |
-
" self.positional_scaling = nn.Parameter(torch.ones(1))\n",
|
528 |
-
"\n",
|
529 |
-
" self.query = nn.Linear(n_state, n_state)\n",
|
530 |
-
" self.key = nn.Linear(n_state, n_state, bias=False)\n",
|
531 |
-
" self.value = nn.Linear(n_state, n_state)\n",
|
532 |
-
" self.out = nn.Linear(n_state, n_state)\n",
|
533 |
-
"\n",
|
534 |
-
" self.max_rel_dist = max_rel_dist\n",
|
535 |
-
" self.base = base\n",
|
536 |
-
" inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n",
|
537 |
-
" self.register_buffer('inv_freq', inv_freq)\n",
|
538 |
-
" self.rel_pos_bias = nn.Embedding(2 * self.max_rel_dist - 1, self.n_head)\n",
|
539 |
-
" self.rel_pos_bias.weight.data.fill_(0)\n",
|
540 |
-
"\n",
|
541 |
-
" self.combined_rotary = CombinedRotaryEmbedding(\n",
|
542 |
-
" n_state=n_state,\n",
|
543 |
-
" n_head=n_head,\n",
|
544 |
-
" num_rotations=self.h_dim // 2,\n",
|
545 |
-
" base=base,\n",
|
546 |
-
" checkpointing=False \n",
|
547 |
-
" )\n",
|
548 |
-
"\n",
|
549 |
-
" if device:\n",
|
550 |
-
" self.to(device)\n",
|
551 |
-
"\n",
|
552 |
-
" def update_base(self, new_base): \n",
|
553 |
-
" self.base = new_base \n",
|
554 |
-
" inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)) \n",
|
555 |
-
" self.register_buffer('inv_freq', inv_freq) \n",
|
556 |
-
" self.combined_rotary.update_base(new_base)\n",
|
557 |
-
"\n",
|
558 |
-
" def forward(self, x, xa = None, mask = None, kv_cache = None):\n",
|
559 |
-
" q = self.query(x)\n",
|
560 |
-
"\n",
|
561 |
-
" if kv_cache is None or xa is None or 'k' not in kv_cache:\n",
|
562 |
-
" k_input = x if xa is None else xa\n",
|
563 |
-
" k = self.key(k_input)\n",
|
564 |
-
" v = self.value(k_input)\n",
|
565 |
-
" if kv_cache is not None:\n",
|
566 |
-
" kv_cache['k'] = k\n",
|
567 |
-
" kv_cache['v'] = v\n",
|
568 |
-
" else:\n",
|
569 |
-
" k = kv_cache['k']\n",
|
570 |
-
" v = kv_cache['v']\n",
|
571 |
-
"\n",
|
572 |
-
" q = q.view(q.shape[0], q.shape[1], self.n_head, -1)\n",
|
573 |
-
" k = k.view(k.shape[0], k.shape[1], self.n_head, -1)\n",
|
574 |
-
" v = v.view(v.shape[0], v.shape[1], self.n_head, -1)\n",
|
575 |
-
"\n",
|
576 |
-
" q = self.combined_rotary(q) \n",
|
577 |
-
" k = self.combined_rotary(k)\n",
|
578 |
-
"\n",
|
579 |
-
" q = q.view(q.shape[0], q.shape[1], -1)\n",
|
580 |
-
" k = k.view(k.shape[0], k.shape[1], -1)\n",
|
581 |
-
"\n",
|
582 |
-
" wv, qk = self.qkv_attention(q, k, v, mask)\n",
|
583 |
-
" return self.out(wv), qk\n",
|
584 |
-
" \n",
|
585 |
-
" def qkv_attention(self, q, k, v, mask = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n",
|
586 |
-
" n_batch, n_ctx, n_state = q.shape\n",
|
587 |
-
"\n",
|
588 |
-
" scale = (n_state // self.n_head) ** -0.25\n",
|
589 |
-
" q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
|
590 |
-
" k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
|
591 |
-
" v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
|
592 |
-
"\n",
|
593 |
-
" qk = (q * scale) @ (k * scale).transpose(-1, -2)\n",
|
594 |
-
"\n",
|
595 |
-
" seq_len_q = q.size(2)\n",
|
596 |
-
" seq_len_k = k.size(2)\n",
|
597 |
-
"\n",
|
598 |
-
" positions = torch.arange(seq_len_q, device=q.device).unsqueeze(1) - torch.arange(seq_len_k, device=q.device).unsqueeze(0)\n",
|
599 |
-
" positions = positions.clamp(-self.max_rel_dist + 1, self.max_rel_dist - 1) + self.max_rel_dist - 1\n",
|
600 |
-
" rel_bias = self.rel_pos_bias(positions) \n",
|
601 |
-
" rel_bias = rel_bias.permute(2, 0, 1).unsqueeze(0) \n",
|
602 |
-
"\n",
|
603 |
-
" qk = qk + rel_bias\n",
|
604 |
-
"\n",
|
605 |
-
" if mask is not None:\n",
|
606 |
-
" qk = qk + mask[:n_ctx, :n_ctx]\n",
|
607 |
-
" qk = qk.float()\n",
|
608 |
-
"\n",
|
609 |
-
" w = F.softmax(qk, dim=-1).to(q.dtype)\n",
|
610 |
-
" out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)\n",
|
611 |
-
" qk = qk.detach()\n",
|
612 |
-
"\n",
|
613 |
-
" return out, qk\n",
|
614 |
-
"\n",
|
615 |
-
"class ResidualAttentionBlock(nn.Module):\n",
|
616 |
-
" def __init__(self, n_state, n_head, cross_attention = False, max_rel_dist = 1, checkpointing=False):\n",
|
617 |
-
" super().__init__()\n",
|
618 |
-
"\n",
|
619 |
-
" self.attn = MultiHeadAttention(n_state, n_head)\n",
|
620 |
-
" self.attn_ln = LayerNorm(n_state)\n",
|
621 |
-
" self.checkpointing = checkpointing\n",
|
622 |
-
" self.max_rel_dist = max_rel_dist\n",
|
623 |
-
"\n",
|
624 |
-
" self.cross_attn = (\n",
|
625 |
-
" MultiHeadAttention(n_state, n_head) if cross_attention else None\n",
|
626 |
-
" )\n",
|
627 |
-
" self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None\n",
|
628 |
-
"\n",
|
629 |
-
" n_mlp = n_state * 4\n",
|
630 |
-
" self.mlp = nn.Sequential(\n",
|
631 |
-
" Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)\n",
|
632 |
-
" )\n",
|
633 |
-
" self.mlp_ln = LayerNorm(n_state)\n",
|
634 |
-
"\n",
|
635 |
-
" def forward(self, x, xa = None, mask = None, kv_cache = None):\n",
|
636 |
-
" if self.checkpointing:\n",
|
637 |
-
" x = checkpoint(self._attn_forward, x, mask, kv_cache)\n",
|
638 |
-
" else:\n",
|
639 |
-
" x = self._attn_forward(x, mask, kv_cache)\n",
|
640 |
-
"\n",
|
641 |
-
" if self.cross_attn:\n",
|
642 |
-
" if self.checkpointing:\n",
|
643 |
-
" x = checkpoint(self._cross_attn_forward, x, xa, kv_cache)\n",
|
644 |
-
" else:\n",
|
645 |
-
" x = self._cross_attn_forward(x, xa, kv_cache)\n",
|
646 |
-
"\n",
|
647 |
-
" if self.checkpointing:\n",
|
648 |
-
" x = checkpoint(self._mlp_forward, x)\n",
|
649 |
-
" else:\n",
|
650 |
-
" x = self._mlp_forward(x)\n",
|
651 |
-
"\n",
|
652 |
-
" return x\n",
|
653 |
-
"\n",
|
654 |
-
" def _attn_forward(self, x, mask, kv_cache):\n",
|
655 |
-
" residual = x\n",
|
656 |
-
" x = self.attn_ln(x)\n",
|
657 |
-
" x = residual + self.attn(x, mask=mask, kv_cache=kv_cache)[0]\n",
|
658 |
-
" return x\n",
|
659 |
-
"\n",
|
660 |
-
" def _cross_attn_forward(self, x, xa, kv_cache):\n",
|
661 |
-
" residual = x\n",
|
662 |
-
" x = self.cross_attn_ln(x)\n",
|
663 |
-
" x = residual + self.cross_attn(x, xa, kv_cache=kv_cache)[0]\n",
|
664 |
-
" return x\n",
|
665 |
-
"\n",
|
666 |
-
" def _mlp_forward(self, x):\n",
|
667 |
-
" residual = x\n",
|
668 |
-
" x = self.mlp_ln(x)\n",
|
669 |
-
" x = residual + self.mlp(x)\n",
|
670 |
-
" return x\n",
|
671 |
-
"\n",
|
672 |
-
"class AudioEncoder(nn.Module):\n",
|
673 |
-
" def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, max_rel_dist = 1, cross_attention=True, checkpointing=False, base=10000):\n",
|
674 |
-
" super().__init__()\n",
|
675 |
-
" self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)\n",
|
676 |
-
" self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)\n",
|
677 |
-
" self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)\n",
|
678 |
-
" self.checkpointing = checkpointing\n",
|
679 |
-
" self.h_dim = n_state // n_head\n",
|
680 |
-
"\n",
|
681 |
-
" self.combined_rotary = CombinedRotaryEmbedding(\n",
|
682 |
-
" n_state=n_state,\n",
|
683 |
-
" n_head=n_head,\n",
|
684 |
-
" num_rotations=self.h_dim // 2,\n",
|
685 |
-
" base=base,\n",
|
686 |
-
" checkpointing=False \n",
|
687 |
-
" )\n",
|
688 |
-
"\n",
|
689 |
-
" self.blocks = nn.ModuleList(\n",
|
690 |
-
" [ResidualAttentionBlock(n_state, n_head, max_rel_dist, checkpointing=checkpointing) for _ in range(n_layer)]\n",
|
691 |
-
" )\n",
|
692 |
-
" self.ln_post = LayerNorm(n_state)\n",
|
693 |
-
"\n",
|
694 |
-
" def update_base(self, new_base):\n",
|
695 |
-
" self.combined_rotary.update_base(new_base)\n",
|
696 |
-
" for block in self.blocks:\n",
|
697 |
-
" if isinstance(block.attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
|
698 |
-
" block.attn.update_base(new_base)\n",
|
699 |
-
" if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
|
700 |
-
" block.cross_attn.update_base(new_base)\n",
|
701 |
-
"\n",
|
702 |
-
" def forward(self, x):\n",
|
703 |
-
" if self.checkpointing:\n",
|
704 |
-
" x = checkpoint(self._conv_forward, x)\n",
|
705 |
-
" else:\n",
|
706 |
-
" x = self._conv_forward(x)\n",
|
707 |
-
"\n",
|
708 |
-
" for block in self.blocks:\n",
|
709 |
-
" if self.checkpointing:\n",
|
710 |
-
" x = checkpoint(block, x)\n",
|
711 |
-
" else:\n",
|
712 |
-
" x = block(x)\n",
|
713 |
-
"\n",
|
714 |
-
" x = self.ln_post(x)\n",
|
715 |
-
" return x\n",
|
716 |
-
"\n",
|
717 |
-
" def _conv_forward(self, x):\n",
|
718 |
-
" x = F.gelu(self.conv1(x))\n",
|
719 |
-
" x = F.gelu(self.conv2(x))\n",
|
720 |
-
" x = x.permute(0, 2, 1)\n",
|
721 |
-
"\n",
|
722 |
-
" x = self.combined_rotary(x)\n",
|
723 |
-
"\n",
|
724 |
-
" pos_emb = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)\n",
|
725 |
-
" x = x + pos_emb\n",
|
726 |
-
" return x\n",
|
727 |
-
"\n",
|
728 |
-
"class TextDecoder(nn.Module):\n",
|
729 |
-
" def __init__(self, vocab_size, n_ctx, n_state, n_head, n_layer, max_rel_dist = 1, cross_attention=True, checkpointing=False, base=10000):\n",
|
730 |
-
" super().__init__()\n",
|
731 |
-
" self.token_embedding = nn.Embedding(vocab_size, n_state)\n",
|
732 |
-
" self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)\n",
|
733 |
-
" self.checkpointing = checkpointing\n",
|
734 |
-
" self.n_head = n_head\n",
|
735 |
-
" self.h_dim = n_state // n_head\n",
|
736 |
-
" \n",
|
737 |
-
" self.combined_rotary = CombinedRotaryEmbedding(\n",
|
738 |
-
" n_state=n_state,\n",
|
739 |
-
" n_head=n_head,\n",
|
740 |
-
" num_rotations=self.h_dim // 2, \n",
|
741 |
-
" base=base,\n",
|
742 |
-
" checkpointing=False \n",
|
743 |
-
" )\n",
|
744 |
-
"\n",
|
745 |
-
" self.blocks = nn.ModuleList([\n",
|
746 |
-
" ResidualAttentionBlock(n_state, n_head, max_rel_dist, cross_attention, checkpointing=checkpointing)\n",
|
747 |
-
" for _ in range(n_layer)\n",
|
748 |
-
" ])\n",
|
749 |
-
" self.ln = LayerNorm(n_state)\n",
|
750 |
-
" mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)\n",
|
751 |
-
" self.register_buffer(\"mask\", mask, persistent=False)\n",
|
752 |
-
"\n",
|
753 |
-
" def update_base(self, new_base):\n",
|
754 |
-
" self.combined_rotary.update_base(new_base)\n",
|
755 |
-
" for block in self.blocks:\n",
|
756 |
-
" if isinstance(block.attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
|
757 |
-
" block.attn.update_base(new_base)\n",
|
758 |
-
" if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
|
759 |
-
" block.cross_attn.update_base(new_base)\n",
|
760 |
-
"\n",
|
761 |
-
" def forward(self, x, xa, kv_cache = None):\n",
|
762 |
-
" if self.checkpointing:\n",
|
763 |
-
" x = checkpoint(self._embedding_forward, x, xa, kv_cache)\n",
|
764 |
-
" else:\n",
|
765 |
-
" x = self._embedding_forward(x, xa, kv_cache)\n",
|
766 |
-
"\n",
|
767 |
-
" for block in self.blocks:\n",
|
768 |
-
" if self.checkpointing:\n",
|
769 |
-
" x = checkpoint(block, x, xa, self.mask, kv_cache)\n",
|
770 |
-
" else:\n",
|
771 |
-
" x = block(x, xa, self.mask, kv_cache)\n",
|
772 |
-
"\n",
|
773 |
-
" x = self.ln(x)\n",
|
774 |
-
" logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()\n",
|
775 |
-
"\n",
|
776 |
-
" return logits\n",
|
777 |
-
"\n",
|
778 |
-
" def _embedding_forward(self, x, xa, kv_cache):\n",
|
779 |
-
" offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0\n",
|
780 |
-
" positions = torch.arange(x.shape[1], device=x.device) + offset\n",
|
781 |
-
" pos_emb = self.positional_embedding(positions).unsqueeze(0)\n",
|
782 |
-
"\n",
|
783 |
-
" x = self.token_embedding(x) + pos_emb\n",
|
784 |
-
" x = x.to(xa.dtype)\n",
|
785 |
-
"\n",
|
786 |
-
" batch_size, seq_length, embedding_dim = x.shape\n",
|
787 |
-
" num_heads = self.n_head\n",
|
788 |
-
" head_dim = embedding_dim // num_heads\n",
|
789 |
-
" x = x.view(batch_size, seq_length, num_heads, head_dim)\n",
|
790 |
-
"\n",
|
791 |
-
" x = self.combined_rotary(x)\n",
|
792 |
-
"\n",
|
793 |
-
" x = x.view(batch_size, seq_length, embedding_dim)\n",
|
794 |
-
" return x\n",
|
795 |
-
" \n",
|
796 |
-
"class Echo(WhisperPreTrainedModel, PyTorchModelHubMixin):\n",
|
797 |
-
" config_class = WhisperConfig\n",
|
798 |
-
"\n",
|
799 |
-
" def __init__(self, config: WhisperConfig):\n",
|
800 |
-
" super().__init__(config)\n",
|
801 |
-
" self.config = config\n",
|
802 |
-
"\n",
|
803 |
-
" self.n_mels = self.config.num_mel_bins\n",
|
804 |
-
" self.n_audio_ctx = self.config.max_source_positions\n",
|
805 |
-
" self.n_audio_state = self.config.d_model\n",
|
806 |
-
" self.n_audio_head = self.config.encoder_attention_heads\n",
|
807 |
-
" self.n_audio_layer = self.config.encoder_layers\n",
|
808 |
-
" self.vocab_size = self.config.vocab_size\n",
|
809 |
-
" self.n_text_ctx = self.config.max_target_positions\n",
|
810 |
-
" self.n_text_state = self.config.d_model\n",
|
811 |
-
" self.n_text_head = self.config.decoder_attention_heads\n",
|
812 |
-
" self.n_text_layer = self.config.decoder_layers\n",
|
813 |
-
" self.checkpointing = self.config.checkpointing\n",
|
814 |
-
" self.max_rel_dist = self.config.max_rel_dist\n",
|
815 |
-
" self.cross_attention = self.config.cross_attention\n",
|
816 |
-
" self.base = self.config.base\n",
|
817 |
-
"\n",
|
818 |
-
" self.encoder = AudioEncoder(\n",
|
819 |
-
" self.config.n_mels,\n",
|
820 |
-
" self.config.n_audio_ctx,\n",
|
821 |
-
" self.config.n_audio_state,\n",
|
822 |
-
" self.config.n_audio_head,\n",
|
823 |
-
" self.config.n_audio_layer,\n",
|
824 |
-
" self.config.checkpointing,\n",
|
825 |
-
" self.config.max_rel_dist,\n",
|
826 |
-
" self.config.cross_attention,\n",
|
827 |
-
" self.config.base,\n",
|
828 |
-
" )\n",
|
829 |
-
" self.decoder = TextDecoder(\n",
|
830 |
-
" self.config.vocab_size,\n",
|
831 |
-
" self.config.n_text_ctx,\n",
|
832 |
-
" self.config.n_text_state,\n",
|
833 |
-
" self.config.n_text_head,\n",
|
834 |
-
" self.config.n_text_layer,\n",
|
835 |
-
" self.config.checkpointing,\n",
|
836 |
-
" self.config.max_rel_dist,\n",
|
837 |
-
" self.config.cross_attention,\n",
|
838 |
-
" self.config.base,\n",
|
839 |
-
" )\n",
|
840 |
-
"\n",
|
841 |
-
" all_heads = torch.zeros(self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool)\n",
|
842 |
-
" all_heads[self.config.n_text_layer // 2:] = True\n",
|
843 |
-
" self.register_buffer(\"alignment_heads\", all_heads.to_sparse(), persistent=False)\n",
|
844 |
-
"\n",
|
845 |
-
" self.best_loss = float('inf')\n",
|
846 |
-
" self.base = 10000 \n",
|
847 |
-
"\n",
|
848 |
-
" def update_base(self, new_base):\n",
|
849 |
-
" self.encoder.combined_rotary.update_base(new_base)\n",
|
850 |
-
" self.decoder.combined_rotary.update_base(new_base)\n",
|
851 |
-
"\n",
|
852 |
-
" for name, module in self.encoder.named_modules():\n",
|
853 |
-
" if isinstance(module, (MultiHeadAttention, CombinedRotaryEmbedding)):\n",
|
854 |
-
" module.update_base(new_base)\n",
|
855 |
-
"\n",
|
856 |
-
" for name, module in self.decoder.named_modules():\n",
|
857 |
-
" if isinstance(module, (MultiHeadAttention, CombinedRotaryEmbedding)):\n",
|
858 |
-
" module.update_base(new_base)\n",
|
859 |
-
"\n",
|
860 |
-
" def adjust_base(self, loss, factor=1.05):\n",
|
861 |
-
" if loss < self.best_loss:\n",
|
862 |
-
" new_base = self.base * factor\n",
|
863 |
-
" else:\n",
|
864 |
-
" new_base = self.base / factor\n",
|
865 |
-
"\n",
|
866 |
-
" self.update_base(new_base)\n",
|
867 |
-
" self.best_loss = loss\n",
|
868 |
-
" # print(f\"Adjusted base: {new_base}\")\n",
|
869 |
-
"\n",
|
870 |
-
" @staticmethod\n",
|
871 |
-
" def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) -> torch.Tensor:\n",
|
872 |
-
" shifted_input_ids = input_ids.new_zeros(input_ids.shape)\n",
|
873 |
-
" shifted_input_ids[:, 1:] = input_ids[:, :-1]\n",
|
874 |
-
" shifted_input_ids[:, 0] = decoder_start_token_id\n",
|
875 |
-
" shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n",
|
876 |
-
" return shifted_input_ids\n",
|
877 |
-
"\n",
|
878 |
-
" def forward(self, input_features, labels=None, dec_input_ids=None):\n",
|
879 |
-
" if labels is not None:\n",
|
880 |
-
" if dec_input_ids is None:\n",
|
881 |
-
" dec_input_ids = self.shift_tokens_right(\n",
|
882 |
-
" labels, self.config.pad_token_id, self.config.decoder_start_token_id\n",
|
883 |
-
" )\n",
|
884 |
-
"\n",
|
885 |
-
" encoded_features = self.encoder(input_features).to(device)\n",
|
886 |
-
" logits = self.decoder(dec_input_ids, encoded_features)\n",
|
887 |
-
"\n",
|
888 |
-
" loss = None\n",
|
889 |
-
" if labels is not None:\n",
|
890 |
-
" loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) \n",
|
891 |
-
" labels = labels.to(logits.device).long()\n",
|
892 |
-
" loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))\n",
|
893 |
-
"\n",
|
894 |
-
" self.adjust_base(loss.item())\n",
|
895 |
-
"\n",
|
896 |
-
" return {\n",
|
897 |
-
" \"loss\": loss,\n",
|
898 |
-
" \"logits\": logits,\n",
|
899 |
-
" }\n",
|
900 |
-
"\n",
|
901 |
-
" def _initialize_weights(self):\n",
|
902 |
-
" nn.init.normal_(self.decoder.token_embedding.weight, mean=0.0, std=self.config.init_std)\n",
|
903 |
-
" if hasattr(self.decoder.positional_embedding, 'weight'):\n",
|
904 |
-
" nn.init.normal_(self.decoder.positional_embedding.weight, mean=0.0, std=self.config.init_std)\n",
|
905 |
-
" for block in self.decoder.blocks:\n",
|
906 |
-
" for layer in block.children():\n",
|
907 |
-
" if isinstance(layer, nn.Linear):\n",
|
908 |
-
" nn.init.xavier_normal_(layer.weight)\n",
|
909 |
-
" if layer.bias is not None:\n",
|
910 |
-
" nn.init.zeros_(layer.bias)\n",
|
911 |
-
"\n",
|
912 |
-
" nn.init.constant_(self.decoder.ln.gamma, 1)\n",
|
913 |
-
" if self.decoder.ln.beta is not None:\n",
|
914 |
-
" nn.init.constant_(self.decoder.ln.beta, 0)\n",
|
915 |
-
"\n",
|
916 |
-
" nn.init.xavier_normal_(self.encoder.conv1.weight)\n",
|
917 |
-
" if self.encoder.conv1.bias is not None:\n",
|
918 |
-
" nn.init.zeros_(self.encoder.conv1.bias)\n",
|
919 |
-
"\n",
|
920 |
-
" nn.init.kaiming_normal_(self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')\n",
|
921 |
-
" if self.encoder.conv2.bias is not None:\n",
|
922 |
-
" nn.init.zeros_(self.encoder.conv2.bias)\n",
|
923 |
-
"\n",
|
924 |
-
" nn.init.constant_(self.encoder.ln_post.gamma, 1)\n",
|
925 |
-
" if self.encoder.ln_post.beta is not None:\n",
|
926 |
-
" nn.init.constant_(self.encoder.ln_post.beta, 0)\n",
|
927 |
-
" \n",
|
928 |
-
" def apply_initialization(self):\n",
|
929 |
-
" self._initialize_weights()\n",
|
930 |
-
"\n",
|
931 |
-
" def set_alignment_heads(self, dump: bytes):\n",
|
932 |
-
" array = np.frombuffer(\n",
|
933 |
-
" gzip.decompress(base64.b85decode(dump)), dtype=bool\n",
|
934 |
-
" ).copy()\n",
|
935 |
-
" mask = torch.from_numpy(array).reshape(\n",
|
936 |
-
" self.config.n_text_layer, self.config.n_text_head\n",
|
937 |
-
" )\n",
|
938 |
-
" self.register_buffer(\"alignment_heads\", mask.to_sparse(), persistent=False)\n",
|
939 |
-
"\n",
|
940 |
-
" def embed_audio(self, mel):\n",
|
941 |
-
" return self.encoder(mel)\n",
|
942 |
-
"\n",
|
943 |
-
" def logits(self, labels, input_features):\n",
|
944 |
-
" return self.decoder(labels, input_features)\n",
|
945 |
-
"\n",
|
946 |
-
" @property\n",
|
947 |
-
" def device(self):\n",
|
948 |
-
" return next(self.parameters()).device\n",
|
949 |
-
"\n",
|
950 |
-
" @property\n",
|
951 |
-
" def is_multilingual(self):\n",
|
952 |
-
" return self.config.vocab_size >= len(tokenizer)\n",
|
953 |
-
"\n",
|
954 |
-
" @property\n",
|
955 |
-
" def num_languages(self):\n",
|
956 |
-
" return self.config.vocab_size - (len(tokenizer)-100) - int(self.is_multilingual)\n",
|
957 |
-
"\n",
|
958 |
-
" def install_kv_cache_hooks(self, cache = None):\n",
|
959 |
-
" cache = {**cache} if cache is not None else {}\n",
|
960 |
-
" hooks = []\n",
|
961 |
-
"\n",
|
962 |
-
" def save_to_cache(module, _, output):\n",
|
963 |
-
" if module not in cache or output.shape[1] > self.config.n_text_ctx:\n",
|
964 |
-
" cache[module] = output\n",
|
965 |
-
" else:\n",
|
966 |
-
" cache[module] = torch.cat([cache[module], output], dim=1).detach()\n",
|
967 |
-
" return cache[module]\n",
|
968 |
-
"\n",
|
969 |
-
" def install_hooks(layer: nn.Module):\n",
|
970 |
-
" if isinstance(layer, MultiHeadAttention):\n",
|
971 |
-
" hooks.append(layer.key.register_forward_hook(save_to_cache))\n",
|
972 |
-
" hooks.append(layer.value.register_forward_hook(save_to_cache))\n",
|
973 |
-
"\n",
|
974 |
-
" self.decoder.apply(install_hooks)\n",
|
975 |
-
" return cache, hooks\n",
|
976 |
-
"\n",
|
977 |
-
" detect_language = detect_language_function\n",
|
978 |
-
" transcribe = transcribe_function\n",
|
979 |
-
" decode = decode_function\n",
|
980 |
-
"\n",
|
981 |
-
" def get_encoder(self):\n",
|
982 |
-
" return self.encoder\n",
|
983 |
-
"\n",
|
984 |
-
" def prepare_inputs_for_generation(self, input_ids, **kwargs):\n",
|
985 |
-
" return {'input_features': input_ids}\n",
|
986 |
-
"\n",
|
987 |
-
" def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):\n",
|
988 |
-
" return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id\n",
|
989 |
-
"\n",
|
990 |
-
" def can_generate(self):\n",
|
991 |
-
" return True\n",
|
992 |
-
" \n",
|
993 |
-
" def generate(self, inputs, **kwargs):\n",
|
994 |
-
" encoder_outputs = self.encoder(inputs)\n",
|
995 |
-
" decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)\n",
|
996 |
-
" outputs = self.decoder(decoder_input_ids, encoder_outputs)\n",
|
997 |
-
" return outputs.argmax(dim=-1)\n",
|
998 |
-
"\n",
|
999 |
-
" def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):\n",
|
1000 |
-
" if not self.supports_gradient_checkpointing:\n",
|
1001 |
-
" raise ValueError(f\"{self.__class__.__name__} does not support gradient checkpointing.\")\n",
|
1002 |
-
"\n",
|
1003 |
-
" if gradient_checkpointing_kwargs is None:\n",
|
1004 |
-
" gradient_checkpointing_kwargs = {\"use_reentrant\": True}\n",
|
1005 |
-
"\n",
|
1006 |
-
" gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)\n",
|
1007 |
-
"\n",
|
1008 |
-
" _is_using_old_format = \"value\" in inspect.signature(self._set_gradient_checkpointing).parameters\n",
|
1009 |
-
"\n",
|
1010 |
-
" if not _is_using_old_format:\n",
|
1011 |
-
" self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)\n",
|
1012 |
-
" else:\n",
|
1013 |
-
" self.apply(partial(self._set_gradient_checkpointing, value=True))\n",
|
1014 |
-
"\n",
|
1015 |
-
" if getattr(self, \"_hf_peft_config_loaded\", False):\n",
|
1016 |
-
" self.enable_input_require_grads()\n",
|
1017 |
-
"\n",
|
1018 |
-
"feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-small\", feature_size=128, n_fft=1024, hop_length=256, sampling_rate=16000)#, win_length=2048, sampling_rate=16000, mel_fmin=0.0, mel_fmax=8000.0)\n",
|
1019 |
-
"tokenizer = WhisperTokenizer.from_pretrained(\"openai/whisper-small\", language=\"japanese\", task=\"transcribe\")\n",
|
1020 |
-
"processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", language=\"japanese\", task=\"transcribe\")\n",
|
1021 |
-
"\n",
|
1022 |
-
"config = WhisperConfig(\n",
|
1023 |
-
" n_mels=128,\n",
|
1024 |
-
" n_audio_ctx=1500,\n",
|
1025 |
-
" n_audio_state=1024,\n",
|
1026 |
-
" n_audio_head=16,\n",
|
1027 |
-
" n_audio_layer=24,\n",
|
1028 |
-
" vocab_size=51865,\n",
|
1029 |
-
" n_text_ctx=448,\n",
|
1030 |
-
" n_text_state=1024,\n",
|
1031 |
-
" n_text_head=16,\n",
|
1032 |
-
" n_text_layer=20,\n",
|
1033 |
-
" max_rel_dist=200,\n",
|
1034 |
-
" cross_attention=True,\n",
|
1035 |
-
" checkpointing=True,\n",
|
1036 |
-
" base=10000,\n",
|
1037 |
-
" bos_token_id = 50257,\n",
|
1038 |
-
" eos_token_id = 50257,\n",
|
1039 |
-
" pad_token_id = 50257,\n",
|
1040 |
-
" decoder_start_token_id = 50258,\n",
|
1041 |
-
" is_encoder_decoder = True,\n",
|
1042 |
-
" init_std=0.02,\n",
|
1043 |
-
" )\n",
|
1044 |
-
"\n",
|
1045 |
-
"model = Echo(config).to(device)\n",
|
1046 |
-
"model.apply_initialization()\n",
|
1047 |
-
"model.save_pretrained(\"./models/echo\")\n",
|
1048 |
-
"# model = Echo.from_pretrained(\"./models/echo2\")\n",
|
1049 |
-
"\n"
|
1050 |
-
]
|
1051 |
-
},
|
1052 |
-
{
|
1053 |
-
"cell_type": "code",
|
1054 |
-
"execution_count": null,
|
1055 |
-
"metadata": {},
|
1056 |
-
"outputs": [],
|
1057 |
-
"source": [
|
1058 |
-
"raw_datasets = IterableDatasetDict()\n",
|
1059 |
-
"\n",
|
1060 |
-
"raw_datasets[\"train\"] = load_dataset(\"mozilla-foundation/common_voice_17_0\", \"ja\", split=\"train\", trust_remote_code=True, streaming=True) # set split=\"train+validation\" for low-resource\n",
|
1061 |
-
"raw_datasets[\"test\"] = load_dataset(\"mozilla-foundation/common_voice_17_0\", \"ja\", split=\"test\", trust_remote_code=True, streaming=True).take(100)\n",
|
1062 |
-
"\n",
|
1063 |
-
"raw_datasets = raw_datasets.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
|
1064 |
-
"\n",
|
1065 |
-
"tokenizer = WhisperTokenizer.from_pretrained(\"openai/whisper-small\", language=\"japanese\", task=\"transcribe\")\n",
|
1066 |
-
"processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", language=\"japanese\", task=\"transcribe\")\n",
|
1067 |
-
"\n",
|
1068 |
-
"def prepare_dataset(batch):\n",
|
1069 |
-
" audio = batch[\"audio\"]\n",
|
1070 |
-
" batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
|
1071 |
-
" transcription = batch[\"sentence\"]\n",
|
1072 |
-
" batch[\"labels\"] = tokenizer(transcription).input_ids\n",
|
1073 |
-
" return batch\n",
|
1074 |
-
"\n",
|
1075 |
-
"vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format(\"torch\")\n",
|
1076 |
-
"\n",
|
1077 |
-
"\n",
|
1078 |
-
"@dataclass\n",
|
1079 |
-
"class DataCollatorSpeechSeq2SeqWithPadding:\n",
|
1080 |
-
" processor: Any\n",
|
1081 |
-
"\n",
|
1082 |
-
" def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
|
1083 |
-
" input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
|
1084 |
-
" batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
|
1085 |
-
" label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
|
1086 |
-
" labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
|
1087 |
-
" labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
|
1088 |
-
" if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
|
1089 |
-
" labels = labels[:, 1:]\n",
|
1090 |
-
" batch[\"labels\"] = labels\n",
|
1091 |
-
" return batch\n",
|
1092 |
-
" \n",
|
1093 |
-
"data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\n",
|
1094 |
-
"\n",
|
1095 |
-
"metric = evaluate.load(\"cer\")\n",
|
1096 |
-
"\n",
|
1097 |
-
"def compute_metrics(pred):\n",
|
1098 |
-
" pred_logits = pred.predictions\n",
|
1099 |
-
" label_ids = pred.label_ids\n",
|
1100 |
-
"\n",
|
1101 |
-
" if isinstance(pred_logits, tuple):\n",
|
1102 |
-
" pred_ids = pred_logits[0]\n",
|
1103 |
-
" else:\n",
|
1104 |
-
" pred_ids = pred_logits\n",
|
1105 |
-
" if pred_ids.ndim == 3:\n",
|
1106 |
-
" pred_ids = np.argmax(pred_ids, axis=-1)\n",
|
1107 |
-
"\n",
|
1108 |
-
" label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
|
1109 |
-
" pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
|
1110 |
-
" label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
|
1111 |
-
" cer = 100 * metric.compute(predictions=pred_str, references=label_str)\n",
|
1112 |
-
" return {\"cer\": cer}\n",
|
1113 |
-
"\n",
|
1114 |
-
"training_args = Seq2SeqTrainingArguments(\n",
|
1115 |
-
" output_dir=\"./test11\", \n",
|
1116 |
-
" per_device_train_batch_size=1,\n",
|
1117 |
-
" per_device_eval_batch_size=1,\n",
|
1118 |
-
" gradient_accumulation_steps=1,\n",
|
1119 |
-
" eval_accumulation_steps=1,\n",
|
1120 |
-
" # num_train_epochs=1,\n",
|
1121 |
-
" tf32=True,\n",
|
1122 |
-
" bf16=True,\n",
|
1123 |
-
" learning_rate=1e-5,\n",
|
1124 |
-
" warmup_steps=500,\n",
|
1125 |
-
" evaluation_strategy=\"steps\",\n",
|
1126 |
-
" # predict_with_generate=True,\n",
|
1127 |
-
" # generation_max_length=225,\n",
|
1128 |
-
" max_steps=40000,\n",
|
1129 |
-
" save_steps=1000,\n",
|
1130 |
-
" eval_steps=100,\n",
|
1131 |
-
" logging_steps=5,\n",
|
1132 |
-
" report_to=[\"tensorboard\"],\n",
|
1133 |
-
" load_best_model_at_end=True,\n",
|
1134 |
-
" metric_for_best_model=\"loss\",\n",
|
1135 |
-
" greater_is_better=False,\n",
|
1136 |
-
" push_to_hub=False,\n",
|
1137 |
-
" optim=\"adafactor\",\n",
|
1138 |
-
" weight_decay=0.0025,\n",
|
1139 |
-
" disable_tqdm=False,\n",
|
1140 |
-
" save_total_limit=2,\n",
|
1141 |
-
" torch_empty_cache_steps=10,\n",
|
1142 |
-
")\n",
|
1143 |
-
"\n",
|
1144 |
-
"trainer = Seq2SeqTrainer(\n",
|
1145 |
-
" args=training_args,\n",
|
1146 |
-
" model=model,\n",
|
1147 |
-
" train_dataset=vectorized_datasets[\"train\"],\n",
|
1148 |
-
" eval_dataset=vectorized_datasets[\"test\"],\n",
|
1149 |
-
" data_collator=data_collator,\n",
|
1150 |
-
" compute_metrics=compute_metrics,\n",
|
1151 |
-
" tokenizer=processor.feature_extractor,\n",
|
1152 |
-
")\n",
|
1153 |
-
"\n",
|
1154 |
-
"# trainer.add_callback(CustomCallback)\n",
|
1155 |
-
"\n",
|
1156 |
-
"trainer.train()\n",
|
1157 |
-
"\n",
|
1158 |
-
"import tensorboard"
|
1159 |
-
]
|
1160 |
-
}
|
1161 |
-
],
|
1162 |
-
"metadata": {
|
1163 |
-
"kernelspec": {
|
1164 |
-
"display_name": "Python 3",
|
1165 |
-
"language": "python",
|
1166 |
-
"name": "python3"
|
1167 |
-
},
|
1168 |
-
"language_info": {
|
1169 |
-
"codemirror_mode": {
|
1170 |
-
"name": "ipython",
|
1171 |
-
"version": 3
|
1172 |
-
},
|
1173 |
-
"file_extension": ".py",
|
1174 |
-
"mimetype": "text/x-python",
|
1175 |
-
"name": "python",
|
1176 |
-
"nbconvert_exporter": "python",
|
1177 |
-
"pygments_lexer": "ipython3",
|
1178 |
-
"version": "3.10.0"
|
1179 |
-
}
|
1180 |
-
},
|
1181 |
-
"nbformat": 4,
|
1182 |
-
"nbformat_minor": 2
|
1183 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|