Spaces:
Runtime error
Runtime error
""" | |
deploy-as-bot\gradio_chatbot.py | |
A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses. | |
Note that the URL is displayed once the script is run. | |
""" | |
from flask import ( | |
Flask, | |
request, | |
session, | |
jsonify, | |
abort, | |
send_file, | |
render_template, | |
redirect, | |
) | |
from ai_single_response import query_gpt_model | |
from datetime import datetime | |
from transformers import pipeline | |
from cleantext import clean | |
from pathlib import Path | |
import warnings | |
import time | |
import argparse | |
import logging | |
import gradio as gr | |
import os | |
import sys | |
from os.path import dirname | |
import nltk | |
nltk.download("stopwords") # still unsure where this error originates from | |
sys.path.append(dirname(dirname(os.path.abspath(__file__)))) | |
# from gradio.networking import get_state, set_state | |
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") | |
logging.basicConfig() | |
cwd = Path.cwd() | |
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects | |
def gramformer_correct(corrector, qphrase: str): | |
""" | |
gramformer_correct - correct a string using a text2textgen pipeline model from transformers | |
Args: | |
corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model] | |
qphrase (str): [text to be corrected] | |
Returns: | |
[str]: [corrected text] | |
""" | |
try: | |
corrected = corrector( | |
clean(qphrase), return_text=True, clean_up_tokenization_spaces=True | |
) | |
return corrected[0]["generated_text"] | |
except: | |
print("NOTE - failed to correct with gramformer") | |
return clean(qphrase) | |
def ask_gpt(message: str): | |
""" | |
ask_gpt - queries the relevant model with a prompt message and (optional) speaker name | |
Args: | |
message (str): prompt message to respond to | |
sender (str, optional): speaker aka who said the message. Defaults to "". | |
Returns: | |
[str]: [model response as a string] | |
""" | |
st = time.perf_counter() | |
prompt = clean(message) # clean user input | |
prompt = prompt.strip() # get rid of any extra whitespace | |
if len(prompt) > 200: | |
prompt = prompt[-200:] # truncate | |
resp = query_gpt_model( | |
prompt_msg=prompt, | |
speaker="person alpha", | |
responder="person beta", | |
kparam=150, | |
temp=0.75, | |
top_p=0.65, # optimize this with hyperparam search | |
) | |
bot_resp = gramformer_correct(corrector, qphrase=resp["out_text"]) | |
rt = round(time.perf_counter() - st, 2) | |
print(f"took {rt} sec to respond") | |
return bot_resp | |
def chat(trivia_query): | |
""" | |
chat - helper function that makes the whole gradio thing work. | |
Args: | |
first_and_last_name (str or None): [speaker of the prompt, if provided] | |
message (str): [description] | |
Returns: | |
[str]: [returns an html string to display] | |
""" | |
history = [] | |
response = ask_gpt(trivia_query) | |
history.append(f"<b>{trivia_query}</b> <br> <br> <b>{response}</b>") | |
gr.set_state(history) # save the history | |
html = "" | |
for item in history: | |
html += f"{item}" | |
html += "" | |
return html | |
def get_parser(): | |
""" | |
get_parser - a helper function for the argparse module | |
Returns: | |
[argparse.ArgumentParser]: [the argparser relevant for this script] | |
""" | |
parser = argparse.ArgumentParser( | |
description="submit a message and have a 774M parameter GPT model respond" | |
) | |
parser.add_argument( | |
"--model", | |
required=False, | |
type=str, | |
default="ballpark-trivia-L", | |
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + " | |
"config.json). No models? Run the script download_models.py", | |
) | |
parser.add_argument( | |
"--gram-model", | |
required=False, | |
type=str, | |
default="prithivida/grammar_error_correcter_v1", | |
help="text2text generation model ID from huggingface for the model to correct grammar", | |
) | |
return parser | |
if __name__ == "__main__": | |
args = get_parser().parse_args() | |
default_model = str(args.model) | |
model_loc = cwd.parent / default_model | |
model_loc = str(model_loc.resolve()) | |
gram_model = args.gram_model | |
print(f"using model stored here: \n {model_loc} \n") | |
corrector = pipeline("text2text-generation", model=gram_model, device=-1) | |
print("Finished loading the gramformer model - ", datetime.now()) | |
iface = gr.Interface( | |
chat, | |
inputs=["text"], | |
outputs="html", | |
title=f"Ballpark Trivia: {default_model} Model", | |
description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.", | |
article="Further details can be found in the [model card](https: // huggingface.co/pszemraj/Ballpark-Trivia-L). If you are interested in a more deceptively incorrect model, there is also [an XL version](https://huggingface.co/pszemraj/Ballpark-Trivia-XL) on my page.\n\n" | |
"**Important Notes & About:**\n\n" | |
"1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n" | |
"2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n ", | |
css=""" | |
.chatbox {display:flex;flex-direction:column} | |
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} | |
.user_msg {background-color:cornflowerblue;color:white;align-self:start} | |
.resp_msg {background-color:lightgray;align-self:self-end} | |
""", | |
allow_screenshot=True, | |
allow_flagging=False, | |
enable_queue=True, # allows for dealing with multiple users simultaneously | |
theme="darkhuggingface", | |
) | |
# launch the gradio interface and start the server | |
iface.launch( | |
share=True, | |
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version) | |
) | |