Spaces:
Runtime error
Runtime error
File size: 10,850 Bytes
74b8229 b4c0306 b67934e 74b8229 950a38f 74b8229 b67934e 74b8229 a5629da 0b3d061 74b8229 a04dbc6 74b8229 b4c0306 74b8229 a5629da 74b8229 b67934e b4c0306 b67934e 950a38f a5629da 950a38f b4c0306 950a38f b67934e 950a38f b67934e 950a38f a5629da 950a38f a5629da 74b8229 a5629da 74b8229 b4c0306 74b8229 b67934e a04dbc6 74b8229 a04dbc6 74b8229 b67934e a04dbc6 74b8229 0b3d061 74b8229 a738f02 0b3d061 74b8229 0b3d061 74b8229 a04dbc6 6e1a316 74b8229 6e1a316 74b8229 38ca40a 74b8229 38ca40a 74b8229 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
"""
converse.py - this script has functions for handling the conversation between the user and the bot.
https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
"""
import logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
import pprint as pp
import time
from grammar_improve import remove_trailing_punctuation
from constrained_generation import constrained_generation
def discussion(
prompt_text: str,
speaker: str,
responder: str,
pipeline,
timeout=45,
min_length=8,
max_length=64,
top_p=0.95,
top_k=50,
temperature=0.7,
full_text=False,
length_penalty=0.8,
no_repeat_ngram_size=2,
num_return_sequences=1,
device=-1,
verbose=False,
constrained_beam_search=False,
):
"""
discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
Parameters
----------
prompt_text : str, the prompt to ask the bot, usually the user's question
speaker : str, the name of the person who is speaking the prompt
responder : str, the name of the person who is responding to the prompt
pipeline : transformers.Pipeline, the pipeline to use for generating the response
timeout : int, optional, the number of seconds to wait before timing out, by default 45
max_length : int, optional, the maximum number of tokens to generate, defaults to 128
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
top_k : int, optional, the top k to use for sampling, defaults to 50
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
num_return_sequences : int, optional, the number of sequences to return, defaults to 1
device : int, optional, the device to use for generation, defaults to -1 (CPU)
verbose : bool, optional, whether to print the generated text, defaults to False
Returns
-------
str, the generated text
"""
logging.debug(f"input args: {locals()}")
p_list = [] # track conversation
p_list.append(speaker.lower() + ":" + "\n")
p_list.append(prompt_text.lower() + "\n")
p_list.append("\n")
p_list.append(responder.lower() + ":" + "\n")
this_prompt = "".join(p_list)
if verbose:
print("overall prompt:\n")
pp.pprint(this_prompt, indent=4)
if constrained_beam_search:
logging.info("generating using constrained beam search ...")
response = constrained_generation(
prompt=this_prompt,
pipeline=pipeline,
min_generated_tokens=min_length,
max_generated_tokens=max_length,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
repetition_penalty=1.0,
num_beams=4,
timeout=timeout,
verbose=False,
full_text=full_text,
speaker_name=speaker,
responder_name=responder,
)
bot_dialogue = consolidate_texts(
name_resp=responder,
model_resp=response.split("\n"),
name_spk=speaker,
verbose=verbose,
print_debug=True,
)
else:
logging.info("generating using sampling ...")
bot_dialogue = gen_response(
this_prompt,
pipeline,
speaker,
responder,
timeout=timeout,
min_length=min_length,
max_length=max_length,
top_p=top_p,
top_k=top_k,
temperature=temperature,
full_text=full_text,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
device=device,
verbose=verbose,
)
logging.debug(f"generation done. bot_dialogue: {bot_dialogue}")
if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
bot_resp = ", ".join(bot_dialogue)
elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
bot_resp = bot_dialogue[0]
else:
bot_resp = bot_dialogue
bot_resp = " ".join(bot_resp) if isinstance(bot_resp, list) else bot_resp
bot_resp = bot_resp.strip()
# remove the last ',' '.' chars
bot_resp = remove_trailing_punctuation(bot_resp)
if verbose:
print("\nfinished!")
print("\n... bot response:\n")
pp.pprint(bot_resp)
p_list.append(bot_resp + "\n")
p_list.append("\n")
logging.info(f"finished generating response:\n\t{bot_resp}")
# return the bot response and the full conversation
return {"out_text": bot_resp, "full_conv": p_list}
def gen_response(
query: str,
pipeline,
speaker: str,
responder: str,
timeout=45,
min_length=12,
max_length=48,
top_p=0.95,
top_k=20,
temperature=0.5,
full_text=False,
num_return_sequences=1,
length_penalty: float = 0.8,
repetition_penalty: float = 3.5,
no_repeat_ngram_size=2,
device=-1,
verbose=False,
**kwargs,
):
"""
gen_response - a function that takes in a prompt and generates a response using the pipeline. This operates underneath the discussion function.
Parameters
----------
query : str, the prompt to ask the bot, usually the user's question
speaker : str, the name of the person who is speaking the prompt
responder : str, the name of the person who is responding to the prompt
pipeline : transformers.Pipeline, the pipeline to use for generating the response
timeout : int, optional, the number of seconds to wait before timing out, by default 45
min_length : int, optional, the minimum number of tokens to generate, defaults to 4
max_length : int, optional, the maximum number of tokens to generate, defaults to 64
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
top_k : int, optional, the top k to use for sampling, defaults to 50
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
num_return_sequences : int, optional, the number of sequences to return, defaults to 1
device : int, optional, the device to use for generation, defaults to -1 (CPU)
verbose : bool, optional, whether to print the generated text, defaults to False
Returns
-------
str, the generated text
"""
logging.debug(f"input args - gen_response() : {locals()}")
input_len = len(pipeline.tokenizer(query).input_ids)
if max_length + input_len > 1024:
max_length = max(1024 - input_len, 8)
print(f"max_length too large, setting to {max_length}")
st = time.perf_counter()
response = pipeline(
query,
min_length=min_length + input_len,
max_length=max_length + input_len,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_return_sequences=num_return_sequences,
max_time=timeout,
return_full_text=full_text,
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
clean_up_tokenization_spaces=True,
remove_invalid_values=True,
**kwargs,
) # the likely better beam-less method
rt = round(time.perf_counter() - st, 2)
if verbose:
print(f"took {rt} sec to respond")
if verbose:
print("\n[DEBUG] generated:\n")
pp.pprint(response) # for debugging
# process the full result to get the ~bot response~ piece
this_result = str(response[0]["generated_text"]).split(
"\n"
) # TODO: adjust hardcoded value for index to dynamic (if n>1)
bot_dialogue = consolidate_texts(
name_resp=responder,
model_resp=this_result,
name_spk=speaker,
verbose=verbose,
print_debug=True,
)
if verbose:
print(f"DEBUG: {bot_dialogue} was original response pre-SC")
return bot_dialogue #
def consolidate_texts(
model_resp: list,
name_resp: str = None,
name_spk: str = None,
verbose=False,
print_debug=False,
):
"""
consolidate_texts - given a list with speaker name followed by speaker text, returns all consecutive values of the first speaker name
Parameters:
name_resp (str): the name of the person who is responding
model_resp (list): the list of strings to consolidate (usually from the model)
name_spk (str): the name of the person who is speaking
verbose (bool): whether to print the results
print_debug (bool): whether to print the debug info during looping
Returns:
list, a list of all the consecutive messages of the first speaker name
"""
assert len(model_resp) > 0, "model_resp is empty"
if len(model_resp) == 1:
return model_resp[0]
name_resp = "person beta" if name_resp is None else name_resp
name_spk = "person alpha" if name_spk is None else name_spk
if verbose:
print("====" * 10)
print(
f"\n[DEBUG] initial model_resp has {len(model_resp)} lines: \n\t{model_resp}"
)
print(
f" the first element is \n\t{model_resp[0]} and it is {type(model_resp[0])}"
)
fn_resp = []
name_counter = 0
break_safe = False
for resline in model_resp:
if name_resp.lower() in resline:
name_counter += 1
break_safe = True # know the line is from bot as this line starts with the name of the bot
continue # don't add this line to the list
if name_spk.lower() in resline.lower():
if print_debug:
print(f"\nDEBUG: \n\t{resline}\ncaused the break")
break # the name of the speaker is in the line, so we're done
if (
any([": " in resline, ":\n" in resline])
and name_resp.lower() not in resline.lower()
):
if print_debug:
print(f"\nDEBUG: \n\t{resline}\ncaused the break")
break
else:
fn_resp.append(resline)
break_safe = False
if verbose:
print("--" * 10)
print("\nthe full response is:\n")
print("\n".join(fn_resp))
print("--" * 10)
return fn_resp
|