File size: 8,705 Bytes
e332ae1
99f4147
f28496d
e332ae1
f28496d
 
6b53043
 
 
 
 
 
 
 
 
 
e332ae1
093b964
b9d3930
c566631
d6b9e51
c566631
 
 
 
 
 
 
093b964
7681a60
63bb54c
7681a60
 
139439b
66d3a16
093b964
6b53043
 
f28496d
 
 
6b53043
 
 
e332ae1
0acac99
c3b001a
f28496d
c3b001a
f28496d
 
 
 
0acac99
f28496d
776c2d9
f28496d
 
 
 
 
 
 
 
5e7918b
f28496d
5e7918b
 
f28496d
e332ae1
f28496d
 
 
 
 
 
 
 
 
 
 
 
e332ae1
f28496d
2a3e09e
a9d9264
e332ae1
f28496d
 
 
c3b001a
f28496d
 
 
 
 
 
 
 
 
 
 
6b72548
f28496d
 
 
d6b9e51
 
 
 
 
 
47de588
d6b9e51
 
c3b001a
 
f28496d
 
 
 
915dc18
e332ae1
 
6b53043
 
 
 
 
c3b001a
6b53043
 
c566631
 
6b53043
 
f28496d
 
6b53043
2830ef7
47de588
2830ef7
4ab014a
 
47de588
2830ef7
50dfda4
f28496d
 
 
 
 
 
6b53043
e332ae1
 
 
6b53043
 
f28496d
 
57c24e9
f28496d
 
 
 
 
 
 
 
 
 
47de588
f28496d
50dfda4
47de588
f28496d
21e8f8d
50dfda4
6b53043
e332ae1
 
 
 
a9a06c0
61c1929
fe636ca
 
 
 
 
 
 
 
 
 
 
 
 
dbe5685
 
 
 
 
 
 
 
 
 
fe636ca
 
61c1929
2a3e09e
e770698
dbe5685
dd38199
e332ae1
d6b9e51
dbe5685
e332ae1
6088517
 
 
 
 
adebe05
6088517
6b72548
e332ae1
 
 
b649bd8
b052220
dbe5685
7681a60
7eab496
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
app.py - the main file for the app. This builds the app and runs it. 

"""

import torch
from transformers import pipeline
from cleantext import clean
from pathlib import Path
import warnings
import time
import argparse
import logging
import gradio as gr
import os
import sys
from os.path import dirname
import nltk
from converse import discussion
from grammar_improve import (
    detect_propers,
    load_ns_checker,
    neuspell_correct,
    remove_repeated_words,
    remove_trailing_punctuation,
    build_symspell_obj,
    symspeller,
)

from utils import (
    cleantxt_wrap,
    corr,
)

nltk.download("stopwords")  # TODO: find where this requirement originates from

sys.path.append(dirname(dirname(os.path.abspath(__file__))))
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
import transformers

transformers.logging.set_verbosity_error()
logging.basicConfig()
cwd = Path.cwd()
my_cwd = str(cwd.resolve())  # string so it can be passed to os.path() objects


def chat(trivia_query):
    history = []
    response = ask_gpt(message=trivia_query, chat_pipe=my_chatbot)
    history = [trivia_query, response]
    html = ""
    for item in history:
        html += f"<b>{item}</b> <br><br>"

    html += ""

    return html


def ask_gpt(
    message: str,
    chat_pipe,
    speaker="person alpha",
    responder="person beta",
    max_len=64,
    top_p=0.95,
    top_k=20,
    temperature=0.3,
):
    """

    ask_gpt - a function that takes in a prompt and generates a response using the pipeline. This interacts the discussion function.

    Parameters:
        message (str): the question to ask the bot
        chat_pipe (str): the chat_pipe to use for the bot (default: "pszemraj/Ballpark-Trivia-XL")
        speaker (str): the name of the speaker (default: "person alpha")
        responder (str): the name of the responder (default: "person beta")
        max_len (int): the maximum length of the response (default: 128)
        top_p (float): the top probability threshold (default: 0.95)
        top_k (int): the top k threshold (default: 50)
        temperature (float): the temperature of the response (default: 0.7)
    """

    st = time.perf_counter()
    prompt = clean(message)  # clean user input
    prompt = prompt.strip()  # get rid of any extra whitespace
    in_len = len(prompt)
    if in_len > 512:
        prompt = prompt[-512:]  # truncate to 512 chars
        print(f"Truncated prompt to last 512 chars: started with {in_len}")
        max_len = min(max_len, 512)

    resp = discussion(
        prompt_text=prompt,
        pipeline=chat_pipe,
        speaker=speaker,
        responder=responder,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        max_length=max_len,
        timeout=30,
    )
    gpt_et = time.perf_counter()
    gpt_rt = round(gpt_et - st, 2)
    rawtxt = resp["out_text"]
    # check for proper nouns
    if basic_sc and not detect_propers(rawtxt):
        cln_resp = symspeller(rawtxt, sym_checker=schnellspell)
    elif not detect_propers(rawtxt):
        cln_resp = neuspell_correct(rawtxt, checker=ns_checker)
    else:
        # no correction needed
        cln_resp = rawtxt.strip()
    bot_resp = corr(remove_repeated_words(cln_resp))
    print(f"\nthe prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n")
    corr_rt = round(time.perf_counter() - gpt_et, 4)
    print(
        f"took {gpt_rt + corr_rt} sec to respond, {gpt_rt} for GPT, {corr_rt} for correction\n"
    )
    return remove_trailing_punctuation(bot_resp)


def get_parser():
    """
    get_parser - a helper function for the argparse module
    """
    parser = argparse.ArgumentParser(
        description="submit a question, GPT model responds "
    )
    parser.add_argument(
        "-m",
        "--model",
        required=False,
        type=str,
        default="pszemraj/Ballpark-Trivia-XL",  # default model
        help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
    )
    parser.add_argument(
        "--basic-sc",
        required=False,
        default=False,
        action="store_true",
        help="turn on symspell (baseline) correction instead of the more advanced neural net models",
    )

    parser.add_argument(
        "--verbose",
        action="store_true",
        default=False,
        help="turn on verbose logging",
    )
    return parser


if __name__ == "__main__":
    args = get_parser().parse_args()
    default_model = str(args.model)
    model_loc = Path(default_model)  # if the model is a path, use it
    basic_sc = args.basic_sc  # whether to use the baseline spellchecker
    basic_sc  = True # TODO: remove once neuspell fixed
    device = 0 if torch.cuda.is_available() else -1
    print(f"CUDA avail is {torch.cuda.is_available()}")

    my_chatbot = (
        pipeline("text-generation", model=model_loc.resolve(), device=device)
        if model_loc.exists() and model_loc.is_dir()
        else pipeline("text-generation", model=default_model, device=device)
    )  # if the model is a name, use it. stays on CPU if no GPU available
    print(f"using model {my_chatbot.model}")

    if basic_sc:
        print("Using the baseline spellchecker")
        schnellspell = build_symspell_obj()
    else:
        print("using Neuspell spell checker")
        ns_checker = load_ns_checker(fast=False)

    print(f"using model stored here: \n {model_loc} \n")
    iface = gr.Interface(
        chat,
        inputs=["text"],
        outputs="html",
        examples_per_page=10,
        examples=[
            "Which President gave us the metric system?",
            "Who let the dogs out?",
            "Where does the term \"ground floor\" come from?",
            "What is the highest point on the globe?",
            "Why do we wear white clothes on our wedding days?",
            "What does the oval and squiggle on a US passport represent?",
            "Why is an electrical socket called a \"socket\", and not, say, a \"bottle\"?",
            "Where are the most active volcanoes on the earth?",
            "What is a cold-blood or cold-blooded animal?",
            "Why do we play volleyball on August 20th?",
            "What is water?",
            "Difference between U, V and W",
            "What is the official language of Vatican City?",
            "In what city is the CDC located?",
            "What are the names of the two major political parties in France?",
            "Who was Charles de Gaulle?",
            "Where is Stonehenge located?",
            "How many moons does Saturn have?",
            "Who invented the telescope?",
            "Who is your daddy and what does he do?",
            "When did Christopher Columbus come to America?",
            "Why are there interstate highways that have only one lane on each side?",
            "Which flavor of ice cream is the most popular in Switzerland?",
            "Who wrote The Jungle?",
            "Where were Benedict Arnold and Gen. Washington when the war started?",
        ],
        title=f"Ballpark Trivia: {default_model} Model",
        description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.",
        article="Further details can be found in the [model card](https://huggingface.co/pszemraj/Ballpark-Trivia-XL).\n\n"
        "**Important Notes & About:**\n\n"
        "1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
        "2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n"
        "3. Some params are still being tweaked (in future, will have them as inputs) any feedback is welcome   :)\n",
        css="""
            .chatbox {display:flex;flex-direction:column}
            .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
            .user_msg {background-color:cornflowerblue;color:white;align-self:start}
            .resp_msg {background-color:lightgray;align-self:self-end}
        """,
        allow_screenshot=True,
        allow_flagging="never",
        theme="dark",
    )

    # launch the gradio interface and start the server
    iface.launch(
        # prevent_thread_lock=True,
        # share=True,
        enable_queue=True,  # also allows for dealing with multiple users simultaneously (per newer gradio version)
    )