pcalhoun commited on
Commit
14410b7
·
1 Parent(s): dde8683

Create eval.py

Browse files
Files changed (1) hide show
  1. eval.py +307 -0
eval.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ hfloc = "pcalhoun/gpt-j-6b-8bit-pun-generator"
4
+ filename = "gptj8bit.pt"
5
+
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ #the next few hundred lines are because HF stull doesn't support directly loading models trained in 8bit
9
+ DOWNLOAD_LOC = hf_hub_download(repo_id=hfloc, filename=filename)
10
+
11
+ import logging
12
+ import math
13
+ import sys
14
+ from dataclasses import dataclass, field
15
+ from typing import Optional
16
+ from pathlib import Path
17
+ import random, json, datetime
18
+
19
+ import transformers
20
+ from transformers import (
21
+ CONFIG_MAPPING,
22
+ MODEL_FOR_CAUSAL_LM_MAPPING,
23
+ AutoConfig,
24
+ AutoModelForCausalLM,
25
+ GPTNeoForCausalLM,
26
+ AutoTokenizer,
27
+ HfArgumentParser,
28
+ Trainer,
29
+ TrainingArguments,
30
+ default_data_collator,
31
+ set_seed,
32
+ GPT2Tokenizer
33
+ )
34
+
35
+ from transformers import GPTNeoForSequenceClassification
36
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
37
+ from transformers.utils import check_min_version
38
+ from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
39
+ import os
40
+ from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
41
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, logging
42
+
43
+ import transformers
44
+ import torch
45
+ import torch.nn.functional as F
46
+ from torch import nn
47
+ from torch.cuda.amp import custom_fwd, custom_bwd
48
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
49
+ from tqdm.auto import tqdm
50
+
51
+ class FrozenBNBLinear(nn.Module):
52
+ def __init__(self, weight, absmax, code, bias=None):
53
+ assert isinstance(bias, nn.Parameter) or bias is None
54
+ super().__init__()
55
+ self.out_features, self.in_features = weight.shape
56
+ self.register_buffer("weight", weight.requires_grad_(False))
57
+ self.register_buffer("absmax", absmax.requires_grad_(False))
58
+ self.register_buffer("code", code.requires_grad_(False))
59
+ self.adapter = None
60
+ self.bias = bias
61
+
62
+ def forward(self, input):
63
+ output = torch.clone(DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias))
64
+ if self.adapter:
65
+ output += self.adapter(input)
66
+ return output
67
+
68
+ @classmethod
69
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
70
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
71
+ return cls(weights_int8, *state, linear.bias)
72
+
73
+ def __repr__(self):
74
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
75
+
76
+
77
+ class DequantizeAndLinear(torch.autograd.Function):
78
+ @staticmethod
79
+ @custom_fwd
80
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
81
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
82
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
83
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
84
+ ctx._has_bias = bias is not None
85
+ return F.linear(input, weights_deq, bias)
86
+
87
+ @staticmethod
88
+ @custom_bwd
89
+ def backward(ctx, grad_output: torch.Tensor):
90
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
91
+ input, weights_quantized, absmax, code = ctx.saved_tensors
92
+ # grad_output: [*batch, out_features]
93
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
94
+ grad_input = grad_output @ weights_deq
95
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
96
+ return grad_input, None, None, None, grad_bias
97
+
98
+
99
+ class FrozenBNBEmbedding(nn.Module):
100
+ def __init__(self, weight, absmax, code):
101
+ super().__init__()
102
+ self.num_embeddings, self.embedding_dim = weight.shape
103
+ self.register_buffer("weight", weight.requires_grad_(False))
104
+ self.register_buffer("absmax", absmax.requires_grad_(False))
105
+ self.register_buffer("code", code.requires_grad_(False))
106
+ self.adapter = None
107
+
108
+ def forward(self, input, **kwargs):
109
+ with torch.no_grad():
110
+ # note: both quantuized weights and input indices are *not* differentiable
111
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
112
+ output = F.embedding(input, weight_deq, **kwargs)
113
+ if self.adapter:
114
+ output += self.adapter(input)
115
+ return output
116
+
117
+ @classmethod
118
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
119
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
120
+ return cls(weights_int8, *state)
121
+
122
+ def __repr__(self):
123
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
124
+
125
+
126
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
127
+ assert chunk_size % 4096 == 0
128
+ code = None
129
+ chunks = []
130
+ absmaxes = []
131
+ flat_tensor = matrix.view(-1)
132
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
133
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
134
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
135
+ chunks.append(quantized_chunk)
136
+ absmaxes.append(absmax_chunk)
137
+
138
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
139
+ absmax = torch.cat(absmaxes)
140
+ return matrix_i8, (absmax, code)
141
+
142
+
143
+ def convert_to_int8(model):
144
+ for module in list(model.modules()):
145
+ for name, child in module.named_children():
146
+ if isinstance(child, nn.Linear):
147
+ print(name, child)
148
+ setattr(
149
+ module,
150
+ name,
151
+ FrozenBNBLinear(
152
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
153
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
154
+ code=torch.zeros(256),
155
+ bias=child.bias,
156
+ ),
157
+ )
158
+ elif isinstance(child, nn.Embedding):
159
+ setattr(
160
+ module,
161
+ name,
162
+ FrozenBNBEmbedding(
163
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
164
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
165
+ code=torch.zeros(256),
166
+ )
167
+ )
168
+
169
+ class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
170
+ def __init__(self, config):
171
+ super().__init__(config)
172
+
173
+ convert_to_int8(self.attn)
174
+ convert_to_int8(self.mlp)
175
+
176
+
177
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
178
+ def __init__(self, config):
179
+ super().__init__(config)
180
+ convert_to_int8(self)
181
+
182
+
183
+ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
184
+ def __init__(self, config):
185
+ super().__init__(config)
186
+ convert_to_int8(self)
187
+
188
+
189
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
190
+
191
+ config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
192
+ tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
193
+
194
+ def add_adapters(model, adapter_dim=16):
195
+ assert adapter_dim > 0
196
+ for module in model.modules():
197
+ if isinstance(module, FrozenBNBLinear):
198
+ module.adapter = nn.Sequential(
199
+ nn.Linear(module.in_features, adapter_dim, bias=False),
200
+ nn.Linear(adapter_dim, module.out_features, bias=False),
201
+ )
202
+ nn.init.zeros_(module.adapter[1].weight)
203
+ elif isinstance(module, FrozenBNBEmbedding):
204
+ module.adapter = nn.Sequential(
205
+ nn.Embedding(module.num_embeddings, adapter_dim),
206
+ nn.Linear(adapter_dim, module.embedding_dim, bias=False),
207
+ )
208
+ nn.init.zeros_(module.adapter[1].weight)
209
+
210
+ from transformers import StoppingCriteria, StoppingCriteriaList
211
+ #modified from the EndOfFunctionCriteria class I found somewhere around here:
212
+ # https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_stopping_criteria.html
213
+ class EndOfXCriteria(StoppingCriteria):
214
+ def __init__(self, start_length, eof_strings, tokenizer):
215
+ self.start_length = start_length
216
+ self.eof_strings = eof_strings
217
+ self.tokenizer = tokenizer
218
+ def __call__(self, input_ids, scores, **kwargs):
219
+ decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
220
+ done = []
221
+ for decoded_generation in decoded_generations:
222
+ done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))
223
+ return all(done)
224
+
225
+ from datasets import load_dataset
226
+ import os
227
+ from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
228
+ import json
229
+
230
+ gpt = torch.load(DOWNLOAD_LOC, map_location=torch.device('cuda'))
231
+
232
+ eval_dir = "eval_logs"
233
+
234
+ #Punbot was partly trained as an experiment to see if <|extratoken_X|>s could be used for anything.
235
+ #I couldn't find much about them on google. They're just a quirk of the way the vocabulary was fit
236
+ # to the TPUs, but I thought fine tuning with tokens that weren't in the training set might help
237
+ # explain the "anomalous token" phenomenon.
238
+ #Additional info: https://paulcalhoun.substack.com/p/this-is-the-moment-ive-been-training
239
+
240
+ tk = ["<|extratoken_{}|>".format(i) for i in range(10,145)]
241
+
242
+ PUN_BEGIN = "".join(tk[0:2]) #<|extratoken_10|><|extratoken_11|>
243
+ PUN_END = "".join(tk[2:4]) #<|extratoken_12|><|extratoken_13|>
244
+ HOMOPHONES_START = "".join(tk[4:6]) #<|extratoken_14|><|extratoken_15|>
245
+ HOMOPHONES_SEP = "".join(tk[6:8]) #<|extratoken_16|><|extratoken_17|>
246
+ HOMOPHONES_END = "".join(tk[8:10]) #<|extratoken_18|><|extratoken_19|>
247
+ PHONEMES_BEGIN = "".join(tk[10:12]) #<|extratoken_20|><|extratoken_21|>
248
+ PHONEMES_END = "".join(tk[12:14]) #<|extratoken_22|><|extratoken_23|>
249
+ GRAPHEMES_BEGIN = "".join(tk[14:16]) #<|extratoken_24|><|extratoken_25|>
250
+ GRAPHEMES_END = "".join(tk[16:18]) #<|extratoken_26|><|extratoken_27|>
251
+ EXPLANATION_BEGIN = "".join(tk[18:20]) #<|extratoken_28|><|extratoken_29|>
252
+ EXPLANATION_END = "".join(tk[20:22]) #<|extratoken_30|><|extratoken_31|>
253
+ KEYWORDS_BEGIN = "".join(tk[22:24]) #<|extratoken_32|><|extratoken_33|>
254
+ KEYWORDS_END = "".join(tk[24:26]) #<|extratoken_34|><|extratoken_35|>
255
+
256
+ GENERATE_KEYWORDS_THEN_EXPLANATION_OF_PUN = "".join(tk[30:32]) #<|extratoken_40|><|extratoken_41|>
257
+ GENERATE_PUN_FROM_EXPLANATION_THEN_KEYWORDS = "".join(tk[32:34]) #<|extratoken_42|><|extratoken_43|>
258
+ #GENERATE_EXPLANATION_THEN_KEYWORDS_FROM_PUN = "".join(tk[34:36]) #<|extratoken_44|><|extratoken_45|>
259
+ GENERATE_PUN_THEN_EXPLANATION_FROM_KEYWORDS = "".join(tk[36:38]) #<|extratoken_46|><|extratoken_47|>
260
+ GENERATE_HOMOPHONE_LIST_FROM_WORD = "".join(tk[38:40]) #<|extratoken_48|><|extratoken_49|>
261
+
262
+ TASK_START = "".join(tk[50:52]) #<|extratoken_60|><|extratoken_61|>
263
+ TASK_END = "".join(tk[52:54]) #<|extratoken_62|><|extratoken_63|>
264
+ GENERATE_EXPLANATION_THEN_PUN_FROM_KEYWORDS = "".join(tk[54:56]) #<|extratoken_64|><|extratoken_65|>
265
+
266
+ preprompt = TASK_START + GENERATE_PUN_THEN_EXPLANATION_FROM_KEYWORDS + PUN_BEGIN + GRAPHEMES_BEGIN
267
+
268
+ def generate(prompt = "",preprompt=preprompt,stop_tokens=tk[40:],max_length=256, top_k=50, top_p=0.98):
269
+ full_prompt=preprompt + prompt
270
+ with torch.no_grad():
271
+ inputs = tokenizer(full_prompt, return_tensors="pt").to("cuda")
272
+ tkn_len = len(tokenizer.tokenize(full_prompt))
273
+ outputs = gpt.generate(
274
+ inputs["input_ids"],
275
+ stopping_criteria=StoppingCriteriaList(
276
+ [EndOfXCriteria(
277
+ tkn_len,
278
+ tk,
279
+ tokenizer
280
+ )]),
281
+ max_length=max_length,
282
+ do_sample=True, top_k=top_k, top_p=top_p, pad_token_id=50256)
283
+ text = tokenizer.decode(outputs[0])
284
+ fixed_text = text[len(full_prompt) - len(prompt):]
285
+ for token in tk:
286
+ fixed_text = fixed_text.replace(token,"")
287
+ if not os.path.exists(eval_dir):
288
+ os.makedirs(eval_dir)
289
+ with open(eval_dir + "/" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "".join([c.lower() if c.isalnum() else "_" for c in str(fixed_text)])[:20] + ".txt", "w") as f:
290
+ f.write(text)
291
+ return fixed_text
292
+
293
+ #create single log file for this run with datetime stamp in filename, and append all jokes to it as they are generated
294
+ log_file = "run_logs/" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_log.txt"
295
+
296
+ os.makedirs("run_logs",exist_ok=True)
297
+ with open(log_file, "w") as f:
298
+ f.write("")
299
+
300
+ for n in range(99):
301
+ pun = generate(prompt="")
302
+ for token in tk:
303
+ pun = pun.replace(token,"\n").replace("\n\n","\n").strip()
304
+ print(pun)
305
+ print("\n\n")
306
+ with open(log_file, "a") as f:
307
+ f.write(pun + "\n\n")