Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from utils.model_configuration_utils import select_best_model, ensure_model | |
| from services.llm import build_llm | |
| from utils.voice_input_utils import update_live_transcription, format_response_for_user | |
| from services.embeddings import configure_embeddings | |
| from services.indexing import create_symptom_index | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import json | |
| import re | |
| # ========== Model setup ========== | |
| MODEL_NAME, REPO_ID = select_best_model() | |
| model_path = ensure_model() | |
| print(f"Using model: {MODEL_NAME} from {REPO_ID}", flush=True) | |
| print(f"Model path: {model_path}", flush=True) | |
| # ========== LLM initialization ========== | |
| print("\n<<< before build_llm: ", flush=True) | |
| llm = build_llm(model_path) | |
| print(">>> after build_llm", flush=True) | |
| # ========== Embeddings & index setup ========== | |
| print("\n<<< before configure_embeddings: ", flush=True) | |
| configure_embeddings() | |
| print(">>> after configure_embeddings", flush=True) | |
| print("Embeddings configured and ready", flush=True) | |
| print("\n<<< before create_symptom_index: ", flush=True) | |
| symptom_index = create_symptom_index() | |
| print(">>> after create_symptom_index", flush=True) | |
| print("Symptom index built successfully. Ready for queries.", flush=True) | |
| # ========== Prompt template ========== | |
| SYSTEM_PROMPT = ( | |
| "You are a medical assistant helping a user narrow down to the most likely ICD-10 code. " | |
| "At each turn, either ask one focused clarifying question (e.g. 'Is your cough dry or productive?') " | |
| "or if you have enough information, provide a final JSON with fields: {\"diagnoses\": [...], " | |
| "\"confidences\": [...], \"follow_up\": [...]}. Output must be valid JSON with no trailing commas. Your output MUST be strictly valid JSON, starting with '{' and ending with '}', with no extra text outside the JSON." | |
| ) | |
| # ========== Generator handler ========== | |
| def on_submit(symptoms_text, history): | |
| log = [] | |
| print("on_submit called", flush=True) | |
| # Placeholder | |
| msg = "π Received input" | |
| log.append(msg) | |
| print(msg, flush=True) | |
| history = history + [{"role": "assistant", "content": "Processing your request..."}] | |
| yield history, None, "\n".join(log) | |
| # Validate | |
| if not symptoms_text.strip(): | |
| msg = "β No symptoms provided" | |
| log.append(msg) | |
| print(msg, flush=True) | |
| result = {"error": "No input provided", "diagnoses": [], "confidences": [], "follow_up": []} | |
| yield history, result, "\n".join(log) | |
| return | |
| # Clean input | |
| cleaned = symptoms_text.strip() | |
| msg = f"π Cleaned text: {cleaned}" | |
| log.append(msg) | |
| print(msg, flush=True) | |
| yield history, None, "\n".join(log) | |
| # Semantic query | |
| msg = "π Running semantic query" | |
| log.append(msg) | |
| print(msg, flush=True) | |
| yield history, None, "\n".join(log) | |
| qe = symptom_index.as_query_engine(retriever_kwargs={"similarity_top_k": 5}) | |
| hits = qe.query(cleaned) | |
| msg = f"π Retrieved context entries" | |
| log.append(msg) | |
| print(msg, flush=True) | |
| history = history + [{"role": "assistant", "content": msg}] | |
| yield history, None, "\n".join(log) | |
| # Build prompt with minimal context | |
| context_list = [] | |
| for node in getattr(hits, 'source_nodes', [])[:3]: | |
| md = getattr(node, 'metadata', {}) or {} | |
| context_list.append(f"{md.get('code','')}: {md.get('description','')}") | |
| context_text = "\n".join(context_list) | |
| prompt = ( | |
| f"{SYSTEM_PROMPT}\n\n" | |
| f"User symptoms: '{cleaned}'\n\n" | |
| f"Relevant ICD-10 context:\n{context_text}\n\n" | |
| "Respond with valid JSON." | |
| ) | |
| msg = "βοΈ Prompt built" | |
| log.append(msg) | |
| print(msg, flush=True) | |
| yield history, None, "\n".join(log) | |
| # Call LLM | |
| # Use constrained decoding to enforce JSON-only output | |
| response = llm.complete(prompt, stop=["}"]) # stop after closing brace | |
| raw = getattr(response, 'text', str(response)) | |
| # Truncate extra content after the final JSON object | |
| if not raw.strip().endswith('}'): | |
| end_idx = raw.rfind('}') | |
| if end_idx != -1: | |
| raw = raw[:end_idx+1] | |
| msg = "π‘ Raw LLM response received" | |
| log.append(msg) | |
| print(msg, flush=True) | |
| yield history, None, "\n".join(log) | |
| # Parse JSON | |
| cleaned_raw = re.sub(r",\s*([}\]])", r"\1", raw) | |
| try: | |
| parsed = json.loads(cleaned_raw) | |
| msg = "β JSON parsed" | |
| except Exception as e: | |
| msg = f"β JSON parse error: {e}" | |
| parsed = {"error": str(e), "raw": raw} | |
| log.append(msg) | |
| print(msg, flush=True) | |
| yield history, parsed, "\n".join(log) | |
| # Final assistant message | |
| assistant_msg = format_response_for_user(parsed) | |
| history = history + [{"role": "assistant", "content": assistant_msg}] | |
| msg = "β Final response appended" | |
| log.append(msg) | |
| print(msg, flush=True) | |
| yield history, parsed, "\n".join(log) | |
| # ========== Gradio UI ========== | |
| with gr.Blocks(theme="default") as demo: | |
| gr.Markdown(""" | |
| # π₯ Medical Symptom to ICD-10 Code Assistant | |
| ## Describe symptoms by typing or speaking. | |
| Debug log updates live below. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| label="Type your symptoms", | |
| placeholder="I'm feeling under the weather...", | |
| lines=3 | |
| ) | |
| microphone = gr.Audio( | |
| sources=["microphone"], | |
| streaming=True, | |
| type="numpy", | |
| label="Or speak your symptoms..." | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| chatbot = gr.Chatbot( | |
| label="Medical Consultation", | |
| height=500, | |
| type="messages" | |
| ) | |
| json_output = gr.JSON(label="Diagnosis JSON") | |
| debug_box = gr.Textbox(label="Debug log", lines=10) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("API Keys (optional)", open=False): | |
| api_key = gr.Textbox(label="OpenAI Key", type="password") | |
| model_selector = gr.Dropdown( | |
| choices=["OpenAI","Modal","Anthropic","MistralAI","Nebius","Hyperbolic","SambaNova"], | |
| value="OpenAI", | |
| label="Model Provider" | |
| ) | |
| temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature") | |
| # Bindings | |
| submit_btn.click( | |
| fn=on_submit, | |
| inputs=[text_input, chatbot], | |
| outputs=[chatbot, json_output, debug_box], | |
| queue=True | |
| ) | |
| clear_btn.click( | |
| lambda: (None, {}, ""), | |
| None, | |
| [chatbot, json_output, debug_box], | |
| queue=False | |
| ) | |
| microphone.stream( | |
| fn=update_live_transcription, | |
| inputs=[microphone], | |
| outputs=[text_input], | |
| show_progress=False, | |
| queue=True | |
| ) | |
| # --- About the Creator --- | |
| gr.Markdown(""" | |
| --- | |
| ### π About the Creator | |
| Hi! I'm Graham Paasch, an experienced technology professional! | |
| π₯ **Check out my YouTube channel** for more tech content: | |
| [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ) | |
| πΌ **Looking for a skilled developer?** | |
| I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/) | |
| β If you found this tool helpful, please consider: | |
| - Subscribing to my YouTube channel | |
| - Connecting on LinkedIn | |
| - Sharing this tool with others in healthcare tech | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_api=True) | |