""" deploy-as-bot\gradio_chatbot.py A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses. Note that the URL is displayed once the script is run. """ from flask import ( Flask, request, session, jsonify, abort, send_file, render_template, redirect, ) from ai_single_response import query_gpt_model from datetime import datetime from transformers import pipeline from cleantext import clean from pathlib import Path import warnings import time import argparse import logging import gradio as gr import os import sys from os.path import dirname import nltk from grammar_improve import ( load_ns_checker, neuspell_correct, remove_repeated_words, remove_trailing_punctuation, build_symspell_obj, symspeller, ) from utils import ( cleantxt_wrap, corr, ) nltk.download("stopwords") # TODO: find where this requirement originates from sys.path.append(dirname(dirname(os.path.abspath(__file__)))) warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") logging.basicConfig() cwd = Path.cwd() my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects def ask_gpt(message: str): """ ask_gpt - queries the relevant model with a prompt message and returns the response. NOTE: because this is for models trained with person alpha and person beta, there is no need for customizing / changing the name settings and so on Args: message (str): prompt message to respond to, usually a question Returns: [str]: [model response as a string] """ st = time.perf_counter() prompt = clean(message) # clean user input prompt = prompt.strip() # get rid of any extra whitespace if len(prompt) > 200: prompt = prompt[-200:] # truncate resp = query_gpt_model( prompt_msg=prompt, speaker="person alpha", responder="person beta", kparam=32, # temp=0.6, top_p=0.7, batch_size=32, # TODO - allow users to adjust these ) # using top_P and top_k to avoid the "too many hypotheses" error, not using temp if basic_sc: cln_resp = symspeller(resp["out_text"], sym_checker=schnellspell) else: cln_resp = neuspell_correct(resp["out_text"], checker=ns_checker) bot_resp = corr(remove_repeated_words(cln_resp)) print(f"the prompt was:\n {message} and the response was:\n {bot_resp}\n") rt = round(time.perf_counter() - st, 2) print(f"took {rt} sec to respond") return remove_trailing_punctuation(bot_resp) def chat(trivia_query): """ chat - helper function that makes the whole gradio thing work. Args: trivia_query (str): the question to ask the bot Returns: [str]: the bot's response """ history = [] response = ask_gpt(trivia_query) history.append(f"{trivia_query}

{response}") gr.set_state(history) # save the history html = "" for item in history: html += f"{item}" html += "" return html def get_parser(): """ get_parser - a helper function for the argparse module """ parser = argparse.ArgumentParser( description="submit a question, GPT model responds" ) parser.add_argument( "-m", "--model", required=False, type=str, default="ballpark-trivia-L", help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + " "config.json)", ) parser.add_argument( "--basic-sc", required=False, default=True, action="store_false", help="turn on symspell (baseline) correction instead of the more advanced neural net models", ) return parser if __name__ == "__main__": args = get_parser().parse_args() default_model = str(args.model) model_loc = cwd.parent / default_model model_loc = str(model_loc.resolve()) basic_sc = args.basic_sc if basic_sc: print("defaulting to symspell for spell checking") schnellspell = build_symspell_obj() else: print("using advanced spell checker (Neuspell)") ns_checker = load_ns_checker() print(f"using model stored here: \n {model_loc} \n") iface = gr.Interface( chat, inputs=["text"], outputs="html", examples=["What is Katy Perry's birthday?", "Who was the first person to walk on the moon?", "Name of the capital of France?", "How many icebergs are in the ocean?"], title=f"Ballpark Trivia: {default_model} Model", description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.", article="Further details can be found in the [model card](https://huggingface.co/pszemraj/Ballpark-Trivia-L). If you are interested in a more deceptively incorrect model, there is also [an XL version](https://huggingface.co/pszemraj/Ballpark-Trivia-XL) on my page.\n\n" "**Important Notes & About:**\n\n" "1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n" "2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n ", css=""" .chatbox {display:flex;flex-direction:column} .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} .user_msg {background-color:cornflowerblue;color:white;align-self:start} .resp_msg {background-color:lightgray;align-self:self-end} """, allow_screenshot=True, allow_flagging='auto', # enable_queue=True, # allows for dealing with multiple users simultaneously theme="darkhuggingface", ) # launch the gradio interface and start the server iface.launch( prevent_thread_lock=True, share=True, enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version) )