|
""" |
|
Chatbot based on the notebook [How to generate text: using different decoding |
|
methods for language generation with |
|
Transformers](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb) |
|
and the blog post [Create conversational agents using BLOOM: |
|
Part-1](https://medium.com/@fractal.ai/create-conversational-agents-using-bloom-part-1-63a66e6321c0). |
|
|
|
This code needs testing, as it is not fitted for a production model. |
|
|
|
It's a very basic chatbot that uses Causal Language Models from Transformers |
|
given an PROMPT. |
|
|
|
An example of a basic PROMPT is given in the file `prompt.txt` for a Spanish |
|
prompt. |
|
""" |
|
|
|
import argparse |
|
import torch |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase |
|
from typing import Optional, Union |
|
|
|
|
|
class ChatBot: |
|
""" |
|
Main class wrapper around the transformers models in order to build a basic |
|
chatbot application. |
|
|
|
Parameters |
|
---------- |
|
base_model : str | PreTrainedModel |
|
A name (path in hugging face hub) for a model, or the model itself. |
|
tokenizer : PreTrainedTokenizerBase | None |
|
Needed in case the base_model is a given model, otherwise it will load |
|
the same model given by the base_model path. |
|
initial_prompt : str |
|
A prompt for the model. Should follow the example given in |
|
`BASE_PROMPT` |
|
keep_context : bool |
|
Whether to accumulate the context as the chatbot is used. |
|
creative : bool |
|
Whether to generate text through sampling (with some very basic config) |
|
or to go with greedy algorithm. Check the notebook "How to generate |
|
text" (link above) for more information. |
|
max_tokens : int |
|
Max number of tokens to generate in the chat. |
|
human_identifier : str |
|
The string that will identify the human speaker in the prompt (e.g. |
|
HUMAN). |
|
bot_identifier : str |
|
The string that will identify the bot speaker in the prompt (e.g. |
|
EXPERT). |
|
""" |
|
|
|
def __init__(self, |
|
base_model: Union[str, PreTrainedModel], |
|
tokenizer: Optional[PreTrainedTokenizerBase] = None, |
|
initial_prompt: Optional[str] = None, |
|
keep_context: bool = False, |
|
creative: bool = False, |
|
max_tokens: int = 50, |
|
human_identifier: str = 'HUMAN', |
|
bot_identifier: str = 'EXPERT'): |
|
if isinstance(base_model, str): |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
low_cpu_mem_usage=True, |
|
torch_dtype='auto' |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
else: |
|
assert isinstance(tokenizer, PreTrainedTokenizerBase),\ |
|
"If the base model is given, the tokenizer should be given as well" |
|
self.model = base_model |
|
self.tokenizer = tokenizer |
|
|
|
if initial_prompt is None: |
|
with open('./prompt.txt', 'r') as fh: |
|
self.initial_prompt = fh.read() |
|
else: |
|
self.initial_prompt = initial_prompt |
|
|
|
self.keep_context = keep_context |
|
self.context = '' |
|
self.creative = creative |
|
self.max_tokens = max_tokens |
|
self.human_identifier = human_identifier |
|
self.bot_identifier = bot_identifier |
|
|
|
def chat(self, input_text): |
|
""" |
|
Generates a response from the prompt (and optionally the context) where |
|
it adds the `input_text` as if it was part of the HUMAN dialog |
|
(identified by `self.human_identifier`), and prompts the bot |
|
(identified by `self.bot_identifier`) for a response. As the bot might |
|
continue the conversation beyond the scope, it trims the output so it |
|
only shows the first dialog given by the bot, following the idea |
|
presented in the Medium blog post for creating conversational agents |
|
(link above). |
|
|
|
Parameters |
|
---------- |
|
input_text : str |
|
The question asked/phrase prompted by a human. |
|
|
|
Returns |
|
------- |
|
str |
|
The output given by the bot, trimmed for better control. |
|
""" |
|
|
|
|
|
|
|
prompt = self.initial_prompt + self.context |
|
prompt += f'{self.human_identifier}: {input_text}\n' |
|
prompt += f'{self.bot_identifier}: ' |
|
|
|
input_ids = self.tokenizer.encode(prompt, return_tensors='pt') |
|
if self.creative: |
|
|
|
|
|
output = self.model.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length=input_ids.shape[1] + self.max_tokens, |
|
top_k=50, |
|
top_p=0.95 |
|
)[0] |
|
else: |
|
|
|
output = self.model.generate( |
|
input_ids, |
|
max_length=input_ids.shape[1] + self.max_tokens |
|
)[0] |
|
|
|
|
|
|
|
decoded_output = self.tokenizer.decode(output, skip_special_tokens=True) |
|
|
|
|
|
trimmed_output = decoded_output[len(prompt):] |
|
|
|
|
|
|
|
trimmed_output = trimmed_output[:trimmed_output.find(f'{self.human_identifier}:')] |
|
|
|
if self.keep_context: |
|
|
|
|
|
self.context += prompt + trimmed_output |
|
|
|
return trimmed_output.strip() |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model-name', '-m', |
|
default='bigscience/bloom-560m', |
|
help="Name of the base model to use for the chatbot") |
|
parser.add_argument('--prompt', '-p', |
|
default='./prompt.txt', |
|
help="Path to the file with the prompt to use") |
|
parser.add_argument('--keep-context', '-k', |
|
action='store_true', |
|
help="Keep context of the conversation.") |
|
parser.add_argument('--creative', '-c', |
|
action='store_true', |
|
help="Make the bot creative when answering.") |
|
parser.add_argument('--random-seed', '-r', |
|
default=42, |
|
help="Seed number for the creative bot.", |
|
type=int) |
|
parser.add_argument('--human-identifier', '-i', |
|
default='HUMANO', |
|
help="Name of the human identifier.") |
|
parser.add_argument('--bot-identifier', '-b', |
|
default='EXPERTO', |
|
help="Name of the bot identifier.") |
|
|
|
args = parser.parse_args() |
|
|
|
torch.manual_seed(args.random_seed) |
|
|
|
with open(args.prompt, 'r') as fh: |
|
initial_prompt = fh.read() |
|
|
|
chatbot = ChatBot( |
|
base_model=args.model_name, |
|
initial_prompt=initial_prompt, |
|
keep_context=args.keep_context, |
|
creative=args.creative, |
|
human_identifier=args.human_identifier, |
|
bot_identifier=args.bot_identifier |
|
) |
|
|
|
print("Write `exit` or `quit` to quit") |
|
while True: |
|
input_text = input('> ') |
|
if input_text == 'exit' or input_text == 'quit': |
|
break |
|
print(chatbot.chat(input_text)) |
|
|