Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import os | |
| import json | |
| import logging | |
| import re | |
| from typing import Dict, Any | |
| from pathlib import Path | |
| from unstructured.partition.pdf import partition_pdf | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from dotenv import load_dotenv | |
| from bloatectomy import bloatectomy | |
| from werkzeug.utils import secure_filename | |
| from langchain_groq import ChatGroq | |
| from typing_extensions import TypedDict, NotRequired | |
| # | |
| # --- Logging --- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger("patient-assistant") | |
| # --- Load environment --- | |
| load_dotenv() | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| logger.error("GROQ_API_KEY not set in environment") | |
| exit(1) | |
| # --- Flask app setup --- | |
| BASE_DIR = Path(__file__).resolve().parent | |
| REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", str(BASE_DIR / "reports"))) | |
| static_folder = BASE_DIR / "static" | |
| app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static") | |
| CORS(app) | |
| # Ensure the reports directory exists | |
| os.makedirs(REPORTS_ROOT, exist_ok=True) | |
| # --- LLM setup --- | |
| llm = ChatGroq( | |
| model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"), | |
| temperature=0.0, | |
| max_tokens=1024, | |
| api_key=GROQ_API_KEY, | |
| ) | |
| def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str: | |
| """Helper function to clean up text using the bloatectomy library.""" | |
| try: | |
| b = bloatectomy(text, style=style, output="html") | |
| tokens = getattr(b, "tokens", None) | |
| if not tokens: | |
| return text | |
| return "\n".join(tokens) | |
| except Exception: | |
| logger.exception("Bloatectomy cleaning failed; returning original text") | |
| return text | |
| # --- Agent prompt instructions --- | |
| PATIENT_ASSISTANT_PROMPT = """ | |
| You are a patient assistant helping to analyze medical records and reports. Your primary task is to get the patient ID (PID) from the user at the start of the conversation. | |
| Once you have the PID, you will be provided with a summary of the patient's medical reports. Use this information, along with the conversation history, to provide a comprehensive response. | |
| Your tasks include: | |
| - **First, ask for the patient ID.** Do not proceed with any other task until you have the PID. | |
| - Analyzing medical records and reports to detect anomalies, redundant tests, or misleading treatments. | |
| - Suggesting preventive care based on the overall patient health history. | |
| - Optimizing healthcare costs by comparing past visits and treatments. | |
| - Offering personalized lifestyle recommendations. | |
| - Generating a natural, helpful reply to the user. | |
| STRICT OUTPUT FORMAT (JSON ONLY): | |
| Return a single JSON object with the following keys: | |
| - assistant_reply: string // a natural language reply to the user (short, helpful, always present) | |
| - patientDetails: object // keys may include name, problem, pid (patient ID), city, contact (update if user shared info) | |
| - conversationSummary: string (optional) // short summary of conversation + relevant patient docs | |
| Rules: | |
| - ALWAYS include `assistant_reply` as a non-empty string. | |
| - Do NOT produce any text outside the JSON object. | |
| - Be concise in `assistant_reply`. If you need more details, ask a targeted follow-up question. | |
| - Do not make up information that is not present in the provided medical reports or conversation history. | |
| """ | |
| # --- JSON extraction helper --- | |
| def extract_json_from_llm_response(raw_response: str) -> dict: | |
| """Safely extracts a JSON object from a string that might contain extra text or markdown.""" | |
| default = { | |
| "assistant_reply": "I'm sorry — I couldn't understand that. Could you please rephrase?", | |
| "patientDetails": {}, | |
| "conversationSummary": "", | |
| } | |
| if not raw_response or not isinstance(raw_response, str): | |
| return default | |
| # Find the JSON object, ignoring any markdown code fences | |
| m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response) | |
| json_string = m.group(1).strip() if m else raw_response | |
| # Find the first opening brace and the last closing brace | |
| first = json_string.find('{') | |
| last = json_string.rfind('}') | |
| if first == -1 or last == -1 or first >= last: | |
| try: | |
| return json.loads(json_string) | |
| except Exception: | |
| logger.warning("Could not locate JSON braces in LLM output. Falling back to default.") | |
| return default | |
| candidate = json_string[first:last+1] | |
| # Remove trailing commas that might cause parsing issues | |
| candidate = re.sub(r',\s*(?=[}\]])', '', candidate) | |
| try: | |
| parsed = json.loads(candidate) | |
| except Exception as e: | |
| logger.warning("Failed to parse JSON from LLM output: %s", e) | |
| return default | |
| # Basic validation of the parsed JSON | |
| if isinstance(parsed, dict) and "assistant_reply" in parsed and isinstance(parsed["assistant_reply"], str) and parsed["assistant_reply"].strip(): | |
| parsed.setdefault("patientDetails", {}) | |
| parsed.setdefault("conversationSummary", "") | |
| return parsed | |
| else: | |
| logger.warning("Parsed JSON missing 'assistant_reply' or invalid format. Returning default.") | |
| return default | |
| # --- Flask routes --- | |
| def serve_frontend(): | |
| """Serves the frontend HTML file.""" | |
| try: | |
| return app.send_static_file("frontend.html") | |
| except Exception: | |
| return "<h3>frontend2.html not found in static/ — please add your frontend2.html there.</h3>", 404 | |
| def upload_report(): | |
| """Handles the upload of a new PDF report for a specific patient.""" | |
| if 'report' not in request.files: | |
| return jsonify({"error": "No file part in the request"}), 400 | |
| file = request.files['report'] | |
| patient_id = request.form.get("patient_id") | |
| if file.filename == '' or not patient_id: | |
| return jsonify({"error": "No selected file or patient ID"}), 400 | |
| if file: | |
| filename = secure_filename(file.filename) | |
| patient_folder = REPORTS_ROOT / f"p_{patient_id}" | |
| os.makedirs(patient_folder, exist_ok=True) | |
| file_path = patient_folder / filename | |
| file.save(file_path) | |
| return jsonify({"message": f"File '{filename}' uploaded successfully for patient ID '{patient_id}'."}), 200 | |
| def chat(): | |
| """Handles the chat conversation with the assistant.""" | |
| data = request.get_json(force=True) | |
| if not isinstance(data, dict): | |
| return jsonify({"error": "invalid request body"}), 400 | |
| chat_history = data.get("chat_history") or [] | |
| patient_state = data.get("patient_state") or {} | |
| patient_id = patient_state.get("patientDetails", {}).get("pid") | |
| # --- Prepare the state for the LLM --- | |
| state = patient_state.copy() | |
| state["lastUserMessage"] = "" | |
| if chat_history: | |
| # Find the last user message | |
| for msg in reversed(chat_history): | |
| if msg.get("role") == "user" and msg.get("content"): | |
| state["lastUserMessage"] = msg["content"] | |
| break | |
| combined_text_parts = [] | |
| # If a PID is not yet known, prompt the agent to ask for it. | |
| if not patient_id: | |
| # A simple prompt to get the agent to ask for the PID. | |
| user_prompt = "Hello. I need to get the patient's ID to proceed." | |
| # Check if the user's last message contains a possible number for the PID | |
| last_message = state.get("lastUserMessage", "") | |
| # A very basic check to see if the user provided a number | |
| if re.search(r'\d+', last_message): | |
| inferred_pid = re.search(r'(\d+)', last_message).group(1) | |
| state["patientDetails"] = {"pid": inferred_pid} | |
| patient_id = inferred_pid | |
| # Now that we have a PID, let the agent know to process the reports. | |
| user_prompt = f"The user provided a patient ID: {inferred_pid}. Please access their reports and respond." | |
| else: | |
| # If no PID is found, the agent should ask for it. | |
| user_prompt = "The patient has not provided a patient ID. Please ask them to provide it to proceed." | |
| # If a PID is known, load the patient reports. | |
| if patient_id: | |
| patient_folder = REPORTS_ROOT / f"p_{patient_id}" | |
| if patient_folder.exists() and patient_folder.is_dir(): | |
| for fname in sorted(os.listdir(patient_folder)): | |
| file_path = patient_folder / fname | |
| page_text = "" | |
| if partition_pdf is not None and str(file_path).lower().endswith('.pdf'): | |
| try: | |
| elements = partition_pdf(filename=str(file_path)) | |
| page_text = "\n".join([el.text for el in elements if hasattr(el, 'text') and el.text]) | |
| except Exception: | |
| logger.exception("Failed to parse PDF %s", file_path) | |
| else: | |
| try: | |
| page_text = file_path.read_text(encoding='utf-8', errors='ignore') | |
| except Exception: | |
| page_text = "" | |
| if page_text: | |
| cleaned = clean_notes_with_bloatectomy(page_text, style="remov") | |
| if cleaned: | |
| combined_text_parts.append(cleaned) | |
| # Update the conversation summary with the parsed documents | |
| base_summary = state.get("conversationSummary", "") or "" | |
| docs_summary = "\n\n".join(combined_text_parts) | |
| if docs_summary: | |
| state["conversationSummary"] = (base_summary + "\n\n" + docs_summary).strip() | |
| else: | |
| state["conversationSummary"] = base_summary | |
| # --- Direct LLM Invocation --- | |
| user_prompt = f""" | |
| Current patientDetails: {json.dumps(state.get("patientDetails", {}))} | |
| Current conversationSummary: {state.get("conversationSummary", "")} | |
| Last user message: {state.get("lastUserMessage", "")} | |
| Return ONLY valid JSON with keys: assistant_reply, patientDetails, conversationSummary. | |
| """ | |
| messages = [ | |
| {"role": "system", "content": PATIENT_ASSISTANT_PROMPT}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| try: | |
| logger.info("Invoking LLM with prepared state and prompt...") | |
| llm_response = llm.invoke(messages) | |
| raw_response = "" | |
| if hasattr(llm_response, "content"): | |
| raw_response = llm_response.content | |
| else: | |
| raw_response = str(llm_response) | |
| logger.info(f"Raw LLM response: {raw_response}") | |
| parsed_result = extract_json_from_llm_response(raw_response) | |
| except Exception as e: | |
| logger.exception("LLM invocation failed") | |
| return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500 | |
| updated_state = parsed_result or {} | |
| assistant_reply = updated_state.get("assistant_reply") | |
| if not assistant_reply or not isinstance(assistant_reply, str) or not assistant_reply.strip(): | |
| # Fallback to a polite message if the LLM response is invalid or empty | |
| assistant_reply = "I'm here to help — could you tell me more about your symptoms?" | |
| response_payload = { | |
| "assistant_reply": assistant_reply, | |
| "updated_state": updated_state, | |
| } | |
| return jsonify(response_payload) | |
| def upload_reports(): | |
| """ | |
| Upload one or more files for a patient. | |
| Expects multipart/form-data with: | |
| - patient_id (form field) | |
| - files (one or multiple files; use the same field name 'files' for each file) | |
| Example curl: | |
| curl -X POST http://localhost:7860/upload_reports \ | |
| -F "patient_id=12345" \ | |
| -F "files[]=@/path/to/report1.pdf" \ | |
| -F "files[]=@/path/to/report2.pdf" | |
| """ | |
| try: | |
| # patient id can be in form or args (for convenience) | |
| patient_id = request.form.get("patient_id") or request.args.get("patient_id") | |
| if not patient_id: | |
| return jsonify({"error": "patient_id form field required"}), 400 | |
| # get uploaded files (support both files and files[] naming) | |
| uploaded_files = request.files.getlist("files") | |
| if not uploaded_files: | |
| # fallback: single file under name 'file' | |
| single = request.files.get("file") | |
| if single: | |
| uploaded_files = [single] | |
| if not uploaded_files: | |
| return jsonify({"error": "no files uploaded (use form field 'files')"}), 400 | |
| # create patient folder under REPORTS_ROOT/<patient_id> | |
| patient_folder = REPORTS_ROOT / str(patient_id) | |
| patient_folder.mkdir(parents=True, exist_ok=True) | |
| saved = [] | |
| skipped = [] | |
| for file_storage in uploaded_files: | |
| orig_name = getattr(file_storage, "filename", "") or "" | |
| filename = secure_filename(orig_name) | |
| if not filename: | |
| skipped.append({"filename": orig_name, "reason": "invalid filename"}) | |
| continue | |
| # extension check | |
| ext = filename.rsplit(".", 1)[1].lower() if "." in filename else "" | |
| if ext not in ALLOWED_EXTENSIONS: | |
| skipped.append({"filename": filename, "reason": f"extension '{ext}' not allowed"}) | |
| continue | |
| # avoid overwriting: if collision, add numeric suffix | |
| dest = patient_folder / filename | |
| if dest.exists(): | |
| base, dot, extension = filename.rpartition(".") | |
| # if no base (e.g. ".bashrc") fallback | |
| base = base or filename | |
| i = 1 | |
| while True: | |
| candidate = f"{base}__{i}.{extension}" if extension else f"{base}__{i}" | |
| dest = patient_folder / candidate | |
| if not dest.exists(): | |
| filename = candidate | |
| break | |
| i += 1 | |
| try: | |
| file_storage.save(str(dest)) | |
| saved.append(filename) | |
| except Exception as e: | |
| logger.exception("Failed to save uploaded file %s: %s", filename, e) | |
| skipped.append({"filename": filename, "reason": f"save failed: {e}"}) | |
| return jsonify({ | |
| "patient_id": str(patient_id), | |
| "saved": saved, | |
| "skipped": skipped, | |
| "patient_folder": str(patient_folder) | |
| }), 200 | |
| except Exception as exc: | |
| logger.exception("Upload failed: %s", exc) | |
| return jsonify({"error": "upload failed", "detail": str(exc)}), 500 | |
| def ping(): | |
| return jsonify({"status": "ok"}) | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 5000)) | |
| app.run(host="0.0.0.0", port=port, debug=True) |