# Mistral model module for chat interaction and model instance control # external imports import re from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig import torch import gradio as gr # internal imports from utils import modelling as mdl from utils import formatting as fmt # global model and tokenizer instance (created on initial build) # determine if GPU is available and load model accordingly device = mdl.get_device() if device == torch.device("cuda"): n_gpus, max_memory, bnb_config = mdl.gpu_loading_config() MODEL = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.2", quantization_config=bnb_config, device_map="auto", max_memory={i: max_memory for i in range(n_gpus)}, ) # otherwise, load model on CPU else: MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") MODEL.to(device) # load tokenizer TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") # default model config CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") base_config_dict = { "temperature": 1, "max_new_tokens": 64, "top_p": 0.9, "repetition_penalty": 1.2, "do_sample": True, } CONFIG.update(**base_config_dict) # function to (re) set config def set_config(config_dict: dict): # if config dict is not given, set to default if config_dict == {}: config_dict = base_config_dict CONFIG.update(**config_dict) # advanced formatting function that takes into account a conversation history # CREDIT: adapted from the Mistral AI Instruct chat template # see https://github.com/chujiezheng/chat_templates/ def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""): prompt = "" # send information to the ui if knowledge is not empty if knowledge != "": gr.Info(""" Mistral doesn't support additional knowledge, it's gonna be ignored. """) # if no history, use system prompt and example message if len(history) == 0: prompt = f""" [INST] {system_prompt} [/INST] How can I help you today? [INST] {message} [/INST] """ else: # takes the very first exchange and the system prompt as base prompt = f""" [INST] {system_prompt} {history[0][0]} [/INST] {history[0][1]} """ # adds conversation history to the prompt for conversation in history[1:]: # takes all the following conversations and adds them as context prompt += "".join( f"\n[INST] {conversation[0]} [/INST] {conversation[1]}" ) prompt += """\n[INST] {message} [/INST]""" # returns full prompt return prompt # function to extract real answer because mistral always returns the full prompt def format_answer(answer: str): # empty answer string formatted_answer = "" # splitting answer by instruction tokens using re and a pattern pattern = r"\[/INST\]|\[ / INST\]|\[ / INST \]|\[/ INST \]" segments = re.split(pattern, answer) # checking if proper history got returned if len(segments) > 1: # return text after the last ['/INST'] - response to last message formatted_answer = segments[-1].strip() else: # return warning and full answer if not enough [/INST] tokens found gr.Warning(""" There was an issue with answer formatting...\n Returning the full answer. """) formatted_answer = answer return formatted_answer # response class calling the model and returning the model output message # CREDIT: Copied from official interference example on Huggingface # see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2 def respond(prompt: str): # setting config to default set_config({}) # tokenizing inputs and configuring model input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"].to(device) # generating text with tokenized input, returning output output_ids = MODEL.generate(input_ids, generation_config=CONFIG) output_text = TOKENIZER.batch_decode(output_ids) # formatting output text with special function output_text = fmt.format_output_text(output_text) # returning the model output string return format_answer(output_text)