Spaces:
Paused
Paused
import gradio as gr | |
import os | |
import spaces | |
from transformers import GemmaTokenizer, AutoModelForCausalLM | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import transformers | |
import torch | |
from peft import PeftModel, PeftConfig | |
import os | |
from transformers import ( | |
BitsAndBytesConfig, | |
pipeline, | |
) | |
access_token = os.getenv('HF_TOKEN') | |
# Set an environment variable | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
DESCRIPTION = ''' | |
<div> | |
<h1 style="text-align: center;">PhysicianAI</h1> | |
</div> | |
''' | |
LICENSE = """ | |
<p/> | |
--- | |
Built with Mistral 7B Model | |
""" | |
PLACEHOLDER = """ | |
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
<img src="https://www.thesmartcityjournal.com/images/Imagenes-Articulos/2023_09_Septiembre/AI_in_healthcare.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; "> | |
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">PhysicianAI</h1> | |
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">I am a Medicle CHatBot...</p> | |
</div> | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
#duplicate-button { | |
margin: auto; | |
color: white; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
""" | |
compute_dtype = getattr(torch, "float16") | |
quant_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=compute_dtype, | |
bnb_4bit_use_double_quant=False, | |
) | |
#config = PeftConfig.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token) | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B",use_auth_token=access_token,quantization_config=quant_config,device_map="auto") | |
model = PeftModel.from_pretrained(model, "physician-ai/llama3-8b-finetuned",use_auth_token=access_token,quantization_config=quant_config,device_map="auto") | |
tokenizer = AutoTokenizer.from_pretrained("physician-ai/llama3-8b-finetuned",use_auth_token=access_token) | |
text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=4096, temperature=0.8, top_p=0.95, repetition_penalty=1.15) | |
terminators = [ | |
tokenizer.eos_token_id, | |
tokenizer.convert_tokens_to_ids("") | |
] | |
def generate_response(input_ids, generate_kwargs, output_queue): | |
try: | |
output = model.generate(**generate_kwargs) | |
output_queue.append(output) | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
output_queue.append(None) | |
def chat_llama3_8b(message, history, temperature=0.95, max_new_tokens=4096): | |
# System prompt | |
system_prompt = "You are an advanced medical assistant trained on extensive medical literature, including the highly authoritative Harrison Principles of Internal Medicine. Your role is to provide detailed and accurate medical information, diagnosis assistance, and guidance based on evidence and clinical best practices. Answer questions clearly, with a focus on providing practical and applicable medical advice. When relevant, cite specific chapters or sections from Harrison s Principles of Internal Medicine to support your responses." | |
# Prepare conversation context without including the system prompt in the final input | |
conversation = [system_prompt, message] + [msg for pair in history for msg in pair] | |
inputs = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) | |
generate_kwargs = { | |
"input_ids": inputs[:, 1:], # Exclude the system prompt from generation input | |
"max_length": inputs.shape[1] - 1 + max_new_tokens, # Adjust max length accordingly | |
"temperature": temperature, | |
"num_return_sequences": 1 | |
} | |
# Thread for generating model response | |
output_queue = [] | |
response_thread = Thread(target=generate_response, args=(inputs[:, 1:], generate_kwargs, output_queue)) | |
response_thread.start() | |
response_thread.join() # Wait for the thread to complete | |
# Retrieve the output from the queue | |
if output_queue: | |
output = output_queue[0] | |
if output is not None: | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
return "An error occurred during text generation." | |
# Gradio block | |
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface') | |
with gr.Blocks(fill_height=True, css=css) as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton(value="Finetuned LLAMA 3 8B Model", elem_id="duplicate-button") | |
gr.ChatInterface( | |
fn=chat_llama3_8b, | |
chatbot=chatbot, | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Slider(minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.95, | |
label="Temperature", | |
render=False), | |
gr.Slider(minimum=128, | |
maximum=4096, | |
step=1, | |
value=512, | |
label="Max new tokens", | |
render=False ), | |
], | |
examples=[ | |
["I've been experiencing persistent headaches, nausea, and sensitivity to light. What could be causing this?"], | |
["Based on my diagnosis of type 2 diabetes, what are the recommended treatment options? Should I consider medication, lifestyle changes, or both?"], | |
["I'm currently taking lisinopril for hypertension and atorvastatin for high cholesterol. Are there any potential interactions or side effects I should be aware of if I start taking ibuprofen for occasional pain relief?"], | |
["I'm in my early 40s and have a family history of heart disease. What are some preventive measures I can take to lower my risk, besides regular exercise and a healthy diet?"], | |
["Can you provide information on rheumatoid arthritis, including common symptoms, diagnostic tests, and available treatment options?"] | |
], | |
cache_examples=False, | |
) | |
gr.Markdown(LICENSE) | |
if __name__ == "__main__": | |
demo.launch() |