dvmazur commited on
Commit
42c6b2b
·
1 Parent(s): 0f72e0c

Upload convert-gpt-j.ipynb

Browse files
Files changed (1) hide show
  1. convert-gpt-j.ipynb +406 -0
convert-gpt-j.ipynb ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "import torch.nn.functional as F\n",
12
+ "\n",
13
+ "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
14
+ "import transformers\n",
15
+ "%config Completer.use_jedi = False\n",
16
+ "\n",
17
+ "\n",
18
+ "model_name = \"EleutherAI/gpt-j-6B\"\n",
19
+ "gpt = transformers.AutoModelForCausalLM.from_pretrained(model_name)\n",
20
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 2,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n",
30
+ " assert chunk_size % 4096 == 0\n",
31
+ " code = None\n",
32
+ " chunks = []\n",
33
+ " absmaxes = []\n",
34
+ " flat_tensor = matrix.view(-1)\n",
35
+ " for i in range((matrix.numel() - 1) // chunk_size + 1):\n",
36
+ " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n",
37
+ " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n",
38
+ " chunks.append(quantized_chunk)\n",
39
+ " absmaxes.append(absmax_chunk)\n",
40
+ " \n",
41
+ " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n",
42
+ " absmax = torch.cat(absmaxes)\n",
43
+ " return matrix_i8, (absmax, code)"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 3,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "from typing import Tuple\n",
53
+ "from torch.cuda.amp import custom_fwd, custom_bwd\n",
54
+ "\n",
55
+ "\n",
56
+ "class DequantizeAndLinear(torch.autograd.Function):\n",
57
+ " \n",
58
+ " @staticmethod\n",
59
+ " @custom_fwd\n",
60
+ " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n",
61
+ " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n",
62
+ " \n",
63
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
64
+ " ctx.save_for_backward(input, weights_quantized, absmax, code)\n",
65
+ " ctx._has_bias = bias is not None\n",
66
+ " return F.linear(input, weights_deq, bias)\n",
67
+ " \n",
68
+ " @staticmethod\n",
69
+ " @custom_bwd\n",
70
+ " def backward(ctx, grad_output: torch.Tensor):\n",
71
+ " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n",
72
+ " input, weights_quantized, absmax, code = ctx.saved_tensors\n",
73
+ " # grad_output: [*batch, out_features]\n",
74
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
75
+ " grad_input = grad_output @ weights_deq\n",
76
+ " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n",
77
+ " return grad_input, None, None, None, grad_bias\n",
78
+ "\n",
79
+ "\n",
80
+ "class BNBLinearWithAdapter(nn.Module):\n",
81
+ " def __init__(self, weight, absmax, code, bias=None, adapter_dim=0):\n",
82
+ " assert isinstance(bias, nn.Parameter) or bias is None\n",
83
+ " super().__init__()\n",
84
+ " self.out_features, self.in_features = weight.shape\n",
85
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
86
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
87
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
88
+ " self.bias = bias\n",
89
+ " \n",
90
+ " if adapter_dim > 0:\n",
91
+ " self.adapter = nn.Sequential(\n",
92
+ " nn.Linear(self.in_features, adapter_dim, bias=False),\n",
93
+ " nn.Linear(adapter_dim, self.out_features, bias=False),\n",
94
+ " )\n",
95
+ " \n",
96
+ " nn.init.zeros_(self.adapter[1].weight)\n",
97
+ " else:\n",
98
+ " self.adapter = None\n",
99
+ " \n",
100
+ " def forward(self, input):\n",
101
+ " out = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n",
102
+ " \n",
103
+ " if self.adapter:\n",
104
+ " return self.adapter(input) + out\n",
105
+ " \n",
106
+ " return out\n",
107
+ " \n",
108
+ " \n",
109
+ " @classmethod\n",
110
+ " def from_linear(cls, linear: nn.Linear, **kwargs) -> \"FrozenBNBLinear\":\n",
111
+ " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n",
112
+ " return cls(weights_int8, *state, linear.bias, **kwargs)\n",
113
+ " \n",
114
+ " def __repr__(self):\n",
115
+ " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n",
116
+ "\n",
117
+ "\n",
118
+ "class BNBEmbeddingWithAdapter(nn.Module):\n",
119
+ " def __init__(self, weight, absmax, code, adapter_dim=0):\n",
120
+ " super().__init__()\n",
121
+ " self.num_embeddings, self.embedding_dim = weight.shape\n",
122
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
123
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
124
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
125
+ " \n",
126
+ " if adapter_dim > 0:\n",
127
+ " self.adapter = nn.Sequential(\n",
128
+ " nn.Embedding(self.num_embeddings, adapter_dim),\n",
129
+ " nn.Linear(adapter_dim, self.embedding_dim, bias=False),\n",
130
+ " )\n",
131
+ " \n",
132
+ " nn.init.zeros_(self.adapter[1].weight)\n",
133
+ " else:\n",
134
+ " self.adapter = None\n",
135
+ " \n",
136
+ " def forward(self, input, **kwargs):\n",
137
+ " with torch.no_grad():\n",
138
+ " # note: both quantuized weights and input indices are *not* differentiable\n",
139
+ " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n",
140
+ " out = F.embedding(input, weight_deq, **kwargs)\n",
141
+ " if self.adapter:\n",
142
+ " return out + self.adapter(input, **kwargs)\n",
143
+ " \n",
144
+ " return out\n",
145
+ " \n",
146
+ " @classmethod\n",
147
+ " def from_embedding(cls, embedding: nn.Embedding, **kwargs) -> \"FrozenBNBEmbedding\":\n",
148
+ " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n",
149
+ " return cls(weights_int8, *state, **kwargs)\n",
150
+ " \n",
151
+ " def __repr__(self):\n",
152
+ " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\""
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 4,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "def bnbfy_(model, adapter_dim: int = 0):\n",
162
+ " for module in list(model.modules()):\n",
163
+ " for name, child in module.named_children():\n",
164
+ " if isinstance(child, nn.Linear):\n",
165
+ " print(name, child)\n",
166
+ " setattr(module, name, BNBLinearWithAdapter.from_linear(child, adapter_dim=adapter_dim))\n",
167
+ " \n",
168
+ " elif isinstance(child, nn.Embedding):\n",
169
+ " print(name, child)\n",
170
+ " setattr(module, name, BNBEmbeddingWithAdapter.from_embedding(child, adapter_dim=adapter_dim))"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 5,
176
+ "metadata": {},
177
+ "outputs": [
178
+ {
179
+ "name": "stdout",
180
+ "output_type": "stream",
181
+ "text": [
182
+ "lm_head Linear(in_features=4096, out_features=50400, bias=True)\n",
183
+ "wte Embedding(50400, 4096)\n",
184
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
185
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
186
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
187
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
188
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
189
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
190
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
191
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
192
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
193
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
194
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
195
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
196
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
197
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
198
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
199
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
200
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
201
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
202
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
203
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
204
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
205
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
206
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
207
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
208
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
209
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
210
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
211
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
212
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
213
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
214
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
215
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
216
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
217
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
218
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
219
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
220
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
221
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
222
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
223
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
224
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
225
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
226
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
227
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
228
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
229
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
230
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
231
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
232
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
233
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
234
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
235
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
236
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
237
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
238
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
239
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
240
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
241
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
242
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
243
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
244
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
245
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
246
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
247
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
248
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
249
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
250
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
251
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
252
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
253
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
254
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
255
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
256
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
257
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
258
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
259
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
260
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
261
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
262
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
263
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
264
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
265
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
266
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
267
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
268
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
269
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
270
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
271
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
272
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
273
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
274
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
275
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
276
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
277
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
278
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
279
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
280
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
281
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
282
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
283
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
284
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
285
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
286
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
287
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
288
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
289
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
290
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
291
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
292
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
293
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
294
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
295
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
296
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
297
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
298
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
299
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
300
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
301
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
302
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
303
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
304
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
305
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
306
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
307
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
308
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
309
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
310
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
311
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
312
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n"
313
+ ]
314
+ },
315
+ {
316
+ "name": "stdout",
317
+ "output_type": "stream",
318
+ "text": [
319
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
320
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
321
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
322
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
323
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
324
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
325
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
326
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
327
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
328
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
329
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
330
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
331
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
332
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
333
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
334
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
335
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
336
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
337
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
338
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
339
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
340
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
341
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
342
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
343
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
344
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
345
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
346
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
347
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
348
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
349
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
350
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
351
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
352
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
353
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
354
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
355
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
356
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
357
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n"
358
+ ]
359
+ }
360
+ ],
361
+ "source": [
362
+ "bnbfy_(gpt, adapter_dim=0)"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": 7,
368
+ "metadata": {},
369
+ "outputs": [
370
+ {
371
+ "name": "stderr",
372
+ "output_type": "stream",
373
+ "text": [
374
+ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
375
+ ]
376
+ }
377
+ ],
378
+ "source": [
379
+ "prompt = tokenizer(\"A cat sat on a mat and\", return_tensors='pt')\n",
380
+ "out = gpt.generate(**prompt, min_length=8, max_length=8, do_sample=True)\n",
381
+ "tokenizer.decode(out[0])"
382
+ ]
383
+ }
384
+ ],
385
+ "metadata": {
386
+ "kernelspec": {
387
+ "display_name": "py38",
388
+ "language": "python",
389
+ "name": "py38"
390
+ },
391
+ "language_info": {
392
+ "codemirror_mode": {
393
+ "name": "ipython",
394
+ "version": 3
395
+ },
396
+ "file_extension": ".py",
397
+ "mimetype": "text/x-python",
398
+ "name": "python",
399
+ "nbconvert_exporter": "python",
400
+ "pygments_lexer": "ipython3",
401
+ "version": "3.8.1"
402
+ }
403
+ },
404
+ "nbformat": 4,
405
+ "nbformat_minor": 2
406
+ }