|
|
|
"""-ai-msgbot-gpt-j-6b-8bit-with-hub.ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/gist/pszemraj/e49c60aafe04acc52fcfdd1baefe12e4/-ai-msgbot-gpt-j-6b-8bit-with-hub.ipynb |
|
|
|
# <center> ai-msgbot - conversational 6B GPT-J 8bit demo |
|
|
|
|
|
> This notebook demos interaction with a 6B GPT-J finetuned for dialogue via methods in [ai-msgbot](https://github.com/pszemraj/ai-msgbot) |
|
|
|
|
|
By [Peter](https://github.com/pszemraj). This notebook and `ai-msgbot` are [licensed under creative commons](https://github.com/pszemraj/ai-msgbot/blob/main/LICENSE). Models trained on given datasets are subject to those datasets' licenses. |
|
|
|
|
|
## usage |
|
|
|
1. select the checkpoint of the model to use for generation in the `model_checkpoint` dropdown |
|
2. Run all cells to load everything |
|
3. adjust the prompt fields at the bottom of the notebook to whatever you want, see how AI responds. |
|
|
|
|
|
A fine-tuning example etc. will come _eventually_ |
|
|
|
|
|
--- |
|
|
|
# setup |
|
""" |
|
|
|
|
|
import logging |
|
from pathlib import Path |
|
for handler in logging.root.handlers[:]: |
|
logging.root.removeHandler(handler) |
|
|
|
das_logfile = Path.cwd() / "8bit_inference.log" |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
filename=das_logfile, |
|
filemode='w', |
|
format="%(asctime)s %(levelname)s %(message)s", |
|
datefmt="%m/%d/%Y %I:%M:%S", |
|
) |
|
|
|
|
|
from IPython.display import HTML, display |
|
|
|
def set_css(): |
|
display( |
|
HTML( |
|
""" |
|
<style> |
|
pre { |
|
white-space: pre-wrap; |
|
} |
|
</style> |
|
""" |
|
) |
|
) |
|
|
|
get_ipython().events.register("pre_run_cell", set_css) |
|
|
|
from pathlib import Path |
|
|
|
"""### GPU info""" |
|
|
|
!nvidia-smi |
|
|
|
"""## install and import |
|
|
|
_this notebook uses a specific version of `torch` which can take a while to install._ |
|
""" |
|
|
|
!pip install transformers==4.24.0 -q |
|
!pip install bitsandbytes==0.32.2 -q |
|
!pip install datasets==1.16.1 -q |
|
!pip install torch==1.11 -q |
|
!pip install accelerate==0.12.0 -q |
|
!pip install pysbd==0.3.4 -q |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers.utils.logging import set_verbosity |
|
|
|
set_verbosity(40) |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module='transformers') |
|
|
|
"""## Converting the model to 8 bits |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
class FrozenBNBLinear(nn.Module): |
|
def __init__(self, weight, absmax, code, bias=None): |
|
assert isinstance(bias, nn.Parameter) or bias is None |
|
super().__init__() |
|
self.out_features, self.in_features = weight.shape |
|
self.register_buffer("weight", weight.requires_grad_(False)) |
|
self.register_buffer("absmax", absmax.requires_grad_(False)) |
|
self.register_buffer("code", code.requires_grad_(False)) |
|
self.adapter = None |
|
self.bias = bias |
|
|
|
def forward(self, input): |
|
output = DequantizeAndLinear.apply( |
|
input, self.weight, self.absmax, self.code, self.bias |
|
) |
|
if self.adapter: |
|
output += self.adapter(input) |
|
return output |
|
|
|
@classmethod |
|
def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear": |
|
weights_int8, state = quantize_blockise_lowmemory(linear.weight) |
|
return cls(weights_int8, *state, linear.bias) |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__}({self.in_features}, {self.out_features})" |
|
|
|
|
|
class DequantizeAndLinear(torch.autograd.Function): |
|
@staticmethod |
|
@custom_fwd |
|
def forward( |
|
ctx, |
|
input: torch.Tensor, |
|
weights_quantized: torch.ByteTensor, |
|
absmax: torch.FloatTensor, |
|
code: torch.FloatTensor, |
|
bias: torch.FloatTensor, |
|
): |
|
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code) |
|
ctx.save_for_backward(input, weights_quantized, absmax, code) |
|
ctx._has_bias = bias is not None |
|
return F.linear(input, weights_deq, bias) |
|
|
|
@staticmethod |
|
@custom_bwd |
|
def backward(ctx, grad_output: torch.Tensor): |
|
assert ( |
|
not ctx.needs_input_grad[1] |
|
and not ctx.needs_input_grad[2] |
|
and not ctx.needs_input_grad[3] |
|
) |
|
input, weights_quantized, absmax, code = ctx.saved_tensors |
|
|
|
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code) |
|
grad_input = grad_output @ weights_deq |
|
grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None |
|
return grad_input, None, None, None, grad_bias |
|
|
|
|
|
class FrozenBNBEmbedding(nn.Module): |
|
def __init__(self, weight, absmax, code): |
|
super().__init__() |
|
self.num_embeddings, self.embedding_dim = weight.shape |
|
self.register_buffer("weight", weight.requires_grad_(False)) |
|
self.register_buffer("absmax", absmax.requires_grad_(False)) |
|
self.register_buffer("code", code.requires_grad_(False)) |
|
self.adapter = None |
|
|
|
def forward(self, input, **kwargs): |
|
with torch.no_grad(): |
|
|
|
weight_deq = dequantize_blockwise( |
|
self.weight, absmax=self.absmax, code=self.code |
|
) |
|
output = F.embedding(input, weight_deq, **kwargs) |
|
if self.adapter: |
|
output += self.adapter(input) |
|
return output |
|
|
|
@classmethod |
|
def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding": |
|
weights_int8, state = quantize_blockise_lowmemory(embedding.weight) |
|
return cls(weights_int8, *state) |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})" |
|
|
|
|
|
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2**20): |
|
assert chunk_size % 4096 == 0 |
|
code = None |
|
chunks = [] |
|
absmaxes = [] |
|
flat_tensor = matrix.view(-1) |
|
for i in range((matrix.numel() - 1) // chunk_size + 1): |
|
input_chunk = flat_tensor[i * chunk_size : (i + 1) * chunk_size].clone() |
|
quantized_chunk, (absmax_chunk, code) = quantize_blockwise( |
|
input_chunk, code=code |
|
) |
|
chunks.append(quantized_chunk) |
|
absmaxes.append(absmax_chunk) |
|
matrix_i8 = torch.cat(chunks).reshape_as(matrix) |
|
absmax = torch.cat(absmaxes) |
|
return matrix_i8, (absmax, code) |
|
|
|
|
|
def convert_to_int8(model): |
|
"""Convert linear and embedding modules to 8-bit with optional adapters""" |
|
for module in list(model.modules()): |
|
for name, child in module.named_children(): |
|
if isinstance(child, nn.Linear): |
|
print(name, child) |
|
setattr( |
|
module, |
|
name, |
|
FrozenBNBLinear( |
|
weight=torch.zeros( |
|
child.out_features, child.in_features, dtype=torch.uint8 |
|
), |
|
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1), |
|
code=torch.zeros(256), |
|
bias=child.bias, |
|
), |
|
) |
|
elif isinstance(child, nn.Embedding): |
|
setattr( |
|
module, |
|
name, |
|
FrozenBNBEmbedding( |
|
weight=torch.zeros( |
|
child.num_embeddings, child.embedding_dim, dtype=torch.uint8 |
|
), |
|
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1), |
|
code=torch.zeros(256), |
|
), |
|
) |
|
|
|
|
|
|
|
|
|
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
convert_to_int8(self.attn) |
|
convert_to_int8(self.mlp) |
|
|
|
|
|
class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
convert_to_int8(self) |
|
|
|
|
|
class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
convert_to_int8(self) |
|
|
|
|
|
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = transformers.GPTJConfig.from_pretrained("hivemind/gpt-j-6B-8bit") |
|
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") |
|
config.pad_token_id = config.eos_token_id |
|
tokenizer.pad_token = config.pad_token_id |
|
|
|
"""# load model |
|
|
|
""" |
|
|
|
from contextlib import contextmanager |
|
import sys, os, gc |
|
import logging |
|
from tqdm.auto import tqdm |
|
|
|
|
|
@contextmanager |
|
def suppress_stdout(): |
|
with open(os.devnull, "w") as devnull: |
|
old_stdout = sys.stdout |
|
sys.stdout = devnull |
|
try: |
|
yield |
|
finally: |
|
sys.stdout = old_stdout |
|
|
|
def load_8bit_from_hub(model_id:str, **kwargs): |
|
pbar = tqdm(desc="instantiating model..", total=3) |
|
|
|
with suppress_stdout(): |
|
gc.collect() |
|
model = GPTJForCausalLM.from_pretrained(model_id, |
|
device_map='auto', |
|
low_cpu_mem_usage=True, |
|
**kwargs) |
|
pbar.update() |
|
add_adapters(model) |
|
pbar.update() |
|
model = model.to("cuda" if torch.cuda.is_available() else -1) |
|
pbar.update() |
|
return model |
|
|
|
|
|
model_name = "ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps" |
|
|
|
|
|
|
|
model = load_8bit_from_hub(model_name,) |
|
|
|
"""# generate text |
|
|
|
## standard generation |
|
` |
|
|
|
with torch: |
|
|
|
> with "standard" generation it's recommended to put the **speaker token labels** at the end of your prompt so the model "knows" to respond. |
|
|
|
i.e `Person Alpha:` or `Person Beta:` for these two models. |
|
""" |
|
|
|
prompt = "Person Alpha: what is the theory of being \"woke\" all about?\\n Person Beta: " |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
with torch.no_grad(): |
|
prompt = tokenizer(prompt, return_tensors="pt") |
|
prompt = {key: value.to(device) for key, value in prompt.items()} |
|
out = model.generate( |
|
**prompt, |
|
min_length=24, |
|
max_length=96, |
|
top_k=30, |
|
top_p=0.9, |
|
temperature=0.4, |
|
do_sample=True, |
|
repetition_penalty=1.2, |
|
no_repeat_ngram_size=3, |
|
pad_token_id=tokenizer.eos_token_id, |
|
) |
|
result = tokenizer.decode( |
|
out[0], |
|
remove_invalid_values=True, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
result |
|
|
|
"""--- |
|
|
|
## 'Extract' bot response |
|
- transformers `pipeline` object |
|
- generate with better params |
|
- extract the bot's response with `get_bot_response()` - start to use [ai-msgbot](https://github.com/pszemraj/ai-msgbot) _like it was meant to be used_ |
|
""" |
|
|
|
from transformers import pipeline |
|
|
|
generator = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer="EleutherAI/gpt-j-6B", |
|
device= 0 if torch.cuda.is_available() else -1, |
|
) |
|
|
|
"""### generation functions |
|
|
|
for extracting the response, beam search vs. sampling, etc |
|
""" |
|
|
|
|
|
|
|
import pysbd |
|
|
|
seg = pysbd.Segmenter(language="en", clean=False) |
|
|
|
import re |
|
|
|
|
|
def split_sentences(text, use_regex=False, min_len=2): |
|
"""given a string, splits it into sentences based on punctuation marks.""" |
|
|
|
if use_regex: |
|
sentences = re.split(r'(?<=[.!?]) +', string) |
|
else: |
|
|
|
sentences = seg.segment(text) |
|
return [s.strip() for s in sentences if len(s.strip()) > min_len] |
|
|
|
|
|
def validate_response(response_text): |
|
|
|
if isinstance(response_text, list): |
|
|
|
return response_text |
|
|
|
elif isinstance(response_text, str): |
|
return split_sentences(response_text) |
|
else: |
|
raise ValueError(f"response input {response_text} not a list or str..") |
|
|
|
|
|
def get_bot_response( |
|
name_resp: str, model_resp: list, name_spk: str, verbose: bool = False |
|
): |
|
""" |
|
get_bot_response - gets the bot response to a prompt, checking to ensure that additional statements by the "speaker" are not included in the response. |
|
Args: |
|
name_resp (str): the name of the responder |
|
model_resp (list): the model response |
|
name_spk (str): the name of the speaker |
|
verbose (bool, optional): Defaults to False. |
|
Returns: |
|
bot_response (str): the bot response, isolated down to just text without the "name tokens" or further messages from the speaker. |
|
""" |
|
|
|
model_resp = validate_response(model_resp) |
|
logging.info(f"isolating response from:\t{model_resp}") |
|
fn_resp = [] |
|
|
|
name_counter = 0 |
|
break_safe = False |
|
for resline in model_resp: |
|
if name_resp.lower() in resline.lower(): |
|
name_counter += 1 |
|
break_safe = True |
|
continue |
|
if ":" in resline and name_resp.lower() not in resline.lower(): |
|
break |
|
if name_spk.lower() in resline.lower() and not break_safe: |
|
break |
|
else: |
|
fn_resp.append(resline) |
|
if verbose: |
|
print("the full response is:\n") |
|
print("\n".join(fn_resp)) |
|
if isinstance(fn_resp, list): |
|
fn_resp = fn_resp[0] if len(fn_resp) == 1 else " ".join(fn_resp) |
|
return fn_resp |
|
|
|
import pprint as pp |
|
|
|
|
|
|
|
|
|
def generate_sampling( |
|
prompt: str, |
|
suffix:str=None, |
|
temperature=0.4, |
|
top_k: int = 40, |
|
top_p=0.90, |
|
min_length: int = 16, |
|
max_length: int = 128, |
|
no_repeat_ngram_size: int = 3, |
|
repetition_penalty=1.5, |
|
return_full_text=False, |
|
verbose=False, |
|
**kwargs, |
|
) -> None: |
|
|
|
logging.info(f"generating results for input:\n\t{prompt}\n\t...") |
|
if verbose: |
|
print(f"generating results for input:\n\t{prompt}\n\t...") |
|
prompt = f"{prompt}{suffix}" if suffix is not None else prompt |
|
|
|
_prompt_tokens = len(generator.tokenizer(prompt).input_ids) |
|
result = generator( |
|
prompt, |
|
min_length=min_length+_prompt_tokens, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
repetition_penalty=repetition_penalty, |
|
remove_invalid_values=True, |
|
clean_up_tokenization_spaces=True, |
|
do_sample=True, |
|
return_full_text=return_full_text, |
|
max_new_tokens=max_length+_prompt_tokens, |
|
pad_token_id=generator.tokenizer.eos_token_id, |
|
**kwargs, |
|
) |
|
|
|
output = result[0]["generated_text"] |
|
logging.info(f"model output:\n\t{output}") |
|
if verbose: |
|
print(f"model output:\n\t{output}") |
|
response = get_bot_response( |
|
model_resp=output, |
|
name_spk="Person Alpha", |
|
name_resp="Person Beta", |
|
verbose=False, |
|
) |
|
|
|
logging.info(f"extracted bot response:\n\t{response}") |
|
|
|
pp.pprint(response) |
|
|
|
return response |
|
|
|
import pprint as pp |
|
|
|
|
|
|
|
|
|
def generate_beams( |
|
prompt: str, |
|
suffix:str=None, |
|
num_beams=4, |
|
min_length: int = 32, |
|
max_length: int = 128, |
|
no_repeat_ngram_size: int = 3, |
|
repetition_penalty=2.5, |
|
return_full_text=False, |
|
verbose=False, |
|
**kwargs, |
|
) -> None: |
|
|
|
logging.info(f"generating results for input:\n\t{prompt}\n\t...") |
|
if verbose: |
|
print(f"generating results for input:\n\t{prompt}\n\t") |
|
|
|
prompt = f"{prompt}{suffix}" if suffix is not None else prompt |
|
_prompt_tokens = len(generator.tokenizer(prompt).input_ids) |
|
result = generator( |
|
prompt, |
|
min_length=min_length+_prompt_tokens, |
|
num_beams=num_beams, |
|
do_sample=False, |
|
early_stopping=True, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
repetition_penalty=repetition_penalty, |
|
remove_invalid_values=True, |
|
clean_up_tokenization_spaces=True, |
|
return_full_text=return_full_text, |
|
max_new_tokens=max_length+_prompt_tokens, |
|
pad_token_id=generator.tokenizer.eos_token_id, |
|
**kwargs, |
|
) |
|
|
|
output = result[0]["generated_text"] |
|
logging.info(f"model output:\n\t{output}") |
|
if verbose: |
|
print(f"model output:\n\t{output}") |
|
response = get_bot_response( |
|
model_resp=output, |
|
name_spk="Person Alpha", |
|
name_resp="Person Beta", |
|
verbose=False, |
|
) |
|
|
|
|
|
logging.info(f"extracted bot response:\n\t{response}") |
|
|
|
pp.pprint(response) |
|
|
|
return response |
|
|
|
import pprint as pp |
|
|
|
|
|
|
|
|
|
def generate_csearch( |
|
prompt: str, |
|
suffix:str=None, |
|
max_length: int = 96, |
|
min_length: int = 24, |
|
penalty_alpha: float=0.6, |
|
top_k: int=5, |
|
return_full_text=False, |
|
verbose=False, |
|
**kwargs, |
|
) -> None: |
|
|
|
logging.info(f"generating results for input:\n\t{prompt}\n\t...") |
|
if verbose: |
|
print(f"generating results for input:\n\t{prompt}\n\t") |
|
|
|
prompt = f"{prompt}{suffix}" if suffix is not None else prompt |
|
_prompt_tokens = len(generator.tokenizer(prompt).input_ids) |
|
result = generator( |
|
prompt, |
|
min_length=min_length+_prompt_tokens, |
|
max_new_tokens=max_length, |
|
penalty_alpha=penalty_alpha, |
|
top_k=top_k, |
|
remove_invalid_values=True, |
|
clean_up_tokenization_spaces=True, |
|
return_full_text=return_full_text, |
|
pad_token_id=generator.tokenizer.eos_token_id, |
|
**kwargs, |
|
) |
|
|
|
output = result[0]["generated_text"] |
|
logging.info(f"model output:\n\t{output}") |
|
if verbose: |
|
print(f"model output:\n\t{output}") |
|
response = get_bot_response( |
|
model_resp=output, |
|
name_spk="Person Alpha", |
|
name_resp="Person Beta", |
|
verbose=False, |
|
) |
|
|
|
|
|
logging.info(f"extracted bot response:\n\t{response}") |
|
|
|
pp.pprint(response) |
|
|
|
return response |
|
|
|
"""### generate - sampling |
|
|
|
> **NOTE:** that here the `suffix="\nPerson Beta: ",` is passed so it does not need to be added to a prompt |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = "What is the purpose of life?" |
|
temperature = 0.5 |
|
top_k = 30 |
|
|
|
generated_result = generate_sampling( |
|
prompt, |
|
temperature=temperature, |
|
top_k=top_k, |
|
min_length=32, |
|
suffix="\nPerson Beta: ", |
|
) |
|
|
|
"""### generate - beam search""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|