Upload ai_msgbot_gpt_j_6b_8bit_with_hub.py
Browse files
ai_msgbot_gpt_j_6b_8bit_with_hub.py
ADDED
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""ai-msgbot-gpt-j-6b-8bit with hub.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/12IXeac5sEUL7dX2bQfB8BZ46lHwK8-dT
|
8 |
+
|
9 |
+
# <center> ai-msgbot - conversational 6B GPT-J 8bit demo
|
10 |
+
|
11 |
+
|
12 |
+
> This notebook demos interaction with a 6B GPT-J finetuned for dialogue via methods in [ai-msgbot](https://github.com/pszemraj/ai-msgbot)
|
13 |
+
|
14 |
+
|
15 |
+
By [Peter](https://github.com/pszemraj). This notebook and `ai-msgbot` are [licensed under creative commons](https://github.com/pszemraj/ai-msgbot/blob/main/LICENSE). Models trained on given datasets are subject to those datasets' licenses.
|
16 |
+
|
17 |
+
|
18 |
+
## usage
|
19 |
+
|
20 |
+
1. select the checkpoint of the model to use for generation in the `model_checkpoint` dropdown
|
21 |
+
2. Run all cells to load everything
|
22 |
+
3. adjust the prompt fields at the bottom of the notebook to whatever you want, see how AI responds.
|
23 |
+
|
24 |
+
|
25 |
+
A fine-tuning example etc. will come _eventually_
|
26 |
+
|
27 |
+
|
28 |
+
---
|
29 |
+
|
30 |
+
# setup
|
31 |
+
"""
|
32 |
+
|
33 |
+
#@markdown setup logging
|
34 |
+
import logging
|
35 |
+
from pathlib import Path
|
36 |
+
for handler in logging.root.handlers[:]:
|
37 |
+
logging.root.removeHandler(handler)
|
38 |
+
|
39 |
+
das_logfile = Path.cwd() / "8bit_inference.log"
|
40 |
+
|
41 |
+
logging.basicConfig(
|
42 |
+
level=logging.INFO,
|
43 |
+
filename=das_logfile,
|
44 |
+
filemode='w',
|
45 |
+
format="%(asctime)s %(levelname)s %(message)s",
|
46 |
+
datefmt="%m/%d/%Y %I:%M:%S",
|
47 |
+
)
|
48 |
+
|
49 |
+
#@markdown add auto-Colab formatting with `IPython.display`
|
50 |
+
from IPython.display import HTML, display
|
51 |
+
# colab formatting
|
52 |
+
def set_css():
|
53 |
+
display(
|
54 |
+
HTML(
|
55 |
+
"""
|
56 |
+
<style>
|
57 |
+
pre {
|
58 |
+
white-space: pre-wrap;
|
59 |
+
}
|
60 |
+
</style>
|
61 |
+
"""
|
62 |
+
)
|
63 |
+
)
|
64 |
+
|
65 |
+
get_ipython().events.register("pre_run_cell", set_css)
|
66 |
+
|
67 |
+
from pathlib import Path
|
68 |
+
|
69 |
+
"""### GPU info"""
|
70 |
+
|
71 |
+
!nvidia-smi
|
72 |
+
|
73 |
+
"""## install and import
|
74 |
+
|
75 |
+
_this notebook uses a specific version of `torch` which can take a while to install._
|
76 |
+
"""
|
77 |
+
|
78 |
+
!pip install transformers==4.24.0 -q
|
79 |
+
!pip install bitsandbytes==0.32.2 -q
|
80 |
+
!pip install datasets==1.16.1 -q
|
81 |
+
!pip install torch==1.11 -q
|
82 |
+
!pip install accelerate==0.12.0 -q
|
83 |
+
!pip install pysbd==0.3.4 -q
|
84 |
+
|
85 |
+
# Commented out IPython magic to ensure Python compatibility.
|
86 |
+
# %%capture
|
87 |
+
# import transformers
|
88 |
+
#
|
89 |
+
# import pandas as pd
|
90 |
+
#
|
91 |
+
# import torch
|
92 |
+
# import torch.nn.functional as F
|
93 |
+
# from torch import nn
|
94 |
+
# from torch.cuda.amp import custom_fwd, custom_bwd
|
95 |
+
#
|
96 |
+
# import bitsandbytes as bnb
|
97 |
+
# from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
|
98 |
+
#
|
99 |
+
# from tqdm.auto import tqdm
|
100 |
+
|
101 |
+
#@markdown utils
|
102 |
+
from transformers.utils.logging import set_verbosity
|
103 |
+
|
104 |
+
set_verbosity(40)
|
105 |
+
|
106 |
+
import warnings
|
107 |
+
# ignore hf pipeline complaints
|
108 |
+
warnings.filterwarnings("ignore", category=UserWarning, module='transformers')
|
109 |
+
|
110 |
+
"""## Converting the model to 8 bits
|
111 |
+
|
112 |
+
"""
|
113 |
+
|
114 |
+
#@title define 8bit classes
|
115 |
+
|
116 |
+
#@markdown - bitsandbytes lib
|
117 |
+
class FrozenBNBLinear(nn.Module):
|
118 |
+
def __init__(self, weight, absmax, code, bias=None):
|
119 |
+
assert isinstance(bias, nn.Parameter) or bias is None
|
120 |
+
super().__init__()
|
121 |
+
self.out_features, self.in_features = weight.shape
|
122 |
+
self.register_buffer("weight", weight.requires_grad_(False))
|
123 |
+
self.register_buffer("absmax", absmax.requires_grad_(False))
|
124 |
+
self.register_buffer("code", code.requires_grad_(False))
|
125 |
+
self.adapter = None
|
126 |
+
self.bias = bias
|
127 |
+
|
128 |
+
def forward(self, input):
|
129 |
+
output = DequantizeAndLinear.apply(
|
130 |
+
input, self.weight, self.absmax, self.code, self.bias
|
131 |
+
)
|
132 |
+
if self.adapter:
|
133 |
+
output += self.adapter(input)
|
134 |
+
return output
|
135 |
+
|
136 |
+
@classmethod
|
137 |
+
def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
|
138 |
+
weights_int8, state = quantize_blockise_lowmemory(linear.weight)
|
139 |
+
return cls(weights_int8, *state, linear.bias)
|
140 |
+
|
141 |
+
def __repr__(self):
|
142 |
+
return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
|
143 |
+
|
144 |
+
|
145 |
+
class DequantizeAndLinear(torch.autograd.Function):
|
146 |
+
@staticmethod
|
147 |
+
@custom_fwd
|
148 |
+
def forward(
|
149 |
+
ctx,
|
150 |
+
input: torch.Tensor,
|
151 |
+
weights_quantized: torch.ByteTensor,
|
152 |
+
absmax: torch.FloatTensor,
|
153 |
+
code: torch.FloatTensor,
|
154 |
+
bias: torch.FloatTensor,
|
155 |
+
):
|
156 |
+
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
|
157 |
+
ctx.save_for_backward(input, weights_quantized, absmax, code)
|
158 |
+
ctx._has_bias = bias is not None
|
159 |
+
return F.linear(input, weights_deq, bias)
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
@custom_bwd
|
163 |
+
def backward(ctx, grad_output: torch.Tensor):
|
164 |
+
assert (
|
165 |
+
not ctx.needs_input_grad[1]
|
166 |
+
and not ctx.needs_input_grad[2]
|
167 |
+
and not ctx.needs_input_grad[3]
|
168 |
+
)
|
169 |
+
input, weights_quantized, absmax, code = ctx.saved_tensors
|
170 |
+
# grad_output: [*batch, out_features]
|
171 |
+
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
|
172 |
+
grad_input = grad_output @ weights_deq
|
173 |
+
grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
|
174 |
+
return grad_input, None, None, None, grad_bias
|
175 |
+
|
176 |
+
|
177 |
+
class FrozenBNBEmbedding(nn.Module):
|
178 |
+
def __init__(self, weight, absmax, code):
|
179 |
+
super().__init__()
|
180 |
+
self.num_embeddings, self.embedding_dim = weight.shape
|
181 |
+
self.register_buffer("weight", weight.requires_grad_(False))
|
182 |
+
self.register_buffer("absmax", absmax.requires_grad_(False))
|
183 |
+
self.register_buffer("code", code.requires_grad_(False))
|
184 |
+
self.adapter = None
|
185 |
+
|
186 |
+
def forward(self, input, **kwargs):
|
187 |
+
with torch.no_grad():
|
188 |
+
# note: both quantuized weights and input indices are *not* differentiable
|
189 |
+
weight_deq = dequantize_blockwise(
|
190 |
+
self.weight, absmax=self.absmax, code=self.code
|
191 |
+
)
|
192 |
+
output = F.embedding(input, weight_deq, **kwargs)
|
193 |
+
if self.adapter:
|
194 |
+
output += self.adapter(input)
|
195 |
+
return output
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
|
199 |
+
weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
|
200 |
+
return cls(weights_int8, *state)
|
201 |
+
|
202 |
+
def __repr__(self):
|
203 |
+
return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
|
204 |
+
|
205 |
+
|
206 |
+
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2**20):
|
207 |
+
assert chunk_size % 4096 == 0
|
208 |
+
code = None
|
209 |
+
chunks = []
|
210 |
+
absmaxes = []
|
211 |
+
flat_tensor = matrix.view(-1)
|
212 |
+
for i in range((matrix.numel() - 1) // chunk_size + 1):
|
213 |
+
input_chunk = flat_tensor[i * chunk_size : (i + 1) * chunk_size].clone()
|
214 |
+
quantized_chunk, (absmax_chunk, code) = quantize_blockwise(
|
215 |
+
input_chunk, code=code
|
216 |
+
)
|
217 |
+
chunks.append(quantized_chunk)
|
218 |
+
absmaxes.append(absmax_chunk)
|
219 |
+
matrix_i8 = torch.cat(chunks).reshape_as(matrix)
|
220 |
+
absmax = torch.cat(absmaxes)
|
221 |
+
return matrix_i8, (absmax, code)
|
222 |
+
|
223 |
+
|
224 |
+
def convert_to_int8(model):
|
225 |
+
"""Convert linear and embedding modules to 8-bit with optional adapters"""
|
226 |
+
for module in list(model.modules()):
|
227 |
+
for name, child in module.named_children():
|
228 |
+
if isinstance(child, nn.Linear):
|
229 |
+
print(name, child)
|
230 |
+
setattr(
|
231 |
+
module,
|
232 |
+
name,
|
233 |
+
FrozenBNBLinear(
|
234 |
+
weight=torch.zeros(
|
235 |
+
child.out_features, child.in_features, dtype=torch.uint8
|
236 |
+
),
|
237 |
+
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
|
238 |
+
code=torch.zeros(256),
|
239 |
+
bias=child.bias,
|
240 |
+
),
|
241 |
+
)
|
242 |
+
elif isinstance(child, nn.Embedding):
|
243 |
+
setattr(
|
244 |
+
module,
|
245 |
+
name,
|
246 |
+
FrozenBNBEmbedding(
|
247 |
+
weight=torch.zeros(
|
248 |
+
child.num_embeddings, child.embedding_dim, dtype=torch.uint8
|
249 |
+
),
|
250 |
+
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
|
251 |
+
code=torch.zeros(256),
|
252 |
+
),
|
253 |
+
)
|
254 |
+
|
255 |
+
#@markdown Patch GPT-J before loading:
|
256 |
+
|
257 |
+
|
258 |
+
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
|
259 |
+
def __init__(self, config):
|
260 |
+
super().__init__(config)
|
261 |
+
|
262 |
+
convert_to_int8(self.attn)
|
263 |
+
convert_to_int8(self.mlp)
|
264 |
+
|
265 |
+
|
266 |
+
class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
|
267 |
+
def __init__(self, config):
|
268 |
+
super().__init__(config)
|
269 |
+
convert_to_int8(self)
|
270 |
+
|
271 |
+
|
272 |
+
class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
|
273 |
+
def __init__(self, config):
|
274 |
+
super().__init__(config)
|
275 |
+
convert_to_int8(self)
|
276 |
+
|
277 |
+
|
278 |
+
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
|
279 |
+
|
280 |
+
# Commented out IPython magic to ensure Python compatibility.
|
281 |
+
# %%capture
|
282 |
+
# #@markdown `add_adapters()`
|
283 |
+
#
|
284 |
+
# def add_adapters(model, adapter_dim=4, p = 0.1):
|
285 |
+
# assert adapter_dim > 0
|
286 |
+
#
|
287 |
+
# for name, module in model.named_modules():
|
288 |
+
# if isinstance(module, FrozenBNBLinear):
|
289 |
+
# if "attn" in name or "mlp" in name or "head" in name:
|
290 |
+
# print("Adding adapter to", name)
|
291 |
+
# module.adapter = nn.Sequential(
|
292 |
+
# nn.Linear(module.in_features, adapter_dim, bias=False),
|
293 |
+
# nn.Dropout(p=p),
|
294 |
+
# nn.Linear(adapter_dim, module.out_features, bias=False),
|
295 |
+
# )
|
296 |
+
# print("Initializing", name)
|
297 |
+
# nn.init.zeros_(module.adapter[2].weight)
|
298 |
+
#
|
299 |
+
# else:
|
300 |
+
# print("Not adding adapter to", name)
|
301 |
+
# elif isinstance(module, FrozenBNBEmbedding):
|
302 |
+
# print("Adding adapter to", name)
|
303 |
+
# module.adapter = nn.Sequential(
|
304 |
+
# nn.Embedding(module.num_embeddings, adapter_dim),
|
305 |
+
# nn.Dropout(p=p),
|
306 |
+
# nn.Linear(adapter_dim, module.embedding_dim, bias=False),
|
307 |
+
# )
|
308 |
+
# print("Initializing", name)
|
309 |
+
# nn.init.zeros_(module.adapter[2].weight)
|
310 |
+
#
|
311 |
+
|
312 |
+
#@markdown set up config
|
313 |
+
config = transformers.GPTJConfig.from_pretrained("hivemind/gpt-j-6B-8bit")
|
314 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
315 |
+
config.pad_token_id = config.eos_token_id
|
316 |
+
tokenizer.pad_token = config.pad_token_id
|
317 |
+
|
318 |
+
"""# load model
|
319 |
+
|
320 |
+
"""
|
321 |
+
|
322 |
+
from contextlib import contextmanager
|
323 |
+
import sys, os, gc
|
324 |
+
import logging
|
325 |
+
from tqdm.auto import tqdm
|
326 |
+
#@markdown define `load_8bit_from_hub()`
|
327 |
+
|
328 |
+
@contextmanager
|
329 |
+
def suppress_stdout():
|
330 |
+
with open(os.devnull, "w") as devnull:
|
331 |
+
old_stdout = sys.stdout
|
332 |
+
sys.stdout = devnull
|
333 |
+
try:
|
334 |
+
yield
|
335 |
+
finally:
|
336 |
+
sys.stdout = old_stdout
|
337 |
+
|
338 |
+
def load_8bit_from_hub(model_id:str, **kwargs):
|
339 |
+
pbar = tqdm(desc="instantiating model..", total=3)
|
340 |
+
|
341 |
+
with suppress_stdout():
|
342 |
+
gc.collect()
|
343 |
+
model = GPTJForCausalLM.from_pretrained(model_id,
|
344 |
+
device_map='auto',
|
345 |
+
low_cpu_mem_usage=True,
|
346 |
+
**kwargs)
|
347 |
+
pbar.update()
|
348 |
+
add_adapters(model)
|
349 |
+
pbar.update()
|
350 |
+
model = model.to("cuda" if torch.cuda.is_available() else -1)
|
351 |
+
pbar.update()
|
352 |
+
return model
|
353 |
+
|
354 |
+
from huggingface_hub import notebook_login
|
355 |
+
|
356 |
+
notebook_login()
|
357 |
+
|
358 |
+
model_name = "ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps" #@param ["ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps"]
|
359 |
+
|
360 |
+
# load_8bit_from_hub() is a wrapper around AutoModel.from_pretrained() and will
|
361 |
+
# passthrough all kwargs to that
|
362 |
+
model = load_8bit_from_hub(model_name, use_auth_token=True, )
|
363 |
+
|
364 |
+
"""# generate text
|
365 |
+
|
366 |
+
## standard generation
|
367 |
+
`
|
368 |
+
|
369 |
+
with torch:
|
370 |
+
|
371 |
+
> with "standard" generation it's recommended to put the **speaker token labels** at the end of your prompt so the model "knows" to respond.
|
372 |
+
|
373 |
+
i.e `Person Alpha:` or `Person Beta:` for these two models.
|
374 |
+
"""
|
375 |
+
|
376 |
+
prompt = "Person Alpha: what is the theory of being \"woke\" all about?\\n Person Beta: " # @param {type:"string"}
|
377 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
378 |
+
with torch.no_grad():
|
379 |
+
prompt = tokenizer(prompt, return_tensors="pt")
|
380 |
+
prompt = {key: value.to(device) for key, value in prompt.items()}
|
381 |
+
out = model.generate(
|
382 |
+
**prompt,
|
383 |
+
min_length=24,
|
384 |
+
max_length=96,
|
385 |
+
top_k=30,
|
386 |
+
top_p=0.9,
|
387 |
+
temperature=0.4,
|
388 |
+
do_sample=True,
|
389 |
+
repetition_penalty=1.2,
|
390 |
+
no_repeat_ngram_size=3,
|
391 |
+
pad_token_id=tokenizer.eos_token_id,
|
392 |
+
)
|
393 |
+
result = tokenizer.decode(
|
394 |
+
out[0],
|
395 |
+
remove_invalid_values=True,
|
396 |
+
skip_special_tokens=True,
|
397 |
+
clean_up_tokenization_spaces=True,
|
398 |
+
)
|
399 |
+
result
|
400 |
+
|
401 |
+
"""---
|
402 |
+
|
403 |
+
## 'Extract' bot response
|
404 |
+
- transformers `pipeline` object
|
405 |
+
- generate with better params
|
406 |
+
- extract the bot's response with `get_bot_response()` - start to use [ai-msgbot](https://github.com/pszemraj/ai-msgbot) _like it was meant to be used_
|
407 |
+
"""
|
408 |
+
|
409 |
+
from transformers import pipeline
|
410 |
+
|
411 |
+
generator = pipeline(
|
412 |
+
"text-generation",
|
413 |
+
model=model,
|
414 |
+
tokenizer="EleutherAI/gpt-j-6B",
|
415 |
+
device= 0 if torch.cuda.is_available() else -1,
|
416 |
+
)
|
417 |
+
|
418 |
+
"""### generation functions
|
419 |
+
|
420 |
+
for extracting the response, beam search vs. sampling, etc
|
421 |
+
"""
|
422 |
+
|
423 |
+
# @markdown `get_bot_response(name_resp: str, model_resp: list, name_spk: str, verbose: bool = False)`
|
424 |
+
# @markdown - this extracts the response from "Person Beta" from the total generation
|
425 |
+
import pysbd
|
426 |
+
|
427 |
+
seg = pysbd.Segmenter(language="en", clean=False)
|
428 |
+
|
429 |
+
import re
|
430 |
+
|
431 |
+
|
432 |
+
def split_sentences(text, use_regex=False, min_len=2):
|
433 |
+
"""given a string, splits it into sentences based on punctuation marks."""
|
434 |
+
|
435 |
+
if use_regex:
|
436 |
+
sentences = re.split(r'(?<=[.!?]) +', string)
|
437 |
+
else:
|
438 |
+
# https://github.com/nipunsadvilkar/pySBD
|
439 |
+
sentences = seg.segment(text)
|
440 |
+
return [s.strip() for s in sentences if len(s.strip()) > min_len]
|
441 |
+
|
442 |
+
|
443 |
+
def validate_response(response_text):
|
444 |
+
|
445 |
+
if isinstance(response_text, list):
|
446 |
+
|
447 |
+
return response_text
|
448 |
+
# if len(response_text) > 1 else split_sentences(str(response_text))
|
449 |
+
elif isinstance(response_text, str):
|
450 |
+
return split_sentences(response_text)
|
451 |
+
else:
|
452 |
+
raise ValueError(f"response input {response_text} not a list or str..")
|
453 |
+
|
454 |
+
|
455 |
+
def get_bot_response(
|
456 |
+
name_resp: str, model_resp: list, name_spk: str, verbose: bool = False
|
457 |
+
):
|
458 |
+
"""
|
459 |
+
get_bot_response - gets the bot response to a prompt, checking to ensure that additional statements by the "speaker" are not included in the response.
|
460 |
+
Args:
|
461 |
+
name_resp (str): the name of the responder
|
462 |
+
model_resp (list): the model response
|
463 |
+
name_spk (str): the name of the speaker
|
464 |
+
verbose (bool, optional): Defaults to False.
|
465 |
+
Returns:
|
466 |
+
bot_response (str): the bot response, isolated down to just text without the "name tokens" or further messages from the speaker.
|
467 |
+
"""
|
468 |
+
|
469 |
+
model_resp = validate_response(model_resp)
|
470 |
+
logging.info(f"isolating response from:\t{model_resp}")
|
471 |
+
fn_resp = []
|
472 |
+
|
473 |
+
name_counter = 0
|
474 |
+
break_safe = False
|
475 |
+
for resline in model_resp:
|
476 |
+
if name_resp.lower() in resline.lower():
|
477 |
+
name_counter += 1
|
478 |
+
break_safe = True
|
479 |
+
continue
|
480 |
+
if ":" in resline and name_resp.lower() not in resline.lower():
|
481 |
+
break
|
482 |
+
if name_spk.lower() in resline.lower() and not break_safe:
|
483 |
+
break
|
484 |
+
else:
|
485 |
+
fn_resp.append(resline)
|
486 |
+
if verbose:
|
487 |
+
print("the full response is:\n")
|
488 |
+
print("\n".join(fn_resp))
|
489 |
+
if isinstance(fn_resp, list):
|
490 |
+
fn_resp = fn_resp[0] if len(fn_resp) == 1 else " ".join(fn_resp)
|
491 |
+
return fn_resp
|
492 |
+
|
493 |
+
import pprint as pp
|
494 |
+
|
495 |
+
# @markdown define `generate_sampling(prompt: str, ...)`
|
496 |
+
|
497 |
+
|
498 |
+
def generate_sampling(
|
499 |
+
prompt: str,
|
500 |
+
suffix:str=None,
|
501 |
+
temperature=0.4,
|
502 |
+
top_k: int = 40,
|
503 |
+
top_p=0.90,
|
504 |
+
min_length: int = 16,
|
505 |
+
max_length: int = 128,
|
506 |
+
no_repeat_ngram_size: int = 3,
|
507 |
+
repetition_penalty=1.5,
|
508 |
+
return_full_text=False,
|
509 |
+
verbose=False,
|
510 |
+
**kwargs,
|
511 |
+
) -> None:
|
512 |
+
|
513 |
+
logging.info(f"generating results for input:\n\t{prompt}\n\t...")
|
514 |
+
if verbose:
|
515 |
+
print(f"generating results for input:\n\t{prompt}\n\t...")
|
516 |
+
prompt = f"{prompt}{suffix}" if suffix is not None else prompt
|
517 |
+
|
518 |
+
_prompt_tokens = len(generator.tokenizer(prompt).input_ids)
|
519 |
+
result = generator(
|
520 |
+
prompt,
|
521 |
+
min_length=min_length+_prompt_tokens,
|
522 |
+
temperature=temperature,
|
523 |
+
top_k=top_k,
|
524 |
+
top_p=top_p,
|
525 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
526 |
+
repetition_penalty=repetition_penalty,
|
527 |
+
remove_invalid_values=True,
|
528 |
+
clean_up_tokenization_spaces=True,
|
529 |
+
do_sample=True,
|
530 |
+
return_full_text=return_full_text,
|
531 |
+
max_new_tokens=max_length+_prompt_tokens,
|
532 |
+
pad_token_id=generator.tokenizer.eos_token_id,
|
533 |
+
**kwargs,
|
534 |
+
)
|
535 |
+
|
536 |
+
output = result[0]["generated_text"]
|
537 |
+
logging.info(f"model output:\n\t{output}")
|
538 |
+
if verbose:
|
539 |
+
print(f"model output:\n\t{output}")
|
540 |
+
response = get_bot_response(
|
541 |
+
model_resp=output,
|
542 |
+
name_spk="Person Alpha",
|
543 |
+
name_resp="Person Beta",
|
544 |
+
verbose=False,
|
545 |
+
)
|
546 |
+
|
547 |
+
logging.info(f"extracted bot response:\n\t{response}")
|
548 |
+
|
549 |
+
pp.pprint(response)
|
550 |
+
|
551 |
+
return response
|
552 |
+
|
553 |
+
import pprint as pp
|
554 |
+
|
555 |
+
#@markdown define `generate_beams(prompt: str, num_beams:int =4, ...)`
|
556 |
+
|
557 |
+
|
558 |
+
def generate_beams(
|
559 |
+
prompt: str,
|
560 |
+
suffix:str=None,
|
561 |
+
num_beams=4,
|
562 |
+
min_length: int = 32,
|
563 |
+
max_length: int = 128,
|
564 |
+
no_repeat_ngram_size: int = 3,
|
565 |
+
repetition_penalty=2.5,
|
566 |
+
return_full_text=False,
|
567 |
+
verbose=False,
|
568 |
+
**kwargs,
|
569 |
+
) -> None:
|
570 |
+
|
571 |
+
logging.info(f"generating results for input:\n\t{prompt}\n\t...")
|
572 |
+
if verbose:
|
573 |
+
print(f"generating results for input:\n\t{prompt}\n\t")
|
574 |
+
|
575 |
+
prompt = f"{prompt}{suffix}" if suffix is not None else prompt
|
576 |
+
_prompt_tokens = len(generator.tokenizer(prompt).input_ids)
|
577 |
+
result = generator(
|
578 |
+
prompt,
|
579 |
+
min_length=min_length+_prompt_tokens,
|
580 |
+
num_beams=num_beams,
|
581 |
+
do_sample=False,
|
582 |
+
early_stopping=True,
|
583 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
584 |
+
repetition_penalty=repetition_penalty,
|
585 |
+
remove_invalid_values=True,
|
586 |
+
clean_up_tokenization_spaces=True,
|
587 |
+
return_full_text=return_full_text,
|
588 |
+
max_new_tokens=max_length+_prompt_tokens,
|
589 |
+
pad_token_id=generator.tokenizer.eos_token_id,
|
590 |
+
**kwargs,
|
591 |
+
)
|
592 |
+
|
593 |
+
output = result[0]["generated_text"]
|
594 |
+
logging.info(f"model output:\n\t{output}")
|
595 |
+
if verbose:
|
596 |
+
print(f"model output:\n\t{output}")
|
597 |
+
response = get_bot_response(
|
598 |
+
model_resp=output,
|
599 |
+
name_spk="Person Alpha",
|
600 |
+
name_resp="Person Beta",
|
601 |
+
verbose=False,
|
602 |
+
)
|
603 |
+
|
604 |
+
|
605 |
+
logging.info(f"extracted bot response:\n\t{response}")
|
606 |
+
|
607 |
+
pp.pprint(response)
|
608 |
+
|
609 |
+
return response
|
610 |
+
|
611 |
+
import pprint as pp
|
612 |
+
|
613 |
+
#@markdown define `generate_csearch(prompt: str, num_beams:int =4, ...)`
|
614 |
+
|
615 |
+
|
616 |
+
def generate_csearch(
|
617 |
+
prompt: str,
|
618 |
+
suffix:str=None,
|
619 |
+
max_length: int = 96,
|
620 |
+
min_length: int = 24,
|
621 |
+
penalty_alpha: float=0.6,
|
622 |
+
top_k: int=5,
|
623 |
+
return_full_text=False,
|
624 |
+
verbose=False,
|
625 |
+
**kwargs,
|
626 |
+
) -> None:
|
627 |
+
|
628 |
+
logging.info(f"generating results for input:\n\t{prompt}\n\t...")
|
629 |
+
if verbose:
|
630 |
+
print(f"generating results for input:\n\t{prompt}\n\t")
|
631 |
+
|
632 |
+
prompt = f"{prompt}{suffix}" if suffix is not None else prompt
|
633 |
+
_prompt_tokens = len(generator.tokenizer(prompt).input_ids)
|
634 |
+
result = generator(
|
635 |
+
prompt,
|
636 |
+
min_length=min_length+_prompt_tokens,
|
637 |
+
max_new_tokens=max_length,
|
638 |
+
penalty_alpha=penalty_alpha,
|
639 |
+
top_k=top_k,
|
640 |
+
remove_invalid_values=True,
|
641 |
+
clean_up_tokenization_spaces=True,
|
642 |
+
return_full_text=return_full_text,
|
643 |
+
pad_token_id=generator.tokenizer.eos_token_id,
|
644 |
+
**kwargs,
|
645 |
+
)
|
646 |
+
|
647 |
+
output = result[0]["generated_text"]
|
648 |
+
logging.info(f"model output:\n\t{output}")
|
649 |
+
if verbose:
|
650 |
+
print(f"model output:\n\t{output}")
|
651 |
+
response = get_bot_response(
|
652 |
+
model_resp=output,
|
653 |
+
name_spk="Person Alpha",
|
654 |
+
name_resp="Person Beta",
|
655 |
+
verbose=False,
|
656 |
+
)
|
657 |
+
|
658 |
+
|
659 |
+
logging.info(f"extracted bot response:\n\t{response}")
|
660 |
+
|
661 |
+
pp.pprint(response)
|
662 |
+
|
663 |
+
return response
|
664 |
+
|
665 |
+
"""### generate - sampling
|
666 |
+
|
667 |
+
> **NOTE:** that here the `suffix="\nPerson Beta: ",` is passed so it does not need to be added to a prompt
|
668 |
+
"""
|
669 |
+
|
670 |
+
# Commented out IPython magic to ensure Python compatibility.
|
671 |
+
# %%time
|
672 |
+
#
|
673 |
+
# prompt = "How do we harness space energy?" #@param {type:"string"}
|
674 |
+
# temperature = 0.2 #@param {type:"slider", min:0.1, max:1, step:0.1}
|
675 |
+
# top_k = 30 #@param {type:"slider", min:10, max:60, step:10}
|
676 |
+
#
|
677 |
+
#
|
678 |
+
# result = generate_sampling(
|
679 |
+
# prompt,
|
680 |
+
# suffix="\nPerson Beta: ",
|
681 |
+
# max_length=128,
|
682 |
+
# min_length=32,
|
683 |
+
# temperature=temperature,
|
684 |
+
# top_k=top_k,
|
685 |
+
# )
|
686 |
+
#
|
687 |
+
|
688 |
+
prompt = "What is the purpose of life?" # @param {type:"string"}
|
689 |
+
temperature = 0.5 # @param {type:"slider", min:0.1, max:1, step:0.1}
|
690 |
+
top_k = 30 # @param {type:"slider", min:10, max:60, step:10}
|
691 |
+
|
692 |
+
generated_result = generate_sampling(
|
693 |
+
prompt,
|
694 |
+
temperature=temperature,
|
695 |
+
top_k=top_k,
|
696 |
+
min_length=32,
|
697 |
+
suffix="\nPerson Beta: ",
|
698 |
+
)
|
699 |
+
|
700 |
+
"""### generate - beam search"""
|
701 |
+
|
702 |
+
# Commented out IPython magic to ensure Python compatibility.
|
703 |
+
# %%time
|
704 |
+
# prompt = "How was your day?" #@param {type:"string"}
|
705 |
+
# num_beams = 4 #@param {type:"slider", min:2, max:10, step:2}
|
706 |
+
# min_length = 16 #@param {type:"slider", min:8, max:128, step:8}
|
707 |
+
#
|
708 |
+
# generated_result = generate_beams(
|
709 |
+
# prompt,
|
710 |
+
# suffix="\nPerson Beta: ",
|
711 |
+
# min_length=min_length,
|
712 |
+
# num_beams=num_beams,
|
713 |
+
# )
|
714 |
+
|
715 |
+
"""### generate - contrastive search"""
|
716 |
+
|
717 |
+
# Commented out IPython magic to ensure Python compatibility.
|
718 |
+
# %%time
|
719 |
+
# prompt = "What do you do for fun?" #@param {type:"string"}
|
720 |
+
# top_k = 4 #@param {type:"slider", min:2, max:10, step:2}
|
721 |
+
# penalty_alpha = 0.6 #@param {type:"slider", min:0, max:1, step:0.1}
|
722 |
+
# min_length = 8 #@param {type:"slider", min:8, max:128, step:8}
|
723 |
+
#
|
724 |
+
# generated_result = generate_csearch(
|
725 |
+
# prompt,
|
726 |
+
# suffix="\nPerson Beta: ",
|
727 |
+
# min_length=min_length,
|
728 |
+
# penalty_alpha=penalty_alpha,
|
729 |
+
# top_k=top_k,
|
730 |
+
# num_beams=num_beams,
|
731 |
+
# )
|
732 |
+
|