stablemed2 / app.py
vaishakgkumar's picture
Update app.py
2241485
raw history blame
No virus
2.64 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import PeftModel, PeftConfig
import gradio as gr
import os
import huggingface
from huggingface_hub import login
# using hf token to login
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
login(hf_token)
# Load tokenizer and model
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the base model's ID
base_model_id = "stabilityai/stablelm-3b-4e1t"
model_directory = "vaishakgkumar/stablemedv1"
# Instantiate the Tokenizer
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t", token=hf_token, trust_remote_code=True, padding_side="left")
# tokenizer = AutoTokenizer.from_pretrained("Tonic/stablemed", trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# Load the PEFT model
peft_config = PeftConfig.from_pretrained("vaishakgkumar/stablemedv1", token=hf_token)
peft_model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t", token=hf_token, trust_remote_code=True)
peft_model = PeftModel.from_pretrained(peft_model, "vaishakgkumar/stablemedv1", token=hf_token)
class ChatBot:
def __init__(self):
self.history = []
def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
# Combine user input and system prompt
formatted_input = f"{user_input}{system_prompt}"
# Encode user input
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
# Concatenate the user input with chat history
if len(self.history) > 0:
chat_history_ids = torch.cat([self.history, user_input_ids], dim=-1)
else:
chat_history_ids = user_input_ids
# Generate a response using the PEFT model
response = peft_model.generate(input_ids=chat_history_ids, max_length=1200, pad_token_id=tokenizer.eos_token_id)
# Update chat history
self.history = chat_history_ids
# Decode and return the response
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
return response_text
bot = ChatBot()
title = "👋🏻Welcome to StableLM MED chat"
description = """
"""
examples = [["What is the proper treatment for buccal herpes?", "Please provide information on the most effective antiviral medications and home remedies for treating buccal herpes."]]
iface = gr.Interface(
fn=bot.predict,
title=title,
description=description,
examples=examples,
inputs=["text", "text"],
outputs="text",
theme="ParityError/Anime"
)
iface.launch()