aishutin commited on
Commit
4c57f07
1 Parent(s): dffafe4

Upload 2 files

Browse files
Files changed (2) hide show
  1. bloom-7b1-quantization.ipynb +946 -0
  2. load_bloom.ipynb +616 -0
bloom-7b1-quantization.ipynb ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "29d14fe0",
6
+ "metadata": {
7
+ "cellId": "hwmcjwsucnwczi4u66ftg",
8
+ "id": "e13eff4e-c134-4dac-9523-07b297164250"
9
+ },
10
+ "source": [
11
+ "# Example of Quantizating 7.1 billion Bloom with 8-bit weights\n",
12
+ "\n",
13
+ "Heavily inspired by [Hivemind's work](https://nbviewer.org/urls/huggingface.co/hivemind/gpt-j-6B-8bit/raw/main/convert-gpt-j.ipynb) and [joaoalvarenga's work](https://huggingface.co/joaoalvarenga/bloom-8bit)"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 1,
19
+ "id": "39f137ae",
20
+ "metadata": {
21
+ "cellId": "wg56t50s3la38havqevkme",
22
+ "colab": {
23
+ "base_uri": "https://localhost:8080/"
24
+ },
25
+ "id": "699e94eb-3ce1-4788-999b-fb6d593ba7e9",
26
+ "outputId": "764a6719-66d0-4ef7-df2d-4cfda0914f65"
27
+ },
28
+ "outputs": [],
29
+ "source": [
30
+ "#%pip install transformers==4.20.1\n",
31
+ "#%pip install bitsandbytes\n",
32
+ "#%pip install datasets\n",
33
+ "#%pip install accelerate"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "id": "53e4dd05",
39
+ "metadata": {
40
+ "cellId": "aklenvay105v0md7yy679m",
41
+ "id": "0afea72c-691d-4719-a84a-663f1891af6e"
42
+ },
43
+ "source": [
44
+ "### Load and convert original Bloom structure to 8-bit\n",
45
+ "\n",
46
+ "You can load an already compressed 8-bit version of Bloom from [OpenDungeon/bloom-7b1-8bit](https://huggingface.co/OpenDungeon/bloom-7b1-8bit/tree/main) with small monkey patching. But this notebook focuses on compression of Bloom, not usage."
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 2,
52
+ "id": "e1ca3df9",
53
+ "metadata": {
54
+ "cellId": "ktgxcupgtcf8hhh2k1r2ij",
55
+ "colab": {
56
+ "base_uri": "https://localhost:8080/"
57
+ },
58
+ "id": "xcdQSnYIk12Z",
59
+ "outputId": "8d0fff65-4d34-41bd-f750-278a35ac9533"
60
+ },
61
+ "outputs": [
62
+ {
63
+ "name": "stderr",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "/home/dm/.local/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
67
+ " from .autonotebook import tqdm as notebook_tqdm\n",
68
+ "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!\n",
69
+ " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n"
70
+ ]
71
+ },
72
+ {
73
+ "name": "stdout",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "\n",
77
+ "===================================BUG REPORT===================================\n",
78
+ "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
79
+ "For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
80
+ "================================================================================\n",
81
+ "CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...\n",
82
+ "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!\n",
83
+ "CUDA SETUP: Loading binary /home/dm/.local/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so...\n"
84
+ ]
85
+ },
86
+ {
87
+ "name": "stderr",
88
+ "output_type": "stream",
89
+ "text": [
90
+ "/home/dm/.local/lib/python3.8/site-packages/bitsandbytes/cuda_setup/paths.py:27: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('8bitexperiments/f746d450-b748-4d1f-b3e3-9e4fd3f72d6e')}\n",
91
+ " warn(\n",
92
+ "/home/dm/.local/lib/python3.8/site-packages/bitsandbytes/cuda_setup/paths.py:27: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('module'), PosixPath('//matplotlib_inline.backend_inline')}\n",
93
+ " warn(\n",
94
+ "/home/dm/.local/lib/python3.8/site-packages/bitsandbytes/cuda_setup/paths.py:27: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/local/cuda/lib64')}\n",
95
+ " warn(\n",
96
+ "/home/dm/.local/lib/python3.8/site-packages/bitsandbytes/cextension.py:48: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.\n",
97
+ " warn(\n"
98
+ ]
99
+ }
100
+ ],
101
+ "source": [
102
+ "import torch\n",
103
+ "import torch.nn as nn\n",
104
+ "import torch.nn.functional as F\n",
105
+ "import transformers\n",
106
+ "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
107
+ "\n",
108
+ "model_name = \"bigscience/bloom-7b1\"\n",
109
+ "gpt = transformers.BloomForCausalLM.from_pretrained(model_name, cache_dir=\"mycache\")\n",
110
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, cache_dir=\"mycache\")"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 3,
116
+ "id": "b37255b0",
117
+ "metadata": {
118
+ "cellId": "wmew4wc0e3pztbva18lggg",
119
+ "id": "YjLHVyIOkdCH"
120
+ },
121
+ "outputs": [],
122
+ "source": [
123
+ "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n",
124
+ " assert chunk_size % 4096 == 0\n",
125
+ " code = None\n",
126
+ " chunks = []\n",
127
+ " absmaxes = []\n",
128
+ " flat_tensor = matrix.view(-1)\n",
129
+ " for i in range((matrix.numel() - 1) // chunk_size + 1):\n",
130
+ " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n",
131
+ " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n",
132
+ " chunks.append(quantized_chunk)\n",
133
+ " absmaxes.append(absmax_chunk)\n",
134
+ " \n",
135
+ " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n",
136
+ " absmax = torch.cat(absmaxes)\n",
137
+ " return matrix_i8, (absmax, code)\n"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 4,
143
+ "id": "5c03f13b",
144
+ "metadata": {
145
+ "cellId": "zwcfu5ypstmsusllfldemc",
146
+ "id": "StJJ6oickpZs"
147
+ },
148
+ "outputs": [],
149
+ "source": [
150
+ "from typing import Tuple\n",
151
+ "from torch.cuda.amp import custom_fwd, custom_bwd\n",
152
+ "\n",
153
+ "\n",
154
+ "class DequantizeAndLinear(torch.autograd.Function):\n",
155
+ " @staticmethod\n",
156
+ " @custom_fwd\n",
157
+ " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n",
158
+ " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n",
159
+ "\n",
160
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
161
+ " ctx.save_for_backward(input, weights_quantized, absmax, code)\n",
162
+ " ctx._has_bias = bias is not None\n",
163
+ " return F.linear(input, weights_deq, bias)\n",
164
+ "\n",
165
+ " @staticmethod\n",
166
+ " @custom_bwd\n",
167
+ " def backward(ctx, grad_output: torch.Tensor):\n",
168
+ " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n",
169
+ " input, weights_quantized, absmax, code = ctx.saved_tensors\n",
170
+ " # grad_output: [*batch, out_features]\n",
171
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
172
+ " grad_input = grad_output @ weights_deq\n",
173
+ " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n",
174
+ " return grad_input, None, None, None, grad_bias\n",
175
+ "\n",
176
+ "\n",
177
+ "class BNBLinearWithAdapter(nn.Module):\n",
178
+ " def __init__(self, weight, absmax, code, bias=None, adapter_dim=0):\n",
179
+ " assert isinstance(bias, nn.Parameter) or bias is None\n",
180
+ " super().__init__()\n",
181
+ " self.out_features, self.in_features = weight.shape\n",
182
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
183
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
184
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
185
+ " self.bias = bias\n",
186
+ "\n",
187
+ " if adapter_dim > 0:\n",
188
+ " self.adapter = nn.Sequential(\n",
189
+ " nn.Linear(self.in_features, adapter_dim, bias=False),\n",
190
+ " nn.Linear(adapter_dim, self.out_features, bias=False),\n",
191
+ " )\n",
192
+ "\n",
193
+ " nn.init.zeros_(self.adapter[1].weight)\n",
194
+ " else:\n",
195
+ " self.adapter = None\n",
196
+ "\n",
197
+ " def forward(self, input):\n",
198
+ " out = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n",
199
+ "\n",
200
+ " if self.adapter:\n",
201
+ " return self.adapter(input) + out\n",
202
+ "\n",
203
+ " return out\n",
204
+ "\n",
205
+ "\n",
206
+ " @classmethod\n",
207
+ " def from_linear(cls, linear: nn.Linear, **kwargs) -> \"FrozenBNBLinear\":\n",
208
+ " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n",
209
+ " return cls(weights_int8, *state, linear.bias, **kwargs)\n",
210
+ "\n",
211
+ " def __repr__(self):\n",
212
+ " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n",
213
+ "\n",
214
+ "\n",
215
+ "class BNBEmbeddingWithAdapter(nn.Module):\n",
216
+ " def __init__(self, weight, absmax, code, adapter_dim=0):\n",
217
+ " super().__init__()\n",
218
+ " self.num_embeddings, self.embedding_dim = weight.shape\n",
219
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
220
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
221
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
222
+ "\n",
223
+ " if adapter_dim > 0:\n",
224
+ " self.adapter = nn.Sequential(\n",
225
+ " nn.Embedding(self.num_embeddings, adapter_dim),\n",
226
+ " nn.Linear(adapter_dim, self.embedding_dim, bias=False),\n",
227
+ " )\n",
228
+ "\n",
229
+ " nn.init.zeros_(self.adapter[1].weight)\n",
230
+ " else:\n",
231
+ " self.adapter = None\n",
232
+ "\n",
233
+ " def forward(self, input, **kwargs):\n",
234
+ " with torch.no_grad():\n",
235
+ " # note: both quantuized weights and input indices are *not* differentiable\n",
236
+ " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n",
237
+ " out = F.embedding(input, weight_deq, **kwargs)\n",
238
+ " if self.adapter:\n",
239
+ " return out + self.adapter(input, **kwargs)\n",
240
+ "\n",
241
+ " return out\n",
242
+ "\n",
243
+ " @classmethod\n",
244
+ " def from_embedding(cls, embedding: nn.Embedding, **kwargs) -> \"FrozenBNBEmbedding\":\n",
245
+ " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n",
246
+ " return cls(weights_int8, *state, **kwargs)\n",
247
+ "\n",
248
+ " def __repr__(self):\n",
249
+ " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\""
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 5,
255
+ "id": "92a58957",
256
+ "metadata": {
257
+ "cellId": "due8kcyko4fv3vxzrbin3",
258
+ "id": "6LafYNhlktnt"
259
+ },
260
+ "outputs": [],
261
+ "source": [
262
+ "def bnbfy_(model, adapter_dim: int = 0): \n",
263
+ " for module in list(model.transformer.h.modules()):\n",
264
+ " for name, child in module.named_children():\n",
265
+ " if isinstance(child, nn.Linear):\n",
266
+ " print(name, child)\n",
267
+ " setattr(module, name, BNBLinearWithAdapter.from_linear(child, adapter_dim=adapter_dim))\n",
268
+ "\n",
269
+ " elif isinstance(child, nn.Embedding):\n",
270
+ " print(name, child)\n",
271
+ " setattr(module, name, BNBEmbeddingWithAdapter.from_embedding(child, adapter_dim=adapter_dim))"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 6,
277
+ "id": "f2d513c6-cd72-411d-9a25-9a21e5c2b87c",
278
+ "metadata": {},
279
+ "outputs": [
280
+ {
281
+ "name": "stdout",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "model size: 26966.156MB\n"
285
+ ]
286
+ }
287
+ ],
288
+ "source": [
289
+ "#!g1.1\n",
290
+ "param_size = 0\n",
291
+ "for param in gpt.parameters():\n",
292
+ " param_size += param.nelement() * param.element_size()\n",
293
+ "buffer_size = 0\n",
294
+ "for buffer in gpt.buffers():\n",
295
+ " buffer_size += buffer.nelement() * buffer.element_size()\n",
296
+ "\n",
297
+ "size_all_mb = (param_size + buffer_size) / 1024**2\n",
298
+ "print('model size: {:.3f}MB'.format(size_all_mb))"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 7,
304
+ "id": "ab52cd5c",
305
+ "metadata": {
306
+ "cellId": "5269rte0cil8omgnvcif"
307
+ },
308
+ "outputs": [
309
+ {
310
+ "data": {
311
+ "text/plain": [
312
+ "BloomForCausalLM(\n",
313
+ " (transformer): BloomModel(\n",
314
+ " (word_embeddings): Embedding(250880, 4096)\n",
315
+ " (word_embeddings_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
316
+ " (h): ModuleList(\n",
317
+ " (0): BloomBlock(\n",
318
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
319
+ " (self_attention): BloomAttention(\n",
320
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
321
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
322
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
323
+ " )\n",
324
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
325
+ " (mlp): BloomMLP(\n",
326
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
327
+ " (gelu_impl): BloomGelu()\n",
328
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
329
+ " )\n",
330
+ " )\n",
331
+ " (1): BloomBlock(\n",
332
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
333
+ " (self_attention): BloomAttention(\n",
334
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
335
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
336
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
337
+ " )\n",
338
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
339
+ " (mlp): BloomMLP(\n",
340
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
341
+ " (gelu_impl): BloomGelu()\n",
342
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
343
+ " )\n",
344
+ " )\n",
345
+ " (2): BloomBlock(\n",
346
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
347
+ " (self_attention): BloomAttention(\n",
348
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
349
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
350
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
351
+ " )\n",
352
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
353
+ " (mlp): BloomMLP(\n",
354
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
355
+ " (gelu_impl): BloomGelu()\n",
356
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
357
+ " )\n",
358
+ " )\n",
359
+ " (3): BloomBlock(\n",
360
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
361
+ " (self_attention): BloomAttention(\n",
362
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
363
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
364
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
365
+ " )\n",
366
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
367
+ " (mlp): BloomMLP(\n",
368
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
369
+ " (gelu_impl): BloomGelu()\n",
370
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
371
+ " )\n",
372
+ " )\n",
373
+ " (4): BloomBlock(\n",
374
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
375
+ " (self_attention): BloomAttention(\n",
376
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
377
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
378
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
379
+ " )\n",
380
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
381
+ " (mlp): BloomMLP(\n",
382
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
383
+ " (gelu_impl): BloomGelu()\n",
384
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
385
+ " )\n",
386
+ " )\n",
387
+ " (5): BloomBlock(\n",
388
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
389
+ " (self_attention): BloomAttention(\n",
390
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
391
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
392
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
393
+ " )\n",
394
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
395
+ " (mlp): BloomMLP(\n",
396
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
397
+ " (gelu_impl): BloomGelu()\n",
398
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
399
+ " )\n",
400
+ " )\n",
401
+ " (6): BloomBlock(\n",
402
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
403
+ " (self_attention): BloomAttention(\n",
404
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
405
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
406
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
407
+ " )\n",
408
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
409
+ " (mlp): BloomMLP(\n",
410
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
411
+ " (gelu_impl): BloomGelu()\n",
412
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
413
+ " )\n",
414
+ " )\n",
415
+ " (7): BloomBlock(\n",
416
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
417
+ " (self_attention): BloomAttention(\n",
418
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
419
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
420
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
421
+ " )\n",
422
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
423
+ " (mlp): BloomMLP(\n",
424
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
425
+ " (gelu_impl): BloomGelu()\n",
426
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
427
+ " )\n",
428
+ " )\n",
429
+ " (8): BloomBlock(\n",
430
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
431
+ " (self_attention): BloomAttention(\n",
432
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
433
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
434
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
435
+ " )\n",
436
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
437
+ " (mlp): BloomMLP(\n",
438
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
439
+ " (gelu_impl): BloomGelu()\n",
440
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
441
+ " )\n",
442
+ " )\n",
443
+ " (9): BloomBlock(\n",
444
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
445
+ " (self_attention): BloomAttention(\n",
446
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
447
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
448
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
449
+ " )\n",
450
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
451
+ " (mlp): BloomMLP(\n",
452
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
453
+ " (gelu_impl): BloomGelu()\n",
454
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
455
+ " )\n",
456
+ " )\n",
457
+ " (10): BloomBlock(\n",
458
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
459
+ " (self_attention): BloomAttention(\n",
460
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
461
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
462
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
463
+ " )\n",
464
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
465
+ " (mlp): BloomMLP(\n",
466
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
467
+ " (gelu_impl): BloomGelu()\n",
468
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
469
+ " )\n",
470
+ " )\n",
471
+ " (11): BloomBlock(\n",
472
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
473
+ " (self_attention): BloomAttention(\n",
474
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
475
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
476
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
477
+ " )\n",
478
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
479
+ " (mlp): BloomMLP(\n",
480
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
481
+ " (gelu_impl): BloomGelu()\n",
482
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
483
+ " )\n",
484
+ " )\n",
485
+ " (12): BloomBlock(\n",
486
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
487
+ " (self_attention): BloomAttention(\n",
488
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
489
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
490
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
491
+ " )\n",
492
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
493
+ " (mlp): BloomMLP(\n",
494
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
495
+ " (gelu_impl): BloomGelu()\n",
496
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
497
+ " )\n",
498
+ " )\n",
499
+ " (13): BloomBlock(\n",
500
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
501
+ " (self_attention): BloomAttention(\n",
502
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
503
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
504
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
505
+ " )\n",
506
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
507
+ " (mlp): BloomMLP(\n",
508
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
509
+ " (gelu_impl): BloomGelu()\n",
510
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
511
+ " )\n",
512
+ " )\n",
513
+ " (14): BloomBlock(\n",
514
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
515
+ " (self_attention): BloomAttention(\n",
516
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
517
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
518
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
519
+ " )\n",
520
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
521
+ " (mlp): BloomMLP(\n",
522
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
523
+ " (gelu_impl): BloomGelu()\n",
524
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
525
+ " )\n",
526
+ " )\n",
527
+ " (15): BloomBlock(\n",
528
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
529
+ " (self_attention): BloomAttention(\n",
530
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
531
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
532
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
533
+ " )\n",
534
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
535
+ " (mlp): BloomMLP(\n",
536
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
537
+ " (gelu_impl): BloomGelu()\n",
538
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
539
+ " )\n",
540
+ " )\n",
541
+ " (16): BloomBlock(\n",
542
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
543
+ " (self_attention): BloomAttention(\n",
544
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
545
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
546
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
547
+ " )\n",
548
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
549
+ " (mlp): BloomMLP(\n",
550
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
551
+ " (gelu_impl): BloomGelu()\n",
552
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
553
+ " )\n",
554
+ " )\n",
555
+ " (17): BloomBlock(\n",
556
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
557
+ " (self_attention): BloomAttention(\n",
558
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
559
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
560
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
561
+ " )\n",
562
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
563
+ " (mlp): BloomMLP(\n",
564
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
565
+ " (gelu_impl): BloomGelu()\n",
566
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
567
+ " )\n",
568
+ " )\n",
569
+ " (18): BloomBlock(\n",
570
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
571
+ " (self_attention): BloomAttention(\n",
572
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
573
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
574
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
575
+ " )\n",
576
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
577
+ " (mlp): BloomMLP(\n",
578
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
579
+ " (gelu_impl): BloomGelu()\n",
580
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
581
+ " )\n",
582
+ " )\n",
583
+ " (19): BloomBlock(\n",
584
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
585
+ " (self_attention): BloomAttention(\n",
586
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
587
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
588
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
589
+ " )\n",
590
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
591
+ " (mlp): BloomMLP(\n",
592
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
593
+ " (gelu_impl): BloomGelu()\n",
594
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
595
+ " )\n",
596
+ " )\n",
597
+ " (20): BloomBlock(\n",
598
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
599
+ " (self_attention): BloomAttention(\n",
600
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
601
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
602
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
603
+ " )\n",
604
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
605
+ " (mlp): BloomMLP(\n",
606
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
607
+ " (gelu_impl): BloomGelu()\n",
608
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
609
+ " )\n",
610
+ " )\n",
611
+ " (21): BloomBlock(\n",
612
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
613
+ " (self_attention): BloomAttention(\n",
614
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
615
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
616
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
617
+ " )\n",
618
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
619
+ " (mlp): BloomMLP(\n",
620
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
621
+ " (gelu_impl): BloomGelu()\n",
622
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
623
+ " )\n",
624
+ " )\n",
625
+ " (22): BloomBlock(\n",
626
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
627
+ " (self_attention): BloomAttention(\n",
628
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
629
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
630
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
631
+ " )\n",
632
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
633
+ " (mlp): BloomMLP(\n",
634
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
635
+ " (gelu_impl): BloomGelu()\n",
636
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
637
+ " )\n",
638
+ " )\n",
639
+ " (23): BloomBlock(\n",
640
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
641
+ " (self_attention): BloomAttention(\n",
642
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
643
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
644
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
645
+ " )\n",
646
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
647
+ " (mlp): BloomMLP(\n",
648
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
649
+ " (gelu_impl): BloomGelu()\n",
650
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
651
+ " )\n",
652
+ " )\n",
653
+ " (24): BloomBlock(\n",
654
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
655
+ " (self_attention): BloomAttention(\n",
656
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
657
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
658
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
659
+ " )\n",
660
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
661
+ " (mlp): BloomMLP(\n",
662
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
663
+ " (gelu_impl): BloomGelu()\n",
664
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
665
+ " )\n",
666
+ " )\n",
667
+ " (25): BloomBlock(\n",
668
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
669
+ " (self_attention): BloomAttention(\n",
670
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
671
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
672
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
673
+ " )\n",
674
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
675
+ " (mlp): BloomMLP(\n",
676
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
677
+ " (gelu_impl): BloomGelu()\n",
678
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
679
+ " )\n",
680
+ " )\n",
681
+ " (26): BloomBlock(\n",
682
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
683
+ " (self_attention): BloomAttention(\n",
684
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
685
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
686
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
687
+ " )\n",
688
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
689
+ " (mlp): BloomMLP(\n",
690
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
691
+ " (gelu_impl): BloomGelu()\n",
692
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
693
+ " )\n",
694
+ " )\n",
695
+ " (27): BloomBlock(\n",
696
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
697
+ " (self_attention): BloomAttention(\n",
698
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
699
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
700
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
701
+ " )\n",
702
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
703
+ " (mlp): BloomMLP(\n",
704
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
705
+ " (gelu_impl): BloomGelu()\n",
706
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
707
+ " )\n",
708
+ " )\n",
709
+ " (28): BloomBlock(\n",
710
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
711
+ " (self_attention): BloomAttention(\n",
712
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
713
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
714
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
715
+ " )\n",
716
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
717
+ " (mlp): BloomMLP(\n",
718
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
719
+ " (gelu_impl): BloomGelu()\n",
720
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
721
+ " )\n",
722
+ " )\n",
723
+ " (29): BloomBlock(\n",
724
+ " (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
725
+ " (self_attention): BloomAttention(\n",
726
+ " (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
727
+ " (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
728
+ " (attention_dropout): Dropout(p=0.0, inplace=False)\n",
729
+ " )\n",
730
+ " (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
731
+ " (mlp): BloomMLP(\n",
732
+ " (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
733
+ " (gelu_impl): BloomGelu()\n",
734
+ " (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
735
+ " )\n",
736
+ " )\n",
737
+ " )\n",
738
+ " (ln_f): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
739
+ " )\n",
740
+ " (lm_head): Linear(in_features=4096, out_features=250880, bias=False)\n",
741
+ ")"
742
+ ]
743
+ },
744
+ "execution_count": 7,
745
+ "metadata": {},
746
+ "output_type": "execute_result"
747
+ }
748
+ ],
749
+ "source": [
750
+ "#!g1.1\n",
751
+ "gpt"
752
+ ]
753
+ },
754
+ {
755
+ "cell_type": "code",
756
+ "execution_count": 8,
757
+ "id": "9280b510",
758
+ "metadata": {
759
+ "cellId": "a7nstbdt9vo9qikpzvo48c",
760
+ "id": "jV3pGEGalDwz"
761
+ },
762
+ "outputs": [
763
+ {
764
+ "name": "stdout",
765
+ "output_type": "stream",
766
+ "text": [
767
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
768
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
769
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
770
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
771
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
772
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
773
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
774
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
775
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
776
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
777
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
778
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
779
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
780
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
781
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
782
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
783
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
784
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
785
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
786
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
787
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
788
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
789
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
790
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
791
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
792
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
793
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
794
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
795
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
796
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
797
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
798
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
799
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
800
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
801
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
802
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
803
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
804
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
805
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
806
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
807
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
808
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
809
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
810
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
811
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
812
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
813
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
814
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
815
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
816
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
817
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
818
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
819
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
820
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
821
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
822
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
823
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
824
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
825
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
826
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
827
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
828
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
829
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
830
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
831
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
832
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
833
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
834
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
835
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
836
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
837
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
838
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
839
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
840
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
841
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
842
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
843
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
844
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
845
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
846
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
847
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
848
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
849
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
850
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
851
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
852
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
853
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
854
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
855
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
856
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
857
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
858
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
859
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
860
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
861
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
862
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
863
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
864
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
865
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
866
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
867
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
868
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
869
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
870
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
871
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
872
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
873
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
874
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
875
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
876
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
877
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
878
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
879
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
880
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
881
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
882
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
883
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
884
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
885
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
886
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n"
887
+ ]
888
+ }
889
+ ],
890
+ "source": [
891
+ "#!g1.1\n",
892
+ "bnbfy_(gpt, adapter_dim=0)"
893
+ ]
894
+ },
895
+ {
896
+ "cell_type": "code",
897
+ "execution_count": null,
898
+ "id": "e35305f2",
899
+ "metadata": {
900
+ "cellId": "q5jafg9w9x0hg355icd4vo"
901
+ },
902
+ "outputs": [],
903
+ "source": [
904
+ "#!g1.1\n",
905
+ "param_size = 0\n",
906
+ "for param in gpt.parameters():\n",
907
+ " param_size += param.nelement() * param.element_size()\n",
908
+ "buffer_size = 0\n",
909
+ "for buffer in gpt.buffers():\n",
910
+ " buffer_size += buffer.nelement() * buffer.element_size()\n",
911
+ "\n",
912
+ "size_all_mb = (param_size + buffer_size) / 1024**2\n",
913
+ "print('model size: {:.3f}MB'.format(size_all_mb))\n",
914
+ "gpt.save_pretrained('bloom-7b1-8bit')"
915
+ ]
916
+ }
917
+ ],
918
+ "metadata": {
919
+ "accelerator": "GPU",
920
+ "colab": {
921
+ "collapsed_sections": [],
922
+ "provenance": []
923
+ },
924
+ "kernelspec": {
925
+ "display_name": "Python 3 (ipykernel)",
926
+ "language": "python",
927
+ "name": "python3"
928
+ },
929
+ "language_info": {
930
+ "codemirror_mode": {
931
+ "name": "ipython",
932
+ "version": 3
933
+ },
934
+ "file_extension": ".py",
935
+ "mimetype": "text/x-python",
936
+ "name": "python",
937
+ "nbconvert_exporter": "python",
938
+ "pygments_lexer": "ipython3",
939
+ "version": "3.8.10"
940
+ },
941
+ "notebookId": "8f3ce20e-06a1-44f2-9373-2b6424b859a3",
942
+ "notebookPath": "bloom8bit.ipynb"
943
+ },
944
+ "nbformat": 4,
945
+ "nbformat_minor": 5
946
+ }
load_bloom.ipynb ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e28407ed",
6
+ "metadata": {
7
+ "cellId": "b3rvfwilbqlfhl99ot1adn"
8
+ },
9
+ "source": [
10
+ "# Example of Fine-tuning 7.1 billion Bloom with 8-bit weights\n",
11
+ "\n",
12
+ "This notebook shows an example of how to fine tune Bloom with Low Rank Adapters. Heavily inspired by [Hivemind's work](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "id": "5f5a6af2",
18
+ "metadata": {
19
+ "cellId": "q43y9u4kj5g2qn01pdohou"
20
+ },
21
+ "source": [
22
+ "### Load and convert original Bloom structure to 8-bit LoRA\n",
23
+ "\n",
24
+ "You can load an already compressed 8-bit version of Bloom from [joaoalvarenga/bloom-8bit](https://huggingface.co/joaoalvarenga/bloom-8bit), but first we need to make some adaptations into original model structure. Some of the following code is an adaptation from [Hivemind's GPT-J 8-bit fine-tuning notebook](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)."
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 1,
30
+ "id": "815f9f31",
31
+ "metadata": {
32
+ "cellId": "qwv8mzg52blrc6ghm3x9s"
33
+ },
34
+ "outputs": [
35
+ {
36
+ "name": "stderr",
37
+ "output_type": "stream",
38
+ "text": [
39
+ "/home/dm/.local/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
40
+ " from .autonotebook import tqdm as notebook_tqdm\n",
41
+ "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!\n",
42
+ " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n"
43
+ ]
44
+ },
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "\n",
50
+ "===================================BUG REPORT===================================\n",
51
+ "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
52
+ "For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
53
+ "================================================================================\n",
54
+ "CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...\n",
55
+ "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!\n",
56
+ "CUDA SETUP: Loading binary /home/dm/.local/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so...\n"
57
+ ]
58
+ },
59
+ {
60
+ "name": "stderr",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "/home/dm/.local/lib/python3.8/site-packages/bitsandbytes/cuda_setup/paths.py:27: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('8bitexperiments/71572fb0-3a48-4729-a863-ee5aa6e60c92')}\n",
64
+ " warn(\n",
65
+ "/home/dm/.local/lib/python3.8/site-packages/bitsandbytes/cuda_setup/paths.py:27: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('//matplotlib_inline.backend_inline'), PosixPath('module')}\n",
66
+ " warn(\n",
67
+ "/home/dm/.local/lib/python3.8/site-packages/bitsandbytes/cuda_setup/paths.py:27: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/local/cuda/lib64')}\n",
68
+ " warn(\n",
69
+ "/home/dm/.local/lib/python3.8/site-packages/bitsandbytes/cextension.py:48: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.\n",
70
+ " warn(\n"
71
+ ]
72
+ }
73
+ ],
74
+ "source": [
75
+ "import transformers\n",
76
+ "\n",
77
+ "import torch\n",
78
+ "import torch.nn.functional as F\n",
79
+ "from torch import nn\n",
80
+ "from torch.cuda.amp import custom_fwd, custom_bwd\n",
81
+ "\n",
82
+ "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
83
+ "\n",
84
+ "from tqdm.auto import tqdm"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 2,
90
+ "id": "640b12f7",
91
+ "metadata": {
92
+ "cellId": "b5qu2yccmjrp0n86a0k2"
93
+ },
94
+ "outputs": [],
95
+ "source": [
96
+ "class FrozenBNBLinear(nn.Module):\n",
97
+ " def __init__(self, weight, absmax, code, bias=None):\n",
98
+ " assert isinstance(bias, nn.Parameter) or bias is None\n",
99
+ " super().__init__()\n",
100
+ " self.out_features, self.in_features = weight.shape\n",
101
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
102
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
103
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
104
+ " self.adapter = None\n",
105
+ " self.bias = bias\n",
106
+ " \n",
107
+ " def forward(self, input):\n",
108
+ " output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n",
109
+ " if self.adapter:\n",
110
+ " output += self.adapter(input)\n",
111
+ " return output\n",
112
+ " \n",
113
+ " @classmethod\n",
114
+ " def from_linear(cls, linear: nn.Linear) -> \"FrozenBNBLinear\":\n",
115
+ " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n",
116
+ " return cls(weights_int8, *state, linear.bias)\n",
117
+ " \n",
118
+ " def __repr__(self):\n",
119
+ " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n",
120
+ " \n",
121
+ " \n",
122
+ "class DequantizeAndLinear(torch.autograd.Function): \n",
123
+ " @staticmethod\n",
124
+ " @custom_fwd\n",
125
+ " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n",
126
+ " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n",
127
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
128
+ " ctx.save_for_backward(input, weights_quantized, absmax, code)\n",
129
+ " ctx._has_bias = bias is not None\n",
130
+ " return F.linear(input, weights_deq, bias)\n",
131
+ " \n",
132
+ " @staticmethod\n",
133
+ " @custom_bwd\n",
134
+ " def backward(ctx, grad_output: torch.Tensor):\n",
135
+ " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n",
136
+ " input, weights_quantized, absmax, code = ctx.saved_tensors\n",
137
+ " # grad_output: [*batch, out_features]\n",
138
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
139
+ " grad_input = grad_output @ weights_deq\n",
140
+ " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n",
141
+ " return grad_input, None, None, None, grad_bias\n",
142
+ " \n",
143
+ " \n",
144
+ "class FrozenBNBEmbedding(nn.Module):\n",
145
+ " def __init__(self, weight, absmax, code):\n",
146
+ " super().__init__()\n",
147
+ " self.num_embeddings, self.embedding_dim = weight.shape\n",
148
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
149
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
150
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
151
+ " self.adapter = None\n",
152
+ " \n",
153
+ " def forward(self, input, **kwargs):\n",
154
+ " with torch.no_grad():\n",
155
+ " # note: both quantuized weights and input indices are *not* differentiable\n",
156
+ " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n",
157
+ " output = F.embedding(input, weight_deq, **kwargs)\n",
158
+ " if self.adapter:\n",
159
+ " output += self.adapter(input)\n",
160
+ " return output \n",
161
+ " \n",
162
+ " @classmethod\n",
163
+ " def from_embedding(cls, embedding: nn.Embedding) -> \"FrozenBNBEmbedding\":\n",
164
+ " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n",
165
+ " return cls(weights_int8, *state)\n",
166
+ " \n",
167
+ " def __repr__(self):\n",
168
+ " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"\n",
169
+ " \n",
170
+ " \n",
171
+ "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n",
172
+ " assert chunk_size % 4096 == 0\n",
173
+ " code = None\n",
174
+ " chunks = []\n",
175
+ " absmaxes = []\n",
176
+ " flat_tensor = matrix.view(-1)\n",
177
+ " for i in range((matrix.numel() - 1) // chunk_size + 1):\n",
178
+ " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n",
179
+ " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n",
180
+ " chunks.append(quantized_chunk)\n",
181
+ " absmaxes.append(absmax_chunk)\n",
182
+ " \n",
183
+ " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n",
184
+ " absmax = torch.cat(absmaxes)\n",
185
+ " return matrix_i8, (absmax, code)\n",
186
+ "\n",
187
+ "\n",
188
+ "def convert_to_int8(model):\n",
189
+ " \"\"\"Convert linear and embedding modules to 8-bit with optional adapters\"\"\"\n",
190
+ " for module in list(model.modules()):\n",
191
+ " for name, child in module.named_children():\n",
192
+ " if isinstance(child, nn.Linear):\n",
193
+ " print(name, child)\n",
194
+ " setattr( \n",
195
+ " module,\n",
196
+ " name,\n",
197
+ " FrozenBNBLinear(\n",
198
+ " weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),\n",
199
+ " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
200
+ " code=torch.zeros(256),\n",
201
+ " bias=child.bias,\n",
202
+ " ),\n",
203
+ " )\n",
204
+ " elif isinstance(child, nn.Embedding):\n",
205
+ " setattr(\n",
206
+ " module,\n",
207
+ " name,\n",
208
+ " FrozenBNBEmbedding(\n",
209
+ " weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),\n",
210
+ " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
211
+ " code=torch.zeros(256),\n",
212
+ " )\n",
213
+ " )"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": 3,
219
+ "id": "e0dbb262",
220
+ "metadata": {
221
+ "cellId": "j9ds51fcwxxy0blcplb6"
222
+ },
223
+ "outputs": [],
224
+ "source": [
225
+ "class BloomBlock(transformers.models.bloom.modeling_bloom.BloomBlock):\n",
226
+ " def __init__(self, config):\n",
227
+ " super().__init__(config)\n",
228
+ " convert_to_int8(self.self_attention)\n",
229
+ " convert_to_int8(self.mlp)\n",
230
+ "\n",
231
+ "\n",
232
+ "class BloomModel(transformers.models.bloom.modeling_bloom.BloomModel):\n",
233
+ " def __init__(self, config):\n",
234
+ " super().__init__(config)\n",
235
+ " convert_to_int8(self)\n",
236
+ " \n",
237
+ "\n",
238
+ "class BloomForCausalLM(transformers.models.bloom.modeling_bloom.BloomForCausalLM):\n",
239
+ " def __init__(self, config):\n",
240
+ " super().__init__(config)\n",
241
+ " convert_to_int8(self)\n",
242
+ " \n",
243
+ "transformers.models.bloom.modeling_bloom.BloomBlock = BloomBlock"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "id": "a68bbee4",
250
+ "metadata": {
251
+ "cellId": "a5he2q7ulm4wkwqno10wsg"
252
+ },
253
+ "outputs": [
254
+ {
255
+ "name": "stdout",
256
+ "output_type": "stream",
257
+ "text": [
258
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
259
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
260
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
261
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
262
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
263
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
264
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
265
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
266
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
267
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
268
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
269
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
270
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
271
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
272
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
273
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
274
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
275
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
276
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
277
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
278
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
279
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
280
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
281
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
282
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
283
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
284
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
285
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
286
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
287
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
288
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
289
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
290
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
291
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
292
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
293
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
294
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
295
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
296
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
297
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
298
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
299
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
300
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
301
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
302
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
303
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
304
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
305
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
306
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
307
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
308
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
309
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
310
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
311
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
312
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
313
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
314
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
315
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
316
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
317
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
318
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
319
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
320
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
321
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
322
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
323
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
324
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
325
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
326
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
327
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
328
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
329
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
330
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
331
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
332
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
333
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
334
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
335
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
336
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
337
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
338
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
339
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
340
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
341
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
342
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
343
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
344
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
345
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
346
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
347
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
348
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
349
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
350
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
351
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
352
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
353
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
354
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
355
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
356
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
357
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
358
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
359
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
360
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
361
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
362
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
363
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
364
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
365
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
366
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
367
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
368
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
369
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
370
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
371
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
372
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
373
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n",
374
+ "query_key_value Linear(in_features=4096, out_features=12288, bias=True)\n",
375
+ "dense Linear(in_features=4096, out_features=4096, bias=True)\n",
376
+ "dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)\n",
377
+ "dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)\n"
378
+ ]
379
+ }
380
+ ],
381
+ "source": [
382
+ "#!g1.1\n",
383
+ "from transformers import BloomForCausalLM, AutoModel\n",
384
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"bigscience/bloom-7b1\", cache_dir=\"mycache\")\n",
385
+ "model = BloomForCausalLM.from_pretrained('bloom-8bit-v4.pt')\n",
386
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
387
+ "model.to(device)\n",
388
+ "pass"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": null,
394
+ "id": "11bdaf76",
395
+ "metadata": {},
396
+ "outputs": [],
397
+ "source": [
398
+ "model"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "id": "5c9feef8",
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": [
408
+ "prefix = \"\"\"\"It is a fantasy role-play game.\n",
409
+ "\n",
410
+ "Game Master: You are John, a wizard living in the kingdom of Larion. You have a staff and a spellbook. You finish your long journey and finally arrive at the ruin you've been looking for. You have come here searching for a mystical spellbook of great power called the book of essence. You look around and see the ancient ruins of an elf tower. The ruins have not been touched for decades. You look at the tower, and you can see a set of stone stairs that seem to lead somewhere deep inside the tower.\n",
411
+ "Player: I walk upstairs\n",
412
+ "Game Master: You climb up the stairs in the ruined tower. There is a door on the second floor of the tower, the door seems to be made of enchanted wood.\n",
413
+ "Player: I ask the door if I may to come in\n",
414
+ "Game Master: The door sighs open and you walk into the room.\n",
415
+ "Player: I take a look around\n",
416
+ "Game Master:\"\"\"\n",
417
+ "\n",
418
+ "print(end=prefix)\n",
419
+ "past_key_values = None # used to keep track of conversation history\n",
420
+ "input_dict = tokenizer([prefix], return_tensors='pt', padding=False)\n",
421
+ "\n",
422
+ "output=\"\"\n",
423
+ "\n",
424
+ "with torch.inference_mode():\n",
425
+ " for i in range(200):\n",
426
+ " outputs = model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)\n",
427
+ " last_logits = outputs.logits[0, -1]\n",
428
+ " \n",
429
+ " last_logits[last_logits.topk(k=10).indices] += 10 # other logits are now e^10 times less likely to be chosen\n",
430
+ "\n",
431
+ " past_key_values = outputs.past_key_values\n",
432
+ " token_ix = torch.multinomial(last_logits.softmax(-1), 1).item()\n",
433
+ " prefix = tokenizer.decode([token_ix])\n",
434
+ " output = output + tokenizer.decode([token_ix])\n",
435
+ " if 'player' in output or 'Player' in output:\n",
436
+ " break\n",
437
+ " if 'Master' in output:\n",
438
+ " break\n",
439
+ " print(end=tokenizer.decode([token_ix]), flush=True)\n",
440
+ "\n",
441
+ " input_dict = dict(input_ids=torch.tensor([[token_ix]]))\n",
442
+ "print()"
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "execution_count": null,
448
+ "id": "7eba8c3f",
449
+ "metadata": {},
450
+ "outputs": [],
451
+ "source": [
452
+ "prefix = \"\"\"\"It is a fantasy role-play game.\n",
453
+ "\n",
454
+ "Game Master: You are John, a wizard living in the kingdom of Larion. You have a staff and a spellbook. You finish your long journey and finally arrive at the ruin you've been looking for. You have come here searching for a mystical spellbook of great power called the book of essence. You look around and see the ancient ruins of an elf tower. The ruins have not been touched for decades. You look at the tower, and you can see a set of stone stairs that seem to lead somewhere deep inside the tower.\n",
455
+ "Player: I walk upstairs\n",
456
+ "Game Master: You climb up the stairs in the ruined tower. There is a door on the second floor of the tower, the door seems to be made of enchanted wood.\n",
457
+ "Player: I ask the door if I may to come in\n",
458
+ "Game Master: The door sighs open and you walk into the room.\n",
459
+ "Player: I take a look around\n",
460
+ "Game Master:\"\"\"\n",
461
+ "\n",
462
+ "print(end=prefix)\n",
463
+ "past_key_values = None # used to keep track of conversation history\n",
464
+ "input_dict = tokenizer([prefix], return_tensors='pt', padding=False)\n",
465
+ "\n",
466
+ "output = \"\"\n",
467
+ "\n",
468
+ "with torch.inference_mode():\n",
469
+ " for i in range(200):\n",
470
+ " outputs = model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)\n",
471
+ " last_logits = outputs.logits[0, -1]\n",
472
+ " \n",
473
+ " last_logits[last_logits.topk(k=10).indices] += 10 # other logits are now e^10 times less likely to be chosen\n",
474
+ "\n",
475
+ " past_key_values = outputs.past_key_values\n",
476
+ " token_ix = torch.multinomial(last_logits.softmax(-1), 1).item()\n",
477
+ " prefix = tokenizer.decode([token_ix])\n",
478
+ " output = output + tokenizer.decode([token_ix])\n",
479
+ " if 'player' in output or 'Player' in output:\n",
480
+ " break\n",
481
+ " if 'Master' in output:\n",
482
+ " break\n",
483
+ " print(end=tokenizer.decode([token_ix]), flush=True)\n",
484
+ "\n",
485
+ " input_dict = dict(input_ids=torch.tensor([[token_ix]]))\n",
486
+ "print()"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "id": "4db07557",
493
+ "metadata": {},
494
+ "outputs": [],
495
+ "source": [
496
+ "#!g1.1\n",
497
+ "prompt = tokenizer(\"A cat sat on a mat and\", return_tensors='pt')\n",
498
+ "out = model.generate(**prompt, min_length=10, max_length=10, do_sample=True)\n",
499
+ "tokenizer.decode(out[0])"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "markdown",
504
+ "id": "5398feef",
505
+ "metadata": {
506
+ "cellId": "uero3zs1ebpefelhzioy2t",
507
+ "execution_id": "243e4f22-9ad3-412c-98cb-6b01253531c9"
508
+ },
509
+ "source": [
510
+ "### Fine-tune and save model"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "id": "a2fdb897",
517
+ "metadata": {
518
+ "cellId": "wmrjusxrcomgqirlydiuj"
519
+ },
520
+ "outputs": [],
521
+ "source": [
522
+ "#!g1.1\n",
523
+ "def add_adapters(model, adapter_dim=16):\n",
524
+ " assert adapter_dim > 0\n",
525
+ "\n",
526
+ " for module in model.modules():\n",
527
+ " if isinstance(module, FrozenBNBLinear):\n",
528
+ " module.adapter = nn.Sequential(\n",
529
+ " nn.Linear(module.in_features, adapter_dim, bias=False),\n",
530
+ " nn.Linear(adapter_dim, module.out_features, bias=False),\n",
531
+ " )\n",
532
+ " nn.init.zeros_(module.adapter[1].weight)\n",
533
+ " elif isinstance(module, FrozenBNBEmbedding):\n",
534
+ " module.adapter = nn.Sequential(\n",
535
+ " nn.Embedding(module.num_embeddings, adapter_dim),\n",
536
+ " nn.Linear(adapter_dim, module.embedding_dim, bias=False),\n",
537
+ " )\n",
538
+ " nn.init.zeros_(module.adapter[1].weight)\n",
539
+ "\n",
540
+ "add_adapters(model)\n",
541
+ "model.to(device)"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": null,
547
+ "id": "ceac0236",
548
+ "metadata": {
549
+ "cellId": "mzznv2rl07bt6x7aybc1h"
550
+ },
551
+ "outputs": [],
552
+ "source": [
553
+ "#!g1.1\n",
554
+ "from datasets import load_dataset\n",
555
+ "from bitsandbytes.optim import Adam8bit\n",
556
+ "\n",
557
+ "model.gradient_checkpointing_enable()\n",
558
+ "\n",
559
+ "wikisql = load_dataset(\"wikisql\", streaming=True)\n",
560
+ "optimizer = Adam8bit(model.parameters(), lr=1e-5)\n",
561
+ "\n",
562
+ "with torch.cuda.amp.autocast():\n",
563
+ " for row in tqdm(wikisql['train']):\n",
564
+ "\n",
565
+ " batch = tokenizer(row['question'] + row['sql']['human_readable'], truncation=True, max_length=128, return_tensors='pt')\n",
566
+ " batch = {k: v.cuda() for k, v in batch.items()}\n",
567
+ "\n",
568
+ " out = gpt.forward(**batch,)\n",
569
+ "\n",
570
+ " loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),\n",
571
+ " reduction='mean')\n",
572
+ " print(loss)\n",
573
+ " loss.backward()\n",
574
+ "\n",
575
+ " optimizer.step()\n",
576
+ " optimizer.zero_grad()"
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": null,
582
+ "id": "7d1b3b65",
583
+ "metadata": {
584
+ "cellId": "mirxlhno0w8wrmaaxj4u7"
585
+ },
586
+ "outputs": [],
587
+ "source": [
588
+ "#!g1.1\n",
589
+ "model.save_pretrained('bloom-8bit-fine-tuned')"
590
+ ]
591
+ }
592
+ ],
593
+ "metadata": {
594
+ "kernelspec": {
595
+ "display_name": "Python 3 (ipykernel)",
596
+ "language": "python",
597
+ "name": "python3"
598
+ },
599
+ "language_info": {
600
+ "codemirror_mode": {
601
+ "name": "ipython",
602
+ "version": 3
603
+ },
604
+ "file_extension": ".py",
605
+ "mimetype": "text/x-python",
606
+ "name": "python",
607
+ "nbconvert_exporter": "python",
608
+ "pygments_lexer": "ipython3",
609
+ "version": "3.8.10"
610
+ },
611
+ "notebookId": "433858c6-d0c2-461b-85f3-1153722e7367",
612
+ "notebookPath": "untitled.ipynb"
613
+ },
614
+ "nbformat": 4,
615
+ "nbformat_minor": 5
616
+ }