from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Optional, Union class ChatBot: """ Chatbot based on the notebook [How to generate text: using different decoding methods for language generation with Transformers](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb) and the blog post [Create conversational agents using BLOOM: Part-1](https://medium.com/@fractal.ai/create-conversational-agents-using-bloom-part-1-63a66e6321c0). This code needs testing, as it is not fitted for a production model. It's a very basic chatbot that uses Causal Language Models from Transformers given an PROMPT. An example of a basic PROMPT is given by the BASE_PROMPT attribute. Parameters ---------- base_model : str | AutoModelForCausalLM A name (path in hugging face hub) for a model, or the model itself. tokenizer : AutoTokenizer | None Needed in case the base_model is a given model, otherwise it will load the same model given by the base_model path. initial_prompt : str A prompt for the model. Should follow the example given in `BASE_PROMPT` keep_context : bool Whether to accumulate the context as the chatbot is used. creative : bool Whether to generate text through sampling (with some very basic config) or to go with greedy algorithm. Check the notebook "How to generate text" (link above) for more information. max_tokens : int Max number of tokens to generate in the chat. human_identifier : str The string that will identify the human speaker in the prompt (e.g. HUMAN). bot_identifier : str The string that will identify the bot speaker in the prompt (e.g. EXPERT). """ BASE_PROMPT = """ The following is a conversation with a movie EXPERT. The EXPERT helps the HUMAN define their personal preferences and provide multiple options to select from, it also helps in selecting the best option. The EXPERT is conversational, optimistic, flexible, empathetic, creative and humanly in generating responses. HUMAN: Hello, how are you? EXPERT: Fine, thanks. I am here to help you by recommending movies. """.strip() def __init__(self, base_model: Union[str, AutoModelForCausalLM], tokenizer: Optional[AutoTokenizer] = None, initial_prompt: Optional[str] = None, keep_context: bool = False, creative: bool = False, max_tokens: int = 50, human_identifier: str = "HUMAN", bot_identifier: str = "EXPERT"): if isinstance(base_model, str): self.model = AutoModelForCausalLM.from_pretrained( base_model, low_cpu_mem_usage=True, torch_dtype='auto' ) self.tokenizer = AutoTokenizer.from_pretrained(base_model) else: assert isinstance(self.tokenizer, AutoTokenizer),\ "If the base model is given, the tokenizer should be given as well" self.model = base_model self.tokenizer = tokenizer self.initial_prompt = initial_prompt if initial_prompt is not None else self.BASE_PROMPT self.keep_context = keep_context self.context = '' self.creative = creative self.max_tokens = max_tokens self.human_identifier = human_identifier self.bot_identifier = bot_identifier def chat(self, input_text): """ Generates a response from the prompt (and optionally the context) where it adds the `input_text` as if it was part of the HUMAN dialog (identified by `self.human_identifier`), and prompts the bot (identified by `self.bot_identifier`) for a response. As the bot might continue the conversation beyond the scope, it trims the output so it only shows the first dialog given by the bot, following the idea presented in the Medium blog post for creating conversational agents (link above). Parameters ---------- input_text : str The question asked/phrase prompted by a human. Returns ------- str The output given by the bot, trimmed for better control. """ prompt = self.initial_prompt + self.context prompt += f'{self.human_identifier}: {input_text}\n' prompt += f'{self.bot_identifier}: ' input_ids = self.tokenizer.encode(prompt, return_tensors='pt') if self.creative: output = self.model.generate( input_ids, do_sample=True, max_length=input_ids.shape[1] + self.max_tokens, top_k=50, top_p=0.95, num_return_sequences=1 )[0] else: output = self.model.generate( input_ids, max_length=input_ids.shape[1] + self.max_tokens )[0] decoded_output = self.tokenizer.decode(output, skip_special_tokens=True) trimmed_output = decoded_output[len(prompt):] trimmed_output = trimmed_output[:trimmed_output.find(f'{self.human_identifier}:')] if self.keep_context: self.context += trimmed_output return trimmed_output.strip()