import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from threading import Thread import time phi4_model_path = "Intelligent-Internet/II-Medical-8B" device = "cuda:0" if torch.cuda.is_available() else "cpu" phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto") phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) # This is our streaming generator function that yields partial results @spaces.GPU(duration=60) def generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history): if not user_message.strip(): yield history, history return model = phi4_model tokenizer = phi4_tokenizer start_tag = "<|im_start|>" sep_tag = "<|im_sep|>" end_tag = "<|im_end|>" system_message = """You are a medical assistant AI designed to help diagnose symptoms, explain possible conditions, and recommend next steps. You must be cautious, thorough, and explain medical reasoning step-by-step. Structure your answer in two sections: In this section, reason through the symptoms by considering patient history, differential diagnoses, relevant physiological mechanisms, and possible investigations. Explain your thought process step-by-step. In the Solution section, summarize your working diagnosis, differential options, and suggest what to do next (e.g., tests, referral, lifestyle changes). Always clarify that this is not a replacement for a licensed medical professional. Use LaTeX for any formulas or values (e.g., $\\text{BMI} = \\frac{\\text{weight (kg)}}{\\text{height (m)}^2}$). Now, analyze the following case:""" # Build conversation history in the format the model expects prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" # Convert chat history format from the Gradio Chatbot format to prompt format for user_msg, bot_msg in history: if user_msg: prompt += f"{start_tag}user{sep_tag}{user_msg}{end_tag}" if bot_msg: prompt += f"{start_tag}assistant{sep_tag}{bot_msg}{end_tag}" # Add the current user message prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" inputs = tokenizer(prompt, return_tensors="pt").to(device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) generation_kwargs = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_new_tokens": int(max_tokens), "do_sample": True, "temperature": float(temperature), "top_k": int(top_k), "top_p": float(top_p), "repetition_penalty": float(repetition_penalty), "streamer": streamer, } # Start generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Create a new history with the current user message new_history = history.copy() + [[user_message, ""]] # Collect the generated response assistant_response = "" for new_token in streamer: cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "") assistant_response += cleaned_token # Update the last message in history with the current response new_history[-1][1] = assistant_response.strip() yield new_history, new_history # Add a small sleep to control the streaming rate time.sleep(0.01) # Return the final state after streaming is completed yield new_history, new_history # This is our non-streaming wrapper function for buttons that don't support streaming def process_input(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history): generator = generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history) # Get the final result by exhausting the generator result = None for result in generator: pass return result example_messages = { "Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.", "Chest pain": "A 58-year-old male presents with chest tightness radiating to his left arm, shortness of breath, and sweating. Symptoms began while climbing stairs.", "Abdominal pain": "A 24-year-old complains of right lower quadrant abdominal pain, nausea, and mild fever. The pain started around the belly button and migrated.", "BMI calculation": "A patient weighs 85 kg and is 1.75 meters tall. Calculate the BMI and interpret whether it's underweight, normal, overweight, or obese." } css = """ .markdown-body .katex { font-size: 1.2em; } .markdown-body .katex-display { margin: 1em 0; overflow-x: auto; overflow-y: hidden; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.Markdown("# Medical Diagnostic Assistant\nThis AI assistant helps analyze symptoms and provide preliminary diagnostic reasoning using LaTeX-rendered medical formulas where needed.") gr.HTML(""" """) chatbot = gr.Chatbot(label="Chat", render_markdown=True, show_copy_button=True) history = gr.State([]) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Settings") max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens") with gr.Accordion("Advanced Settings", open=False): temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") with gr.Column(scale=4): with gr.Row(): user_input = gr.Textbox(label="Describe symptoms or ask a medical question", placeholder="Type your message here...", scale=3) submit_button = gr.Button("Send", variant="primary", scale=1) clear_button = gr.Button("Clear", scale=1) gr.Markdown("**Try these examples:**") with gr.Row(): example1 = gr.Button("Headache case") example2 = gr.Button("Chest pain") example3 = gr.Button("Abdominal pain") example4 = gr.Button("BMI calculation") # Set up the streaming interface def on_submit(message, history, max_tokens, temperature, top_k, top_p, repetition_penalty): # Return the modified history that includes the new user message modified_history = history + [[message, ""]] return "", modified_history, modified_history def on_stream(history, max_tokens, temperature, top_k, top_p, repetition_penalty): if not history: return history # Get the last user message from history user_message = history[-1][0] # Start a fresh history without the last entry prev_history = history[:-1] # Generate streaming responses for new_history, _ in generate_streaming_response( user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, prev_history ): yield new_history # Connect the submission event submit_button.click( fn=on_submit, inputs=[user_input, history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], outputs=[user_input, chatbot, history] ).then( fn=on_stream, inputs=[history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], outputs=chatbot ) # Handle examples def set_example(example_text): return gr.update(value=example_text) clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history]) example1.click(fn=lambda: set_example(example_messages["Headache case"]), inputs=None, outputs=user_input) example2.click(fn=lambda: set_example(example_messages["Chest pain"]), inputs=None, outputs=user_input) example3.click(fn=lambda: set_example(example_messages["Abdominal pain"]), inputs=None, outputs=user_input) example4.click(fn=lambda: set_example(example_messages["BMI calculation"]), inputs=None, outputs=user_input) demo.launch(ssr_mode=False)