File size: 6,410 Bytes
e332ae1
6b53043
 
adebe05
e332ae1
b649bd8
 
 
 
 
 
 
 
 
 
6b53043
 
 
 
 
 
 
 
 
 
 
 
e332ae1
093b964
 
b649bd8
093b964
6b53043
b649bd8
6b53043
 
 
 
e332ae1
 
6b53043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e332ae1
 
 
 
6b53043
e332ae1
 
 
 
 
 
2a3e09e
e332ae1
 
6b53043
 
e332ae1
 
 
 
 
6b53043
 
 
e332ae1
6b53043
2a3e09e
e332ae1
 
 
 
2a3e09e
e332ae1
 
635e033
e332ae1
 
 
635e033
e332ae1
 
 
635e033
2a3e09e
971c338
635e033
 
 
 
 
 
 
 
 
6b53043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e332ae1
 
 
6b53043
 
 
 
 
 
 
 
e332ae1
 
 
 
2a3e09e
e770698
4922cde
dd38199
e332ae1
4922cde
e332ae1
 
 
 
 
 
adebe05
1cf48f8
62981f2
635e033
e332ae1
 
 
b649bd8
 
2a3e09e
b649bd8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
deploy-as-bot\gradio_chatbot.py
A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
Note that the URL is displayed once the script is run.
"""
from flask import (
    Flask,
    request,
    session,
    jsonify,
    abort,
    send_file,
    render_template,
    redirect,
)
from ai_single_response import query_gpt_model
from datetime import datetime
from transformers import pipeline
from cleantext import clean
from pathlib import Path
import warnings
import time
import argparse
import logging
import gradio as gr
import os
import sys
from os.path import dirname
import nltk

nltk.download("stopwords")  # still unsure where this error originates from

sys.path.append(dirname(dirname(os.path.abspath(__file__))))
# from gradio.networking import get_state, set_state
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
logging.basicConfig()
cwd = Path.cwd()
my_cwd = str(cwd.resolve())  # string so it can be passed to os.path() objects


def gramformer_correct(corrector, qphrase: str):
    """
    gramformer_correct - correct a string using a text2textgen pipeline model from transformers
    Args:
        corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
        qphrase (str): [text to be corrected]
    Returns:
        [str]: [corrected text]
    """
    try:
        corrected = corrector(
            clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
        )
        return corrected[0]["generated_text"]
    except:
        print("NOTE - failed to correct with gramformer")
        return clean(qphrase)


def ask_gpt(message: str):
    """
    ask_gpt - queries the relevant model with a prompt message and (optional) speaker name
    Args:
        message (str): prompt message to respond to
        sender (str, optional): speaker aka who said the message. Defaults to "".
    Returns:
        [str]: [model response as a string]
    """
    st = time.perf_counter()
    prompt = clean(message)  # clean user input
    prompt = prompt.strip()  # get rid of any extra whitespace
    if len(prompt) > 200:
        prompt = prompt[-200:]  # truncate

    resp = query_gpt_model(
        prompt_msg=prompt,
        speaker="person alpha",
        responder="person beta",
        kparam=150,
        temp=0.75,
        top_p=0.65,  # optimize this with hyperparam search
    )
    bot_resp = gramformer_correct(corrector, qphrase=resp["out_text"])
    rt = round(time.perf_counter() - st, 2)
    print(f"took {rt} sec to respond")
    return bot_resp


def chat(trivia_query):
    """
    chat - helper function that makes the whole gradio thing work.

    Args:
        first_and_last_name (str or None): [speaker of the prompt, if provided]
        message (str): [description]

    Returns:
        [str]: [returns an html string to display]
    """
    history = []
    response = ask_gpt(trivia_query)
    history.append(f"<b>{trivia_query}</b> <br> <br> <b>{response}</b>")
    gr.set_state(history)  # save the history
    html = ""
    for item in history:
        html += f"{item}"

    html += ""

    return html


def get_parser():
    """
    get_parser - a helper function for the argparse module
    Returns:
        [argparse.ArgumentParser]: [the argparser relevant for this script]
    """
    parser = argparse.ArgumentParser(
        description="submit a message and have a 774M parameter GPT model respond"
    )
    parser.add_argument(
        "--model",
        required=False,
        type=str,
        default="ballpark-trivia-L",
        help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
        "config.json). No models? Run the script download_models.py",
    )
    parser.add_argument(
        "--gram-model",
        required=False,
        type=str,
        default="prithivida/grammar_error_correcter_v1",
        help="text2text generation model ID from huggingface for the model to correct grammar",
    )
    return parser


if __name__ == "__main__":
    args = get_parser().parse_args()
    default_model = str(args.model)
    model_loc = cwd.parent / default_model
    model_loc = str(model_loc.resolve())
    gram_model = args.gram_model
    print(f"using model stored here: \n {model_loc} \n")
    corrector = pipeline("text2text-generation", model=gram_model, device=-1)
    print("Finished loading the gramformer model - ", datetime.now())
    iface = gr.Interface(
        chat,
        inputs=["text"],
        outputs="html",
        title=f"Ballpark Trivia: {default_model} Model",
        description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.",
        article="Further details can be found in the [model card](https: // huggingface.co/pszemraj/Ballpark-Trivia-L).  If you are interested in a more deceptively incorrect model, there is also [an XL version](https://huggingface.co/pszemraj/Ballpark-Trivia-XL) on my page.\n\n"
        "**Important Notes & About:**\n\n"
        "1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
        "2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n ",
        css="""
        .chatbox {display:flex;flex-direction:column}
        .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
        .user_msg {background-color:cornflowerblue;color:white;align-self:start}
        .resp_msg {background-color:lightgray;align-self:self-end}
    """,
        allow_screenshot=True,
        allow_flagging=False,
        enable_queue=True,  # allows for dealing with multiple users simultaneously
        theme="darkhuggingface",
    )

    # launch the gradio interface and start the server
    iface.launch(
        share=True,
        enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
    )