crumb commited on
Commit
e6a91e0
1 Parent(s): f5922a7

Upload shakespeare-inference.ipynb

Browse files
Files changed (1) hide show
  1. shakespeare-inference.ipynb +255 -0
shakespeare-inference.ipynb ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "inference",
7
+ "provenance": []
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU",
17
+ "gpuClass": "standard"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "code",
22
+ "source": [
23
+ "!pip install transformers==4.14.1\n",
24
+ "!pip install bitsandbytes-cuda111==0.26.0\n",
25
+ "\n",
26
+ "from IPython import display \n",
27
+ "display.clear_output()"
28
+ ],
29
+ "metadata": {
30
+ "id": "Q8cuAdVDGXR6"
31
+ },
32
+ "execution_count": 1,
33
+ "outputs": []
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 2,
38
+ "metadata": {
39
+ "id": "8mkqaWlNGLKn"
40
+ },
41
+ "outputs": [],
42
+ "source": [
43
+ "import transformers\n",
44
+ "import torch\n",
45
+ "import torch.nn.functional as F\n",
46
+ "from torch import nn\n",
47
+ "from torch.cuda.amp import custom_fwd, custom_bwd\n",
48
+ "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
49
+ "from tqdm.auto import tqdm"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "source": [
55
+ "#@title convert to 8bit\n",
56
+ "class FrozenBNBLinear(nn.Module):\n",
57
+ " def __init__(self, weight, absmax, code, bias=None):\n",
58
+ " assert isinstance(bias, nn.Parameter) or bias is None\n",
59
+ " super().__init__()\n",
60
+ " self.out_features, self.in_features = weight.shape\n",
61
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
62
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
63
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
64
+ " self.adapter = None\n",
65
+ " self.bias = bias\n",
66
+ " \n",
67
+ " def forward(self, input):\n",
68
+ " output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n",
69
+ " if self.adapter:\n",
70
+ " output += self.adapter(input)\n",
71
+ " return output\n",
72
+ " \n",
73
+ " @classmethod\n",
74
+ " def from_linear(cls, linear: nn.Linear) -> \"FrozenBNBLinear\":\n",
75
+ " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n",
76
+ " return cls(weights_int8, *state, linear.bias)\n",
77
+ " \n",
78
+ " def __repr__(self):\n",
79
+ " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n",
80
+ " \n",
81
+ " \n",
82
+ "class DequantizeAndLinear(torch.autograd.Function): \n",
83
+ " @staticmethod\n",
84
+ " @custom_fwd\n",
85
+ " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n",
86
+ " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n",
87
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
88
+ " ctx.save_for_backward(input, weights_quantized, absmax, code)\n",
89
+ " ctx._has_bias = bias is not None\n",
90
+ " return F.linear(input, weights_deq, bias)\n",
91
+ " \n",
92
+ " @staticmethod\n",
93
+ " @custom_bwd\n",
94
+ " def backward(ctx, grad_output: torch.Tensor):\n",
95
+ " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n",
96
+ " input, weights_quantized, absmax, code = ctx.saved_tensors\n",
97
+ " # grad_output: [*batch, out_features]\n",
98
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
99
+ " grad_input = grad_output @ weights_deq\n",
100
+ " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n",
101
+ " return grad_input, None, None, None, grad_bias\n",
102
+ " \n",
103
+ " \n",
104
+ "class FrozenBNBEmbedding(nn.Module):\n",
105
+ " def __init__(self, weight, absmax, code):\n",
106
+ " super().__init__()\n",
107
+ " self.num_embeddings, self.embedding_dim = weight.shape\n",
108
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
109
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
110
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
111
+ " self.adapter = None\n",
112
+ " \n",
113
+ " def forward(self, input, **kwargs):\n",
114
+ " with torch.no_grad():\n",
115
+ " # note: both quantuized weights and input indices are *not* differentiable\n",
116
+ " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n",
117
+ " output = F.embedding(input, weight_deq, **kwargs)\n",
118
+ " if self.adapter:\n",
119
+ " output += self.adapter(input)\n",
120
+ " return output \n",
121
+ " \n",
122
+ " @classmethod\n",
123
+ " def from_embedding(cls, embedding: nn.Embedding) -> \"FrozenBNBEmbedding\":\n",
124
+ " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n",
125
+ " return cls(weights_int8, *state)\n",
126
+ " \n",
127
+ " def __repr__(self):\n",
128
+ " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"\n",
129
+ " \n",
130
+ " \n",
131
+ "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n",
132
+ " assert chunk_size % 4096 == 0\n",
133
+ " code = None\n",
134
+ " chunks = []\n",
135
+ " absmaxes = []\n",
136
+ " flat_tensor = matrix.view(-1)\n",
137
+ " for i in range((matrix.numel() - 1) // chunk_size + 1):\n",
138
+ " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n",
139
+ " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n",
140
+ " chunks.append(quantized_chunk)\n",
141
+ " absmaxes.append(absmax_chunk)\n",
142
+ " \n",
143
+ " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n",
144
+ " absmax = torch.cat(absmaxes)\n",
145
+ " return matrix_i8, (absmax, code)\n",
146
+ " \n",
147
+ " \n",
148
+ "def convert_to_int8(model):\n",
149
+ " \"\"\"Convert linear and embedding modules to 8-bit with optional adapters\"\"\"\n",
150
+ " for module in list(model.modules()):\n",
151
+ " for name, child in module.named_children():\n",
152
+ " if isinstance(child, nn.Linear):\n",
153
+ " setattr( \n",
154
+ " module,\n",
155
+ " name,\n",
156
+ " FrozenBNBLinear(\n",
157
+ " weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),\n",
158
+ " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
159
+ " code=torch.zeros(256),\n",
160
+ " bias=child.bias,\n",
161
+ " ),\n",
162
+ " )\n",
163
+ " elif isinstance(child, nn.Embedding):\n",
164
+ " setattr(\n",
165
+ " module,\n",
166
+ " name,\n",
167
+ " FrozenBNBEmbedding(\n",
168
+ " weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),\n",
169
+ " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
170
+ " code=torch.zeros(256),\n",
171
+ " )\n",
172
+ " )\n",
173
+ "class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):\n",
174
+ " def __init__(self, config):\n",
175
+ " super().__init__(config)\n",
176
+ "\n",
177
+ " convert_to_int8(self.attn)\n",
178
+ " convert_to_int8(self.mlp)\n",
179
+ "\n",
180
+ "class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):\n",
181
+ " def __init__(self, config):\n",
182
+ " super().__init__(config)\n",
183
+ " convert_to_int8(self)\n",
184
+ " \n",
185
+ "class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):\n",
186
+ " def __init__(self, config):\n",
187
+ " super().__init__(config)\n",
188
+ " convert_to_int8(self)\n",
189
+ "\n",
190
+ "\n",
191
+ "transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J"
192
+ ],
193
+ "metadata": {
194
+ "cellView": "form",
195
+ "id": "fmpdVvfVG7Pc"
196
+ },
197
+ "execution_count": 3,
198
+ "outputs": []
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "source": [
203
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"EleutherAI/gpt-j-6B\")\n",
204
+ "gpt = GPTJForCausalLM.from_pretrained(\"crumb/gpt-j-6b-shakespeare\", low_cpu_mem_usage=True)\n",
205
+ "\n",
206
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
207
+ "gpt = gpt.to(device)"
208
+ ],
209
+ "metadata": {
210
+ "id": "ttKTRoUlG5YM"
211
+ },
212
+ "execution_count": null,
213
+ "outputs": []
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "source": [
218
+ "prompt = \"\"\"ROMEO: I would I were thy bird. \n",
219
+ "JULIET: Sweet, so would I, Yet I should kill thee with much cherishing. Good night, good night! Parting is such sweet\"\"\"\n",
220
+ "prompt = tokenizer(prompt, return_tensors='pt')\n",
221
+ "prompt = {key: value.to(device) for key, value in prompt.items()}\n",
222
+ "out = gpt.generate(**prompt, min_length=32, max_length=64, do_sample=True)\n",
223
+ "out = tokenizer.decode(out[0])\n",
224
+ "print(out)"
225
+ ],
226
+ "metadata": {
227
+ "colab": {
228
+ "base_uri": "https://localhost:8080/"
229
+ },
230
+ "id": "kSXSZz_kGcfm",
231
+ "outputId": "d91dda66-88ab-4e52-bfb9-3df8092abe2f"
232
+ },
233
+ "execution_count": 9,
234
+ "outputs": [
235
+ {
236
+ "output_type": "stream",
237
+ "name": "stderr",
238
+ "text": [
239
+ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
240
+ ]
241
+ },
242
+ {
243
+ "output_type": "stream",
244
+ "name": "stdout",
245
+ "text": [
246
+ "ROMEO: I would I were thy bird. \n",
247
+ "JULIET: Sweet, so would I, Yet I should kill thee with much cherishing. Good night, good night! Parting is such sweet sorrow, As a lost angel's song, in answer to An evil dream.\n",
248
+ "\n",
249
+ "ROMEO\n"
250
+ ]
251
+ }
252
+ ]
253
+ }
254
+ ]
255
+ }