hponepyae's picture
Update UI, color of input box
b3f457b verified
raw
history blame
7.71 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces
# --- Initialize the Model Pipeline (No changes) ---
print("Loading MedGemma model...")
try:
pipe = pipeline(
"image-text-to-text",
model="google/medgemma-4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
model_loaded = False
print(f"Error loading model: {e}")
# --- Core CONVERSATIONAL Logic (Modified for Streaming) ---
@spaces.GPU()
def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list):
"""
Manages a single conversation turn and streams the AI response back.
This function is now a Python generator.
"""
if not model_loaded:
history[-1] = (user_input, "Error: The AI model is not loaded.")
yield history, history, None
return
try:
system_prompt = (
"You are an expert, empathetic AI medical assistant conducting a virtual consultation. "
"Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. "
"Do NOT provide a diagnosis or a list of possibilities right away. Ask only one or two focused questions per turn. "
"If the user provides an image, your first step is to analyze it from an expert perspective. Briefly describe the key findings from the image. "
"Then, use this analysis to ask relevant follow-up questions about the user's symptoms or medical history to better understand the context. "
"For example, after seeing a rash, you might say, 'I see a reddish rash with well-defined borders on the forearm. To help me understand more, could you tell me when you first noticed this? Is it itchy, painful, or does it have any other sensation?'"
"After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. "
"THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan."
)
generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7}
ai_response = ""
if user_image:
# ... (logic remains the same)
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
for user_msg, assistant_msg in history[:-1]:
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
if assistant_msg: messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
latest_user_content = [{"type": "text", "text": user_input}, {"type": "image", "image": user_image}]
messages.append({"role": "user", "content": latest_user_content})
output = pipe(text=messages, **generation_args)
ai_response = output[0]["generated_text"][-1]["content"]
else:
# ... (logic remains the same)
prompt_parts = [f"<start_of_turn>system\n{system_prompt}"]
for user_msg, assistant_msg in history[:-1]:
prompt_parts.append(f"<start_of_turn>user\n{user_msg}")
if assistant_msg: prompt_parts.append(f"<start_of_turn>model\n{assistant_msg}")
prompt_parts.append(f"<start_of_turn>user\n{user_input}")
prompt_parts.append("<start_of_turn>model")
prompt = "\n".join(prompt_parts)
output = pipe(prompt, **generation_args)
full_text = output[0]["generated_text"]
ai_response = full_text.split("<start_of_turn>model")[-1].strip()
# Stream the AI response back to the chatbot
history[-1] = (user_input, "")
for character in ai_response:
history[-1] = (user_input, history[-1][1] + character)
yield history, history, None
except Exception as e:
error_message = f"An error occurred: {str(e)}"
history[-1] = (user_input, error_message)
print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}")
yield history, history, None
# --- UI MODIFICATION: Professional CSS for the chat interface ---
css = """
/* Make the main app container fill the screen height */
div.gradio-container { height: 100vh !important; }
/* Main chat area styling */
#chat-container { flex-grow: 1; overflow-y: auto; padding-bottom: 120px; }
/* Sticky footer for the input bar */
#footer-container {
position: fixed !important; bottom: 0; left: 0; width: 100%;
background-color: #e0f2fe !important; /* Light Sky Blue background */
border-top: 1px solid #bae6fd !important;
padding: 10px; z-index: 1000;
}
/* White, rounded textbox */
#user-textbox textarea {
background-color: #ffffff !important;
border: 1px solid #cbd5e1 !important;
border-radius: 8px !important;
}
/* Style the image upload button */
#image-upload-btn { border: 1px dashed #9ca3af !important; border-radius: 8px !important; }
"""
with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation", css=css) as demo:
conversation_history = gr.State([])
with gr.Column(elem_id="chat-container"):
chatbot_display = gr.Chatbot(label="Consultation", show_copy_button=True, bubble_full_width=False)
with gr.Column(elem_id="footer-container"):
with gr.Row():
image_input = gr.Image(elem_id="image-upload-btn", label="Image", type="pil", height=80, show_label=False, container=False, scale=1)
user_textbox = gr.Textbox(
elem_id="user-textbox",
label="Your Message",
placeholder="Type your message, or upload an image...",
show_label=False, scale=4, container=False
)
send_button = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_button = gr.Button("๐Ÿ—‘๏ธ Start New Consultation")
# This new function handles the full UX flow: instant feedback + streaming AI response
def submit_message_and_stream(user_input: str, user_image: Image.Image, history: list):
if not user_input.strip() and user_image is None:
# Do nothing if the input is empty
return history, history, None
# 1. Instantly add the user's message to the chat UI
history.append((user_input, None))
yield history, history, None
# 2. Start the generator to get the AI's response stream
for updated_history, new_state, cleared_image in handle_conversation_turn(user_input, user_image, history):
yield updated_history, new_state, cleared_image
# --- Event Handlers ---
send_button.click(
fn=submit_message_and_stream,
inputs=[user_textbox, image_input, conversation_history],
outputs=[chatbot_display, conversation_history, image_input],
).then(lambda: "", outputs=user_textbox) # Clear textbox after submission
user_textbox.submit(
fn=submit_message_and_stream,
inputs=[user_textbox, image_input, conversation_history],
outputs=[chatbot_display, conversation_history, image_input],
).then(lambda: "", outputs=user_textbox)
clear_button.click(
lambda: ([], [], None, ""),
outputs=[chatbot_display, conversation_history, image_input, user_textbox]
)
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.launch(debug=True)