Tonic's picture
Update app.py
6ab3fad
raw
history blame
5.45 kB
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import torch
import gradio as gr
import json
import os
import shutil
import requests
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use model IDs as variables
base_model_id = "tiiuae/falcon-7b-instruct"
model_directory = "Tonic/GaiaMiniMed"
# Instantiate the Tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# Load the GaiaMiniMed model with the specified configuration
# Load the Peft model with a specific configuration
# Specify the configuration class for the model
model_config = AutoConfig.from_pretrained(base_model_id)
# Load the PEFT model with the specified configuration
peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config)
peft_model = PeftModel.from_pretrained(peft_model, model_directory)
# Class to encapsulate the Falcon chatbot
class FalconChatBot:
def __init__(self, system_prompt="You are an expert medical analyst:"):
self.system_prompt = system_prompt
def process_history(self, history):
if history is None:
return []
# Filter out special commands from the history
filtered_history = []
for message in history:
user_message = message["user"]
assistant_message = message["assistant"]
# Check if the user_message is not a special command
if not user_message.startswith("Falcon:"):
filtered_history.append({"user": user_message, "assistant": assistant_message})
return filtered_history
def predict(self, system_prompt, user_message, assistant_message, history, temperature, max_new_tokens, top_p, repetition_penalty):
# Process the history to remove special commands
processed_history = self.process_history(history)
# Combine the user and assistant messages into a conversation
conversation = f"{system_prompt}\nFalcon: {assistant_message if assistant_message else ''} User: {user_message}\nFalcon:\n"
# Encode the conversation using the tokenizer
input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
# Generate a response using the Falcon model
response_text = peft_model.generate(input_ids=input_ids, max_length=max_length, use_cache=True, early_stopping=True, bos_token_id=peft_model.config.bos_token_id, eos_token_id=peft_model.config.eos_token_id, pad_token_id=peft_model.config.eos_token_id, temperature=0.4, do_sample=True)
# Generate the formatted conversation in Falcon message format
conversation = f"{system_prompt}\n"
for message in processed_history:
user_message = message["user"]
assistant_message = message["assistant"]
conversation += f"Falcon:{' ' + assistant_message if assistant_message else ''} User: {user_message}\n Falcon:\n"
return response_text
# Create the Falcon chatbot instance
falcon_bot = FalconChatBot()
# Define the Gradio interface
title = "👋🏻Welcome to Tonic's 🦅Falcon's Medical👨🏻‍⚕️Expert Chat🚀"
description = "You can use this Space to test out the GaiaMiniMed model [(Tonic/GaiaMiniMed)](https://huggingface.co/Tonic/GaiaMiniMed) or duplicate this Space and use it locally or on 🤗HuggingFace. [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
examples = [
["Assistant is a public health and medical expert ready to help the user.", [{"user": "Hi there, I have a question!", "assistant": "My name is Gaia, I'm a health and sanitation expert ready to answer your medical questions."}]],
["Assistant is a public health and medical expert ready to help the user.", [{"user": "What is the proper treatment for buccal herpes?", "assistant": None}]]
]
additional_inputs=[
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=3000,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.01,
maximum=0.99,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
iface = gr.Interface(
fn=falcon_bot.predict,
title=title,
description=description,
examples=examples,
inputs=[
gr.inputs.Textbox(label="System Prompt", type="text", lines=2),
gr.inputs.Textbox(label="User Message", type="text", lines=3),
gr.inputs.Textbox(label="Assistant Message", type="text", lines=2),
] + additional_inputs,
outputs="text",
theme="ParityError/Anime"
)
# Launch the Gradio interface for the Falcon model
iface.launch()