import random import time import torch import gradio as gr from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM from peft import PeftModel, PeftConfig from textwrap import wrap, fill # Functions to Wrap the Prompt Correctly def wrap_text(text, width=90): lines = text.split('\n') wrapped_lines = [fill(line, width=width) for line in lines] wrapped_text = '\n'.join(wrapped_lines) return wrapped_text def multimodal_prompt(user_input, system_prompt): """ Generates text using a large language model, given a user input and a system prompt. Args: user_input: The user's input text to generate a response for. system_prompt: Optional system prompt. Returns: A string containing the generated text in the Falcon-like format. """ # Combine user input and system prompt formatted_input = f"{{{{ {system_prompt} }}}}\nUser: {user_input}\nFalcon:" # Encode the input text encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False) model_inputs = encodeds.to(device) # Generate a response using the model output = peft_model.generate( **model_inputs, max_length=500, use_cache=True, early_stopping=False, bos_token_id=peft_model.config.bos_token_id, eos_token_id=peft_model.config.eos_token_id, pad_token_id=peft_model.config.eos_token_id, temperature=0.4, do_sample=True ) # Decode the response response_text = tokenizer.decode(output[0], skip_special_tokens=True) return response_text class ChatbotInterface(): def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): self.name = name self.system_prompt = system_prompt self.chatbot = gr.Chatbot() self.chat_history = [] with gr.Row() as row: row.justify = "end" self.msg = gr.Textbox(scale=7) #self.msg.change(fn=, inputs=, outputs=) self.submit = gr.Button("Submit", scale=1) clear = gr.ClearButton([self.msg, self.chatbot]) chat_history = [] self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot]) def respond(self, msg, history): #bot_message = random.choice(["Hello, I'm MedChat! How can I help you?", "Hello there! I'm Medchat, a medical assistant! How can I help you?"]) formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:" input_ids = tokenizer.encode( formatted_input, return_tensors="pt", add_special_tokens=False ) response = peft_model.generate( input_ids=input_ids, max_length=900, use_cache=False, early_stopping=False, bos_token_id=peft_model.config.bos_token_id, eos_token_id=peft_model.config.eos_token_id, pad_token_id=peft_model.config.eos_token_id, temperature=0.4, do_sample=True ) response_text = tokenizer.decode(response[0], skip_special_tokens=True) self.chat_history.append([formatted_input, response_text]) return "", self.chat_history if __name__ == "__main__": # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Use the base model's ID base_model_id = "tiiuae/falcon-7b-instruct" model_directory = "Tonic/GaiaMiniMed" # Instantiate the Tokenizer tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left") # Specify the configuration class for the model model_config = AutoConfig.from_pretrained(base_model_id) # Load the PEFT model with the specified configuration peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config) peft_model = PeftModel.from_pretrained(peft_model, model_directory) with gr.Blocks() as demo: with gr.Row() as intro: gr.Markdown( """ ## MedChat Welcome to MedChat, a medical assistant chatbot! You can currently chat with three chatbots that are trained on the same medical dataset. If you want to compare the output of each model, click the submit to all button and see the magic happen! """ ) with gr.Row() as row: with gr.Column() as col1: with gr.Tab("GaiaMinimed") as gaia: gaia_bot = ChatbotInterface("GaiaMinimed") with gr.Column() as col2: with gr.Tab("MistralMed") as mistral: mistral_bot = ChatbotInterface("MistralMed") with gr.Tab("Falcon-7B") as falcon7b: falcon_bot = ChatbotInterface("Falcon-7B") gaia_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg]) mistral_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg]) falcon_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg]) demo.launch()