Spaces:
Running
Running
| """ | |
| Internal Medicine Discharge Letter Error-Check — Backend | |
| Prospective study: AI-assisted error detection in ED discharge letters | |
| Flow: | |
| 1. Receive Croatian discharge letter from doctor | |
| 2. Translate to English (Gemini 3.1 Flash Lite) | |
| 3. Run concurrent error-detection analysis: | |
| - DeepSeek Reasoner (via DeepSeek API) | |
| - GPT-OSS-120B (via Groq) | |
| 4. Parse structured output and return errors + suggestions | |
| """ | |
| import os | |
| import json | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from google import genai | |
| from openai import OpenAI | |
| from groq import Groq | |
| load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), ".env")) | |
| # --------------------------------------------------------------------------- | |
| # API clients | |
| # --------------------------------------------------------------------------- | |
| def get_gemini_client() -> genai.Client: | |
| key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") | |
| return genai.Client(api_key=key) | |
| def get_deepseek_client() -> OpenAI: | |
| return OpenAI( | |
| api_key=os.environ.get("DEEPSEEK_API_KEY"), | |
| base_url="https://api.deepseek.com", | |
| ) | |
| def get_groq_client() -> Groq: | |
| return Groq(api_key=os.environ.get("GROQ_API_KEY_OSS")) | |
| DEEPSEEK_TIMEOUT_SECONDS = 120 | |
| DEEPSEEK_MAX_TOKENS = 8192 | |
| DEEPSEEK_MAX_ATTEMPTS = 2 | |
| DEEPSEEK_RETRY_SLEEP_SECONDS = 2 | |
| def _log_deepseek(event: str, **kwargs): | |
| parts = [f"{key}={value!r}" for key, value in kwargs.items()] | |
| suffix = f" | {' | '.join(parts)}" if parts else "" | |
| print(f"[DeepSeek] {event}{suffix}", flush=True) | |
| def _deepseek_response_meta(response) -> dict: | |
| choice = response.choices[0] | |
| message = choice.message | |
| content = message.content or "" | |
| reasoning = getattr(message, "reasoning_content", "") or "" | |
| return { | |
| "finish_reason": getattr(choice, "finish_reason", None), | |
| "content_len": len(content), | |
| "reasoning_len": len(reasoning), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Prompts | |
| # --------------------------------------------------------------------------- | |
| TRANSLATION_PROMPT = """You are a medical translator. Translate the following Croatian clinical discharge letter to English. | |
| Preserve ALL medical terminology, values, units, drug names, dosages, and clinical details exactly. | |
| Output ONLY the English translation, nothing else. | |
| Croatian text: | |
| {text}""" | |
| ERROR_CHECK_SYSTEM_PROMPT = """You are an expert internal medicine physician reviewing emergency department discharge letters for errors and quality issues. | |
| Your task: carefully analyze the discharge letter and identify up to 3 ERRORS and up to 2 IMPROVEMENT SUGGESTIONS. | |
| The goal is precision, not forcing findings. | |
| ERRORS are factual, clinical, or documentation mistakes present in the letter, such as: | |
| - Medication errors (wrong drug, wrong dose, drug interactions, contraindications) | |
| - Diagnostic errors (incorrect diagnosis given the findings, missed diagnosis) | |
| - Dosing errors (incorrect dose for patient weight/age/renal function) | |
| - Lab interpretation errors (misinterpreted lab values, missed abnormal results) | |
| - Documentation errors (inconsistencies, contradictions within the letter) | |
| - Omissions (critical missing information that should be documented) | |
| SUGGESTIONS are general quality improvements that are NOT necessarily errors, such as: | |
| - Documentation completeness improvements | |
| - Clinical workflow recommendations | |
| - Patient safety enhancements | |
| - Follow-up care suggestions | |
| For every suggestion you MUST: | |
| - Identify the specific part of the letter that could be improved | |
| - Quote the relevant original text (or note what is missing) | |
| - Provide the exact rewritten version or additional text you would use instead | |
| This makes every suggestion concrete and immediately usable rather than vague or generic. | |
| CRITICAL RULES: | |
| - Only report genuine errors you are confident about. Do NOT fabricate errors. | |
| - Do NOT force yourself to find 3 errors. | |
| - If you find fewer than 3 errors, report only what you find. | |
| - It is acceptable to find 0 errors. If no clear error is present, return "errors": []. | |
| - When uncertain, prefer returning no error rather than a speculative one. | |
| - You may still provide 0-2 useful improvement suggestions even when errors is empty. | |
| - Be specific: quote the relevant part of the letter for each error and suggestion. | |
| - Categorize each error and suggestion precisely. | |
| - For every suggestion, always include both the original quote and your exact suggested rewrite. | |
| You MUST respond in the following JSON format and NOTHING else: | |
| { | |
| "errors": [ | |
| { | |
| "description": "Clear description of the error", | |
| "category": "medication_error|diagnostic_error|dosing_error|documentation_error|lab_interpretation_error|contraindication|omission|other", | |
| "severity": "low|medium|high|critical", | |
| "quote": "Exact quote from the letter where the error appears" | |
| } | |
| ], | |
| "suggestions": [ | |
| { | |
| "description": "Clear description of the improvement suggestion", | |
| "category": "documentation_quality|clinical_workflow|patient_safety|completeness|other", | |
| "quote": "Exact quote from the letter (or 'N/A' if adding entirely new content)", | |
| "suggested_rewrite": "Exactly how you would have written it differently - the full improved text you recommend" | |
| } | |
| ], | |
| "summary": "One-sentence overall assessment of the discharge letter quality" | |
| } | |
| Valid zero-error example: | |
| { | |
| "errors": [], | |
| "suggestions": [ | |
| { | |
| "description": "Make the follow-up plan more explicit and actionable for the patient and primary care provider.", | |
| "category": "documentation_quality", | |
| "quote": "Follow up with primary care in 1 week.", | |
| "suggested_rewrite": "Please follow up with your primary care physician within 7 days for repeat labs and clinical reassessment. If you experience worsening shortness of breath, chest pain, or fever, return to the emergency department immediately or call the 24-hour advice line at (555) 123-4567." | |
| } | |
| ], | |
| "summary": "No clear clinical or documentation errors were identified, but the discharge letter could be improved with more specific follow-up instructions." | |
| }""" | |
| ERROR_CHECK_USER_PROMPT = """Analyze the following internal medicine emergency department discharge letter for errors and quality issues. | |
| DISCHARGE LETTER: | |
| {clinical_text} | |
| Respond with the JSON format specified in your instructions. | |
| Remember: | |
| - up to 3 errors | |
| - up to 2 suggestions | |
| - only report genuine errors | |
| - if no clear errors are present, return `"errors": []` and optionally provide suggestions""" | |
| # --------------------------------------------------------------------------- | |
| # Data classes | |
| # --------------------------------------------------------------------------- | |
| class ParsedError: | |
| description: str | |
| category: str | |
| severity: str | |
| quote: str | |
| class ParsedSuggestion: | |
| description: str | |
| category: str | |
| quote: str = "" | |
| suggested_rewrite: str = "" | |
| class ModelResult: | |
| model_name: str | |
| raw_response: str | |
| errors: list = field(default_factory=list) | |
| suggestions: list = field(default_factory=list) | |
| summary: str = "" | |
| success: bool = True | |
| error_message: Optional[str] = None | |
| latency_seconds: float = 0.0 | |
| class AnalysisResponse: | |
| original_text: str | |
| translated_text: str | |
| model_a_result: ModelResult | |
| model_b_result: ModelResult | |
| translation_latency: float = 0.0 | |
| # --------------------------------------------------------------------------- | |
| # Translation | |
| # --------------------------------------------------------------------------- | |
| def translate_to_english(text: str) -> str: | |
| client = get_gemini_client() | |
| response = client.models.generate_content( | |
| model="gemini-3.1-flash-lite-preview", | |
| contents=TRANSLATION_PROMPT.format(text=text), | |
| ) | |
| return response.text | |
| # --------------------------------------------------------------------------- | |
| # JSON parsing helper | |
| # --------------------------------------------------------------------------- | |
| def parse_model_json(raw: str) -> dict: | |
| """Extract JSON from model response, handling markdown code fences.""" | |
| text = raw.strip() | |
| if text.startswith("```"): | |
| first_newline = text.index("\n") | |
| last_fence = text.rfind("```") | |
| text = text[first_newline + 1 : last_fence].strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| start = text.find("{") | |
| end = text.rfind("}") + 1 | |
| if start != -1 and end > start: | |
| return json.loads(text[start:end]) | |
| raise | |
| # --------------------------------------------------------------------------- | |
| # Model calls | |
| # --------------------------------------------------------------------------- | |
| def _parse_to_result(model_label: str, raw: str, latency: float) -> ModelResult: | |
| parsed = parse_model_json(raw) | |
| errors = [ | |
| ParsedError( | |
| description=e.get("description", ""), | |
| category=e.get("category", "other"), | |
| severity=e.get("severity", "medium"), | |
| quote=e.get("quote", ""), | |
| ) | |
| for e in parsed.get("errors", []) | |
| ] | |
| suggestions = [ | |
| ParsedSuggestion( | |
| description=s.get("description", ""), | |
| category=s.get("category", "other"), | |
| quote=s.get("quote", ""), | |
| suggested_rewrite=s.get("suggested_rewrite", ""), | |
| ) | |
| for s in parsed.get("suggestions", []) | |
| ] | |
| return ModelResult( | |
| model_name=model_label, | |
| raw_response=raw, | |
| errors=errors, | |
| suggestions=suggestions, | |
| summary=parsed.get("summary", ""), | |
| success=True, | |
| latency_seconds=round(latency, 2), | |
| ) | |
| def call_model_a(clinical_text: str) -> ModelResult: | |
| """DeepSeek Reasoner via DeepSeek API.""" | |
| start = time.time() | |
| client = get_deepseek_client() | |
| last_error = None | |
| for attempt in range(1, DEEPSEEK_MAX_ATTEMPTS + 1): | |
| attempt_start = time.time() | |
| try: | |
| _log_deepseek("attempt_start", attempt=attempt) | |
| response = client.chat.completions.create( | |
| model="deepseek-reasoner", | |
| messages=[ | |
| {"role": "system", "content": ERROR_CHECK_SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": ERROR_CHECK_USER_PROMPT.format( | |
| clinical_text=clinical_text | |
| ), | |
| }, | |
| ], | |
| max_tokens=DEEPSEEK_MAX_TOKENS, | |
| timeout=DEEPSEEK_TIMEOUT_SECONDS, | |
| ) | |
| meta = _deepseek_response_meta(response) | |
| _log_deepseek("attempt_response", attempt=attempt, **meta) | |
| raw = response.choices[0].message.content or "" | |
| if not raw.strip(): | |
| raise ValueError( | |
| "DeepSeek returned an empty response body " | |
| f"(finish_reason={meta['finish_reason']}, " | |
| f"reasoning_len={meta['reasoning_len']})." | |
| ) | |
| result = _parse_to_result("DeepSeek Reasoner", raw, time.time() - start) | |
| _log_deepseek( | |
| "attempt_success", | |
| attempt=attempt, | |
| elapsed_total=round(time.time() - start, 2), | |
| errors_found=len(result.errors), | |
| suggestions_found=len(result.suggestions), | |
| ) | |
| return result | |
| except Exception as exc: | |
| last_error = exc | |
| _log_deepseek( | |
| "attempt_failed", | |
| attempt=attempt, | |
| elapsed_attempt=round(time.time() - attempt_start, 2), | |
| error_type=type(exc).__name__, | |
| error=str(exc), | |
| ) | |
| if attempt < DEEPSEEK_MAX_ATTEMPTS: | |
| time.sleep(DEEPSEEK_RETRY_SLEEP_SECONDS) | |
| return ModelResult( | |
| model_name="DeepSeek Reasoner", | |
| raw_response="", | |
| success=False, | |
| error_message=( | |
| f"{last_error} after {DEEPSEEK_MAX_ATTEMPTS} attempts" | |
| if last_error | |
| else "DeepSeek failed for an unknown reason." | |
| ), | |
| latency_seconds=round(time.time() - start, 2), | |
| ) | |
| def call_model_b(clinical_text: str) -> ModelResult: | |
| """GPT-OSS-120B via Groq.""" | |
| start = time.time() | |
| try: | |
| client = get_groq_client() | |
| response = client.chat.completions.create( | |
| model="openai/gpt-oss-120b", | |
| messages=[ | |
| {"role": "system", "content": ERROR_CHECK_SYSTEM_PROMPT}, | |
| {"role": "user", "content": ERROR_CHECK_USER_PROMPT.format(clinical_text=clinical_text)}, | |
| ], | |
| temperature=0.2, | |
| max_tokens=4096, | |
| ) | |
| raw = response.choices[0].message.content | |
| return _parse_to_result("GPT-OSS-120B", raw, time.time() - start) | |
| except Exception as exc: | |
| return ModelResult( | |
| model_name="GPT-OSS-120B", | |
| raw_response="", | |
| success=False, | |
| error_message=str(exc), | |
| latency_seconds=round(time.time() - start, 2), | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Main pipeline | |
| # --------------------------------------------------------------------------- | |
| def run_error_check(croatian_text: str) -> AnalysisResponse: | |
| """Full pipeline: translate, then run both models concurrently.""" | |
| t0 = time.time() | |
| english_text = translate_to_english(croatian_text) | |
| translation_latency = round(time.time() - t0, 2) | |
| with ThreadPoolExecutor(max_workers=2) as pool: | |
| future_a = pool.submit(call_model_a, english_text) | |
| future_b = pool.submit(call_model_b, english_text) | |
| result_a = future_a.result() | |
| result_b = future_b.result() | |
| return AnalysisResponse( | |
| original_text=croatian_text, | |
| translated_text=english_text, | |
| model_a_result=result_a, | |
| model_b_result=result_b, | |
| translation_latency=translation_latency, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # CLI test | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import sys, io | |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") | |
| sample = """Bolesnik 68 godina, dolazi zbog bolova u prsištu. | |
| Dijagnoza: STEMI prednje stijenke. | |
| Terapija: Aspirin 100mg, Klopidogrel 75mg, Ramipril 5mg, Atorvastatin 40mg. | |
| Preporučen kontrolni pregled za 7 dana.""" | |
| print("=" * 60) | |
| print("ERROR CHECK TEST") | |
| print("=" * 60) | |
| result = run_error_check(sample) | |
| print(f"\nTranslation ({result.translation_latency}s):") | |
| print(result.translated_text) | |
| for r in [result.model_a_result, result.model_b_result]: | |
| print(f"\n{'=' * 60}") | |
| print(f"{r.model_name} ({r.latency_seconds}s):") | |
| if r.success: | |
| print(f"Summary: {r.summary}") | |
| for i, e in enumerate(r.errors, 1): | |
| print(f" Error {i}: [{e.category}/{e.severity}] {e.description}") | |
| for i, s in enumerate(r.suggestions, 1): | |
| print(f" Suggestion {i}: [{s.category}] {s.description}") | |
| else: | |
| print(f"ERROR: {r.error_message}") | |