Dev2_new / app.py
neuralleap's picture
Update app.py
6b9591f verified
raw
history blame
No virus
5.66 kB
import gradio as gr
import os
import spaces
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import transformers
import torch
from peft import PeftModel, PeftConfig
import os
from transformers import (
BitsAndBytesConfig,
pipeline,
)
access_token = os.getenv('HF_TOKEN')
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">PhysicianAI</h1>
</div>
'''
LICENSE = """
<p/>
---
Built with Mistral 7B Model
"""
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://www.thesmartcityjournal.com/images/Imagenes-Articulos/2023_09_Septiembre/AI_in_healthcare.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">PhysicianAI</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">I am a Medicle CHatBot...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
compute_dtype = getattr(torch, "float16")
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=False,
)
#config = PeftConfig.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",use_auth_token=access_token,quantization_config=quant_config,device_map="auto")
model = PeftModel.from_pretrained(model, "physician-ai/mistral-finetuned1",use_auth_token=access_token)
tokenizer = AutoTokenizer.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, temperature=0.8, top_p=0.95, repetition_penalty=1.15)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
def generate_response(input_ids, generate_kwargs, output_queue):
try:
output = model.generate(**generate_kwargs)
output_queue.append(output)
except Exception as e:
print(f"Error during generation: {e}")
output_queue.append(None)
@spaces.GPU(duration=120)
def chat_llama3_8b(message, history, temperature=0.95, max_new_tokens=512):
# Prepare conversation context
conversation = [message] + [msg for pair in history for msg in pair]
inputs = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
generate_kwargs = {
"input_ids": inputs,
"max_length": inputs.shape[1] + max_new_tokens,
"temperature": temperature,
"num_return_sequences": 1
}
# Thread for generating model response
output_queue = []
response_thread = Thread(target=generate_response, args=(inputs, generate_kwargs, output_queue))
response_thread.start()
response_thread.join() # Wait for the thread to complete
# Retrieve the output from the queue
if output_queue:
output = output_queue[0]
if output is not None:
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
return "An error occurred during text generation."
# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Used Finetuned Mistral 7B Model", elem_id="duplicate-button")
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.95,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False ),
],
examples=[
["I've been experiencing persistent headaches, nausea, and sensitivity to light. What could be causing this?"],
["Based on my diagnosis of type 2 diabetes, what are the recommended treatment options? Should I consider medication, lifestyle changes, or both?"],
["I'm currently taking lisinopril for hypertension and atorvastatin for high cholesterol. Are there any potential interactions or side effects I should be aware of if I start taking ibuprofen for occasional pain relief?"],
["I'm in my early 40s and have a family history of heart disease. What are some preventive measures I can take to lower my risk, besides regular exercise and a healthy diet?"],
["Can you provide information on rheumatoid arthritis, including common symptoms, diagnostic tests, and available treatment options?"]
],
cache_examples=False,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()