File size: 10,262 Bytes
e332ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2979040
ea23ea9
e332ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea23ea9
 
e332ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a948dd7
bd355c2
e332ae1
 
 
d2d61ad
e332ae1
 
9c6d59b
e332ae1
 
2979040
e58a27a
2979040
 
 
 
 
 
 
 
 
 
 
 
e332ae1
 
e58a27a
1851dba
e332ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e58a27a
b29bf7e
e58a27a
 
3fb4116
e58a27a
9c6d59b
1851dba
 
 
e58a27a
3bd3331
9c6d59b
e58a27a
e332ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987b742
e332ae1
 
 
2979040
e332ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ai_single_response.py

An executable way to call the model. example:
*\gpt2_chatbot> python ai_single_response.py --model "GPT2_conversational_355M_WoW10k" --prompt "hey, what's up?" --time

query_gpt_model is used throughout the code, and is the "fundamental" building block of the bot and how everything works. I would recommend testing this function with a few different models.

"""
import argparse
import pprint as pp
import sys
import time
import warnings
from datetime import datetime
from pathlib import Path
from grammar_improve import remove_trailing_punctuation
from utils import print_spacer, cleantxt_wrap

warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")

from aitextgen import aitextgen


def extract_response(full_resp: list, plist: list, verbose: bool = False):
    """
    extract_response - helper fn for ai_single_response.py. By default aitextgen returns the prompt and the response, we just want the response

    Args:
        full_resp (list): a list of strings, each string is a response
        plist (list): a list of strings, each string is a prompt

        verbose (bool, optional): 4 debug. Defaults to False.
    """
    full_resp = [cleantxt_wrap(ele) for ele in full_resp]
    plist = [cleantxt_wrap(pr) for pr in plist]
    p_len = len(plist)
    assert (
        len(full_resp) >= p_len
    ), "model output should have as many lines or longer as the input."

    if set(plist).issubset(full_resp):

        del full_resp[:p_len]  # remove the prompts from the responses
    else:
        print("the isolated responses are:\n")
        pp.pprint(full_resp)
        print_spacer()
        print("the input prompt was:\n")
        pp.pprint(plist)
        print_spacer()
        sys.exit("Exiting: some prompts not found in the responses")
    if verbose:
        print("the isolated responses are:\n")
        pp.pprint(full_resp)
        print_spacer()
        print("the input prompt was:\n")
        pp.pprint(plist)
        print_spacer()
    return full_resp  # list of only the model generated responses


def get_bot_response(
    name_resp: str, model_resp: str, name_spk: str, verbose: bool = False
):

    """

    get_bot_response  - from the model response, extract the bot response. This is needed because depending on the generation length the model may return more than one response.

    Args:   name_resp (str): the name of the responder
    model_resp (str): the model response
    verbose (bool, optional): 4 debug. Defaults to False.

    returns: fn_resp (list of str)
    """

    fn_resp = []

    name_counter = 0
    break_safe = False
    for resline in model_resp:
        if resline.startswith(name_resp):
            name_counter += 1
            break_safe = True  # know the line is from bot as this line starts with the name of the bot
            continue
        if name_spk is not None and name_spk.lower() in resline.lower():
            break
        if ":" in resline and name_counter > 0:
            if break_safe:
                # we know this is a response from the bot even tho ':' is in the line
                fn_resp.append(resline)
                break_safe = False
            else:
                # we do not know this is a response from the bot. could be name of another person.. bot is "finished" response
                break
        else:
            fn_resp.append(resline)
            break_safe = False
    if verbose:
        print("the full response is:\n")
        print("\n".join(fn_resp))

    return fn_resp


def query_gpt_model(
    prompt_msg: str,
    speaker=None,
    responder=None,
    resp_length=40,
    resp_min=10,
    kparam=150,
    temp=0.75,
    top_p=0.65,
    batch_size=64,
    verbose=False,
    use_gpu=False,
    beams=4,
):
    """
    query_gpt_model - the main function that calls the model.

    Parameters:
    -----------
    prompt_msg (str): the prompt to be sent to the model
    speaker (str, optional): the name of the speaker. Defaults to None.
    responder (str, optional): the name of the responder. Defaults to None.
    resp_length (int, optional): the length of the response. Defaults to 128.
    resp_min (int, optional): the minimum length of the response. Defaults to 4.
    kparam (int, optional): the k parameter for the top_p. Defaults to 150.
    temp (float, optional): the temperature for the top_p. Defaults to 0.75.
    top_p (float, optional): the top_p parameter for the top_p. Defaults to 0.65.
    verbose (bool, optional): 4 debug. Defaults to False.
    use_gpu (bool, optional): use gpu. Defaults to False.
    """
    ai = aitextgen(
        model="pszemraj/Ballpark-Trivia-L",  # THIS WORKS
        # model="pszemraj/Ballpark-Trivia-XL", # does not seem to work TODO: test further with after it loads
        to_gpu=use_gpu,
    )

    p_list = []  # track conversation
    p_list.append(speaker.lower() + ":" + "\n")
    p_list.append(prompt_msg.lower() + "\n")
    p_list.append("\n")
    p_list.append(responder.lower() + ":" + "\n")
    this_prompt = "".join(p_list)
    pr_len = len(this_prompt)
    if verbose:
        print("overall prompt:\n")
        pp.pprint(this_prompt, indent=4)
    # call the model
    print("\n... generating...")
    this_result = ai.generate(
        n=1,
        batch_size=batch_size,
        # the prompt input counts for text length constraints
        max_length=resp_length + pr_len,
        # min_length=resp_min + pr_len,
        prompt=this_prompt,
        # temperature=temp,
        top_k=kparam,
        top_p=top_p,
        do_sample=True,
        return_as_list=True,
        use_cache=True,

    )
    if verbose:
        print("\n... generated:\n")
        pp.pprint(this_result)  # for debugging
    # process the full result to get the ~bot response~ piece
    this_result = str(this_result[0]).split(
        "\n"
    )  # TODO: adjust hardcoded value for index to dynamic (if n>1)
    og_res = this_result.copy()
    og_prompt = p_list.copy()
    diff_list = extract_response(
        this_result, p_list, verbose=verbose
    )  # isolate the responses from the prompts
    # extract the bot response from the model generated text
    bot_dialogue = get_bot_response(
        name_resp=responder, model_resp=diff_list, name_spk=speaker, verbose=verbose
    )
    print(f"FOR DEBUG: {bot_dialogue}")
    bot_resp = ", ".join(bot_dialogue)
    bot_resp = bot_resp.strip()
    # remove the last ',' '.' chars
    bot_resp = remove_trailing_punctuation(bot_resp)
    if verbose:
        print("\n... bot response:\n")
        pp.pprint(bot_resp)
    og_prompt.append(bot_resp + "\n")
    og_prompt.append("\n")

    print("\nfinished!")
    # return the bot response and the full conversation

    return {"out_text": bot_resp, "full_conv": og_prompt}  # model responses


# Set up the parsing of command-line arguments
def get_parser():
    """
    get_parser [a helper function for the argparse module]

    Returns: argparse.ArgumentParser
    """

    parser = argparse.ArgumentParser(
        description="submit a message and have a pretrained GPT model respond"
    )
    parser.add_argument(
        "--prompt",
        required=True,  # MUST HAVE A PROMPT
        type=str,
        help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.",
    )
    parser.add_argument(
        "--model",
        required=False,
        type=str,
        default="GPT2_trivNatQAdailydia_774M_175Ksteps",
        help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
        "config.json). No models? Run the script download_models.py",
    )

    parser.add_argument(
        "--speaker",
        required=False,
        default=None,
        help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data",
    )
    parser.add_argument(
        "--responder",
        required=False,
        default="person beta",
        help="who the responder is. Primarily relevant to bots trained on multi-individual chat data",
    )

    parser.add_argument(
        "--topk",
        required=False,
        type=int,
        default=150,
        help="how many responses to sample (positive integer). lower = more random responses",
    )

    parser.add_argument(
        "--temp",
        required=False,
        type=float,
        default=0.75,
        help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'",
    )

    parser.add_argument(
        "--topp",
        required=False,
        type=float,
        default=0.65,
        help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?",
    )

    parser.add_argument(
        "--verbose",
        default=False,
        action="store_true",
        help="pass this argument if you want all the printouts",
    )
    parser.add_argument(
        "--time",
        default=False,
        action="store_true",
        help="pass this argument if you want to know runtime",
    )
    return parser


if __name__ == "__main__":
    # parse the command line arguments
    args = get_parser().parse_args()
    query = args.prompt
    model_dir = str(args.model)
    model_loc = Path.cwd() / model_dir
    spkr = args.speaker
    rspndr = args.responder
    k_results = args.topk
    my_temp = args.temp
    my_top_p = args.topp
    want_verbose = args.verbose
    want_rt = args.time

    st = time.perf_counter()

    resp = query_gpt_model(
        folder_path=model_loc,
        prompt_msg=query,
        speaker=spkr,
        responder=rspndr,
        kparam=k_results,
        temp=my_temp,
        top_p=my_top_p,
        verbose=want_verbose,
        use_gpu=False,
    )

    output = resp["out_text"]
    pp.pprint(output, indent=4)

    rt = round(time.perf_counter() - st, 1)

    if want_rt:
        print("took {runtime} seconds to generate. \n".format(runtime=rt))

    if want_verbose:
        print("finished - ", datetime.now())
        p_list = resp["full_conv"]
        print("A transcript of your chat is as follows: \n")
        p_list = [item.strip() for item in p_list]
        pp.pprint(p_list)