crscardellino's picture
Fixed an error regarding the verification of a type
1e86311
raw
history blame
7.92 kB
"""
Chatbot based on the notebook [How to generate text: using different decoding
methods for language generation with
Transformers](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)
and the blog post [Create conversational agents using BLOOM:
Part-1](https://medium.com/@fractal.ai/create-conversational-agents-using-bloom-part-1-63a66e6321c0).
This code needs testing, as it is not fitted for a production model.
It's a very basic chatbot that uses Causal Language Models from Transformers
given an PROMPT.
An example of a basic PROMPT is given in the file `prompt.txt` for a Spanish
prompt.
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
from typing import Optional, Union
class ChatBot:
"""
Main class wrapper around the transformers models in order to build a basic
chatbot application.
Parameters
----------
base_model : str | PreTrainedModel
A name (path in hugging face hub) for a model, or the model itself.
tokenizer : PreTrainedTokenizerBase | None
Needed in case the base_model is a given model, otherwise it will load
the same model given by the base_model path.
initial_prompt : str
A prompt for the model. Should follow the example given in
`BASE_PROMPT`
keep_context : bool
Whether to accumulate the context as the chatbot is used.
creative : bool
Whether to generate text through sampling (with some very basic config)
or to go with greedy algorithm. Check the notebook "How to generate
text" (link above) for more information.
max_tokens : int
Max number of tokens to generate in the chat.
human_identifier : str
The string that will identify the human speaker in the prompt (e.g.
HUMAN).
bot_identifier : str
The string that will identify the bot speaker in the prompt (e.g.
EXPERT).
"""
def __init__(self,
base_model: Union[str, PreTrainedModel],
tokenizer: Optional[PreTrainedTokenizerBase] = None,
initial_prompt: Optional[str] = None,
keep_context: bool = False,
creative: bool = False,
max_tokens: int = 50,
human_identifier: str = 'HUMAN',
bot_identifier: str = 'EXPERT'):
if isinstance(base_model, str):
self.model = AutoModelForCausalLM.from_pretrained(
base_model,
low_cpu_mem_usage=True,
torch_dtype='auto'
)
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
else:
assert isinstance(tokenizer, PreTrainedTokenizerBase),\
"If the base model is given, the tokenizer should be given as well"
self.model = base_model
self.tokenizer = tokenizer
if initial_prompt is None:
with open('./prompt.txt', 'r') as fh:
self.initial_prompt = fh.read()
else:
self.initial_prompt = initial_prompt
self.keep_context = keep_context
self.context = ''
self.creative = creative
self.max_tokens = max_tokens
self.human_identifier = human_identifier
self.bot_identifier = bot_identifier
def chat(self, input_text):
"""
Generates a response from the prompt (and optionally the context) where
it adds the `input_text` as if it was part of the HUMAN dialog
(identified by `self.human_identifier`), and prompts the bot
(identified by `self.bot_identifier`) for a response. As the bot might
continue the conversation beyond the scope, it trims the output so it
only shows the first dialog given by the bot, following the idea
presented in the Medium blog post for creating conversational agents
(link above).
Parameters
----------
input_text : str
The question asked/phrase prompted by a human.
Returns
-------
str
The output given by the bot, trimmed for better control.
"""
# Setup the prompt given the initial prompt and add the words that
# start the dialog between the human and the bot. Give space for the
# model to continue from the prompt
prompt = self.initial_prompt + self.context
prompt += f'{self.human_identifier}: {input_text}\n'
prompt += f'{self.bot_identifier}: ' # check the space after the colon
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
if self.creative:
# In case you want the bot to be creative, we sample using `top_k`
# and `top_p`
output = self.model.generate(
input_ids,
do_sample=True,
max_length=input_ids.shape[1] + self.max_tokens,
top_k=50,
top_p=0.95
)[0]
else:
# Otherwise we return the most probable token
output = self.model.generate(
input_ids,
max_length=input_ids.shape[1] + self.max_tokens
)[0]
# Decode the output, removing special tokens for the model (like
# `[CLS]` and similar)
decoded_output = self.tokenizer.decode(output, skip_special_tokens=True)
# Trim the output, first by removing the original prompt
trimmed_output = decoded_output[len(prompt):]
# Then we find the stop token, in this case the human identifier, and
# we get up to that point
trimmed_output = trimmed_output[:trimmed_output.find(f'{self.human_identifier}:')]
if self.keep_context:
# If we want to keep the context of the conversation we add the
# trimmed output so far
self.context += prompt + trimmed_output
return trimmed_output.strip() # we only return the trimmed output
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model-name', '-m',
default='bigscience/bloom-560m',
help="Name of the base model to use for the chatbot")
parser.add_argument('--prompt', '-p',
default='./prompt.txt',
help="Path to the file with the prompt to use")
parser.add_argument('--keep-context', '-k',
action='store_true',
help="Keep context of the conversation.")
parser.add_argument('--creative', '-c',
action='store_true',
help="Make the bot creative when answering.")
parser.add_argument('--random-seed', '-r',
default=42,
help="Seed number for the creative bot.",
type=int)
parser.add_argument('--human-identifier', '-i',
default='HUMANO',
help="Name of the human identifier.")
parser.add_argument('--bot-identifier', '-b',
default='EXPERTO',
help="Name of the bot identifier.")
args = parser.parse_args()
torch.manual_seed(args.random_seed)
with open(args.prompt, 'r') as fh:
initial_prompt = fh.read()
chatbot = ChatBot(
base_model=args.model_name,
initial_prompt=initial_prompt,
keep_context=args.keep_context,
creative=args.creative,
human_identifier=args.human_identifier,
bot_identifier=args.bot_identifier
)
print("Write `exit` or `quit` to quit")
while True:
input_text = input('> ')
if input_text == 'exit' or input_text == 'quit':
break
print(chatbot.chat(input_text))