Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import shutil | |
| import re | |
| import gc | |
| import time | |
| from datetime import datetime | |
| from typing import List, Tuple, Dict, Union, Optional | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import pandas as pd | |
| import pdfplumber | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from fpdf import FPDF | |
| import unicodedata | |
| import uvicorn | |
| # === Configuration === | |
| persistent_dir = "/data/hf_cache" | |
| model_cache_dir = os.path.join(persistent_dir, "txagent_models") | |
| tool_cache_dir = os.path.join(persistent_dir, "tool_cache") | |
| file_cache_dir = os.path.join(persistent_dir, "cache") | |
| report_dir = os.path.join(persistent_dir, "reports") | |
| for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: | |
| os.makedirs(d, exist_ok=True) | |
| os.environ["HF_HOME"] = model_cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = model_cache_dir | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| src_path = os.path.abspath(os.path.join(current_dir, "src")) | |
| sys.path.insert(0, src_path) | |
| from txagent.txagent import TxAgent | |
| MAX_MODEL_TOKENS = 131072 | |
| MAX_NEW_TOKENS = 4096 | |
| MAX_CHUNK_TOKENS = 8192 | |
| BATCH_SIZE = 1 | |
| PROMPT_OVERHEAD = 300 | |
| SAFE_SLEEP = 0.5 | |
| app = FastAPI(title="Clinical Patient Support System API", | |
| description="API for analyzing and summarizing unstructured medical files", | |
| version="1.0.0") | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize agent at startup | |
| agent = None | |
| async def startup_event(): | |
| global agent | |
| agent = init_agent() | |
| def estimate_tokens(text: str) -> int: | |
| return len(text) // 4 + 1 | |
| def clean_response(text: str) -> str: | |
| text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) | |
| text = re.sub(r"\n{3,}", "\n\n", text) | |
| return text.strip() | |
| def remove_duplicate_paragraphs(text: str) -> str: | |
| paragraphs = text.strip().split("\n\n") | |
| seen = set() | |
| unique_paragraphs = [] | |
| for p in paragraphs: | |
| clean_p = p.strip() | |
| if clean_p and clean_p not in seen: | |
| unique_paragraphs.append(clean_p) | |
| seen.add(clean_p) | |
| return "\n\n".join(unique_paragraphs) | |
| def extract_text_from_excel(path: str) -> str: | |
| all_text = [] | |
| xls = pd.ExcelFile(path) | |
| for sheet_name in xls.sheet_names: | |
| try: | |
| df = xls.parse(sheet_name).astype(str).fillna("") | |
| except Exception: | |
| continue | |
| for _, row in df.iterrows(): | |
| non_empty = [cell.strip() for cell in row if cell.strip()] | |
| if len(non_empty) >= 2: | |
| text_line = " | ".join(non_empty) | |
| if len(text_line) > 15: | |
| all_text.append(f"[{sheet_name}] {text_line}") | |
| return "\n".join(all_text) | |
| def extract_text_from_csv(path: str) -> str: | |
| all_text = [] | |
| try: | |
| df = pd.read_csv(path).astype(str).fillna("") | |
| except Exception: | |
| return "" | |
| for _, row in df.iterrows(): | |
| non_empty = [cell.strip() for cell in row if cell.strip()] | |
| if len(non_empty) >= 2: | |
| text_line = " | ".join(non_empty) | |
| if len(text_line) > 15: | |
| all_text.append(text_line) | |
| return "\n".join(all_text) | |
| def extract_text_from_pdf(path: str) -> str: | |
| import logging | |
| logging.getLogger("pdfminer").setLevel(logging.ERROR) | |
| all_text = [] | |
| try: | |
| with pdfplumber.open(path) as pdf: | |
| for page in pdf.pages: | |
| text = page.extract_text() | |
| if text: | |
| all_text.append(text.strip()) | |
| except Exception: | |
| return "" | |
| return "\n".join(all_text) | |
| def extract_text(file_path: str) -> str: | |
| if file_path.endswith(".xlsx"): | |
| return extract_text_from_excel(file_path) | |
| elif file_path.endswith(".csv"): | |
| return extract_text_from_csv(file_path) | |
| elif file_path.endswith(".pdf"): | |
| return extract_text_from_pdf(file_path) | |
| else: | |
| return "" | |
| def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]: | |
| effective_limit = max_tokens - PROMPT_OVERHEAD | |
| chunks, current, current_tokens = [], [], 0 | |
| for line in text.split("\n"): | |
| tokens = estimate_tokens(line) | |
| if current_tokens + tokens > effective_limit: | |
| if current: | |
| chunks.append("\n".join(current)) | |
| current, current_tokens = [line], tokens | |
| else: | |
| current.append(line) | |
| current_tokens += tokens | |
| if current: | |
| chunks.append("\n".join(current)) | |
| return chunks | |
| def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[str]]: | |
| return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)] | |
| def build_prompt(chunk: str) -> str: | |
| return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning.""" | |
| def init_agent() -> TxAgent: | |
| tool_path = os.path.join(tool_cache_dir, "new_tool.json") | |
| if not os.path.exists(tool_path): | |
| shutil.copy(os.path.abspath("data/new_tool.json"), tool_path) | |
| agent = TxAgent( | |
| model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
| rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
| tool_files_dict={"new_tool": tool_path}, | |
| force_finish=True, | |
| enable_checker=True, | |
| step_rag_num=4, | |
| seed=100 | |
| ) | |
| agent.init_model() | |
| return agent | |
| def analyze_batches(agent, batches: List[List[str]]) -> List[str]: | |
| results = [] | |
| for batch in batches: | |
| prompt = "\n\n".join(build_prompt(chunk) for chunk in batch) | |
| try: | |
| batch_response = "" | |
| for r in agent.run_gradio_chat( | |
| message=prompt, | |
| history=[], | |
| temperature=0.0, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| max_token=MAX_MODEL_TOKENS, | |
| call_agent=False, | |
| conversation=[] | |
| ): | |
| if isinstance(r, str): | |
| batch_response += r | |
| elif isinstance(r, list): | |
| for m in r: | |
| if hasattr(m, "content"): | |
| batch_response += m.content | |
| elif hasattr(r, "content"): | |
| batch_response += r.content | |
| results.append(clean_response(batch_response)) | |
| time.sleep(SAFE_SLEEP) | |
| except Exception as e: | |
| results.append(f"❌ Batch failed: {str(e)}") | |
| time.sleep(SAFE_SLEEP * 2) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return results | |
| def generate_final_summary(agent, combined: str) -> str: | |
| combined = remove_duplicate_paragraphs(combined) | |
| final_prompt = f""" | |
| You are an expert clinical summarizer. Analyze the following summaries carefully and generate a **single final concise structured medical report**, avoiding any repetition or redundancy. | |
| Summaries: | |
| {combined} | |
| Respond with: | |
| - Diagnostic Patterns | |
| - Medication Issues | |
| - Missed Opportunities | |
| - Inconsistencies | |
| - Follow-up Recommendations | |
| Avoid repeating the same points multiple times. | |
| """.strip() | |
| final_response = "" | |
| for r in agent.run_gradio_chat( | |
| message=final_prompt, | |
| history=[], | |
| temperature=0.0, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| max_token=MAX_MODEL_TOKENS, | |
| call_agent=False, | |
| conversation=[] | |
| ): | |
| if isinstance(r, str): | |
| final_response += r | |
| elif isinstance(r, list): | |
| for m in r: | |
| if hasattr(m, "content"): | |
| final_response += m.content | |
| elif hasattr(r, "content"): | |
| final_response += r.content | |
| final_response = clean_response(final_response) | |
| final_response = remove_duplicate_paragraphs(final_response) | |
| return final_response | |
| def remove_non_ascii(text): | |
| return ''.join(c for c in text if ord(c) < 256) | |
| def generate_pdf_report_with_charts(summary: str, report_path: str, detailed_batches: List[str] = None): | |
| chart_dir = os.path.join(os.path.dirname(report_path), "charts") | |
| os.makedirs(chart_dir, exist_ok=True) | |
| # Prepare static data | |
| categories = ['Diagnostics', 'Medications', 'Missed', 'Inconsistencies', 'Follow-up'] | |
| values = [4, 2, 3, 1, 5] | |
| # === Static Charts === | |
| chart_paths = [] | |
| def save_chart(fig_func, filename): | |
| path = os.path.join(chart_dir, filename) | |
| fig_func() | |
| plt.tight_layout() | |
| plt.savefig(path) | |
| plt.close() | |
| chart_paths.append((filename.split('.')[0].replace('_', ' ').title(), path)) | |
| save_chart(lambda: plt.bar(categories, values), "bar_chart.png") | |
| save_chart(lambda: plt.pie(values, labels=categories, autopct='%1.1f%%'), "pie_chart.png") | |
| save_chart(lambda: plt.plot(categories, values, marker='o'), "trend_chart.png") | |
| save_chart(lambda: plt.barh(categories, values), "horizontal_bar_chart.png") | |
| # Radar chart | |
| import numpy as np | |
| labels = np.array(categories) | |
| stats = np.array(values) | |
| angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist() | |
| stats = np.concatenate((stats, [stats[0]])) | |
| angles += angles[:1] | |
| fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) | |
| ax.plot(angles, stats, marker='o') | |
| ax.fill(angles, stats, alpha=0.25) | |
| ax.set_yticklabels([]) | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels(labels) | |
| ax.set_title('Radar Chart: Clinical Focus') | |
| radar_path = os.path.join(chart_dir, "radar_chart.png") | |
| plt.tight_layout() | |
| plt.savefig(radar_path) | |
| plt.close() | |
| chart_paths.append(("Radar Chart: Clinical Focus", radar_path)) | |
| # === Dynamic Chart: Drug Frequency === | |
| drug_counter = {} | |
| if detailed_batches: | |
| for batch in detailed_batches: | |
| lines = batch.split("\n") | |
| for line in lines: | |
| match = re.search(r"(?i)medication[s]?:\s*(.+)", line) | |
| if match: | |
| items = re.split(r"[,;]", match.group(1)) | |
| for item in items: | |
| drug = item.strip().title() | |
| if len(drug) > 2: | |
| drug_counter[drug] = drug_counter.get(drug, 0) + 1 | |
| if drug_counter: | |
| drugs, freqs = zip(*sorted(drug_counter.items(), key=lambda x: x[1], reverse=True)[:10]) | |
| plt.figure(figsize=(6, 4)) | |
| plt.bar(drugs, freqs) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.title('Top Medications Frequency') | |
| drug_chart_path = os.path.join(chart_dir, "drug_frequency_chart.png") | |
| plt.tight_layout() | |
| plt.savefig(drug_chart_path) | |
| plt.close() | |
| chart_paths.append(("Top Medications Frequency", drug_chart_path)) | |
| # === PDF === | |
| pdf_path = report_path.replace('.md', '.pdf') | |
| pdf = FPDF() | |
| pdf.set_auto_page_break(auto=True, margin=20) | |
| def add_section_title(pdf, title): | |
| pdf.set_fill_color(230, 230, 230) | |
| pdf.set_font("Arial", 'B', 14) | |
| pdf.cell(0, 10, remove_non_ascii(title), ln=True, fill=True) | |
| pdf.ln(3) | |
| def add_footer(pdf): | |
| pdf.set_y(-15) | |
| pdf.set_font('Arial', 'I', 8) | |
| pdf.set_text_color(150, 150, 150) | |
| pdf.cell(0, 10, f"Page {pdf.page_no()}", align='C') | |
| # Title Page | |
| pdf.add_page() | |
| pdf.set_font("Arial", 'B', 26) | |
| pdf.set_text_color(0, 70, 140) | |
| pdf.cell(0, 20, remove_non_ascii("Final Medical Report"), ln=True, align='C') | |
| pdf.set_text_color(0, 0, 0) | |
| pdf.set_font("Arial", '', 13) | |
| pdf.cell(0, 10, datetime.now().strftime("Generated on %B %d, %Y at %H:%M"), ln=True, align='C') | |
| pdf.ln(15) | |
| pdf.set_font("Arial", '', 11) | |
| pdf.set_fill_color(245, 245, 245) | |
| pdf.multi_cell(0, 9, remove_non_ascii( | |
| "This report contains a professional summary of clinical observations, potential inconsistencies, and follow-up recommendations based on the uploaded medical document." | |
| ), border=1, fill=True, align="J") | |
| add_footer(pdf) | |
| # Final Summary | |
| pdf.add_page() | |
| add_section_title(pdf, "Final Summary") | |
| pdf.set_font("Arial", '', 11) | |
| for line in summary.split("\n"): | |
| clean_line = remove_non_ascii(line.strip()) | |
| if clean_line: | |
| pdf.multi_cell(0, 8, txt=clean_line) | |
| add_footer(pdf) | |
| # Charts Section | |
| pdf.add_page() | |
| add_section_title(pdf, "Statistical Overview") | |
| for title, path in chart_paths: | |
| pdf.set_font("Arial", 'B', 12) | |
| pdf.cell(0, 9, remove_non_ascii(title), ln=True) | |
| pdf.image(path, w=170) | |
| pdf.ln(6) | |
| add_footer(pdf) | |
| # Detailed Tool Outputs | |
| if detailed_batches: | |
| pdf.add_page() | |
| add_section_title(pdf, "Detailed Tool Insights") | |
| for idx, detail in enumerate(detailed_batches): | |
| pdf.set_font("Arial", 'B', 12) | |
| pdf.cell(0, 9, remove_non_ascii(f"Tool Output #{idx + 1}"), ln=True) | |
| pdf.set_font("Arial", '', 11) | |
| for line in remove_non_ascii(detail).split("\n"): | |
| pdf.multi_cell(0, 8, txt=line.strip()) | |
| pdf.ln(3) | |
| add_footer(pdf) | |
| pdf.output(pdf_path) | |
| return pdf_path | |
| async def analyze_document(file: UploadFile = File(...)): | |
| """ | |
| Analyze a medical document (PDF, Excel, or CSV) and return a structured analysis. | |
| Args: | |
| file: The medical document to analyze (PDF, Excel, or CSV format) | |
| Returns: | |
| JSONResponse: Contains analysis results and report download path | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Save the uploaded file temporarily | |
| temp_path = os.path.join(file_cache_dir, file.filename) | |
| with open(temp_path, "wb") as f: | |
| f.write(await file.read()) | |
| extracted = extract_text(temp_path) | |
| if not extracted: | |
| raise HTTPException(status_code=400, detail="Could not extract text from the file") | |
| chunks = split_text(extracted) | |
| batches = batch_chunks(chunks, batch_size=BATCH_SIZE) | |
| batch_results = analyze_batches(agent, batches) | |
| all_tool_outputs = batch_results.copy() | |
| valid = [res for res in batch_results if not res.startswith("❌")] | |
| if not valid: | |
| raise HTTPException(status_code=400, detail="No valid analysis results were generated") | |
| summary = generate_final_summary(agent, "\n\n".join(valid)) | |
| # Generate report files | |
| report_filename = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
| report_path = os.path.join(report_dir, f"{report_filename}.md") | |
| with open(report_path, 'w', encoding='utf-8') as f: | |
| f.write(f"# Final Medical Report\n\n{summary}") | |
| pdf_path = generate_pdf_report_with_charts(summary, report_path, detailed_batches=all_tool_outputs) | |
| end_time = time.time() | |
| elapsed_time = end_time - start_time | |
| # Clean up temp file | |
| os.remove(temp_path) | |
| return JSONResponse({ | |
| "status": "success", | |
| "summary": summary, | |
| "report_path": f"/reports/{os.path.basename(pdf_path)}", | |
| "processing_time": f"{elapsed_time:.2f} seconds", | |
| "detailed_outputs": all_tool_outputs | |
| }) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def download_report(filename: str): | |
| """ | |
| Download a generated report PDF file. | |
| Args: | |
| filename: The name of the report file to download | |
| Returns: | |
| FileResponse: The PDF file for download | |
| """ | |
| file_path = os.path.join(report_dir, filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Report not found") | |
| return FileResponse(file_path, media_type='application/pdf', filename=filename) | |
| async def service_status(): | |
| """ | |
| Check the service status and version information. | |
| Returns: | |
| JSONResponse: Service status information | |
| """ | |
| return JSONResponse({ | |
| "status": "running", | |
| "version": "1.0.0", | |
| "model": "mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
| "rag_model": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
| "max_tokens": MAX_MODEL_TOKENS, | |
| "supported_file_types": [".pdf", ".xlsx", ".csv"] | |
| }) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |