VisoLearn's picture
Update app.py
8140621 verified
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import time
phi4_model_path = "Intelligent-Internet/II-Medical-8B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto")
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
# This is our streaming generator function that yields partial results
@spaces.GPU(duration=60)
def generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history):
if not user_message.strip():
yield history, history
return
model = phi4_model
tokenizer = phi4_tokenizer
start_tag = "<|im_start|>"
sep_tag = "<|im_sep|>"
end_tag = "<|im_end|>"
system_message = """You are a medical assistant AI designed to help diagnose symptoms, explain possible conditions, and recommend next steps. You must be cautious, thorough, and explain medical reasoning step-by-step. Structure your answer in two sections:
<think> In this section, reason through the symptoms by considering patient history, differential diagnoses, relevant physiological mechanisms, and possible investigations. Explain your thought process step-by-step. </think>
In the Solution section, summarize your working diagnosis, differential options, and suggest what to do next (e.g., tests, referral, lifestyle changes). Always clarify that this is not a replacement for a licensed medical professional.
Use LaTeX for any formulas or values (e.g., $\\text{BMI} = \\frac{\\text{weight (kg)}}{\\text{height (m)}^2}$).
Now, analyze the following case:"""
# Build conversation history in the format the model expects
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
# Convert chat history format from the Gradio Chatbot format to prompt format
for user_msg, bot_msg in history:
if user_msg:
prompt += f"{start_tag}user{sep_tag}{user_msg}{end_tag}"
if bot_msg:
prompt += f"{start_tag}assistant{sep_tag}{bot_msg}{end_tag}"
# Add the current user message
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": int(max_tokens),
"do_sample": True,
"temperature": float(temperature),
"top_k": int(top_k),
"top_p": float(top_p),
"repetition_penalty": float(repetition_penalty),
"streamer": streamer,
}
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Create a new history with the current user message
new_history = history.copy() + [[user_message, ""]]
# Collect the generated response
assistant_response = ""
for new_token in streamer:
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
assistant_response += cleaned_token
# Update the last message in history with the current response
new_history[-1][1] = assistant_response.strip()
yield new_history, new_history
# Add a small sleep to control the streaming rate
time.sleep(0.01)
# Return the final state after streaming is completed
yield new_history, new_history
# This is our non-streaming wrapper function for buttons that don't support streaming
def process_input(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history):
generator = generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history)
# Get the final result by exhausting the generator
result = None
for result in generator:
pass
return result
example_messages = {
"Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.",
"Chest pain": "A 58-year-old male presents with chest tightness radiating to his left arm, shortness of breath, and sweating. Symptoms began while climbing stairs.",
"Abdominal pain": "A 24-year-old complains of right lower quadrant abdominal pain, nausea, and mild fever. The pain started around the belly button and migrated.",
"BMI calculation": "A patient weighs 85 kg and is 1.75 meters tall. Calculate the BMI and interpret whether it's underweight, normal, overweight, or obese."
}
css = """
.markdown-body .katex {
font-size: 1.2em;
}
.markdown-body .katex-display {
margin: 1em 0;
overflow-x: auto;
overflow-y: hidden;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("# Medical Diagnostic Assistant\nThis AI assistant helps analyze symptoms and provide preliminary diagnostic reasoning using LaTeX-rendered medical formulas where needed.")
gr.HTML("""
<script>
if (typeof window.MathJax === 'undefined') {
const script = document.createElement('script');
script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML';
script.async = true;
document.head.appendChild(script);
window.MathJax = {
tex2jax: {
inlineMath: [['$', '$']],
displayMath: [['$$', '$$']],
processEscapes: true
},
showProcessingMessages: false,
messageStyle: 'none'
};
}
function rerender() {
if (window.MathJax && window.MathJax.Hub) {
window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
}
}
setInterval(rerender, 1000);
</script>
""")
chatbot = gr.Chatbot(label="Chat", render_markdown=True, show_copy_button=True)
history = gr.State([])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens")
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
with gr.Column(scale=4):
with gr.Row():
user_input = gr.Textbox(label="Describe symptoms or ask a medical question", placeholder="Type your message here...", scale=3)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these examples:**")
with gr.Row():
example1 = gr.Button("Headache case")
example2 = gr.Button("Chest pain")
example3 = gr.Button("Abdominal pain")
example4 = gr.Button("BMI calculation")
# Set up the streaming interface
def on_submit(message, history, max_tokens, temperature, top_k, top_p, repetition_penalty):
# Return the modified history that includes the new user message
modified_history = history + [[message, ""]]
return "", modified_history, modified_history
def on_stream(history, max_tokens, temperature, top_k, top_p, repetition_penalty):
if not history:
return history
# Get the last user message from history
user_message = history[-1][0]
# Start a fresh history without the last entry
prev_history = history[:-1]
# Generate streaming responses
for new_history, _ in generate_streaming_response(
user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, prev_history
):
yield new_history
# Connect the submission event
submit_button.click(
fn=on_submit,
inputs=[user_input, history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
outputs=[user_input, chatbot, history]
).then(
fn=on_stream,
inputs=[history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
outputs=chatbot
)
# Handle examples
def set_example(example_text):
return gr.update(value=example_text)
clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history])
example1.click(fn=lambda: set_example(example_messages["Headache case"]), inputs=None, outputs=user_input)
example2.click(fn=lambda: set_example(example_messages["Chest pain"]), inputs=None, outputs=user_input)
example3.click(fn=lambda: set_example(example_messages["Abdominal pain"]), inputs=None, outputs=user_input)
example4.click(fn=lambda: set_example(example_messages["BMI calculation"]), inputs=None, outputs=user_input)
demo.launch(ssr_mode=False)