Spaces:
Runtime error
Runtime error
File size: 6,946 Bytes
df376e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""
deploy-as-bot\gradio_chatbot.py
A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
Note that the URL is displayed once the script is run.
Set the working directory to */deploy-as-bot in terminal before running.
"""
from utils import remove_trailing_punctuation, DisableLogger
import os
import sys
from os.path import dirname
# add the path to the script to the sys.path
sys.path.append(dirname(dirname(os.path.abspath(__file__))))
import gradio as gr
import logging
import argparse
import time
import warnings
from pathlib import Path
from transformers import pipeline
from datetime import datetime
from ai_single_response import query_gpt_model
logging.basicConfig(
filename=f"LOGFILE-{Path(__file__).stem}.log",
filemode="a",
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
with DisableLogger():
from cleantext import clean
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
cwd = Path.cwd()
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
def gramformer_correct(corrector, qphrase: str):
"""
gramformer_correct - correct a string using a text2textgen pipeline model from transformers
Args:
corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
qphrase (str): [text to be corrected]
Returns:
[str]: [corrected text]
"""
try:
corrected = corrector(
clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
)
return corrected[0]["generated_text"]
except:
print("NOTE - failed to correct with gramformer")
return clean(
qphrase
) # fallback is to return the cleaned up version of the message
def ask_gpt(message: str, sender: str = ""):
"""
ask_gpt - queries the relevant model with a prompt message and (optional) speaker name.
nnote this version is modified w.r.t gradio local server deploy
Args:
message (str): prompt message to respond to
sender (str, optional): speaker aka who said the message. Defaults to "".
Returns:
[str]: [model response as a string]
"""
st = time.time()
prompt = clean(message) # clean user input
prompt = prompt.strip() # get rid of any extra whitespace
if len(prompt) > 100:
prompt = prompt[:100] # truncate
sender = clean(sender.strip())
if len(sender) > 2:
try:
prompt_speaker = clean(sender)
except:
prompt_speaker = None # fallback
else:
prompt_speaker = None # fallback
resp = query_gpt_model(
folder_path=model_loc,
prompt_msg=prompt,
speaker=prompt_speaker,
kparam=150, # top k responses
temp=0.75, # temperature
top_p=0.65, # nucleus sampling
)
bot_resp = gramformer_correct(
corrector, qphrase=resp["out_text"]
) # correct grammar
bot_resp = remove_trailing_punctuation(
bot_resp
) # remove trailing punctuation to seem more natural
rt = round(time.time() - st, 2)
print(f"took {rt} sec to respond")
return bot_resp
def chat(first_and_last_name, message):
"""
chat - helper function that makes the whole gradio thing work.
Args:
first_and_last_name (str or None): [speaker of the prompt, if provided]
message (str): [description]
Returns:
[str]: [returns an html string to display]
"""
history = gr.get_state() or []
response = ask_gpt(message, sender=first_and_last_name)
history.append(("You: " + message, " GPT-Model: " + response + " [end] "))
gr.set_state(history) # save the history
html = ""
for user_msg, resp_msg in history:
html += f"{user_msg}"
html += f"{resp_msg}"
html += ""
return html
def get_parser():
"""
get_parser - a helper function for the argparse module
Returns:
[argparse.ArgumentParser]: [the argparser relevant for this script]
"""
parser = argparse.ArgumentParser(
description="host a chatbot on gradio",
)
parser.add_argument(
"--model",
required=False,
type=str,
default="GPT2_trivNatQAdailydia_774M_175Ksteps", # folder name of model
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
"config.json). No models? Run the script download_models.py",
)
parser.add_argument(
"--gram-model",
required=False,
type=str,
default="prithivida/grammar_error_correcter_v1", # huggingface model
help="text2text generation model ID from huggingface for the model to correct grammar",
)
return parser
if __name__ == "__main__":
args = get_parser().parse_args()
default_model = str(args.model)
model_loc = cwd.parent / default_model
model_loc = str(model_loc.resolve())
gram_model = args.gram_model
# init items for the pipeline
iface = gr.Interface(
chat,
inputs=["text", "text"],
outputs="html",
title=f"GPT-Chatbot Demo: {default_model} Model",
description=f"A basic interface with a GPT2-based model, specifically {default_model}. Treat it like a friend!",
article="**Important Notes & About:**\n"
"1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
"2. entering a username is completely optional.\n"
"3. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says sshould be fact-checked before being regarded as a true statement.\n ",
css="""
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""",
allow_screenshot=True,
allow_flagging=True, # allow users to flag responses as inappropriate
flagging_dir="gradio_data",
flagging_options=[
"great response",
"doesn't make sense",
"bad/offensive response",
],
enable_queue=True, # allows for dealing with multiple users simultaneously
theme="darkhuggingface",
)
corrector = pipeline("text2text-generation", model=gram_model, device=-1)
print("Finished loading the gramformer model - ", datetime.now())
print(f"using model stored here: \n {model_loc} \n")
# launch the gradio interface and start the server
iface.launch(share=True)
|