|
""" |
|
Chatbot based on the notebook [How to generate text: using different decoding |
|
methods for language generation with |
|
Transformers](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb) |
|
and the blog post [Create conversational agents using BLOOM: |
|
Part-1](https://medium.com/@fractal.ai/create-conversational-agents-using-bloom-part-1-63a66e6321c0). |
|
|
|
This code needs testing, as it is not fitted for a production model. |
|
|
|
It's a very basic chatbot that uses Causal Language Models from Transformers |
|
given an PROMPT. |
|
|
|
An example of a basic PROMPT is given in the file `prompt.txt` for a Spanish |
|
prompt. |
|
""" |
|
|
|
import argparse |
|
import torch |
|
|
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
PreTrainedModel, |
|
PreTrainedTokenizerBase, |
|
) |
|
from typing import Optional, Union |
|
|
|
|
|
class ChatBot: |
|
""" |
|
Main class wrapper around the transformers models in order to build a basic |
|
chatbot application. |
|
|
|
Parameters |
|
---------- |
|
base_model : str | PreTrainedModel |
|
A name (path in hugging face hub) for a model, or the model itself. |
|
tokenizer : PreTrainedTokenizerBase | None |
|
Needed in case the base_model is a given model, otherwise it will load |
|
the same model given by the base_model path. |
|
initial_prompt : str |
|
A prompt for the model. Should follow the example given in |
|
`BASE_PROMPT` |
|
keep_context : bool |
|
Whether to accumulate the context as the chatbot is used. |
|
creative : bool |
|
Whether to generate text through sampling (with some very basic config) |
|
or to go with greedy algorithm. Check the notebook "How to generate |
|
text" (link above) for more information. |
|
max_tokens : int |
|
Max number of tokens to generate in the chat. |
|
human_identifier : str |
|
The string that will identify the human speaker in the prompt (e.g. |
|
HUMAN). |
|
bot_identifier : str |
|
The string that will identify the bot speaker in the prompt (e.g. |
|
EXPERT). |
|
device: torch.device |
|
Device to run the model |
|
""" |
|
|
|
def __init__( |
|
self, |
|
base_model: Union[str, PreTrainedModel], |
|
tokenizer: Optional[PreTrainedTokenizerBase] = None, |
|
initial_prompt: Optional[str] = None, |
|
keep_context: bool = False, |
|
creative: bool = False, |
|
max_tokens: int = 50, |
|
human_identifier: str = "HUMAN", |
|
bot_identifier: str = "EXPERT", |
|
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
|
): |
|
if isinstance(base_model, str): |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
base_model, low_cpu_mem_usage=True, torch_dtype="auto" |
|
).to(device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
else: |
|
assert isinstance( |
|
tokenizer, PreTrainedTokenizerBase |
|
), "If the base model is given, the tokenizer should be given as well" |
|
self.model = base_model.to(device) |
|
self.tokenizer = tokenizer |
|
|
|
if initial_prompt is None: |
|
with open("./prompt.txt", "r") as fh: |
|
self.initial_prompt = fh.read() |
|
else: |
|
self.initial_prompt = initial_prompt |
|
|
|
self.keep_context = keep_context |
|
self.context = "" |
|
self.creative = creative |
|
self.max_tokens = max_tokens |
|
self.human_identifier = human_identifier |
|
self.bot_identifier = bot_identifier |
|
self.device = device |
|
|
|
def chat(self, input_text: str) -> str: |
|
""" |
|
Generates a response from the prompt (and optionally the context) where |
|
it adds the `input_text` as if it was part of the HUMAN dialog |
|
(identified by `self.human_identifier`), and prompts the bot |
|
(identified by `self.bot_identifier`) for a response. As the bot might |
|
continue the conversation beyond the scope, it trims the output so it |
|
only shows the first dialog given by the bot, following the idea |
|
presented in the Medium blog post for creating conversational agents |
|
(link above). |
|
|
|
Parameters |
|
---------- |
|
input_text : str |
|
The question asked/phrase prompted by a human. |
|
|
|
Returns |
|
------- |
|
str |
|
The output given by the bot, trimmed for better control. |
|
""" |
|
|
|
|
|
|
|
prompt = self.initial_prompt + self.context |
|
prompt += f"{self.human_identifier}: {input_text}\n" |
|
prompt += f"{self.bot_identifier}: " |
|
|
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
|
if self.creative: |
|
|
|
|
|
output = self.model.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length=input_ids.shape[1] + self.max_tokens, |
|
top_k=50, |
|
top_p=0.95, |
|
)[0] |
|
else: |
|
|
|
output = self.model.generate( |
|
input_ids, max_length=input_ids.shape[1] + self.max_tokens |
|
)[0] |
|
|
|
|
|
|
|
decoded_output = self.tokenizer.decode(output, skip_special_tokens=True) |
|
|
|
|
|
trimmed_output = decoded_output[len(prompt) :] |
|
|
|
|
|
|
|
trimmed_output = trimmed_output[: trimmed_output.find(f"{self.human_identifier}:")] |
|
|
|
if self.keep_context: |
|
|
|
|
|
self.context += prompt + trimmed_output |
|
|
|
return trimmed_output.strip() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model-name", |
|
"-m", |
|
default="bigscience/bloom-560m", |
|
help="Name of the base model to use for the chatbot", |
|
) |
|
parser.add_argument( |
|
"--prompt", "-p", default="./prompt.txt", help="Path to the file with the prompt to use" |
|
) |
|
parser.add_argument( |
|
"--keep-context", "-k", action="store_true", help="Keep context of the conversation." |
|
) |
|
parser.add_argument( |
|
"--creative", "-c", action="store_true", help="Make the bot creative when answering." |
|
) |
|
parser.add_argument( |
|
"--random-seed", "-r", default=42, help="Seed number for the creative bot.", type=int |
|
) |
|
parser.add_argument( |
|
"--human-identifier", "-i", default="HUMANO", help="Name of the human identifier." |
|
) |
|
parser.add_argument( |
|
"--bot-identifier", "-b", default="EXPERTO", help="Name of the bot identifier." |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
torch.manual_seed(args.random_seed) |
|
|
|
with open(args.prompt, "r") as fh: |
|
initial_prompt = fh.read() |
|
|
|
chatbot = ChatBot( |
|
base_model=args.model_name, |
|
initial_prompt=initial_prompt, |
|
keep_context=args.keep_context, |
|
creative=args.creative, |
|
human_identifier=args.human_identifier, |
|
bot_identifier=args.bot_identifier, |
|
) |
|
|
|
print("Write `exit` or `quit` to quit") |
|
while True: |
|
input_text = input("> ") |
|
if input_text == "exit" or input_text == "quit": |
|
break |
|
print(chatbot.chat(input_text)) |
|
|