gpt-j-8bit-daily_dialogues / _ai_msgbot_gpt_j_6b_8bit_with_hub.py
pszemraj's picture
Upload _ai_msgbot_gpt_j_6b_8bit_with_hub.py
1721ded
raw
history blame
21.8 kB
# -*- coding: utf-8 -*-
"""-ai-msgbot-gpt-j-6b-8bit-with-hub.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/gist/pszemraj/e49c60aafe04acc52fcfdd1baefe12e4/-ai-msgbot-gpt-j-6b-8bit-with-hub.ipynb
# <center> ai-msgbot - conversational 6B GPT-J 8bit demo
> This notebook demos interaction with a 6B GPT-J finetuned for dialogue via methods in [ai-msgbot](https://github.com/pszemraj/ai-msgbot)
By [Peter](https://github.com/pszemraj). This notebook and `ai-msgbot` are [licensed under creative commons](https://github.com/pszemraj/ai-msgbot/blob/main/LICENSE). Models trained on given datasets are subject to those datasets' licenses.
## usage
1. select the checkpoint of the model to use for generation in the `model_checkpoint` dropdown
2. Run all cells to load everything
3. adjust the prompt fields at the bottom of the notebook to whatever you want, see how AI responds.
A fine-tuning example etc. will come _eventually_
---
# setup
"""
#@markdown setup logging
import logging
from pathlib import Path
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
das_logfile = Path.cwd() / "8bit_inference.log"
logging.basicConfig(
level=logging.INFO,
filename=das_logfile,
filemode='w',
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%m/%d/%Y %I:%M:%S",
)
#@markdown add auto-Colab formatting with `IPython.display`
from IPython.display import HTML, display
# colab formatting
def set_css():
display(
HTML(
"""
<style>
pre {
white-space: pre-wrap;
}
</style>
"""
)
)
get_ipython().events.register("pre_run_cell", set_css)
from pathlib import Path
"""### GPU info"""
!nvidia-smi
"""## install and import
_this notebook uses a specific version of `torch` which can take a while to install._
"""
!pip install transformers==4.24.0 -q
!pip install bitsandbytes==0.32.2 -q
!pip install datasets==1.16.1 -q
!pip install torch==1.11 -q
!pip install accelerate==0.12.0 -q
!pip install pysbd==0.3.4 -q
# Commented out IPython magic to ensure Python compatibility.
# %%capture
# import transformers
#
# import pandas as pd
#
# import torch
# import torch.nn.functional as F
# from torch import nn
# from torch.cuda.amp import custom_fwd, custom_bwd
#
# import bitsandbytes as bnb
# from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
#
# from tqdm.auto import tqdm
#@markdown utils
from transformers.utils.logging import set_verbosity
set_verbosity(40)
import warnings
# ignore hf pipeline complaints
warnings.filterwarnings("ignore", category=UserWarning, module='transformers')
"""## Converting the model to 8 bits
"""
#@title define 8bit classes
#@markdown - bitsandbytes lib
class FrozenBNBLinear(nn.Module):
def __init__(self, weight, absmax, code, bias=None):
assert isinstance(bias, nn.Parameter) or bias is None
super().__init__()
self.out_features, self.in_features = weight.shape
self.register_buffer("weight", weight.requires_grad_(False))
self.register_buffer("absmax", absmax.requires_grad_(False))
self.register_buffer("code", code.requires_grad_(False))
self.adapter = None
self.bias = bias
def forward(self, input):
output = DequantizeAndLinear.apply(
input, self.weight, self.absmax, self.code, self.bias
)
if self.adapter:
output += self.adapter(input)
return output
@classmethod
def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
weights_int8, state = quantize_blockise_lowmemory(linear.weight)
return cls(weights_int8, *state, linear.bias)
def __repr__(self):
return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
class DequantizeAndLinear(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(
ctx,
input: torch.Tensor,
weights_quantized: torch.ByteTensor,
absmax: torch.FloatTensor,
code: torch.FloatTensor,
bias: torch.FloatTensor,
):
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
ctx.save_for_backward(input, weights_quantized, absmax, code)
ctx._has_bias = bias is not None
return F.linear(input, weights_deq, bias)
@staticmethod
@custom_bwd
def backward(ctx, grad_output: torch.Tensor):
assert (
not ctx.needs_input_grad[1]
and not ctx.needs_input_grad[2]
and not ctx.needs_input_grad[3]
)
input, weights_quantized, absmax, code = ctx.saved_tensors
# grad_output: [*batch, out_features]
weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
grad_input = grad_output @ weights_deq
grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
return grad_input, None, None, None, grad_bias
class FrozenBNBEmbedding(nn.Module):
def __init__(self, weight, absmax, code):
super().__init__()
self.num_embeddings, self.embedding_dim = weight.shape
self.register_buffer("weight", weight.requires_grad_(False))
self.register_buffer("absmax", absmax.requires_grad_(False))
self.register_buffer("code", code.requires_grad_(False))
self.adapter = None
def forward(self, input, **kwargs):
with torch.no_grad():
# note: both quantuized weights and input indices are *not* differentiable
weight_deq = dequantize_blockwise(
self.weight, absmax=self.absmax, code=self.code
)
output = F.embedding(input, weight_deq, **kwargs)
if self.adapter:
output += self.adapter(input)
return output
@classmethod
def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
return cls(weights_int8, *state)
def __repr__(self):
return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2**20):
assert chunk_size % 4096 == 0
code = None
chunks = []
absmaxes = []
flat_tensor = matrix.view(-1)
for i in range((matrix.numel() - 1) // chunk_size + 1):
input_chunk = flat_tensor[i * chunk_size : (i + 1) * chunk_size].clone()
quantized_chunk, (absmax_chunk, code) = quantize_blockwise(
input_chunk, code=code
)
chunks.append(quantized_chunk)
absmaxes.append(absmax_chunk)
matrix_i8 = torch.cat(chunks).reshape_as(matrix)
absmax = torch.cat(absmaxes)
return matrix_i8, (absmax, code)
def convert_to_int8(model):
"""Convert linear and embedding modules to 8-bit with optional adapters"""
for module in list(model.modules()):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
print(name, child)
setattr(
module,
name,
FrozenBNBLinear(
weight=torch.zeros(
child.out_features, child.in_features, dtype=torch.uint8
),
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
code=torch.zeros(256),
bias=child.bias,
),
)
elif isinstance(child, nn.Embedding):
setattr(
module,
name,
FrozenBNBEmbedding(
weight=torch.zeros(
child.num_embeddings, child.embedding_dim, dtype=torch.uint8
),
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
code=torch.zeros(256),
),
)
#@markdown Patch GPT-J before loading:
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self.attn)
convert_to_int8(self.mlp)
class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self)
class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self)
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
# Commented out IPython magic to ensure Python compatibility.
# %%capture
# #@markdown `add_adapters()`
#
# def add_adapters(model, adapter_dim=4, p = 0.1):
# assert adapter_dim > 0
#
# for name, module in model.named_modules():
# if isinstance(module, FrozenBNBLinear):
# if "attn" in name or "mlp" in name or "head" in name:
# print("Adding adapter to", name)
# module.adapter = nn.Sequential(
# nn.Linear(module.in_features, adapter_dim, bias=False),
# nn.Dropout(p=p),
# nn.Linear(adapter_dim, module.out_features, bias=False),
# )
# print("Initializing", name)
# nn.init.zeros_(module.adapter[2].weight)
#
# else:
# print("Not adding adapter to", name)
# elif isinstance(module, FrozenBNBEmbedding):
# print("Adding adapter to", name)
# module.adapter = nn.Sequential(
# nn.Embedding(module.num_embeddings, adapter_dim),
# nn.Dropout(p=p),
# nn.Linear(adapter_dim, module.embedding_dim, bias=False),
# )
# print("Initializing", name)
# nn.init.zeros_(module.adapter[2].weight)
#
#@markdown set up config
config = transformers.GPTJConfig.from_pretrained("hivemind/gpt-j-6B-8bit")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
config.pad_token_id = config.eos_token_id
tokenizer.pad_token = config.pad_token_id
"""# load model
"""
from contextlib import contextmanager
import sys, os, gc
import logging
from tqdm.auto import tqdm
#@markdown define `load_8bit_from_hub()`
@contextmanager
def suppress_stdout():
with open(os.devnull, "w") as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
def load_8bit_from_hub(model_id:str, **kwargs):
pbar = tqdm(desc="instantiating model..", total=3)
with suppress_stdout():
gc.collect()
model = GPTJForCausalLM.from_pretrained(model_id,
device_map='auto',
low_cpu_mem_usage=True,
**kwargs)
pbar.update()
add_adapters(model)
pbar.update()
model = model.to("cuda" if torch.cuda.is_available() else -1)
pbar.update()
return model
#@title <font color="orange"> Select Model to Load </font>
model_name = "ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps" #@param ["ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps", "ethzanalytics/gpt-j-8bit-daily_dialogues", "ethzanalytics/gpt-j-6B-8bit-sharded"]
# load_8bit_from_hub() is a wrapper around AutoModel.from_pretrained() and will
# passthrough all kwargs to that
model = load_8bit_from_hub(model_name,)
"""# generate text
## standard generation
`
with torch:
> with "standard" generation it's recommended to put the **speaker token labels** at the end of your prompt so the model "knows" to respond.
i.e `Person Alpha:` or `Person Beta:` for these two models.
"""
prompt = "Person Alpha: what is the theory of being \"woke\" all about?\\n Person Beta: " # @param {type:"string"}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.no_grad():
prompt = tokenizer(prompt, return_tensors="pt")
prompt = {key: value.to(device) for key, value in prompt.items()}
out = model.generate(
**prompt,
min_length=24,
max_length=96,
top_k=30,
top_p=0.9,
temperature=0.4,
do_sample=True,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.eos_token_id,
)
result = tokenizer.decode(
out[0],
remove_invalid_values=True,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
result
"""---
## 'Extract' bot response
- transformers `pipeline` object
- generate with better params
- extract the bot's response with `get_bot_response()` - start to use [ai-msgbot](https://github.com/pszemraj/ai-msgbot) _like it was meant to be used_
"""
from transformers import pipeline
generator = pipeline(
"text-generation",
model=model,
tokenizer="EleutherAI/gpt-j-6B",
device= 0 if torch.cuda.is_available() else -1,
)
"""### generation functions
for extracting the response, beam search vs. sampling, etc
"""
# @markdown `get_bot_response(name_resp: str, model_resp: list, name_spk: str, verbose: bool = False)`
# @markdown - this extracts the response from "Person Beta" from the total generation
import pysbd
seg = pysbd.Segmenter(language="en", clean=False)
import re
def split_sentences(text, use_regex=False, min_len=2):
"""given a string, splits it into sentences based on punctuation marks."""
if use_regex:
sentences = re.split(r'(?<=[.!?]) +', string)
else:
# https://github.com/nipunsadvilkar/pySBD
sentences = seg.segment(text)
return [s.strip() for s in sentences if len(s.strip()) > min_len]
def validate_response(response_text):
if isinstance(response_text, list):
return response_text
# if len(response_text) > 1 else split_sentences(str(response_text))
elif isinstance(response_text, str):
return split_sentences(response_text)
else:
raise ValueError(f"response input {response_text} not a list or str..")
def get_bot_response(
name_resp: str, model_resp: list, name_spk: str, verbose: bool = False
):
"""
get_bot_response - gets the bot response to a prompt, checking to ensure that additional statements by the "speaker" are not included in the response.
Args:
name_resp (str): the name of the responder
model_resp (list): the model response
name_spk (str): the name of the speaker
verbose (bool, optional): Defaults to False.
Returns:
bot_response (str): the bot response, isolated down to just text without the "name tokens" or further messages from the speaker.
"""
model_resp = validate_response(model_resp)
logging.info(f"isolating response from:\t{model_resp}")
fn_resp = []
name_counter = 0
break_safe = False
for resline in model_resp:
if name_resp.lower() in resline.lower():
name_counter += 1
break_safe = True
continue
if ":" in resline and name_resp.lower() not in resline.lower():
break
if name_spk.lower() in resline.lower() and not break_safe:
break
else:
fn_resp.append(resline)
if verbose:
print("the full response is:\n")
print("\n".join(fn_resp))
if isinstance(fn_resp, list):
fn_resp = fn_resp[0] if len(fn_resp) == 1 else " ".join(fn_resp)
return fn_resp
import pprint as pp
# @markdown define `generate_sampling(prompt: str, ...)`
def generate_sampling(
prompt: str,
suffix:str=None,
temperature=0.4,
top_k: int = 40,
top_p=0.90,
min_length: int = 16,
max_length: int = 128,
no_repeat_ngram_size: int = 3,
repetition_penalty=1.5,
return_full_text=False,
verbose=False,
**kwargs,
) -> None:
logging.info(f"generating results for input:\n\t{prompt}\n\t...")
if verbose:
print(f"generating results for input:\n\t{prompt}\n\t...")
prompt = f"{prompt}{suffix}" if suffix is not None else prompt
_prompt_tokens = len(generator.tokenizer(prompt).input_ids)
result = generator(
prompt,
min_length=min_length+_prompt_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
remove_invalid_values=True,
clean_up_tokenization_spaces=True,
do_sample=True,
return_full_text=return_full_text,
max_new_tokens=max_length+_prompt_tokens,
pad_token_id=generator.tokenizer.eos_token_id,
**kwargs,
)
output = result[0]["generated_text"]
logging.info(f"model output:\n\t{output}")
if verbose:
print(f"model output:\n\t{output}")
response = get_bot_response(
model_resp=output,
name_spk="Person Alpha",
name_resp="Person Beta",
verbose=False,
)
logging.info(f"extracted bot response:\n\t{response}")
pp.pprint(response)
return response
import pprint as pp
#@markdown define `generate_beams(prompt: str, num_beams:int =4, ...)`
def generate_beams(
prompt: str,
suffix:str=None,
num_beams=4,
min_length: int = 32,
max_length: int = 128,
no_repeat_ngram_size: int = 3,
repetition_penalty=2.5,
return_full_text=False,
verbose=False,
**kwargs,
) -> None:
logging.info(f"generating results for input:\n\t{prompt}\n\t...")
if verbose:
print(f"generating results for input:\n\t{prompt}\n\t")
prompt = f"{prompt}{suffix}" if suffix is not None else prompt
_prompt_tokens = len(generator.tokenizer(prompt).input_ids)
result = generator(
prompt,
min_length=min_length+_prompt_tokens,
num_beams=num_beams,
do_sample=False,
early_stopping=True,
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
remove_invalid_values=True,
clean_up_tokenization_spaces=True,
return_full_text=return_full_text,
max_new_tokens=max_length+_prompt_tokens,
pad_token_id=generator.tokenizer.eos_token_id,
**kwargs,
)
output = result[0]["generated_text"]
logging.info(f"model output:\n\t{output}")
if verbose:
print(f"model output:\n\t{output}")
response = get_bot_response(
model_resp=output,
name_spk="Person Alpha",
name_resp="Person Beta",
verbose=False,
)
logging.info(f"extracted bot response:\n\t{response}")
pp.pprint(response)
return response
import pprint as pp
#@markdown define `generate_csearch(prompt: str, num_beams:int =4, ...)`
def generate_csearch(
prompt: str,
suffix:str=None,
max_length: int = 96,
min_length: int = 24,
penalty_alpha: float=0.6,
top_k: int=5,
return_full_text=False,
verbose=False,
**kwargs,
) -> None:
logging.info(f"generating results for input:\n\t{prompt}\n\t...")
if verbose:
print(f"generating results for input:\n\t{prompt}\n\t")
prompt = f"{prompt}{suffix}" if suffix is not None else prompt
_prompt_tokens = len(generator.tokenizer(prompt).input_ids)
result = generator(
prompt,
min_length=min_length+_prompt_tokens,
max_new_tokens=max_length,
penalty_alpha=penalty_alpha,
top_k=top_k,
remove_invalid_values=True,
clean_up_tokenization_spaces=True,
return_full_text=return_full_text,
pad_token_id=generator.tokenizer.eos_token_id,
**kwargs,
)
output = result[0]["generated_text"]
logging.info(f"model output:\n\t{output}")
if verbose:
print(f"model output:\n\t{output}")
response = get_bot_response(
model_resp=output,
name_spk="Person Alpha",
name_resp="Person Beta",
verbose=False,
)
logging.info(f"extracted bot response:\n\t{response}")
pp.pprint(response)
return response
"""### generate - sampling
> **NOTE:** that here the `suffix="\nPerson Beta: ",` is passed so it does not need to be added to a prompt
"""
# Commented out IPython magic to ensure Python compatibility.
# %%time
#
# prompt = "How do we harness space energy?" #@param {type:"string"}
# temperature = 0.2 #@param {type:"slider", min:0.1, max:1, step:0.1}
# top_k = 30 #@param {type:"slider", min:10, max:60, step:10}
#
#
# result = generate_sampling(
# prompt,
# suffix="\nPerson Beta: ",
# max_length=128,
# min_length=32,
# temperature=temperature,
# top_k=top_k,
# )
#
prompt = "What is the purpose of life?" # @param {type:"string"}
temperature = 0.5 # @param {type:"slider", min:0.1, max:1, step:0.1}
top_k = 30 # @param {type:"slider", min:10, max:60, step:10}
generated_result = generate_sampling(
prompt,
temperature=temperature,
top_k=top_k,
min_length=32,
suffix="\nPerson Beta: ",
)
"""### generate - beam search"""
# Commented out IPython magic to ensure Python compatibility.
# %%time
# prompt = "How was your day?" #@param {type:"string"}
# num_beams = 4 #@param {type:"slider", min:2, max:10, step:2}
# min_length = 16 #@param {type:"slider", min:8, max:128, step:8}
#
# generated_result = generate_beams(
# prompt,
# suffix="\nPerson Beta: ",
# min_length=min_length,
# num_beams=num_beams,
# )