Spaces:
Runtime error
Runtime error
File size: 6,410 Bytes
e332ae1 6b53043 adebe05 e332ae1 b649bd8 6b53043 e332ae1 093b964 b649bd8 093b964 6b53043 b649bd8 6b53043 e332ae1 6b53043 e332ae1 6b53043 e332ae1 2a3e09e e332ae1 6b53043 e332ae1 6b53043 e332ae1 6b53043 2a3e09e e332ae1 2a3e09e e332ae1 635e033 e332ae1 635e033 e332ae1 635e033 2a3e09e 971c338 635e033 6b53043 e332ae1 6b53043 e332ae1 2a3e09e e770698 4922cde dd38199 e332ae1 4922cde e332ae1 adebe05 1cf48f8 62981f2 635e033 e332ae1 b649bd8 2a3e09e b649bd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""
deploy-as-bot\gradio_chatbot.py
A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
Note that the URL is displayed once the script is run.
"""
from flask import (
Flask,
request,
session,
jsonify,
abort,
send_file,
render_template,
redirect,
)
from ai_single_response import query_gpt_model
from datetime import datetime
from transformers import pipeline
from cleantext import clean
from pathlib import Path
import warnings
import time
import argparse
import logging
import gradio as gr
import os
import sys
from os.path import dirname
import nltk
nltk.download("stopwords") # still unsure where this error originates from
sys.path.append(dirname(dirname(os.path.abspath(__file__))))
# from gradio.networking import get_state, set_state
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
logging.basicConfig()
cwd = Path.cwd()
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
def gramformer_correct(corrector, qphrase: str):
"""
gramformer_correct - correct a string using a text2textgen pipeline model from transformers
Args:
corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
qphrase (str): [text to be corrected]
Returns:
[str]: [corrected text]
"""
try:
corrected = corrector(
clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
)
return corrected[0]["generated_text"]
except:
print("NOTE - failed to correct with gramformer")
return clean(qphrase)
def ask_gpt(message: str):
"""
ask_gpt - queries the relevant model with a prompt message and (optional) speaker name
Args:
message (str): prompt message to respond to
sender (str, optional): speaker aka who said the message. Defaults to "".
Returns:
[str]: [model response as a string]
"""
st = time.perf_counter()
prompt = clean(message) # clean user input
prompt = prompt.strip() # get rid of any extra whitespace
if len(prompt) > 200:
prompt = prompt[-200:] # truncate
resp = query_gpt_model(
prompt_msg=prompt,
speaker="person alpha",
responder="person beta",
kparam=150,
temp=0.75,
top_p=0.65, # optimize this with hyperparam search
)
bot_resp = gramformer_correct(corrector, qphrase=resp["out_text"])
rt = round(time.perf_counter() - st, 2)
print(f"took {rt} sec to respond")
return bot_resp
def chat(trivia_query):
"""
chat - helper function that makes the whole gradio thing work.
Args:
first_and_last_name (str or None): [speaker of the prompt, if provided]
message (str): [description]
Returns:
[str]: [returns an html string to display]
"""
history = []
response = ask_gpt(trivia_query)
history.append(f"<b>{trivia_query}</b> <br> <br> <b>{response}</b>")
gr.set_state(history) # save the history
html = ""
for item in history:
html += f"{item}"
html += ""
return html
def get_parser():
"""
get_parser - a helper function for the argparse module
Returns:
[argparse.ArgumentParser]: [the argparser relevant for this script]
"""
parser = argparse.ArgumentParser(
description="submit a message and have a 774M parameter GPT model respond"
)
parser.add_argument(
"--model",
required=False,
type=str,
default="ballpark-trivia-L",
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
"config.json). No models? Run the script download_models.py",
)
parser.add_argument(
"--gram-model",
required=False,
type=str,
default="prithivida/grammar_error_correcter_v1",
help="text2text generation model ID from huggingface for the model to correct grammar",
)
return parser
if __name__ == "__main__":
args = get_parser().parse_args()
default_model = str(args.model)
model_loc = cwd.parent / default_model
model_loc = str(model_loc.resolve())
gram_model = args.gram_model
print(f"using model stored here: \n {model_loc} \n")
corrector = pipeline("text2text-generation", model=gram_model, device=-1)
print("Finished loading the gramformer model - ", datetime.now())
iface = gr.Interface(
chat,
inputs=["text"],
outputs="html",
title=f"Ballpark Trivia: {default_model} Model",
description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.",
article="Further details can be found in the [model card](https: // huggingface.co/pszemraj/Ballpark-Trivia-L). If you are interested in a more deceptively incorrect model, there is also [an XL version](https://huggingface.co/pszemraj/Ballpark-Trivia-XL) on my page.\n\n"
"**Important Notes & About:**\n\n"
"1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
"2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n ",
css="""
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""",
allow_screenshot=True,
allow_flagging=False,
enable_queue=True, # allows for dealing with multiple users simultaneously
theme="darkhuggingface",
)
# launch the gradio interface and start the server
iface.launch(
share=True,
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
)
|