ballpark-trivia / ai_single_response.py
peter szemraj
add cache back
3bd3331
raw
history blame
No virus
10.3 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ai_single_response.py
An executable way to call the model. example:
*\gpt2_chatbot> python ai_single_response.py --model "GPT2_conversational_355M_WoW10k" --prompt "hey, what's up?" --time
query_gpt_model is used throughout the code, and is the "fundamental" building block of the bot and how everything works. I would recommend testing this function with a few different models.
"""
import argparse
import pprint as pp
import sys
import time
import warnings
from datetime import datetime
from pathlib import Path
from grammar_improve import remove_trailing_punctuation
from utils import print_spacer, cleantxt_wrap
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
from aitextgen import aitextgen
def extract_response(full_resp: list, plist: list, verbose: bool = False):
"""
extract_response - helper fn for ai_single_response.py. By default aitextgen returns the prompt and the response, we just want the response
Args:
full_resp (list): a list of strings, each string is a response
plist (list): a list of strings, each string is a prompt
verbose (bool, optional): 4 debug. Defaults to False.
"""
full_resp = [cleantxt_wrap(ele) for ele in full_resp]
plist = [cleantxt_wrap(pr) for pr in plist]
p_len = len(plist)
assert (
len(full_resp) >= p_len
), "model output should have as many lines or longer as the input."
if set(plist).issubset(full_resp):
del full_resp[:p_len] # remove the prompts from the responses
else:
print("the isolated responses are:\n")
pp.pprint(full_resp)
print_spacer()
print("the input prompt was:\n")
pp.pprint(plist)
print_spacer()
sys.exit("Exiting: some prompts not found in the responses")
if verbose:
print("the isolated responses are:\n")
pp.pprint(full_resp)
print_spacer()
print("the input prompt was:\n")
pp.pprint(plist)
print_spacer()
return full_resp # list of only the model generated responses
def get_bot_response(
name_resp: str, model_resp: str, name_spk: str, verbose: bool = False
):
"""
get_bot_response - from the model response, extract the bot response. This is needed because depending on the generation length the model may return more than one response.
Args: name_resp (str): the name of the responder
model_resp (str): the model response
verbose (bool, optional): 4 debug. Defaults to False.
returns: fn_resp (list of str)
"""
fn_resp = []
name_counter = 0
break_safe = False
for resline in model_resp:
if resline.startswith(name_resp):
name_counter += 1
break_safe = True # know the line is from bot as this line starts with the name of the bot
continue
if name_spk is not None and name_spk.lower() in resline.lower():
break
if ":" in resline and name_counter > 0:
if break_safe:
# we know this is a response from the bot even tho ':' is in the line
fn_resp.append(resline)
break_safe = False
else:
# we do not know this is a response from the bot. could be name of another person.. bot is "finished" response
break
else:
fn_resp.append(resline)
break_safe = False
if verbose:
print("the full response is:\n")
print("\n".join(fn_resp))
return fn_resp
def query_gpt_model(
prompt_msg: str,
speaker=None,
responder=None,
resp_length=40,
resp_min=10,
kparam=150,
temp=0.75,
top_p=0.65,
batch_size=64,
verbose=False,
use_gpu=False,
beams=4,
):
"""
query_gpt_model - the main function that calls the model.
Parameters:
-----------
prompt_msg (str): the prompt to be sent to the model
speaker (str, optional): the name of the speaker. Defaults to None.
responder (str, optional): the name of the responder. Defaults to None.
resp_length (int, optional): the length of the response. Defaults to 128.
resp_min (int, optional): the minimum length of the response. Defaults to 4.
kparam (int, optional): the k parameter for the top_p. Defaults to 150.
temp (float, optional): the temperature for the top_p. Defaults to 0.75.
top_p (float, optional): the top_p parameter for the top_p. Defaults to 0.65.
verbose (bool, optional): 4 debug. Defaults to False.
use_gpu (bool, optional): use gpu. Defaults to False.
"""
ai = aitextgen(
model="pszemraj/Ballpark-Trivia-L", # THIS WORKS
# model="pszemraj/Ballpark-Trivia-XL", # does not seem to work TODO: test further with after it loads
to_gpu=use_gpu,
)
p_list = [] # track conversation
p_list.append(speaker.lower() + ":" + "\n")
p_list.append(prompt_msg.lower() + "\n")
p_list.append("\n")
p_list.append(responder.lower() + ":" + "\n")
this_prompt = "".join(p_list)
pr_len = len(this_prompt)
if verbose:
print("overall prompt:\n")
pp.pprint(this_prompt, indent=4)
# call the model
print("\n... generating...")
this_result = ai.generate(
n=1,
batch_size=batch_size,
# the prompt input counts for text length constraints
max_length=resp_length + pr_len,
# min_length=resp_min + pr_len,
prompt=this_prompt,
# temperature=temp,
top_k=kparam,
top_p=top_p,
do_sample=True,
return_as_list=True,
use_cache=True,
)
if verbose:
print("\n... generated:\n")
pp.pprint(this_result) # for debugging
# process the full result to get the ~bot response~ piece
this_result = str(this_result[0]).split(
"\n"
) # TODO: adjust hardcoded value for index to dynamic (if n>1)
og_res = this_result.copy()
og_prompt = p_list.copy()
diff_list = extract_response(
this_result, p_list, verbose=verbose
) # isolate the responses from the prompts
# extract the bot response from the model generated text
bot_dialogue = get_bot_response(
name_resp=responder, model_resp=diff_list, name_spk=speaker, verbose=verbose
)
print(f"FOR DEBUG: {bot_dialogue}")
bot_resp = ", ".join(bot_dialogue)
bot_resp = bot_resp.strip()
# remove the last ',' '.' chars
bot_resp = remove_trailing_punctuation(bot_resp)
if verbose:
print("\n... bot response:\n")
pp.pprint(bot_resp)
og_prompt.append(bot_resp + "\n")
og_prompt.append("\n")
print("\nfinished!")
# return the bot response and the full conversation
return {"out_text": bot_resp, "full_conv": og_prompt} # model responses
# Set up the parsing of command-line arguments
def get_parser():
"""
get_parser [a helper function for the argparse module]
Returns: argparse.ArgumentParser
"""
parser = argparse.ArgumentParser(
description="submit a message and have a pretrained GPT model respond"
)
parser.add_argument(
"--prompt",
required=True, # MUST HAVE A PROMPT
type=str,
help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.",
)
parser.add_argument(
"--model",
required=False,
type=str,
default="GPT2_trivNatQAdailydia_774M_175Ksteps",
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
"config.json). No models? Run the script download_models.py",
)
parser.add_argument(
"--speaker",
required=False,
default=None,
help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data",
)
parser.add_argument(
"--responder",
required=False,
default="person beta",
help="who the responder is. Primarily relevant to bots trained on multi-individual chat data",
)
parser.add_argument(
"--topk",
required=False,
type=int,
default=150,
help="how many responses to sample (positive integer). lower = more random responses",
)
parser.add_argument(
"--temp",
required=False,
type=float,
default=0.75,
help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'",
)
parser.add_argument(
"--topp",
required=False,
type=float,
default=0.65,
help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?",
)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="pass this argument if you want all the printouts",
)
parser.add_argument(
"--time",
default=False,
action="store_true",
help="pass this argument if you want to know runtime",
)
return parser
if __name__ == "__main__":
# parse the command line arguments
args = get_parser().parse_args()
query = args.prompt
model_dir = str(args.model)
model_loc = Path.cwd() / model_dir
spkr = args.speaker
rspndr = args.responder
k_results = args.topk
my_temp = args.temp
my_top_p = args.topp
want_verbose = args.verbose
want_rt = args.time
st = time.perf_counter()
resp = query_gpt_model(
folder_path=model_loc,
prompt_msg=query,
speaker=spkr,
responder=rspndr,
kparam=k_results,
temp=my_temp,
top_p=my_top_p,
verbose=want_verbose,
use_gpu=False,
)
output = resp["out_text"]
pp.pprint(output, indent=4)
rt = round(time.perf_counter() - st, 1)
if want_rt:
print("took {runtime} seconds to generate. \n".format(runtime=rt))
if want_verbose:
print("finished - ", datetime.now())
p_list = resp["full_conv"]
print("A transcript of your chat is as follows: \n")
p_list = [item.strip() for item in p_list]
pp.pprint(p_list)