Create eval.py
Browse files
eval.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
hfloc = "pcalhoun/gpt-j-6b-8bit-pun-generator"
|
4 |
+
filename = "gptj8bit.pt"
|
5 |
+
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
#the next few hundred lines are because HF stull doesn't support directly loading models trained in 8bit
|
9 |
+
DOWNLOAD_LOC = hf_hub_download(repo_id=hfloc, filename=filename)
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
import sys
|
14 |
+
from dataclasses import dataclass, field
|
15 |
+
from typing import Optional
|
16 |
+
from pathlib import Path
|
17 |
+
import random, json, datetime
|
18 |
+
|
19 |
+
import transformers
|
20 |
+
from transformers import (
|
21 |
+
CONFIG_MAPPING,
|
22 |
+
MODEL_FOR_CAUSAL_LM_MAPPING,
|
23 |
+
AutoConfig,
|
24 |
+
AutoModelForCausalLM,
|
25 |
+
GPTNeoForCausalLM,
|
26 |
+
AutoTokenizer,
|
27 |
+
HfArgumentParser,
|
28 |
+
Trainer,
|
29 |
+
TrainingArguments,
|
30 |
+
default_data_collator,
|
31 |
+
set_seed,
|
32 |
+
GPT2Tokenizer
|
33 |
+
)
|
34 |
+
|
35 |
+
from transformers import GPTNeoForSequenceClassification
|
36 |
+
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
37 |
+
from transformers.utils import check_min_version
|
38 |
+
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
|
39 |
+
import os
|
40 |
+
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
|
41 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, logging
|
42 |
+
|
43 |
+
import transformers
|
44 |
+
import torch
|
45 |
+
import torch.nn.functional as F
|
46 |
+
from torch import nn
|
47 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
48 |
+
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
|
49 |
+
from tqdm.auto import tqdm
|
50 |
+
|
51 |
+
class FrozenBNBLinear(nn.Module):
|
52 |
+
def __init__(self, weight, absmax, code, bias=None):
|
53 |
+
assert isinstance(bias, nn.Parameter) or bias is None
|
54 |
+
super().__init__()
|
55 |
+
self.out_features, self.in_features = weight.shape
|
56 |
+
self.register_buffer("weight", weight.requires_grad_(False))
|
57 |
+
self.register_buffer("absmax", absmax.requires_grad_(False))
|
58 |
+
self.register_buffer("code", code.requires_grad_(False))
|
59 |
+
self.adapter = None
|
60 |
+
self.bias = bias
|
61 |
+
|
62 |
+
def forward(self, input):
|
63 |
+
output = torch.clone(DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias))
|
64 |
+
if self.adapter:
|
65 |
+
output += self.adapter(input)
|
66 |
+
return output
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
|
70 |
+
weights_int8, state = quantize_blockise_lowmemory(linear.weight)
|
71 |
+
return cls(weights_int8, *state, linear.bias)
|
72 |
+
|
73 |
+
def __repr__(self):
|
74 |
+
return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
|
75 |
+
|
76 |
+
|
77 |
+
class DequantizeAndLinear(torch.autograd.Function):
|
78 |
+
@staticmethod
|
79 |
+
@custom_fwd
|
80 |
+
def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
|
81 |
+
absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
|
82 |
+
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
|
83 |
+
ctx.save_for_backward(input, weights_quantized, absmax, code)
|
84 |
+
ctx._has_bias = bias is not None
|
85 |
+
return F.linear(input, weights_deq, bias)
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
@custom_bwd
|
89 |
+
def backward(ctx, grad_output: torch.Tensor):
|
90 |
+
assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
|
91 |
+
input, weights_quantized, absmax, code = ctx.saved_tensors
|
92 |
+
# grad_output: [*batch, out_features]
|
93 |
+
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
|
94 |
+
grad_input = grad_output @ weights_deq
|
95 |
+
grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
|
96 |
+
return grad_input, None, None, None, grad_bias
|
97 |
+
|
98 |
+
|
99 |
+
class FrozenBNBEmbedding(nn.Module):
|
100 |
+
def __init__(self, weight, absmax, code):
|
101 |
+
super().__init__()
|
102 |
+
self.num_embeddings, self.embedding_dim = weight.shape
|
103 |
+
self.register_buffer("weight", weight.requires_grad_(False))
|
104 |
+
self.register_buffer("absmax", absmax.requires_grad_(False))
|
105 |
+
self.register_buffer("code", code.requires_grad_(False))
|
106 |
+
self.adapter = None
|
107 |
+
|
108 |
+
def forward(self, input, **kwargs):
|
109 |
+
with torch.no_grad():
|
110 |
+
# note: both quantuized weights and input indices are *not* differentiable
|
111 |
+
weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
|
112 |
+
output = F.embedding(input, weight_deq, **kwargs)
|
113 |
+
if self.adapter:
|
114 |
+
output += self.adapter(input)
|
115 |
+
return output
|
116 |
+
|
117 |
+
@classmethod
|
118 |
+
def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
|
119 |
+
weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
|
120 |
+
return cls(weights_int8, *state)
|
121 |
+
|
122 |
+
def __repr__(self):
|
123 |
+
return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
|
124 |
+
|
125 |
+
|
126 |
+
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
|
127 |
+
assert chunk_size % 4096 == 0
|
128 |
+
code = None
|
129 |
+
chunks = []
|
130 |
+
absmaxes = []
|
131 |
+
flat_tensor = matrix.view(-1)
|
132 |
+
for i in range((matrix.numel() - 1) // chunk_size + 1):
|
133 |
+
input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
|
134 |
+
quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
|
135 |
+
chunks.append(quantized_chunk)
|
136 |
+
absmaxes.append(absmax_chunk)
|
137 |
+
|
138 |
+
matrix_i8 = torch.cat(chunks).reshape_as(matrix)
|
139 |
+
absmax = torch.cat(absmaxes)
|
140 |
+
return matrix_i8, (absmax, code)
|
141 |
+
|
142 |
+
|
143 |
+
def convert_to_int8(model):
|
144 |
+
for module in list(model.modules()):
|
145 |
+
for name, child in module.named_children():
|
146 |
+
if isinstance(child, nn.Linear):
|
147 |
+
print(name, child)
|
148 |
+
setattr(
|
149 |
+
module,
|
150 |
+
name,
|
151 |
+
FrozenBNBLinear(
|
152 |
+
weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
|
153 |
+
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
|
154 |
+
code=torch.zeros(256),
|
155 |
+
bias=child.bias,
|
156 |
+
),
|
157 |
+
)
|
158 |
+
elif isinstance(child, nn.Embedding):
|
159 |
+
setattr(
|
160 |
+
module,
|
161 |
+
name,
|
162 |
+
FrozenBNBEmbedding(
|
163 |
+
weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
|
164 |
+
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
|
165 |
+
code=torch.zeros(256),
|
166 |
+
)
|
167 |
+
)
|
168 |
+
|
169 |
+
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
|
170 |
+
def __init__(self, config):
|
171 |
+
super().__init__(config)
|
172 |
+
|
173 |
+
convert_to_int8(self.attn)
|
174 |
+
convert_to_int8(self.mlp)
|
175 |
+
|
176 |
+
|
177 |
+
class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
|
178 |
+
def __init__(self, config):
|
179 |
+
super().__init__(config)
|
180 |
+
convert_to_int8(self)
|
181 |
+
|
182 |
+
|
183 |
+
class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
|
184 |
+
def __init__(self, config):
|
185 |
+
super().__init__(config)
|
186 |
+
convert_to_int8(self)
|
187 |
+
|
188 |
+
|
189 |
+
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
|
190 |
+
|
191 |
+
config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
|
192 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
193 |
+
|
194 |
+
def add_adapters(model, adapter_dim=16):
|
195 |
+
assert adapter_dim > 0
|
196 |
+
for module in model.modules():
|
197 |
+
if isinstance(module, FrozenBNBLinear):
|
198 |
+
module.adapter = nn.Sequential(
|
199 |
+
nn.Linear(module.in_features, adapter_dim, bias=False),
|
200 |
+
nn.Linear(adapter_dim, module.out_features, bias=False),
|
201 |
+
)
|
202 |
+
nn.init.zeros_(module.adapter[1].weight)
|
203 |
+
elif isinstance(module, FrozenBNBEmbedding):
|
204 |
+
module.adapter = nn.Sequential(
|
205 |
+
nn.Embedding(module.num_embeddings, adapter_dim),
|
206 |
+
nn.Linear(adapter_dim, module.embedding_dim, bias=False),
|
207 |
+
)
|
208 |
+
nn.init.zeros_(module.adapter[1].weight)
|
209 |
+
|
210 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
211 |
+
#modified from the EndOfFunctionCriteria class I found somewhere around here:
|
212 |
+
# https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_stopping_criteria.html
|
213 |
+
class EndOfXCriteria(StoppingCriteria):
|
214 |
+
def __init__(self, start_length, eof_strings, tokenizer):
|
215 |
+
self.start_length = start_length
|
216 |
+
self.eof_strings = eof_strings
|
217 |
+
self.tokenizer = tokenizer
|
218 |
+
def __call__(self, input_ids, scores, **kwargs):
|
219 |
+
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
|
220 |
+
done = []
|
221 |
+
for decoded_generation in decoded_generations:
|
222 |
+
done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))
|
223 |
+
return all(done)
|
224 |
+
|
225 |
+
from datasets import load_dataset
|
226 |
+
import os
|
227 |
+
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
|
228 |
+
import json
|
229 |
+
|
230 |
+
gpt = torch.load(DOWNLOAD_LOC, map_location=torch.device('cuda'))
|
231 |
+
|
232 |
+
eval_dir = "eval_logs"
|
233 |
+
|
234 |
+
#Punbot was partly trained as an experiment to see if <|extratoken_X|>s could be used for anything.
|
235 |
+
#I couldn't find much about them on google. They're just a quirk of the way the vocabulary was fit
|
236 |
+
# to the TPUs, but I thought fine tuning with tokens that weren't in the training set might help
|
237 |
+
# explain the "anomalous token" phenomenon.
|
238 |
+
#Additional info: https://paulcalhoun.substack.com/p/this-is-the-moment-ive-been-training
|
239 |
+
|
240 |
+
tk = ["<|extratoken_{}|>".format(i) for i in range(10,145)]
|
241 |
+
|
242 |
+
PUN_BEGIN = "".join(tk[0:2]) #<|extratoken_10|><|extratoken_11|>
|
243 |
+
PUN_END = "".join(tk[2:4]) #<|extratoken_12|><|extratoken_13|>
|
244 |
+
HOMOPHONES_START = "".join(tk[4:6]) #<|extratoken_14|><|extratoken_15|>
|
245 |
+
HOMOPHONES_SEP = "".join(tk[6:8]) #<|extratoken_16|><|extratoken_17|>
|
246 |
+
HOMOPHONES_END = "".join(tk[8:10]) #<|extratoken_18|><|extratoken_19|>
|
247 |
+
PHONEMES_BEGIN = "".join(tk[10:12]) #<|extratoken_20|><|extratoken_21|>
|
248 |
+
PHONEMES_END = "".join(tk[12:14]) #<|extratoken_22|><|extratoken_23|>
|
249 |
+
GRAPHEMES_BEGIN = "".join(tk[14:16]) #<|extratoken_24|><|extratoken_25|>
|
250 |
+
GRAPHEMES_END = "".join(tk[16:18]) #<|extratoken_26|><|extratoken_27|>
|
251 |
+
EXPLANATION_BEGIN = "".join(tk[18:20]) #<|extratoken_28|><|extratoken_29|>
|
252 |
+
EXPLANATION_END = "".join(tk[20:22]) #<|extratoken_30|><|extratoken_31|>
|
253 |
+
KEYWORDS_BEGIN = "".join(tk[22:24]) #<|extratoken_32|><|extratoken_33|>
|
254 |
+
KEYWORDS_END = "".join(tk[24:26]) #<|extratoken_34|><|extratoken_35|>
|
255 |
+
|
256 |
+
GENERATE_KEYWORDS_THEN_EXPLANATION_OF_PUN = "".join(tk[30:32]) #<|extratoken_40|><|extratoken_41|>
|
257 |
+
GENERATE_PUN_FROM_EXPLANATION_THEN_KEYWORDS = "".join(tk[32:34]) #<|extratoken_42|><|extratoken_43|>
|
258 |
+
#GENERATE_EXPLANATION_THEN_KEYWORDS_FROM_PUN = "".join(tk[34:36]) #<|extratoken_44|><|extratoken_45|>
|
259 |
+
GENERATE_PUN_THEN_EXPLANATION_FROM_KEYWORDS = "".join(tk[36:38]) #<|extratoken_46|><|extratoken_47|>
|
260 |
+
GENERATE_HOMOPHONE_LIST_FROM_WORD = "".join(tk[38:40]) #<|extratoken_48|><|extratoken_49|>
|
261 |
+
|
262 |
+
TASK_START = "".join(tk[50:52]) #<|extratoken_60|><|extratoken_61|>
|
263 |
+
TASK_END = "".join(tk[52:54]) #<|extratoken_62|><|extratoken_63|>
|
264 |
+
GENERATE_EXPLANATION_THEN_PUN_FROM_KEYWORDS = "".join(tk[54:56]) #<|extratoken_64|><|extratoken_65|>
|
265 |
+
|
266 |
+
preprompt = TASK_START + GENERATE_PUN_THEN_EXPLANATION_FROM_KEYWORDS + PUN_BEGIN + GRAPHEMES_BEGIN
|
267 |
+
|
268 |
+
def generate(prompt = "",preprompt=preprompt,stop_tokens=tk[40:],max_length=256, top_k=50, top_p=0.98):
|
269 |
+
full_prompt=preprompt + prompt
|
270 |
+
with torch.no_grad():
|
271 |
+
inputs = tokenizer(full_prompt, return_tensors="pt").to("cuda")
|
272 |
+
tkn_len = len(tokenizer.tokenize(full_prompt))
|
273 |
+
outputs = gpt.generate(
|
274 |
+
inputs["input_ids"],
|
275 |
+
stopping_criteria=StoppingCriteriaList(
|
276 |
+
[EndOfXCriteria(
|
277 |
+
tkn_len,
|
278 |
+
tk,
|
279 |
+
tokenizer
|
280 |
+
)]),
|
281 |
+
max_length=max_length,
|
282 |
+
do_sample=True, top_k=top_k, top_p=top_p, pad_token_id=50256)
|
283 |
+
text = tokenizer.decode(outputs[0])
|
284 |
+
fixed_text = text[len(full_prompt) - len(prompt):]
|
285 |
+
for token in tk:
|
286 |
+
fixed_text = fixed_text.replace(token,"")
|
287 |
+
if not os.path.exists(eval_dir):
|
288 |
+
os.makedirs(eval_dir)
|
289 |
+
with open(eval_dir + "/" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "".join([c.lower() if c.isalnum() else "_" for c in str(fixed_text)])[:20] + ".txt", "w") as f:
|
290 |
+
f.write(text)
|
291 |
+
return fixed_text
|
292 |
+
|
293 |
+
#create single log file for this run with datetime stamp in filename, and append all jokes to it as they are generated
|
294 |
+
log_file = "run_logs/" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_log.txt"
|
295 |
+
|
296 |
+
os.makedirs("run_logs",exist_ok=True)
|
297 |
+
with open(log_file, "w") as f:
|
298 |
+
f.write("")
|
299 |
+
|
300 |
+
for n in range(99):
|
301 |
+
pun = generate(prompt="")
|
302 |
+
for token in tk:
|
303 |
+
pun = pun.replace(token,"\n").replace("\n\n","\n").strip()
|
304 |
+
print(pun)
|
305 |
+
print("\n\n")
|
306 |
+
with open(log_file, "a") as f:
|
307 |
+
f.write(pun + "\n\n")
|