""" Chatbot based on the notebook [How to generate text: using different decoding methods for language generation with Transformers](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb) and the blog post [Create conversational agents using BLOOM: Part-1](https://medium.com/@fractal.ai/create-conversational-agents-using-bloom-part-1-63a66e6321c0). This code needs testing, as it is not fitted for a production model. It's a very basic chatbot that uses Causal Language Models from Transformers given an PROMPT. An example of a basic PROMPT is given in the file `prompt.txt` for a Spanish prompt. """ import argparse import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, ) from typing import Optional, Union class ChatBot: """ Main class wrapper around the transformers models in order to build a basic chatbot application. Parameters ---------- base_model : str | PreTrainedModel A name (path in hugging face hub) for a model, or the model itself. tokenizer : PreTrainedTokenizerBase | None Needed in case the base_model is a given model, otherwise it will load the same model given by the base_model path. initial_prompt : str A prompt for the model. Should follow the example given in `BASE_PROMPT` keep_context : bool Whether to accumulate the context as the chatbot is used. creative : bool Whether to generate text through sampling (with some very basic config) or to go with greedy algorithm. Check the notebook "How to generate text" (link above) for more information. max_tokens : int Max number of tokens to generate in the chat. human_identifier : str The string that will identify the human speaker in the prompt (e.g. HUMAN). bot_identifier : str The string that will identify the bot speaker in the prompt (e.g. EXPERT). device: torch.device Device to run the model """ def __init__( self, base_model: Union[str, PreTrainedModel], tokenizer: Optional[PreTrainedTokenizerBase] = None, initial_prompt: Optional[str] = None, keep_context: bool = False, creative: bool = False, max_tokens: int = 50, human_identifier: str = "HUMAN", bot_identifier: str = "EXPERT", device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): if isinstance(base_model, str): self.model = AutoModelForCausalLM.from_pretrained( base_model, low_cpu_mem_usage=True, torch_dtype="auto" ).to(device) self.tokenizer = AutoTokenizer.from_pretrained(base_model) else: assert isinstance( tokenizer, PreTrainedTokenizerBase ), "If the base model is given, the tokenizer should be given as well" self.model = base_model.to(device) self.tokenizer = tokenizer if initial_prompt is None: with open("./prompt.txt", "r") as fh: self.initial_prompt = fh.read() else: self.initial_prompt = initial_prompt self.keep_context = keep_context self.context = "" self.creative = creative self.max_tokens = max_tokens self.human_identifier = human_identifier self.bot_identifier = bot_identifier self.device = device def chat(self, input_text: str) -> str: """ Generates a response from the prompt (and optionally the context) where it adds the `input_text` as if it was part of the HUMAN dialog (identified by `self.human_identifier`), and prompts the bot (identified by `self.bot_identifier`) for a response. As the bot might continue the conversation beyond the scope, it trims the output so it only shows the first dialog given by the bot, following the idea presented in the Medium blog post for creating conversational agents (link above). Parameters ---------- input_text : str The question asked/phrase prompted by a human. Returns ------- str The output given by the bot, trimmed for better control. """ # Setup the prompt given the initial prompt and add the words that # start the dialog between the human and the bot. Give space for the # model to continue from the prompt prompt = self.initial_prompt + self.context prompt += f"{self.human_identifier}: {input_text}\n" prompt += f"{self.bot_identifier}: " # check the space after the colon input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) if self.creative: # In case you want the bot to be creative, we sample using `top_k` # and `top_p` output = self.model.generate( input_ids, do_sample=True, max_length=input_ids.shape[1] + self.max_tokens, top_k=50, top_p=0.95, )[0] else: # Otherwise we return the most probable token output = self.model.generate( input_ids, max_length=input_ids.shape[1] + self.max_tokens )[0] # Decode the output, removing special tokens for the model (like # `[CLS]` and similar) decoded_output = self.tokenizer.decode(output, skip_special_tokens=True) # Trim the output, first by removing the original prompt trimmed_output = decoded_output[len(prompt) :] # Then we find the stop token, in this case the human identifier, and # we get up to that point trimmed_output = trimmed_output[: trimmed_output.find(f"{self.human_identifier}:")] if self.keep_context: # If we want to keep the context of the conversation we add the # trimmed output so far self.context += prompt + trimmed_output return trimmed_output.strip() # we only return the trimmed output if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model-name", "-m", default="bigscience/bloom-560m", help="Name of the base model to use for the chatbot", ) parser.add_argument( "--prompt", "-p", default="./prompt.txt", help="Path to the file with the prompt to use" ) parser.add_argument( "--keep-context", "-k", action="store_true", help="Keep context of the conversation." ) parser.add_argument( "--creative", "-c", action="store_true", help="Make the bot creative when answering." ) parser.add_argument( "--random-seed", "-r", default=42, help="Seed number for the creative bot.", type=int ) parser.add_argument( "--human-identifier", "-i", default="HUMANO", help="Name of the human identifier." ) parser.add_argument( "--bot-identifier", "-b", default="EXPERTO", help="Name of the bot identifier." ) args = parser.parse_args() torch.manual_seed(args.random_seed) with open(args.prompt, "r") as fh: initial_prompt = fh.read() chatbot = ChatBot( base_model=args.model_name, initial_prompt=initial_prompt, keep_context=args.keep_context, creative=args.creative, human_identifier=args.human_identifier, bot_identifier=args.bot_identifier, ) print("Write `exit` or `quit` to quit") while True: input_text = input("> ") if input_text == "exit" or input_text == "quit": break print(chatbot.chat(input_text))