Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
ai_single_response.py | |
An executable way to call the model. example: | |
*\gpt2_chatbot> python ai_single_response.py --model "GPT2_conversational_355M_WoW10k" --prompt "hey, what's up?" --time | |
query_gpt_model is used throughout the code, and is the "fundamental" building block of the bot and how everything works. I would recommend testing this function with a few different models. | |
""" | |
import argparse | |
import pprint as pp | |
import sys | |
import time | |
import warnings | |
from datetime import datetime | |
from pathlib import Path | |
from grammar_improve import remove_trailing_punctuation | |
from utils import print_spacer, cleantxt_wrap | |
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") | |
from aitextgen import aitextgen | |
def extract_response(full_resp: list, plist: list, verbose: bool = False): | |
""" | |
extract_response - helper fn for ai_single_response.py. By default aitextgen returns the prompt and the response, we just want the response | |
Args: | |
full_resp (list): a list of strings, each string is a response | |
plist (list): a list of strings, each string is a prompt | |
verbose (bool, optional): 4 debug. Defaults to False. | |
""" | |
full_resp = [cleantxt_wrap(ele) for ele in full_resp] | |
plist = [cleantxt_wrap(pr) for pr in plist] | |
p_len = len(plist) | |
assert ( | |
len(full_resp) >= p_len | |
), "model output should have as many lines or longer as the input." | |
if set(plist).issubset(full_resp): | |
del full_resp[:p_len] # remove the prompts from the responses | |
else: | |
print("the isolated responses are:\n") | |
pp.pprint(full_resp) | |
print_spacer() | |
print("the input prompt was:\n") | |
pp.pprint(plist) | |
print_spacer() | |
sys.exit("Exiting: some prompts not found in the responses") | |
if verbose: | |
print("the isolated responses are:\n") | |
pp.pprint(full_resp) | |
print_spacer() | |
print("the input prompt was:\n") | |
pp.pprint(plist) | |
print_spacer() | |
return full_resp # list of only the model generated responses | |
def get_bot_response( | |
name_resp: str, model_resp: str, name_spk: str, verbose: bool = False | |
): | |
""" | |
get_bot_response - from the model response, extract the bot response. This is needed because depending on the generation length the model may return more than one response. | |
Args: name_resp (str): the name of the responder | |
model_resp (str): the model response | |
verbose (bool, optional): 4 debug. Defaults to False. | |
returns: fn_resp (list of str) | |
""" | |
fn_resp = [] | |
name_counter = 0 | |
break_safe = False | |
for resline in model_resp: | |
if resline.startswith(name_resp): | |
name_counter += 1 | |
break_safe = True # know the line is from bot as this line starts with the name of the bot | |
continue | |
if name_spk is not None and name_spk.lower() in resline.lower(): | |
break | |
if ":" in resline and name_counter > 0: | |
if break_safe: | |
# we know this is a response from the bot even tho ':' is in the line | |
fn_resp.append(resline) | |
break_safe = False | |
else: | |
# we do not know this is a response from the bot. could be name of another person.. bot is "finished" response | |
break | |
else: | |
fn_resp.append(resline) | |
break_safe = False | |
if verbose: | |
print("the full response is:\n") | |
print("\n".join(fn_resp)) | |
return fn_resp | |
def query_gpt_model( | |
prompt_msg: str, | |
speaker=None, | |
responder=None, | |
resp_length=40, | |
resp_min=10, | |
kparam=150, | |
temp=0.75, | |
top_p=0.65, | |
batch_size=64, | |
verbose=False, | |
use_gpu=False, | |
beams=4, | |
): | |
""" | |
query_gpt_model - the main function that calls the model. | |
Parameters: | |
----------- | |
prompt_msg (str): the prompt to be sent to the model | |
speaker (str, optional): the name of the speaker. Defaults to None. | |
responder (str, optional): the name of the responder. Defaults to None. | |
resp_length (int, optional): the length of the response. Defaults to 128. | |
resp_min (int, optional): the minimum length of the response. Defaults to 4. | |
kparam (int, optional): the k parameter for the top_p. Defaults to 150. | |
temp (float, optional): the temperature for the top_p. Defaults to 0.75. | |
top_p (float, optional): the top_p parameter for the top_p. Defaults to 0.65. | |
verbose (bool, optional): 4 debug. Defaults to False. | |
use_gpu (bool, optional): use gpu. Defaults to False. | |
""" | |
ai = aitextgen( | |
model="pszemraj/Ballpark-Trivia-L", # THIS WORKS | |
# model="pszemraj/Ballpark-Trivia-XL", # does not seem to work TODO: test further with after it loads | |
to_gpu=use_gpu, | |
) | |
p_list = [] # track conversation | |
p_list.append(speaker.lower() + ":" + "\n") | |
p_list.append(prompt_msg.lower() + "\n") | |
p_list.append("\n") | |
p_list.append(responder.lower() + ":" + "\n") | |
this_prompt = "".join(p_list) | |
pr_len = len(this_prompt) | |
if verbose: | |
print("overall prompt:\n") | |
pp.pprint(this_prompt, indent=4) | |
# call the model | |
print("\n... generating...") | |
this_result = ai.generate( | |
n=1, | |
batch_size=batch_size, | |
# the prompt input counts for text length constraints | |
max_length=resp_length + pr_len, | |
# min_length=resp_min + pr_len, | |
prompt=this_prompt, | |
# temperature=temp, | |
top_k=kparam, | |
top_p=top_p, | |
do_sample=True, | |
return_as_list=True, | |
use_cache=True, | |
) | |
if verbose: | |
print("\n... generated:\n") | |
pp.pprint(this_result) # for debugging | |
# process the full result to get the ~bot response~ piece | |
this_result = str(this_result[0]).split( | |
"\n" | |
) # TODO: adjust hardcoded value for index to dynamic (if n>1) | |
og_res = this_result.copy() | |
og_prompt = p_list.copy() | |
diff_list = extract_response( | |
this_result, p_list, verbose=verbose | |
) # isolate the responses from the prompts | |
# extract the bot response from the model generated text | |
bot_dialogue = get_bot_response( | |
name_resp=responder, model_resp=diff_list, name_spk=speaker, verbose=verbose | |
) | |
print(f"FOR DEBUG: {bot_dialogue}") | |
bot_resp = ", ".join(bot_dialogue) | |
bot_resp = bot_resp.strip() | |
# remove the last ',' '.' chars | |
bot_resp = remove_trailing_punctuation(bot_resp) | |
if verbose: | |
print("\n... bot response:\n") | |
pp.pprint(bot_resp) | |
og_prompt.append(bot_resp + "\n") | |
og_prompt.append("\n") | |
print("\nfinished!") | |
# return the bot response and the full conversation | |
return {"out_text": bot_resp, "full_conv": og_prompt} # model responses | |
# Set up the parsing of command-line arguments | |
def get_parser(): | |
""" | |
get_parser [a helper function for the argparse module] | |
Returns: argparse.ArgumentParser | |
""" | |
parser = argparse.ArgumentParser( | |
description="submit a message and have a pretrained GPT model respond" | |
) | |
parser.add_argument( | |
"--prompt", | |
required=True, # MUST HAVE A PROMPT | |
type=str, | |
help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.", | |
) | |
parser.add_argument( | |
"--model", | |
required=False, | |
type=str, | |
default="GPT2_trivNatQAdailydia_774M_175Ksteps", | |
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + " | |
"config.json). No models? Run the script download_models.py", | |
) | |
parser.add_argument( | |
"--speaker", | |
required=False, | |
default=None, | |
help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data", | |
) | |
parser.add_argument( | |
"--responder", | |
required=False, | |
default="person beta", | |
help="who the responder is. Primarily relevant to bots trained on multi-individual chat data", | |
) | |
parser.add_argument( | |
"--topk", | |
required=False, | |
type=int, | |
default=150, | |
help="how many responses to sample (positive integer). lower = more random responses", | |
) | |
parser.add_argument( | |
"--temp", | |
required=False, | |
type=float, | |
default=0.75, | |
help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'", | |
) | |
parser.add_argument( | |
"--topp", | |
required=False, | |
type=float, | |
default=0.65, | |
help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?", | |
) | |
parser.add_argument( | |
"--verbose", | |
default=False, | |
action="store_true", | |
help="pass this argument if you want all the printouts", | |
) | |
parser.add_argument( | |
"--time", | |
default=False, | |
action="store_true", | |
help="pass this argument if you want to know runtime", | |
) | |
return parser | |
if __name__ == "__main__": | |
# parse the command line arguments | |
args = get_parser().parse_args() | |
query = args.prompt | |
model_dir = str(args.model) | |
model_loc = Path.cwd() / model_dir | |
spkr = args.speaker | |
rspndr = args.responder | |
k_results = args.topk | |
my_temp = args.temp | |
my_top_p = args.topp | |
want_verbose = args.verbose | |
want_rt = args.time | |
st = time.perf_counter() | |
resp = query_gpt_model( | |
folder_path=model_loc, | |
prompt_msg=query, | |
speaker=spkr, | |
responder=rspndr, | |
kparam=k_results, | |
temp=my_temp, | |
top_p=my_top_p, | |
verbose=want_verbose, | |
use_gpu=False, | |
) | |
output = resp["out_text"] | |
pp.pprint(output, indent=4) | |
rt = round(time.perf_counter() - st, 1) | |
if want_rt: | |
print("took {runtime} seconds to generate. \n".format(runtime=rt)) | |
if want_verbose: | |
print("finished - ", datetime.now()) | |
p_list = resp["full_conv"] | |
print("A transcript of your chat is as follows: \n") | |
p_list = [item.strip() for item in p_list] | |
pp.pprint(p_list) | |