Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
import time | |
phi4_model_path = "Intelligent-Internet/II-Medical-8B" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto") | |
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
# This is our streaming generator function that yields partial results | |
def generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history): | |
if not user_message.strip(): | |
yield history, history | |
return | |
model = phi4_model | |
tokenizer = phi4_tokenizer | |
start_tag = "<|im_start|>" | |
sep_tag = "<|im_sep|>" | |
end_tag = "<|im_end|>" | |
system_message = """You are a medical assistant AI designed to help diagnose symptoms, explain possible conditions, and recommend next steps. You must be cautious, thorough, and explain medical reasoning step-by-step. Structure your answer in two sections: | |
<think> In this section, reason through the symptoms by considering patient history, differential diagnoses, relevant physiological mechanisms, and possible investigations. Explain your thought process step-by-step. </think> | |
In the Solution section, summarize your working diagnosis, differential options, and suggest what to do next (e.g., tests, referral, lifestyle changes). Always clarify that this is not a replacement for a licensed medical professional. | |
Use LaTeX for any formulas or values (e.g., $\\text{BMI} = \\frac{\\text{weight (kg)}}{\\text{height (m)}^2}$). | |
Now, analyze the following case:""" | |
# Build conversation history in the format the model expects | |
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
# Convert chat history format from the Gradio Chatbot format to prompt format | |
for user_msg, bot_msg in history: | |
if user_msg: | |
prompt += f"{start_tag}user{sep_tag}{user_msg}{end_tag}" | |
if bot_msg: | |
prompt += f"{start_tag}assistant{sep_tag}{bot_msg}{end_tag}" | |
# Add the current user message | |
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
generation_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_new_tokens": int(max_tokens), | |
"do_sample": True, | |
"temperature": float(temperature), | |
"top_k": int(top_k), | |
"top_p": float(top_p), | |
"repetition_penalty": float(repetition_penalty), | |
"streamer": streamer, | |
} | |
# Start generation in a separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Create a new history with the current user message | |
new_history = history.copy() + [[user_message, ""]] | |
# Collect the generated response | |
assistant_response = "" | |
for new_token in streamer: | |
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "") | |
assistant_response += cleaned_token | |
# Update the last message in history with the current response | |
new_history[-1][1] = assistant_response.strip() | |
yield new_history, new_history | |
# Add a small sleep to control the streaming rate | |
time.sleep(0.01) | |
# Return the final state after streaming is completed | |
yield new_history, new_history | |
# This is our non-streaming wrapper function for buttons that don't support streaming | |
def process_input(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history): | |
generator = generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history) | |
# Get the final result by exhausting the generator | |
result = None | |
for result in generator: | |
pass | |
return result | |
example_messages = { | |
"Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.", | |
"Chest pain": "A 58-year-old male presents with chest tightness radiating to his left arm, shortness of breath, and sweating. Symptoms began while climbing stairs.", | |
"Abdominal pain": "A 24-year-old complains of right lower quadrant abdominal pain, nausea, and mild fever. The pain started around the belly button and migrated.", | |
"BMI calculation": "A patient weighs 85 kg and is 1.75 meters tall. Calculate the BMI and interpret whether it's underweight, normal, overweight, or obese." | |
} | |
css = """ | |
.markdown-body .katex { | |
font-size: 1.2em; | |
} | |
.markdown-body .katex-display { | |
margin: 1em 0; | |
overflow-x: auto; | |
overflow-y: hidden; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.Markdown("# Medical Diagnostic Assistant\nThis AI assistant helps analyze symptoms and provide preliminary diagnostic reasoning using LaTeX-rendered medical formulas where needed.") | |
gr.HTML(""" | |
<script> | |
if (typeof window.MathJax === 'undefined') { | |
const script = document.createElement('script'); | |
script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML'; | |
script.async = true; | |
document.head.appendChild(script); | |
window.MathJax = { | |
tex2jax: { | |
inlineMath: [['$', '$']], | |
displayMath: [['$$', '$$']], | |
processEscapes: true | |
}, | |
showProcessingMessages: false, | |
messageStyle: 'none' | |
}; | |
} | |
function rerender() { | |
if (window.MathJax && window.MathJax.Hub) { | |
window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]); | |
} | |
} | |
setInterval(rerender, 1000); | |
</script> | |
""") | |
chatbot = gr.Chatbot(label="Chat", render_markdown=True, show_copy_button=True) | |
history = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Settings") | |
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens") | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") | |
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") | |
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") | |
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") | |
with gr.Column(scale=4): | |
with gr.Row(): | |
user_input = gr.Textbox(label="Describe symptoms or ask a medical question", placeholder="Type your message here...", scale=3) | |
submit_button = gr.Button("Send", variant="primary", scale=1) | |
clear_button = gr.Button("Clear", scale=1) | |
gr.Markdown("**Try these examples:**") | |
with gr.Row(): | |
example1 = gr.Button("Headache case") | |
example2 = gr.Button("Chest pain") | |
example3 = gr.Button("Abdominal pain") | |
example4 = gr.Button("BMI calculation") | |
# Set up the streaming interface | |
def on_submit(message, history, max_tokens, temperature, top_k, top_p, repetition_penalty): | |
# Return the modified history that includes the new user message | |
modified_history = history + [[message, ""]] | |
return "", modified_history, modified_history | |
def on_stream(history, max_tokens, temperature, top_k, top_p, repetition_penalty): | |
if not history: | |
return history | |
# Get the last user message from history | |
user_message = history[-1][0] | |
# Start a fresh history without the last entry | |
prev_history = history[:-1] | |
# Generate streaming responses | |
for new_history, _ in generate_streaming_response( | |
user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, prev_history | |
): | |
yield new_history | |
# Connect the submission event | |
submit_button.click( | |
fn=on_submit, | |
inputs=[user_input, history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], | |
outputs=[user_input, chatbot, history] | |
).then( | |
fn=on_stream, | |
inputs=[history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], | |
outputs=chatbot | |
) | |
# Handle examples | |
def set_example(example_text): | |
return gr.update(value=example_text) | |
clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history]) | |
example1.click(fn=lambda: set_example(example_messages["Headache case"]), inputs=None, outputs=user_input) | |
example2.click(fn=lambda: set_example(example_messages["Chest pain"]), inputs=None, outputs=user_input) | |
example3.click(fn=lambda: set_example(example_messages["Abdominal pain"]), inputs=None, outputs=user_input) | |
example4.click(fn=lambda: set_example(example_messages["BMI calculation"]), inputs=None, outputs=user_input) | |
demo.launch(ssr_mode=False) |