File size: 7,860 Bytes
cad4540 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
"""
Chatbot based on the notebook [How to generate text: using different decoding
methods for language generation with
Transformers](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)
and the blog post [Create conversational agents using BLOOM:
Part-1](https://medium.com/@fractal.ai/create-conversational-agents-using-bloom-part-1-63a66e6321c0).
This code needs testing, as it is not fitted for a production model.
It's a very basic chatbot that uses Causal Language Models from Transformers
given an PROMPT.
An example of a basic PROMPT is given in the file `prompt.txt` for a Spanish
prompt.
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional, Union
class ChatBot:
"""
Main class wrapper around the transformers models in order to build a basic
chatbot application.
Parameters
----------
base_model : str | AutoModelForCausalLM
A name (path in hugging face hub) for a model, or the model itself.
tokenizer : AutoTokenizer | None
Needed in case the base_model is a given model, otherwise it will load
the same model given by the base_model path.
initial_prompt : str
A prompt for the model. Should follow the example given in
`BASE_PROMPT`
keep_context : bool
Whether to accumulate the context as the chatbot is used.
creative : bool
Whether to generate text through sampling (with some very basic config)
or to go with greedy algorithm. Check the notebook "How to generate
text" (link above) for more information.
max_tokens : int
Max number of tokens to generate in the chat.
human_identifier : str
The string that will identify the human speaker in the prompt (e.g.
HUMAN).
bot_identifier : str
The string that will identify the bot speaker in the prompt (e.g.
EXPERT).
"""
def __init__(self,
base_model: Union[str, AutoModelForCausalLM],
tokenizer: Optional[AutoTokenizer] = None,
initial_prompt: Optional[str] = None,
keep_context: bool = False,
creative: bool = False,
max_tokens: int = 50,
human_identifier: str = 'HUMAN',
bot_identifier: str = 'EXPERT'):
if isinstance(base_model, str):
self.model = AutoModelForCausalLM.from_pretrained(
base_model,
low_cpu_mem_usage=True,
torch_dtype='auto'
)
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
else:
assert isinstance(self.tokenizer, AutoTokenizer),\
"If the base model is given, the tokenizer should be given as well"
self.model = base_model
self.tokenizer = tokenizer
if initial_prompt is None:
with open('./prompt.txt', 'r') as fh:
self.initial_prompt = fh.read()
else:
self.initial_prompt = initial_prompt
self.keep_context = keep_context
self.context = ''
self.creative = creative
self.max_tokens = max_tokens
self.human_identifier = human_identifier
self.bot_identifier = bot_identifier
def chat(self, input_text):
"""
Generates a response from the prompt (and optionally the context) where
it adds the `input_text` as if it was part of the HUMAN dialog
(identified by `self.human_identifier`), and prompts the bot
(identified by `self.bot_identifier`) for a response. As the bot might
continue the conversation beyond the scope, it trims the output so it
only shows the first dialog given by the bot, following the idea
presented in the Medium blog post for creating conversational agents
(link above).
Parameters
----------
input_text : str
The question asked/phrase prompted by a human.
Returns
-------
str
The output given by the bot, trimmed for better control.
"""
# Setup the prompt given the initial prompt and add the words that
# start the dialog between the human and the bot. Give space for the
# model to continue from the prompt
prompt = self.initial_prompt + self.context
prompt += f'{self.human_identifier}: {input_text}\n'
prompt += f'{self.bot_identifier}: ' # check the space after the colon
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
if self.creative:
# In case you want the bot to be creative, we sample using `top_k`
# and `top_p`
output = self.model.generate(
input_ids,
do_sample=True,
max_length=input_ids.shape[1] + self.max_tokens,
top_k=50,
top_p=0.95
)[0]
else:
# Otherwise we return the most probable token
output = self.model.generate(
input_ids,
max_length=input_ids.shape[1] + self.max_tokens
)[0]
# Decode the output, removing special tokens for the model (like
# `[CLS]` and similar)
decoded_output = self.tokenizer.decode(output, skip_special_tokens=True)
# Trim the output, first by removing the original prompt
trimmed_output = decoded_output[len(prompt):]
# Then we find the stop token, in this case the human identifier, and
# we get up to that point
trimmed_output = trimmed_output[:trimmed_output.find(f'{self.human_identifier}:')]
if self.keep_context:
# If we want to keep the context of the conversation we add the
# trimmed output so far
self.context += prompt + trimmed_output
return trimmed_output.strip() # we only return the trimmed output
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model-name', '-m',
default='bigscience/bloom-560m',
help="Name of the base model to use for the chatbot")
parser.add_argument('--prompt', '-p',
default='./prompt.txt',
help="Path to the file with the prompt to use")
parser.add_argument('--keep-context', '-k',
action='store_true',
help="Keep context of the conversation.")
parser.add_argument('--creative', '-c',
action='store_true',
help="Make the bot creative when answering.")
parser.add_argument('--random-seed', '-r',
default=42,
help="Seed number for the creative bot.",
type=int)
parser.add_argument('--human-identifier', '-i',
default='HUMANO',
help="Name of the human identifier.")
parser.add_argument('--bot-identifier', '-b',
default='EXPERTO',
help="Name of the bot identifier.")
args = parser.parse_args()
torch.manual_seed(args.random_seed)
with open(args.prompt, 'r') as fh:
initial_prompt = fh.read()
chatbot = ChatBot(
base_model=args.model_name,
initial_prompt=initial_prompt,
keep_context=args.keep_context,
creative=args.creative,
human_identifier=args.human_identifier,
bot_identifier=args.bot_identifier
)
print("Write `exit` or `quit` to quit")
while True:
input_text = input('> ')
if input_text == 'exit' or input_text == 'quit':
break
print(chatbot.chat(input_text))
|