pszemraj commited on
Commit
2e9abb9
1 Parent(s): 12e7206

Upload ai_msgbot_gpt_j_6b_8bit_with_hub.py

Browse files
Files changed (1) hide show
  1. ai_msgbot_gpt_j_6b_8bit_with_hub.py +732 -0
ai_msgbot_gpt_j_6b_8bit_with_hub.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """ai-msgbot-gpt-j-6b-8bit with hub.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/12IXeac5sEUL7dX2bQfB8BZ46lHwK8-dT
8
+
9
+ # <center> ai-msgbot - conversational 6B GPT-J 8bit demo
10
+
11
+
12
+ > This notebook demos interaction with a 6B GPT-J finetuned for dialogue via methods in [ai-msgbot](https://github.com/pszemraj/ai-msgbot)
13
+
14
+
15
+ By [Peter](https://github.com/pszemraj). This notebook and `ai-msgbot` are [licensed under creative commons](https://github.com/pszemraj/ai-msgbot/blob/main/LICENSE). Models trained on given datasets are subject to those datasets' licenses.
16
+
17
+
18
+ ## usage
19
+
20
+ 1. select the checkpoint of the model to use for generation in the `model_checkpoint` dropdown
21
+ 2. Run all cells to load everything
22
+ 3. adjust the prompt fields at the bottom of the notebook to whatever you want, see how AI responds.
23
+
24
+
25
+ A fine-tuning example etc. will come _eventually_
26
+
27
+
28
+ ---
29
+
30
+ # setup
31
+ """
32
+
33
+ #@markdown setup logging
34
+ import logging
35
+ from pathlib import Path
36
+ for handler in logging.root.handlers[:]:
37
+ logging.root.removeHandler(handler)
38
+
39
+ das_logfile = Path.cwd() / "8bit_inference.log"
40
+
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ filename=das_logfile,
44
+ filemode='w',
45
+ format="%(asctime)s %(levelname)s %(message)s",
46
+ datefmt="%m/%d/%Y %I:%M:%S",
47
+ )
48
+
49
+ #@markdown add auto-Colab formatting with `IPython.display`
50
+ from IPython.display import HTML, display
51
+ # colab formatting
52
+ def set_css():
53
+ display(
54
+ HTML(
55
+ """
56
+ <style>
57
+ pre {
58
+ white-space: pre-wrap;
59
+ }
60
+ </style>
61
+ """
62
+ )
63
+ )
64
+
65
+ get_ipython().events.register("pre_run_cell", set_css)
66
+
67
+ from pathlib import Path
68
+
69
+ """### GPU info"""
70
+
71
+ !nvidia-smi
72
+
73
+ """## install and import
74
+
75
+ _this notebook uses a specific version of `torch` which can take a while to install._
76
+ """
77
+
78
+ !pip install transformers==4.24.0 -q
79
+ !pip install bitsandbytes==0.32.2 -q
80
+ !pip install datasets==1.16.1 -q
81
+ !pip install torch==1.11 -q
82
+ !pip install accelerate==0.12.0 -q
83
+ !pip install pysbd==0.3.4 -q
84
+
85
+ # Commented out IPython magic to ensure Python compatibility.
86
+ # %%capture
87
+ # import transformers
88
+ #
89
+ # import pandas as pd
90
+ #
91
+ # import torch
92
+ # import torch.nn.functional as F
93
+ # from torch import nn
94
+ # from torch.cuda.amp import custom_fwd, custom_bwd
95
+ #
96
+ # import bitsandbytes as bnb
97
+ # from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
98
+ #
99
+ # from tqdm.auto import tqdm
100
+
101
+ #@markdown utils
102
+ from transformers.utils.logging import set_verbosity
103
+
104
+ set_verbosity(40)
105
+
106
+ import warnings
107
+ # ignore hf pipeline complaints
108
+ warnings.filterwarnings("ignore", category=UserWarning, module='transformers')
109
+
110
+ """## Converting the model to 8 bits
111
+
112
+ """
113
+
114
+ #@title define 8bit classes
115
+
116
+ #@markdown - bitsandbytes lib
117
+ class FrozenBNBLinear(nn.Module):
118
+ def __init__(self, weight, absmax, code, bias=None):
119
+ assert isinstance(bias, nn.Parameter) or bias is None
120
+ super().__init__()
121
+ self.out_features, self.in_features = weight.shape
122
+ self.register_buffer("weight", weight.requires_grad_(False))
123
+ self.register_buffer("absmax", absmax.requires_grad_(False))
124
+ self.register_buffer("code", code.requires_grad_(False))
125
+ self.adapter = None
126
+ self.bias = bias
127
+
128
+ def forward(self, input):
129
+ output = DequantizeAndLinear.apply(
130
+ input, self.weight, self.absmax, self.code, self.bias
131
+ )
132
+ if self.adapter:
133
+ output += self.adapter(input)
134
+ return output
135
+
136
+ @classmethod
137
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
138
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
139
+ return cls(weights_int8, *state, linear.bias)
140
+
141
+ def __repr__(self):
142
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
143
+
144
+
145
+ class DequantizeAndLinear(torch.autograd.Function):
146
+ @staticmethod
147
+ @custom_fwd
148
+ def forward(
149
+ ctx,
150
+ input: torch.Tensor,
151
+ weights_quantized: torch.ByteTensor,
152
+ absmax: torch.FloatTensor,
153
+ code: torch.FloatTensor,
154
+ bias: torch.FloatTensor,
155
+ ):
156
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
157
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
158
+ ctx._has_bias = bias is not None
159
+ return F.linear(input, weights_deq, bias)
160
+
161
+ @staticmethod
162
+ @custom_bwd
163
+ def backward(ctx, grad_output: torch.Tensor):
164
+ assert (
165
+ not ctx.needs_input_grad[1]
166
+ and not ctx.needs_input_grad[2]
167
+ and not ctx.needs_input_grad[3]
168
+ )
169
+ input, weights_quantized, absmax, code = ctx.saved_tensors
170
+ # grad_output: [*batch, out_features]
171
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
172
+ grad_input = grad_output @ weights_deq
173
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
174
+ return grad_input, None, None, None, grad_bias
175
+
176
+
177
+ class FrozenBNBEmbedding(nn.Module):
178
+ def __init__(self, weight, absmax, code):
179
+ super().__init__()
180
+ self.num_embeddings, self.embedding_dim = weight.shape
181
+ self.register_buffer("weight", weight.requires_grad_(False))
182
+ self.register_buffer("absmax", absmax.requires_grad_(False))
183
+ self.register_buffer("code", code.requires_grad_(False))
184
+ self.adapter = None
185
+
186
+ def forward(self, input, **kwargs):
187
+ with torch.no_grad():
188
+ # note: both quantuized weights and input indices are *not* differentiable
189
+ weight_deq = dequantize_blockwise(
190
+ self.weight, absmax=self.absmax, code=self.code
191
+ )
192
+ output = F.embedding(input, weight_deq, **kwargs)
193
+ if self.adapter:
194
+ output += self.adapter(input)
195
+ return output
196
+
197
+ @classmethod
198
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
199
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
200
+ return cls(weights_int8, *state)
201
+
202
+ def __repr__(self):
203
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
204
+
205
+
206
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2**20):
207
+ assert chunk_size % 4096 == 0
208
+ code = None
209
+ chunks = []
210
+ absmaxes = []
211
+ flat_tensor = matrix.view(-1)
212
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
213
+ input_chunk = flat_tensor[i * chunk_size : (i + 1) * chunk_size].clone()
214
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(
215
+ input_chunk, code=code
216
+ )
217
+ chunks.append(quantized_chunk)
218
+ absmaxes.append(absmax_chunk)
219
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
220
+ absmax = torch.cat(absmaxes)
221
+ return matrix_i8, (absmax, code)
222
+
223
+
224
+ def convert_to_int8(model):
225
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
226
+ for module in list(model.modules()):
227
+ for name, child in module.named_children():
228
+ if isinstance(child, nn.Linear):
229
+ print(name, child)
230
+ setattr(
231
+ module,
232
+ name,
233
+ FrozenBNBLinear(
234
+ weight=torch.zeros(
235
+ child.out_features, child.in_features, dtype=torch.uint8
236
+ ),
237
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
238
+ code=torch.zeros(256),
239
+ bias=child.bias,
240
+ ),
241
+ )
242
+ elif isinstance(child, nn.Embedding):
243
+ setattr(
244
+ module,
245
+ name,
246
+ FrozenBNBEmbedding(
247
+ weight=torch.zeros(
248
+ child.num_embeddings, child.embedding_dim, dtype=torch.uint8
249
+ ),
250
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
251
+ code=torch.zeros(256),
252
+ ),
253
+ )
254
+
255
+ #@markdown Patch GPT-J before loading:
256
+
257
+
258
+ class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
259
+ def __init__(self, config):
260
+ super().__init__(config)
261
+
262
+ convert_to_int8(self.attn)
263
+ convert_to_int8(self.mlp)
264
+
265
+
266
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ convert_to_int8(self)
270
+
271
+
272
+ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
273
+ def __init__(self, config):
274
+ super().__init__(config)
275
+ convert_to_int8(self)
276
+
277
+
278
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
279
+
280
+ # Commented out IPython magic to ensure Python compatibility.
281
+ # %%capture
282
+ # #@markdown `add_adapters()`
283
+ #
284
+ # def add_adapters(model, adapter_dim=4, p = 0.1):
285
+ # assert adapter_dim > 0
286
+ #
287
+ # for name, module in model.named_modules():
288
+ # if isinstance(module, FrozenBNBLinear):
289
+ # if "attn" in name or "mlp" in name or "head" in name:
290
+ # print("Adding adapter to", name)
291
+ # module.adapter = nn.Sequential(
292
+ # nn.Linear(module.in_features, adapter_dim, bias=False),
293
+ # nn.Dropout(p=p),
294
+ # nn.Linear(adapter_dim, module.out_features, bias=False),
295
+ # )
296
+ # print("Initializing", name)
297
+ # nn.init.zeros_(module.adapter[2].weight)
298
+ #
299
+ # else:
300
+ # print("Not adding adapter to", name)
301
+ # elif isinstance(module, FrozenBNBEmbedding):
302
+ # print("Adding adapter to", name)
303
+ # module.adapter = nn.Sequential(
304
+ # nn.Embedding(module.num_embeddings, adapter_dim),
305
+ # nn.Dropout(p=p),
306
+ # nn.Linear(adapter_dim, module.embedding_dim, bias=False),
307
+ # )
308
+ # print("Initializing", name)
309
+ # nn.init.zeros_(module.adapter[2].weight)
310
+ #
311
+
312
+ #@markdown set up config
313
+ config = transformers.GPTJConfig.from_pretrained("hivemind/gpt-j-6B-8bit")
314
+ tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
315
+ config.pad_token_id = config.eos_token_id
316
+ tokenizer.pad_token = config.pad_token_id
317
+
318
+ """# load model
319
+
320
+ """
321
+
322
+ from contextlib import contextmanager
323
+ import sys, os, gc
324
+ import logging
325
+ from tqdm.auto import tqdm
326
+ #@markdown define `load_8bit_from_hub()`
327
+
328
+ @contextmanager
329
+ def suppress_stdout():
330
+ with open(os.devnull, "w") as devnull:
331
+ old_stdout = sys.stdout
332
+ sys.stdout = devnull
333
+ try:
334
+ yield
335
+ finally:
336
+ sys.stdout = old_stdout
337
+
338
+ def load_8bit_from_hub(model_id:str, **kwargs):
339
+ pbar = tqdm(desc="instantiating model..", total=3)
340
+
341
+ with suppress_stdout():
342
+ gc.collect()
343
+ model = GPTJForCausalLM.from_pretrained(model_id,
344
+ device_map='auto',
345
+ low_cpu_mem_usage=True,
346
+ **kwargs)
347
+ pbar.update()
348
+ add_adapters(model)
349
+ pbar.update()
350
+ model = model.to("cuda" if torch.cuda.is_available() else -1)
351
+ pbar.update()
352
+ return model
353
+
354
+ from huggingface_hub import notebook_login
355
+
356
+ notebook_login()
357
+
358
+ model_name = "ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps" #@param ["ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps"]
359
+
360
+ # load_8bit_from_hub() is a wrapper around AutoModel.from_pretrained() and will
361
+ # passthrough all kwargs to that
362
+ model = load_8bit_from_hub(model_name, use_auth_token=True, )
363
+
364
+ """# generate text
365
+
366
+ ## standard generation
367
+ `
368
+
369
+ with torch:
370
+
371
+ > with "standard" generation it's recommended to put the **speaker token labels** at the end of your prompt so the model "knows" to respond.
372
+
373
+ i.e `Person Alpha:` or `Person Beta:` for these two models.
374
+ """
375
+
376
+ prompt = "Person Alpha: what is the theory of being \"woke\" all about?\\n Person Beta: " # @param {type:"string"}
377
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
378
+ with torch.no_grad():
379
+ prompt = tokenizer(prompt, return_tensors="pt")
380
+ prompt = {key: value.to(device) for key, value in prompt.items()}
381
+ out = model.generate(
382
+ **prompt,
383
+ min_length=24,
384
+ max_length=96,
385
+ top_k=30,
386
+ top_p=0.9,
387
+ temperature=0.4,
388
+ do_sample=True,
389
+ repetition_penalty=1.2,
390
+ no_repeat_ngram_size=3,
391
+ pad_token_id=tokenizer.eos_token_id,
392
+ )
393
+ result = tokenizer.decode(
394
+ out[0],
395
+ remove_invalid_values=True,
396
+ skip_special_tokens=True,
397
+ clean_up_tokenization_spaces=True,
398
+ )
399
+ result
400
+
401
+ """---
402
+
403
+ ## 'Extract' bot response
404
+ - transformers `pipeline` object
405
+ - generate with better params
406
+ - extract the bot's response with `get_bot_response()` - start to use [ai-msgbot](https://github.com/pszemraj/ai-msgbot) _like it was meant to be used_
407
+ """
408
+
409
+ from transformers import pipeline
410
+
411
+ generator = pipeline(
412
+ "text-generation",
413
+ model=model,
414
+ tokenizer="EleutherAI/gpt-j-6B",
415
+ device= 0 if torch.cuda.is_available() else -1,
416
+ )
417
+
418
+ """### generation functions
419
+
420
+ for extracting the response, beam search vs. sampling, etc
421
+ """
422
+
423
+ # @markdown `get_bot_response(name_resp: str, model_resp: list, name_spk: str, verbose: bool = False)`
424
+ # @markdown - this extracts the response from "Person Beta" from the total generation
425
+ import pysbd
426
+
427
+ seg = pysbd.Segmenter(language="en", clean=False)
428
+
429
+ import re
430
+
431
+
432
+ def split_sentences(text, use_regex=False, min_len=2):
433
+ """given a string, splits it into sentences based on punctuation marks."""
434
+
435
+ if use_regex:
436
+ sentences = re.split(r'(?<=[.!?]) +', string)
437
+ else:
438
+ # https://github.com/nipunsadvilkar/pySBD
439
+ sentences = seg.segment(text)
440
+ return [s.strip() for s in sentences if len(s.strip()) > min_len]
441
+
442
+
443
+ def validate_response(response_text):
444
+
445
+ if isinstance(response_text, list):
446
+
447
+ return response_text
448
+ # if len(response_text) > 1 else split_sentences(str(response_text))
449
+ elif isinstance(response_text, str):
450
+ return split_sentences(response_text)
451
+ else:
452
+ raise ValueError(f"response input {response_text} not a list or str..")
453
+
454
+
455
+ def get_bot_response(
456
+ name_resp: str, model_resp: list, name_spk: str, verbose: bool = False
457
+ ):
458
+ """
459
+ get_bot_response - gets the bot response to a prompt, checking to ensure that additional statements by the "speaker" are not included in the response.
460
+ Args:
461
+ name_resp (str): the name of the responder
462
+ model_resp (list): the model response
463
+ name_spk (str): the name of the speaker
464
+ verbose (bool, optional): Defaults to False.
465
+ Returns:
466
+ bot_response (str): the bot response, isolated down to just text without the "name tokens" or further messages from the speaker.
467
+ """
468
+
469
+ model_resp = validate_response(model_resp)
470
+ logging.info(f"isolating response from:\t{model_resp}")
471
+ fn_resp = []
472
+
473
+ name_counter = 0
474
+ break_safe = False
475
+ for resline in model_resp:
476
+ if name_resp.lower() in resline.lower():
477
+ name_counter += 1
478
+ break_safe = True
479
+ continue
480
+ if ":" in resline and name_resp.lower() not in resline.lower():
481
+ break
482
+ if name_spk.lower() in resline.lower() and not break_safe:
483
+ break
484
+ else:
485
+ fn_resp.append(resline)
486
+ if verbose:
487
+ print("the full response is:\n")
488
+ print("\n".join(fn_resp))
489
+ if isinstance(fn_resp, list):
490
+ fn_resp = fn_resp[0] if len(fn_resp) == 1 else " ".join(fn_resp)
491
+ return fn_resp
492
+
493
+ import pprint as pp
494
+
495
+ # @markdown define `generate_sampling(prompt: str, ...)`
496
+
497
+
498
+ def generate_sampling(
499
+ prompt: str,
500
+ suffix:str=None,
501
+ temperature=0.4,
502
+ top_k: int = 40,
503
+ top_p=0.90,
504
+ min_length: int = 16,
505
+ max_length: int = 128,
506
+ no_repeat_ngram_size: int = 3,
507
+ repetition_penalty=1.5,
508
+ return_full_text=False,
509
+ verbose=False,
510
+ **kwargs,
511
+ ) -> None:
512
+
513
+ logging.info(f"generating results for input:\n\t{prompt}\n\t...")
514
+ if verbose:
515
+ print(f"generating results for input:\n\t{prompt}\n\t...")
516
+ prompt = f"{prompt}{suffix}" if suffix is not None else prompt
517
+
518
+ _prompt_tokens = len(generator.tokenizer(prompt).input_ids)
519
+ result = generator(
520
+ prompt,
521
+ min_length=min_length+_prompt_tokens,
522
+ temperature=temperature,
523
+ top_k=top_k,
524
+ top_p=top_p,
525
+ no_repeat_ngram_size=no_repeat_ngram_size,
526
+ repetition_penalty=repetition_penalty,
527
+ remove_invalid_values=True,
528
+ clean_up_tokenization_spaces=True,
529
+ do_sample=True,
530
+ return_full_text=return_full_text,
531
+ max_new_tokens=max_length+_prompt_tokens,
532
+ pad_token_id=generator.tokenizer.eos_token_id,
533
+ **kwargs,
534
+ )
535
+
536
+ output = result[0]["generated_text"]
537
+ logging.info(f"model output:\n\t{output}")
538
+ if verbose:
539
+ print(f"model output:\n\t{output}")
540
+ response = get_bot_response(
541
+ model_resp=output,
542
+ name_spk="Person Alpha",
543
+ name_resp="Person Beta",
544
+ verbose=False,
545
+ )
546
+
547
+ logging.info(f"extracted bot response:\n\t{response}")
548
+
549
+ pp.pprint(response)
550
+
551
+ return response
552
+
553
+ import pprint as pp
554
+
555
+ #@markdown define `generate_beams(prompt: str, num_beams:int =4, ...)`
556
+
557
+
558
+ def generate_beams(
559
+ prompt: str,
560
+ suffix:str=None,
561
+ num_beams=4,
562
+ min_length: int = 32,
563
+ max_length: int = 128,
564
+ no_repeat_ngram_size: int = 3,
565
+ repetition_penalty=2.5,
566
+ return_full_text=False,
567
+ verbose=False,
568
+ **kwargs,
569
+ ) -> None:
570
+
571
+ logging.info(f"generating results for input:\n\t{prompt}\n\t...")
572
+ if verbose:
573
+ print(f"generating results for input:\n\t{prompt}\n\t")
574
+
575
+ prompt = f"{prompt}{suffix}" if suffix is not None else prompt
576
+ _prompt_tokens = len(generator.tokenizer(prompt).input_ids)
577
+ result = generator(
578
+ prompt,
579
+ min_length=min_length+_prompt_tokens,
580
+ num_beams=num_beams,
581
+ do_sample=False,
582
+ early_stopping=True,
583
+ no_repeat_ngram_size=no_repeat_ngram_size,
584
+ repetition_penalty=repetition_penalty,
585
+ remove_invalid_values=True,
586
+ clean_up_tokenization_spaces=True,
587
+ return_full_text=return_full_text,
588
+ max_new_tokens=max_length+_prompt_tokens,
589
+ pad_token_id=generator.tokenizer.eos_token_id,
590
+ **kwargs,
591
+ )
592
+
593
+ output = result[0]["generated_text"]
594
+ logging.info(f"model output:\n\t{output}")
595
+ if verbose:
596
+ print(f"model output:\n\t{output}")
597
+ response = get_bot_response(
598
+ model_resp=output,
599
+ name_spk="Person Alpha",
600
+ name_resp="Person Beta",
601
+ verbose=False,
602
+ )
603
+
604
+
605
+ logging.info(f"extracted bot response:\n\t{response}")
606
+
607
+ pp.pprint(response)
608
+
609
+ return response
610
+
611
+ import pprint as pp
612
+
613
+ #@markdown define `generate_csearch(prompt: str, num_beams:int =4, ...)`
614
+
615
+
616
+ def generate_csearch(
617
+ prompt: str,
618
+ suffix:str=None,
619
+ max_length: int = 96,
620
+ min_length: int = 24,
621
+ penalty_alpha: float=0.6,
622
+ top_k: int=5,
623
+ return_full_text=False,
624
+ verbose=False,
625
+ **kwargs,
626
+ ) -> None:
627
+
628
+ logging.info(f"generating results for input:\n\t{prompt}\n\t...")
629
+ if verbose:
630
+ print(f"generating results for input:\n\t{prompt}\n\t")
631
+
632
+ prompt = f"{prompt}{suffix}" if suffix is not None else prompt
633
+ _prompt_tokens = len(generator.tokenizer(prompt).input_ids)
634
+ result = generator(
635
+ prompt,
636
+ min_length=min_length+_prompt_tokens,
637
+ max_new_tokens=max_length,
638
+ penalty_alpha=penalty_alpha,
639
+ top_k=top_k,
640
+ remove_invalid_values=True,
641
+ clean_up_tokenization_spaces=True,
642
+ return_full_text=return_full_text,
643
+ pad_token_id=generator.tokenizer.eos_token_id,
644
+ **kwargs,
645
+ )
646
+
647
+ output = result[0]["generated_text"]
648
+ logging.info(f"model output:\n\t{output}")
649
+ if verbose:
650
+ print(f"model output:\n\t{output}")
651
+ response = get_bot_response(
652
+ model_resp=output,
653
+ name_spk="Person Alpha",
654
+ name_resp="Person Beta",
655
+ verbose=False,
656
+ )
657
+
658
+
659
+ logging.info(f"extracted bot response:\n\t{response}")
660
+
661
+ pp.pprint(response)
662
+
663
+ return response
664
+
665
+ """### generate - sampling
666
+
667
+ > **NOTE:** that here the `suffix="\nPerson Beta: ",` is passed so it does not need to be added to a prompt
668
+ """
669
+
670
+ # Commented out IPython magic to ensure Python compatibility.
671
+ # %%time
672
+ #
673
+ # prompt = "How do we harness space energy?" #@param {type:"string"}
674
+ # temperature = 0.2 #@param {type:"slider", min:0.1, max:1, step:0.1}
675
+ # top_k = 30 #@param {type:"slider", min:10, max:60, step:10}
676
+ #
677
+ #
678
+ # result = generate_sampling(
679
+ # prompt,
680
+ # suffix="\nPerson Beta: ",
681
+ # max_length=128,
682
+ # min_length=32,
683
+ # temperature=temperature,
684
+ # top_k=top_k,
685
+ # )
686
+ #
687
+
688
+ prompt = "What is the purpose of life?" # @param {type:"string"}
689
+ temperature = 0.5 # @param {type:"slider", min:0.1, max:1, step:0.1}
690
+ top_k = 30 # @param {type:"slider", min:10, max:60, step:10}
691
+
692
+ generated_result = generate_sampling(
693
+ prompt,
694
+ temperature=temperature,
695
+ top_k=top_k,
696
+ min_length=32,
697
+ suffix="\nPerson Beta: ",
698
+ )
699
+
700
+ """### generate - beam search"""
701
+
702
+ # Commented out IPython magic to ensure Python compatibility.
703
+ # %%time
704
+ # prompt = "How was your day?" #@param {type:"string"}
705
+ # num_beams = 4 #@param {type:"slider", min:2, max:10, step:2}
706
+ # min_length = 16 #@param {type:"slider", min:8, max:128, step:8}
707
+ #
708
+ # generated_result = generate_beams(
709
+ # prompt,
710
+ # suffix="\nPerson Beta: ",
711
+ # min_length=min_length,
712
+ # num_beams=num_beams,
713
+ # )
714
+
715
+ """### generate - contrastive search"""
716
+
717
+ # Commented out IPython magic to ensure Python compatibility.
718
+ # %%time
719
+ # prompt = "What do you do for fun?" #@param {type:"string"}
720
+ # top_k = 4 #@param {type:"slider", min:2, max:10, step:2}
721
+ # penalty_alpha = 0.6 #@param {type:"slider", min:0, max:1, step:0.1}
722
+ # min_length = 8 #@param {type:"slider", min:8, max:128, step:8}
723
+ #
724
+ # generated_result = generate_csearch(
725
+ # prompt,
726
+ # suffix="\nPerson Beta: ",
727
+ # min_length=min_length,
728
+ # penalty_alpha=penalty_alpha,
729
+ # top_k=top_k,
730
+ # num_beams=num_beams,
731
+ # )
732
+