Spaces:
Running
Running
Upload ALSARA app files (#1)
Browse files- Upload ALSARA app files (2076784d2a441a06677136e33fa32aa2e9bacb72)
- .env.example +37 -0
- README.md +469 -9
- als_agent_app.py +1832 -0
- custom_mcp_client.py +241 -0
- llm_client.py +439 -0
- llm_providers.py +357 -0
- parallel_tool_execution.py +235 -0
- query_classifier.py +202 -0
- refactored_helpers.py +200 -0
- requirements.txt +37 -0
- servers/aact_server.py +472 -0
- servers/biorxiv_server.py +440 -0
- servers/clinicaltrials_links.py +245 -0
- servers/elevenlabs_server.py +561 -0
- servers/fetch_server.py +206 -0
- servers/llamaindex_server.py +729 -0
- servers/pubmed_server.py +269 -0
- shared/__init__.py +34 -0
- shared/cache.py +94 -0
- shared/config.py +134 -0
- shared/http_client.py +68 -0
- shared/utils.py +194 -0
- smart_cache.py +458 -0
.env.example
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ALS Research Agent Environment Configuration
|
| 2 |
+
|
| 3 |
+
# Anthropic API Key (Required)
|
| 4 |
+
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
| 5 |
+
|
| 6 |
+
# Optional: Specify which Anthropic model to use
|
| 7 |
+
# Available models: claude-sonnet-4-5-20250929, claude-opus-3-20240229, claude-3-haiku-20240307
|
| 8 |
+
# Default: claude-sonnet-4-5-20250929
|
| 9 |
+
ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
|
| 10 |
+
|
| 11 |
+
# Optional: Gradio server configuration
|
| 12 |
+
GRADIO_SERVER_PORT=7860 # Default port for Gradio UI
|
| 13 |
+
|
| 14 |
+
# ElevenLabs Configuration (Optional - for voice capabilities)
|
| 15 |
+
ELEVENLABS_API_KEY=your_elevenlabs_api_key_here
|
| 16 |
+
ELEVENLABS_VOICE_ID=21m00Tcm4TlvDq8ikWAM # Rachel voice (clear for medical terms)
|
| 17 |
+
|
| 18 |
+
# LlamaIndex RAG Configuration (Optional - for research memory)
|
| 19 |
+
CHROMA_DB_PATH=./chroma_db # Path to persist vector database
|
| 20 |
+
LLAMAINDEX_EMBED_MODEL=dmis-lab/biobert-base-cased-v1.2 # Biomedical embedding model
|
| 21 |
+
LLAMAINDEX_CHUNK_SIZE=1024 # Text chunk size for indexing
|
| 22 |
+
LLAMAINDEX_CHUNK_OVERLAP=200 # Overlap between chunks
|
| 23 |
+
|
| 24 |
+
# Optional: Show agent thinking process in UI
|
| 25 |
+
SHOW_THINKING=false
|
| 26 |
+
|
| 27 |
+
# Optional: LLM provider preference
|
| 28 |
+
# Options: quality_optimize (best model), cost_optimize (cheaper model), auto (default)
|
| 29 |
+
LLM_PROVIDER_PREFERENCE=auto
|
| 30 |
+
|
| 31 |
+
# Research API Configuration (Optional)
|
| 32 |
+
# Configure these if you want to limit API usage
|
| 33 |
+
RATE_LIMIT_PUBMED_DELAY=1.0 # Delay between PubMed requests (seconds)
|
| 34 |
+
RATE_LIMIT_BIORXIV_DELAY=1.0 # Delay between bioRxiv requests (seconds)
|
| 35 |
+
|
| 36 |
+
# Optional: Max concurrent searches
|
| 37 |
+
MAX_CONCURRENT_SEARCHES=3
|
README.md
CHANGED
|
@@ -1,14 +1,474 @@
|
|
| 1 |
---
|
| 2 |
-
title: ALSARA
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.0.
|
| 8 |
-
app_file:
|
| 9 |
-
pinned: false
|
| 10 |
license: mit
|
| 11 |
-
short_description:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: ALSARA - ALS Agentic Research Agent
|
| 3 |
+
emoji: 🧬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "6.0.0"
|
| 8 |
+
app_file: als_agent_app.py
|
|
|
|
| 9 |
license: mit
|
| 10 |
+
short_description: AI research assistant for ALS research & trials
|
| 11 |
+
pinned: false
|
| 12 |
+
sponsors: Sambanova, Anthropic, ElevenLabs, LlamaIndex
|
| 13 |
+
tags:
|
| 14 |
+
- mcp-in-action-track-consumer
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ALSARA - ALS Agentic Research Agent
|
| 19 |
+
|
| 20 |
+
ALSARA (ALS Agentic Research Assistant) is an AI-powered research tool that intelligently orchestrates multiple biomedical databases to answer complex questions about ALS (Amyotrophic Lateral Sclerosis) research, treatments, and clinical trials in real-time.
|
| 21 |
+
|
| 22 |
+
Built with a 4-phase agentic workflow (Planning → Executing → Reflecting → Synthesis), ALSARA searches PubMed, 559,000+ clinical trials via AACT database, and provides voice accessibility for ALS patients - delivering comprehensive research in 5-10 seconds.
|
| 23 |
+
|
| 24 |
+
Built with Model Context Protocol (MCP), Gradio 6.x, and Anthropic Claude.
|
| 25 |
+
|
| 26 |
+
## Key Features
|
| 27 |
+
|
| 28 |
+
### Core Capabilities
|
| 29 |
+
- **4-Phase Agentic Workflow**: Intelligent planning, parallel execution, reflection with gap-filling, and comprehensive synthesis
|
| 30 |
+
- **Real-time Literature Search**: Query millions of PubMed peer-reviewed papers
|
| 31 |
+
- **Clinical Trial Discovery**: Access 559,000+ trials from AACT PostgreSQL database (primary) with ClinicalTrials.gov fallback
|
| 32 |
+
- **Voice Accessibility**: Text-to-speech using ElevenLabs for ALS patients with limited mobility
|
| 33 |
+
- **Smart Caching**: Query normalization with 24-hour TTL for instant similar query responses
|
| 34 |
+
- **Parallel Tool Execution**: 70% faster responses by running all searches simultaneously
|
| 35 |
+
|
| 36 |
+
### Advanced Features
|
| 37 |
+
- **Multi-Provider LLM Support**: Claude primary with SambaNova Llama 3.3 70B fallback
|
| 38 |
+
- **Query Classification**: Smart routing between simple answers and complex research
|
| 39 |
+
- **Rate Limiting**: 30 requests/minute per user with exponential backoff
|
| 40 |
+
- **Memory Management**: Automatic conversation truncation and garbage collection
|
| 41 |
+
- **Health Monitoring**: Uptime tracking, error rates, and tool usage statistics
|
| 42 |
+
- **Citation Tracking**: All responses include PMIDs, DOIs, NCT IDs, and source references
|
| 43 |
+
- **Web Scraping**: Fetch full-text articles with SSRF protection
|
| 44 |
+
- **Export Conversations**: Download chat history as markdown files
|
| 45 |
+
|
| 46 |
+
## Architecture
|
| 47 |
+
|
| 48 |
+
The system uses a sophisticated multi-layer architecture:
|
| 49 |
+
|
| 50 |
+
### 1. User Interface Layer
|
| 51 |
+
- **Gradio 6.x** web application with chat interface
|
| 52 |
+
- Real-time streaming responses
|
| 53 |
+
- Voice output controls
|
| 54 |
+
- Export and retry functionality
|
| 55 |
+
|
| 56 |
+
### 2. Agentic Orchestration Layer
|
| 57 |
+
**4-Phase Workflow:**
|
| 58 |
+
1. **PLANNING**: Agent strategizes which databases to query
|
| 59 |
+
2. **EXECUTING**: Parallel searches across all data sources
|
| 60 |
+
3. **REFLECTING**: Evaluates results, identifies gaps, runs additional searches
|
| 61 |
+
4. **SYNTHESIS**: Comprehensive answer with citations and confidence scoring
|
| 62 |
+
|
| 63 |
+
### 3. LLM Provider Layer
|
| 64 |
+
- **Primary**: Anthropic Claude (claude-sonnet-4-5-20250929)
|
| 65 |
+
- **Fallback**: SambaNova Llama 3.3 70B (free alternative)
|
| 66 |
+
- Smart routing based on query complexity
|
| 67 |
+
|
| 68 |
+
### 4. MCP Server Layer
|
| 69 |
+
Each server runs as a separate subprocess with JSON-RPC communication:
|
| 70 |
+
|
| 71 |
+
- **aact-server**: Primary clinical trials database (559,000+ trials)
|
| 72 |
+
- **pubmed-server**: PubMed literature search
|
| 73 |
+
- **fetch-server**: Web scraping with security hardening
|
| 74 |
+
- **elevenlabs-server**: Voice synthesis for accessibility
|
| 75 |
+
- **clinicaltrials_links**: Fallback trial links when AACT unavailable
|
| 76 |
+
- **llamaindex-server**: RAG/semantic search (optional)
|
| 77 |
+
|
| 78 |
+
**Technical Note:** Uses custom MCP client (`custom_mcp_client.py`) to bypass SDK bugs with proper async/await handling, line-buffered I/O, and automatic retry logic.
|
| 79 |
+
|
| 80 |
+
## Available Tools
|
| 81 |
+
|
| 82 |
+
The agent has access to specialized tools across 6 MCP servers:
|
| 83 |
+
|
| 84 |
+
### AACT Clinical Trials Database Tools (PRIMARY)
|
| 85 |
+
|
| 86 |
+
#### 1. `aact__search_aact_trials`
|
| 87 |
+
Search 559,000+ clinical trials from the AACT PostgreSQL database.
|
| 88 |
+
|
| 89 |
+
**Parameters:**
|
| 90 |
+
- `condition` (string, optional): Medical condition (default: "ALS")
|
| 91 |
+
- `status` (string, optional): Trial status - "recruiting", "active", "completed", "all"
|
| 92 |
+
- `intervention` (string, optional): Treatment/drug name
|
| 93 |
+
- `sponsor` (string, optional): Trial sponsor organization
|
| 94 |
+
- `phase` (string, optional): Trial phase (1, 2, 3, 4)
|
| 95 |
+
- `max_results` (integer, optional): Maximum results (default: 10)
|
| 96 |
+
|
| 97 |
+
**Returns:** Comprehensive trial data with NCT IDs, titles, status, phases, enrollment, and locations.
|
| 98 |
+
|
| 99 |
+
#### 2. `aact__get_aact_trial`
|
| 100 |
+
Get complete details for a specific clinical trial.
|
| 101 |
+
|
| 102 |
+
**Parameters:**
|
| 103 |
+
- `nct_id` (string, required): ClinicalTrials.gov NCT ID
|
| 104 |
+
|
| 105 |
+
**Returns:** Full trial information including eligibility, outcomes, interventions, and contacts.
|
| 106 |
+
|
| 107 |
---
|
| 108 |
|
| 109 |
+
### PubMed Literature Tools
|
| 110 |
+
|
| 111 |
+
#### 3. `pubmed__search_pubmed`
|
| 112 |
+
Search PubMed for peer-reviewed research papers.
|
| 113 |
+
|
| 114 |
+
**Parameters:**
|
| 115 |
+
- `query` (string, required): Search query (e.g., "ALS SOD1 therapy")
|
| 116 |
+
- `max_results` (integer, optional): Maximum results (default: 10)
|
| 117 |
+
- `sort` (string, optional): Sort by "relevance" or "date"
|
| 118 |
+
|
| 119 |
+
**Returns:** Papers with titles, abstracts, PMIDs, authors, and publication dates.
|
| 120 |
+
|
| 121 |
+
#### 4. `pubmed__get_paper_details`
|
| 122 |
+
Get complete details for a specific PubMed paper.
|
| 123 |
+
|
| 124 |
+
**Parameters:**
|
| 125 |
+
- `pmid` (string, required): PubMed ID
|
| 126 |
+
|
| 127 |
+
**Returns:** Full paper information including abstract, journal, DOI, and PubMed URL.
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
### Web Fetching Tools
|
| 132 |
+
|
| 133 |
+
#### 5. `fetch__fetch_url`
|
| 134 |
+
Fetch and extract content from web URLs with security hardening.
|
| 135 |
+
|
| 136 |
+
**Parameters:**
|
| 137 |
+
- `url` (string, required): URL to fetch
|
| 138 |
+
- `extract_text_only` (boolean, optional): Extract only text content (default: true)
|
| 139 |
+
|
| 140 |
+
**Returns:** Extracted webpage content with SSRF protection.
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
### Voice Accessibility Tools
|
| 145 |
+
|
| 146 |
+
#### 6. `elevenlabs__text_to_speech`
|
| 147 |
+
Convert research findings to audio for accessibility.
|
| 148 |
+
|
| 149 |
+
**Parameters:**
|
| 150 |
+
- `text` (string, required): Text to convert (max 2500 chars)
|
| 151 |
+
- `voice_id` (string, optional): Voice selection (default: Rachel - medical-friendly)
|
| 152 |
+
- `speed` (number, optional): Speech speed (0.5-2.0)
|
| 153 |
+
|
| 154 |
+
**Returns:** Audio stream for playback.
|
| 155 |
+
|
| 156 |
+
---
|
| 157 |
+
|
| 158 |
+
### Fallback Tools
|
| 159 |
+
|
| 160 |
+
#### 7. `clinicaltrials_links__get_known_als_trials`
|
| 161 |
+
Returns curated list of important ALS trials when AACT is unavailable.
|
| 162 |
+
|
| 163 |
+
#### 8. `clinicaltrials_links__get_search_link`
|
| 164 |
+
Generates direct ClinicalTrials.gov search URLs.
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
### Tool Usage Notes
|
| 169 |
+
|
| 170 |
+
- **Rate Limiting**: All tools respect API rate limits (PubMed: 3 req/sec)
|
| 171 |
+
- **Caching**: Results cached for 24 hours with smart query normalization
|
| 172 |
+
- **Connection Pooling**: AACT uses async PostgreSQL with 2-10 connections
|
| 173 |
+
- **Timeout Protection**: 90-second timeout with automatic retry
|
| 174 |
+
- **Security**: SSRF protection, input validation, content size limits
|
| 175 |
+
|
| 176 |
+
## Quick Start
|
| 177 |
+
|
| 178 |
+
### Prerequisites
|
| 179 |
+
|
| 180 |
+
- Python 3.10+ (3.12 recommended)
|
| 181 |
+
- Anthropic API key
|
| 182 |
+
- Git
|
| 183 |
+
|
| 184 |
+
### Installation
|
| 185 |
+
|
| 186 |
+
1. Clone the repository
|
| 187 |
+
|
| 188 |
+
```bash
|
| 189 |
+
git clone https://github.com/yourusername/als-research-agent.git
|
| 190 |
+
cd als-research-agent
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
2. Create virtual environment
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
python3.12 -m venv venv
|
| 197 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
3. Install dependencies
|
| 201 |
+
|
| 202 |
+
```bash
|
| 203 |
+
pip install -r requirements.txt
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
4. Set up environment variables
|
| 207 |
+
|
| 208 |
+
Create a `.env` file:
|
| 209 |
+
|
| 210 |
+
```bash
|
| 211 |
+
# Required
|
| 212 |
+
ANTHROPIC_API_KEY=sk-ant-xxx
|
| 213 |
+
|
| 214 |
+
# Recommended
|
| 215 |
+
ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
|
| 216 |
+
ELEVENLABS_API_KEY=xxx
|
| 217 |
+
ELEVENLABS_VOICE_ID=21m00Tcm4TlvDq8ikWAM # Rachel voice
|
| 218 |
+
|
| 219 |
+
# Optional Features
|
| 220 |
+
ENABLE_RAG=false # Enable semantic search (requires setup)
|
| 221 |
+
USE_FALLBACK_LLM=true # Enable free SambaNova fallback
|
| 222 |
+
DISABLE_CACHE=false # Disable smart caching
|
| 223 |
+
|
| 224 |
+
# Configuration
|
| 225 |
+
GRADIO_SERVER_PORT=7860
|
| 226 |
+
MAX_CONCURRENT_SEARCHES=3
|
| 227 |
+
RATE_LIMIT_PUBMED_DELAY=1.0
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
5. Run the application
|
| 231 |
+
|
| 232 |
+
```bash
|
| 233 |
+
python als_agent_app.py
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
or
|
| 237 |
+
|
| 238 |
+
```bash
|
| 239 |
+
./venv/bin/python3.12 als_agent_app.py 2>&1
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
The app will launch at http://localhost:7860
|
| 243 |
+
|
| 244 |
+
## Project Structure
|
| 245 |
+
|
| 246 |
+
```
|
| 247 |
+
als-research-agent/
|
| 248 |
+
├── README.md
|
| 249 |
+
├── requirements.txt
|
| 250 |
+
├── .env.example
|
| 251 |
+
├── als_agent_app.py # Main Gradio application (1835 lines)
|
| 252 |
+
├── custom_mcp_client.py # Custom MCP client implementation
|
| 253 |
+
├── llm_client.py # Multi-provider LLM abstraction
|
| 254 |
+
├── query_classifier.py # Research vs simple query detection
|
| 255 |
+
├── smart_cache.py # Query normalization and caching
|
| 256 |
+
├── refactored_helpers.py # Streaming and tool execution
|
| 257 |
+
├── parallel_tool_execution.py # Concurrent search management
|
| 258 |
+
├── servers/
|
| 259 |
+
│ ├── aact_server.py # AACT clinical trials database (PRIMARY)
|
| 260 |
+
│ ├── pubmed_server.py # PubMed literature search
|
| 261 |
+
│ ├── fetch_server.py # Web scraping with security
|
| 262 |
+
│ ├── elevenlabs_server.py # Voice synthesis
|
| 263 |
+
│ ├── clinicaltrials_links.py # Fallback trial links
|
| 264 |
+
│ └── llamaindex_server.py # RAG/semantic search (optional)
|
| 265 |
+
├── shared/
|
| 266 |
+
│ ├── __init__.py
|
| 267 |
+
│ ├── config.py # Centralized configuration
|
| 268 |
+
│ ├── cache.py # TTL-based caching
|
| 269 |
+
│ └── utils.py # Rate limiting and formatting
|
| 270 |
+
└── tests/
|
| 271 |
+
├── test_pubmed_server.py
|
| 272 |
+
├── test_aact_server.py
|
| 273 |
+
├── test_fetch_server.py
|
| 274 |
+
├── test_elevenlabs.py
|
| 275 |
+
├── test_integration.py
|
| 276 |
+
├── test_llm_client.py
|
| 277 |
+
├── test_performance.py
|
| 278 |
+
└── test_workflow_*.py
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
## Usage Examples
|
| 282 |
+
|
| 283 |
+
### Example Queries
|
| 284 |
+
|
| 285 |
+
**Complex Research Questions:**
|
| 286 |
+
- "What are the latest gene therapy trials for SOD1 mutations with recent biomarker data?"
|
| 287 |
+
- "Compare antisense oligonucleotide therapies in Phase 2 or 3 trials"
|
| 288 |
+
- "Find recent PubMed papers on ALS protein aggregation from Japanese researchers"
|
| 289 |
+
|
| 290 |
+
**Clinical Trial Discovery:**
|
| 291 |
+
- "Active trials in Germany for bulbar onset ALS"
|
| 292 |
+
- "Recruiting trials for ALS patients under 40 with slow progression"
|
| 293 |
+
- "Phase 3 trials sponsored by Biogen or Ionis"
|
| 294 |
+
|
| 295 |
+
**Treatment Information:**
|
| 296 |
+
- "Compare efficacy of riluzole, edaravone, and AMX0035"
|
| 297 |
+
- "What combination therapies showed promise in 2024?"
|
| 298 |
+
- "Latest developments in stem cell therapy for ALS"
|
| 299 |
+
|
| 300 |
+
**Accessibility Features:**
|
| 301 |
+
- Click the voice icon to hear research summaries
|
| 302 |
+
- Adjustable speech speed for comfort
|
| 303 |
+
- Medical-friendly voice optimized for clarity
|
| 304 |
+
|
| 305 |
+
## Performance Characteristics
|
| 306 |
+
|
| 307 |
+
- **Typical Response Time**: 5-10 seconds for complex queries
|
| 308 |
+
- **Parallel Speedup**: 70% faster than sequential searching
|
| 309 |
+
- **Cache Hit Time**: <100ms for similar queries (24-hour TTL)
|
| 310 |
+
- **Concurrent Handling**: 4 requests in ~8 seconds
|
| 311 |
+
- **Tool Call Timeout**: 90 seconds with automatic retry
|
| 312 |
+
- **Memory Limit**: 50 messages per conversation (~8-50KB per message)
|
| 313 |
+
|
| 314 |
+
## Development
|
| 315 |
+
|
| 316 |
+
### Running Tests
|
| 317 |
+
|
| 318 |
+
```bash
|
| 319 |
+
# All tests
|
| 320 |
+
pytest tests/ -v
|
| 321 |
+
|
| 322 |
+
# Unit tests only
|
| 323 |
+
pytest tests/ -m "not integration"
|
| 324 |
+
|
| 325 |
+
# With coverage
|
| 326 |
+
pytest --cov=servers --cov-report=html
|
| 327 |
+
|
| 328 |
+
# Quick tests
|
| 329 |
+
./run_quick_tests.sh
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
### Adding New MCP Servers
|
| 333 |
+
|
| 334 |
+
1. Create new server file in `servers/`
|
| 335 |
+
2. Use FastMCP API to implement tools:
|
| 336 |
+
|
| 337 |
+
```python
|
| 338 |
+
from mcp.server.fastmcp import FastMCP
|
| 339 |
+
|
| 340 |
+
mcp = FastMCP("my-server")
|
| 341 |
+
|
| 342 |
+
@mcp.tool()
|
| 343 |
+
async def my_tool(param: str) -> str:
|
| 344 |
+
"""Tool description"""
|
| 345 |
+
return f"Result: {param}"
|
| 346 |
+
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
mcp.run(transport="stdio")
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
3. Add server to `als_agent_app.py` in `setup_mcp_servers()`
|
| 352 |
+
4. Write tests in `tests/`
|
| 353 |
+
|
| 354 |
+
## Deployment
|
| 355 |
+
|
| 356 |
+
### Hugging Face Spaces
|
| 357 |
+
|
| 358 |
+
1. Create a Gradio Space
|
| 359 |
+
2. Push your code
|
| 360 |
+
3. Add secrets:
|
| 361 |
+
- `ANTHROPIC_API_KEY` (required)
|
| 362 |
+
- `ELEVENLABS_API_KEY` (for voice features)
|
| 363 |
+
|
| 364 |
+
### Docker
|
| 365 |
+
|
| 366 |
+
```bash
|
| 367 |
+
docker build -t als-research-agent .
|
| 368 |
+
docker run -p 7860:7860 \
|
| 369 |
+
-e ANTHROPIC_API_KEY=your_key \
|
| 370 |
+
-e ELEVENLABS_API_KEY=your_key \
|
| 371 |
+
als-research-agent
|
| 372 |
+
```
|
| 373 |
+
|
| 374 |
+
### Cloud Deployment (Azure/AWS/GCP)
|
| 375 |
+
|
| 376 |
+
The application is containerized and ready for deployment on any cloud platform supporting Docker containers. See deployment guides for specific platforms.
|
| 377 |
+
|
| 378 |
+
## Troubleshooting
|
| 379 |
+
|
| 380 |
+
**MCP server not responding**
|
| 381 |
+
- Check Python path and virtual environment activation
|
| 382 |
+
- Verify all dependencies installed: `pip install -r requirements.txt`
|
| 383 |
+
|
| 384 |
+
**Rate limit exceeded**
|
| 385 |
+
- Add delays between requests
|
| 386 |
+
- Check Anthropic API quota
|
| 387 |
+
- Use `USE_FALLBACK_LLM=true` for free alternative
|
| 388 |
+
|
| 389 |
+
**Voice synthesis not working**
|
| 390 |
+
- Verify `ELEVENLABS_API_KEY` is set
|
| 391 |
+
- Check API quota at ElevenLabs dashboard
|
| 392 |
+
- Text may be too long (max 2500 chars)
|
| 393 |
+
|
| 394 |
+
**AACT database connection issues**
|
| 395 |
+
- Database may be under maintenance (Sunday 7 AM ET)
|
| 396 |
+
- Fallback to `clinicaltrials_links` server activates automatically
|
| 397 |
+
|
| 398 |
+
**Cache not working**
|
| 399 |
+
- Check `DISABLE_CACHE` is not set to true
|
| 400 |
+
- Verify `.cache/` directory has write permissions
|
| 401 |
+
|
| 402 |
+
## Resources
|
| 403 |
+
|
| 404 |
+
### ALS Research Organizations
|
| 405 |
+
- ALS Association: https://www.als.org/
|
| 406 |
+
- ALS Therapy Development Institute: https://www.als.net/
|
| 407 |
+
- Answer ALS Data Portal: https://dataportal.answerals.org/
|
| 408 |
+
- International Alliance of ALS/MND Associations: https://www.als-mnd.org/
|
| 409 |
+
|
| 410 |
+
### Data Sources
|
| 411 |
+
- PubMed E-utilities: https://www.ncbi.nlm.nih.gov/books/NBK25501/
|
| 412 |
+
- AACT Database: https://aact.ctti-clinicaltrials.org/
|
| 413 |
+
- ClinicalTrials.gov: https://clinicaltrials.gov/
|
| 414 |
+
|
| 415 |
+
### Technologies
|
| 416 |
+
- Model Context Protocol: https://modelcontextprotocol.io/
|
| 417 |
+
- Gradio Documentation: https://www.gradio.app/docs/
|
| 418 |
+
- Anthropic Claude: https://www.anthropic.com/
|
| 419 |
+
- ElevenLabs API: https://elevenlabs.io/
|
| 420 |
+
|
| 421 |
+
## Security & Privacy
|
| 422 |
+
|
| 423 |
+
- **No Patient Data Storage**: Conversations are not permanently stored
|
| 424 |
+
- **SSRF Protection**: Blocks access to private IPs and localhost
|
| 425 |
+
- **Input Validation**: Injection pattern detection and length limits
|
| 426 |
+
- **Rate Limiting**: Per-user request throttling
|
| 427 |
+
- **API Key Security**: All keys stored as environment variables
|
| 428 |
+
|
| 429 |
+
## License
|
| 430 |
+
|
| 431 |
+
MIT License - See LICENSE file for details
|
| 432 |
+
|
| 433 |
+
## Contributing
|
| 434 |
+
|
| 435 |
+
1. Fork the repository
|
| 436 |
+
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
|
| 437 |
+
3. Write tests for your changes
|
| 438 |
+
4. Ensure all tests pass (`pytest`)
|
| 439 |
+
5. Commit your changes (`git commit -m 'Add amazing feature'`)
|
| 440 |
+
6. Push to the branch (`git push origin feature/amazing-feature`)
|
| 441 |
+
7. Open a Pull Request
|
| 442 |
+
|
| 443 |
+
## Future Enhancements
|
| 444 |
+
|
| 445 |
+
### In Development
|
| 446 |
+
- **NCBI Gene Database**: Gene information and mutations
|
| 447 |
+
- **OMIM Integration**: Genetic disorder phenotypes
|
| 448 |
+
- **Protein Data Bank**: 3D protein structures
|
| 449 |
+
- **AlphaFold Database**: AI-predicted protein structures
|
| 450 |
+
|
| 451 |
+
### Planned Features
|
| 452 |
+
- **Voice Input**: Speech recognition for queries
|
| 453 |
+
- **Patient Trial Matching**: Personalized eligibility assessment
|
| 454 |
+
- **Research Trend Analysis**: Track emerging themes
|
| 455 |
+
- **Alert System**: Notifications for new trials/papers
|
| 456 |
+
- **Enhanced Export**: BibTeX, CSV, PDF formats
|
| 457 |
+
- **Multi-language Support**: Global accessibility
|
| 458 |
+
- **Drug Repurposing Module**: Identify potential ALS treatments
|
| 459 |
+
- **arXiv Integration**: Computational biology papers
|
| 460 |
+
|
| 461 |
+
## Acknowledgments
|
| 462 |
+
|
| 463 |
+
Built for the global ALS research community to accelerate the path to a cure.
|
| 464 |
+
|
| 465 |
+
Special thanks to:
|
| 466 |
+
- The MCP team for the Model Context Protocol
|
| 467 |
+
- Anthropic for Claude AI
|
| 468 |
+
- The open-source community for invaluable contributions
|
| 469 |
+
|
| 470 |
+
---
|
| 471 |
+
|
| 472 |
+
**ALSARA - Accelerating ALS research, one query at a time.**
|
| 473 |
+
|
| 474 |
+
For questions, issues, or contributions, please open an issue on GitHub.
|
als_agent_app.py
ADDED
|
@@ -0,0 +1,1832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# als_agent_app.py
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator, Union
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
import httpx
|
| 15 |
+
import base64
|
| 16 |
+
import tempfile
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
# Load environment variables from .env file
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
# Add current directory to path for shared imports
|
| 23 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 24 |
+
from shared import SimpleCache
|
| 25 |
+
from custom_mcp_client import MCPClientManager
|
| 26 |
+
from llm_client import UnifiedLLMClient
|
| 27 |
+
from smart_cache import SmartCache, DEFAULT_PREWARM_QUERIES
|
| 28 |
+
|
| 29 |
+
# Helper function imports for refactored code
|
| 30 |
+
from refactored_helpers import (
|
| 31 |
+
stream_with_retry,
|
| 32 |
+
execute_tool_calls,
|
| 33 |
+
build_assistant_message
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Configure logging
|
| 37 |
+
logging.basicConfig(
|
| 38 |
+
level=logging.INFO,
|
| 39 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 40 |
+
handlers=[
|
| 41 |
+
logging.StreamHandler(),
|
| 42 |
+
logging.FileHandler('app.log', mode='a', encoding='utf-8')
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
# Rate Limiter Class
|
| 48 |
+
class RateLimiter:
|
| 49 |
+
"""Rate limiter to prevent API overload"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, max_requests_per_minute: int = 30):
|
| 52 |
+
self.max_requests_per_minute = max_requests_per_minute
|
| 53 |
+
self.request_times = defaultdict(list)
|
| 54 |
+
|
| 55 |
+
async def check_rate_limit(self, key: str = "default") -> bool:
|
| 56 |
+
"""Check if request is within rate limit"""
|
| 57 |
+
now = datetime.now()
|
| 58 |
+
minute_ago = now - timedelta(minutes=1)
|
| 59 |
+
|
| 60 |
+
# Clean old requests
|
| 61 |
+
self.request_times[key] = [
|
| 62 |
+
t for t in self.request_times[key]
|
| 63 |
+
if t > minute_ago
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
# Check if under limit
|
| 67 |
+
if len(self.request_times[key]) >= self.max_requests_per_minute:
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
# Record this request
|
| 71 |
+
self.request_times[key].append(now)
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
async def wait_if_needed(self, key: str = "default"):
|
| 75 |
+
"""Wait if rate limit exceeded"""
|
| 76 |
+
while not await self.check_rate_limit(key):
|
| 77 |
+
await asyncio.sleep(2) # Wait 2 seconds before retry
|
| 78 |
+
|
| 79 |
+
# Initialize rate limiter
|
| 80 |
+
rate_limiter = RateLimiter(max_requests_per_minute=30)
|
| 81 |
+
|
| 82 |
+
# Memory management settings
|
| 83 |
+
MAX_CONVERSATION_LENGTH = 50 # Maximum messages to keep in history
|
| 84 |
+
MEMORY_CLEANUP_INTERVAL = 300 # Cleanup every 5 minutes
|
| 85 |
+
|
| 86 |
+
async def cleanup_memory():
|
| 87 |
+
"""Periodic memory cleanup task"""
|
| 88 |
+
while True:
|
| 89 |
+
try:
|
| 90 |
+
# Clean up expired cache entries
|
| 91 |
+
tool_cache.cleanup_expired()
|
| 92 |
+
smart_cache.cleanup() if smart_cache else None
|
| 93 |
+
|
| 94 |
+
# Force garbage collection for large cleanups
|
| 95 |
+
import gc
|
| 96 |
+
collected = gc.collect()
|
| 97 |
+
if collected > 0:
|
| 98 |
+
logger.debug(f"Memory cleanup: collected {collected} objects")
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Error during memory cleanup: {e}")
|
| 102 |
+
|
| 103 |
+
await asyncio.sleep(MEMORY_CLEANUP_INTERVAL)
|
| 104 |
+
|
| 105 |
+
# Start memory cleanup task
|
| 106 |
+
cleanup_task = None
|
| 107 |
+
|
| 108 |
+
# Track whether last response used research workflow (for voice button)
|
| 109 |
+
last_response_was_research = False
|
| 110 |
+
|
| 111 |
+
# Health monitoring
|
| 112 |
+
class HealthMonitor:
|
| 113 |
+
"""Monitor system health and performance"""
|
| 114 |
+
|
| 115 |
+
def __init__(self):
|
| 116 |
+
self.start_time = datetime.now()
|
| 117 |
+
self.request_count = 0
|
| 118 |
+
self.error_count = 0
|
| 119 |
+
self.tool_call_count = defaultdict(int)
|
| 120 |
+
self.response_times = []
|
| 121 |
+
self.last_error = None
|
| 122 |
+
|
| 123 |
+
def record_request(self):
|
| 124 |
+
self.request_count += 1
|
| 125 |
+
|
| 126 |
+
def record_error(self, error: str):
|
| 127 |
+
self.error_count += 1
|
| 128 |
+
self.last_error = {"time": datetime.now(), "error": str(error)[:500]}
|
| 129 |
+
|
| 130 |
+
def record_tool_call(self, tool_name: str):
|
| 131 |
+
self.tool_call_count[tool_name] += 1
|
| 132 |
+
|
| 133 |
+
def record_response_time(self, duration: float):
|
| 134 |
+
self.response_times.append(duration)
|
| 135 |
+
# Keep only last 100 response times to avoid memory buildup
|
| 136 |
+
if len(self.response_times) > 100:
|
| 137 |
+
self.response_times = self.response_times[-100:]
|
| 138 |
+
|
| 139 |
+
def get_health_status(self) -> Dict[str, Any]:
|
| 140 |
+
"""Get current health status"""
|
| 141 |
+
uptime = (datetime.now() - self.start_time).total_seconds()
|
| 142 |
+
avg_response_time = sum(self.response_times) / len(self.response_times) if self.response_times else 0
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
"status": "healthy" if self.error_count < 10 else "degraded",
|
| 146 |
+
"uptime_seconds": uptime,
|
| 147 |
+
"request_count": self.request_count,
|
| 148 |
+
"error_count": self.error_count,
|
| 149 |
+
"error_rate": self.error_count / max(1, self.request_count),
|
| 150 |
+
"avg_response_time": avg_response_time,
|
| 151 |
+
"cache_size": tool_cache.size(),
|
| 152 |
+
"rate_limit_status": f"{len(rate_limiter.request_times)} active keys",
|
| 153 |
+
"most_used_tools": dict(sorted(self.tool_call_count.items(), key=lambda x: x[1], reverse=True)[:5]),
|
| 154 |
+
"last_error": self.last_error
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Initialize health monitor
|
| 158 |
+
health_monitor = HealthMonitor()
|
| 159 |
+
|
| 160 |
+
# Error message formatter
|
| 161 |
+
def format_error_message(error: Exception, context: str = "") -> str:
|
| 162 |
+
"""Format error messages with helpful suggestions"""
|
| 163 |
+
|
| 164 |
+
error_str = str(error)
|
| 165 |
+
error_type = type(error).__name__
|
| 166 |
+
|
| 167 |
+
# Common error patterns and suggestions
|
| 168 |
+
if "timeout" in error_str.lower():
|
| 169 |
+
suggestion = """
|
| 170 |
+
**Suggestions:**
|
| 171 |
+
- Try simplifying your search query
|
| 172 |
+
- Break complex questions into smaller parts
|
| 173 |
+
- Check your internet connection
|
| 174 |
+
- The service may be temporarily overloaded - try again in a moment
|
| 175 |
+
"""
|
| 176 |
+
elif "rate limit" in error_str.lower():
|
| 177 |
+
suggestion = """
|
| 178 |
+
**Suggestions:**
|
| 179 |
+
- Wait a moment before trying again
|
| 180 |
+
- Reduce the number of simultaneous searches
|
| 181 |
+
- Consider using cached results when available
|
| 182 |
+
"""
|
| 183 |
+
elif "connection" in error_str.lower() or "network" in error_str.lower():
|
| 184 |
+
suggestion = """
|
| 185 |
+
**Suggestions:**
|
| 186 |
+
- Check your internet connection
|
| 187 |
+
- The external service may be temporarily unavailable
|
| 188 |
+
- Try again in a few moments
|
| 189 |
+
"""
|
| 190 |
+
elif "invalid" in error_str.lower() or "validation" in error_str.lower():
|
| 191 |
+
suggestion = """
|
| 192 |
+
**Suggestions:**
|
| 193 |
+
- Check your query for special characters or formatting issues
|
| 194 |
+
- Ensure your question is clear and well-formed
|
| 195 |
+
- Avoid using HTML or script tags in your query
|
| 196 |
+
"""
|
| 197 |
+
elif "memory" in error_str.lower() or "resource" in error_str.lower():
|
| 198 |
+
suggestion = """
|
| 199 |
+
**Suggestions:**
|
| 200 |
+
- The system may be under heavy load
|
| 201 |
+
- Try a simpler query
|
| 202 |
+
- Clear your browser cache and refresh the page
|
| 203 |
+
"""
|
| 204 |
+
else:
|
| 205 |
+
suggestion = """
|
| 206 |
+
**Suggestions:**
|
| 207 |
+
- Try rephrasing your question
|
| 208 |
+
- Break complex queries into simpler parts
|
| 209 |
+
- If the error persists, please report it to support
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
formatted = f"""
|
| 213 |
+
❌ **Error Encountered**
|
| 214 |
+
|
| 215 |
+
**Type:** {error_type}
|
| 216 |
+
**Details:** {error_str[:500]}
|
| 217 |
+
{f"**Context:** {context}" if context else ""}
|
| 218 |
+
|
| 219 |
+
{suggestion}
|
| 220 |
+
|
| 221 |
+
**Need Help?**
|
| 222 |
+
- Try the example queries in the sidebar
|
| 223 |
+
- Check the System Health tab for service status
|
| 224 |
+
- Report persistent issues on GitHub
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
return formatted.strip()
|
| 228 |
+
|
| 229 |
+
# Initialize the unified LLM client
|
| 230 |
+
# All provider logic is now handled inside UnifiedLLMClient
|
| 231 |
+
client = None # Initialize to None for proper cleanup handling
|
| 232 |
+
try:
|
| 233 |
+
client = UnifiedLLMClient()
|
| 234 |
+
logger.info(f"LLM client initialized: {client.get_provider_display_name()}")
|
| 235 |
+
except ValueError as e:
|
| 236 |
+
# Re-raise configuration errors with clear instructions
|
| 237 |
+
logger.error(f"LLM configuration error: {e}")
|
| 238 |
+
raise
|
| 239 |
+
|
| 240 |
+
# Global MCP client manager
|
| 241 |
+
mcp_manager = MCPClientManager()
|
| 242 |
+
|
| 243 |
+
# Internal thinking tags are always filtered for cleaner output
|
| 244 |
+
|
| 245 |
+
# Model configuration
|
| 246 |
+
# Use Claude 3.5 Sonnet with correct model ID that works with the API key
|
| 247 |
+
ANTHROPIC_MODEL = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5-20250929")
|
| 248 |
+
logger.info(f"Using model: {ANTHROPIC_MODEL}")
|
| 249 |
+
|
| 250 |
+
# Configuration for max tokens in responses
|
| 251 |
+
# Set MAX_RESPONSE_TOKENS in .env to control response length
|
| 252 |
+
# Claude 3.5 Sonnet supports up to 8192 tokens
|
| 253 |
+
MAX_RESPONSE_TOKENS = min(int(os.getenv("MAX_RESPONSE_TOKENS", "8192")), 8192)
|
| 254 |
+
logger.info(f"Max response tokens set to: {MAX_RESPONSE_TOKENS}")
|
| 255 |
+
|
| 256 |
+
# Global smart cache (24 hour TTL for research queries)
|
| 257 |
+
smart_cache = SmartCache(cache_dir=".cache", ttl_hours=24)
|
| 258 |
+
|
| 259 |
+
# Keep tool cache for MCP tool results
|
| 260 |
+
tool_cache = SimpleCache(ttl=3600)
|
| 261 |
+
|
| 262 |
+
# Cache for tool definitions to avoid repeated fetching
|
| 263 |
+
_cached_tools = None
|
| 264 |
+
_tools_cache_time = None
|
| 265 |
+
TOOLS_CACHE_TTL = 86400 # 24 hour cache for tool definitions (tools rarely change)
|
| 266 |
+
|
| 267 |
+
async def setup_mcp_servers() -> MCPClientManager:
|
| 268 |
+
"""Initialize all MCP servers using custom client"""
|
| 269 |
+
logger.info("Setting up MCP servers...")
|
| 270 |
+
|
| 271 |
+
# Get the directory where this script is located
|
| 272 |
+
script_dir = Path(__file__).parent.resolve()
|
| 273 |
+
servers_dir = script_dir / "servers"
|
| 274 |
+
|
| 275 |
+
logger.info(f"Script directory: {script_dir}")
|
| 276 |
+
logger.info(f"Servers directory: {servers_dir}")
|
| 277 |
+
|
| 278 |
+
# Verify servers directory exists
|
| 279 |
+
if not servers_dir.exists():
|
| 280 |
+
logger.error(f"Servers directory not found: {servers_dir}")
|
| 281 |
+
raise FileNotFoundError(f"Servers directory not found: {servers_dir}")
|
| 282 |
+
|
| 283 |
+
# Add all servers to manager
|
| 284 |
+
servers = {
|
| 285 |
+
"pubmed": servers_dir / "pubmed_server.py",
|
| 286 |
+
"aact": servers_dir / "aact_server.py", # PRIMARY: AACT database for comprehensive clinical trials data
|
| 287 |
+
"trials_links": servers_dir / "clinicaltrials_links.py", # FALLBACK: Direct links and known ALS trials
|
| 288 |
+
"fetch": servers_dir / "fetch_server.py",
|
| 289 |
+
"elevenlabs": servers_dir / "elevenlabs_server.py", # Voice capabilities for accessibility
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
# bioRxiv temporarily disabled - commenting out to hide from users
|
| 293 |
+
# enable_biorxiv = os.getenv("ENABLE_BIORXIV", "true").lower() == "true"
|
| 294 |
+
# if enable_biorxiv:
|
| 295 |
+
# servers["biorxiv"] = servers_dir / "biorxiv_server.py"
|
| 296 |
+
# else:
|
| 297 |
+
# logger.info("⚠️ bioRxiv/medRxiv disabled for faster searches (set ENABLE_BIORXIV=true to enable)")
|
| 298 |
+
|
| 299 |
+
# Conditionally add LlamaIndex RAG based on environment variable
|
| 300 |
+
enable_rag = os.getenv("ENABLE_RAG", "false").lower() == "true"
|
| 301 |
+
if enable_rag:
|
| 302 |
+
logger.info("📚 RAG/LlamaIndex enabled (will add ~10s to startup for semantic search)")
|
| 303 |
+
servers["llamaindex"] = servers_dir / "llamaindex_server.py"
|
| 304 |
+
else:
|
| 305 |
+
logger.info("🚀 RAG/LlamaIndex disabled for faster startup (set ENABLE_RAG=true to enable)")
|
| 306 |
+
|
| 307 |
+
# Parallelize server initialization for faster startup
|
| 308 |
+
async def init_server(name: str, script_path: Path):
|
| 309 |
+
try:
|
| 310 |
+
await mcp_manager.add_server(name, str(script_path))
|
| 311 |
+
logger.info(f"✓ MCP server {name} initialized")
|
| 312 |
+
except Exception as e:
|
| 313 |
+
logger.error(f"Failed to initialize MCP server {name}: {e}")
|
| 314 |
+
raise
|
| 315 |
+
|
| 316 |
+
# Start all servers concurrently
|
| 317 |
+
tasks = [init_server(name, script_path) for name, script_path in servers.items()]
|
| 318 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 319 |
+
|
| 320 |
+
# Check for any failures
|
| 321 |
+
for i, result in enumerate(results):
|
| 322 |
+
if isinstance(result, Exception):
|
| 323 |
+
name = list(servers.keys())[i]
|
| 324 |
+
logger.error(f"Failed to initialize MCP server {name}: {result}")
|
| 325 |
+
raise result
|
| 326 |
+
|
| 327 |
+
logger.info("All MCP servers initialized successfully")
|
| 328 |
+
return mcp_manager
|
| 329 |
+
|
| 330 |
+
async def cleanup_mcp_servers() -> None:
|
| 331 |
+
"""Cleanup MCP server sessions"""
|
| 332 |
+
logger.info("Cleaning up MCP server sessions...")
|
| 333 |
+
await mcp_manager.close_all()
|
| 334 |
+
logger.info("MCP cleanup complete")
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def export_conversation(history: Optional[List[Any]]) -> Optional[Path]:
|
| 338 |
+
"""Export conversation to markdown format"""
|
| 339 |
+
if not history:
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 343 |
+
filename = f"als_conversation_{timestamp}.md"
|
| 344 |
+
|
| 345 |
+
content = f"""# ALS Research Conversation
|
| 346 |
+
**Exported:** {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
| 347 |
+
|
| 348 |
+
---
|
| 349 |
+
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
for i, (user_msg, assistant_msg) in enumerate(history, 1):
|
| 353 |
+
content += f"## Query {i}\n\n**User:** {user_msg}\n\n**Assistant:**\n{assistant_msg}\n\n---\n\n"
|
| 354 |
+
|
| 355 |
+
content += f"""
|
| 356 |
+
*Generated by ALSARA - ALS Agentic Research Agent*
|
| 357 |
+
*Total interactions: {len(history)}*
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
filepath = Path(filename)
|
| 361 |
+
filepath.write_text(content, encoding='utf-8')
|
| 362 |
+
logger.info(f"Exported conversation to {filename}")
|
| 363 |
+
|
| 364 |
+
return filepath
|
| 365 |
+
|
| 366 |
+
async def get_all_tools() -> List[Dict[str, Any]]:
|
| 367 |
+
"""Retrieve all available tools from MCP servers with caching"""
|
| 368 |
+
global _cached_tools, _tools_cache_time
|
| 369 |
+
|
| 370 |
+
# Check if cache is valid
|
| 371 |
+
if _cached_tools and _tools_cache_time:
|
| 372 |
+
if time.time() - _tools_cache_time < TOOLS_CACHE_TTL:
|
| 373 |
+
logger.debug("Using cached tool definitions")
|
| 374 |
+
return _cached_tools
|
| 375 |
+
|
| 376 |
+
# Fetch fresh tool definitions
|
| 377 |
+
logger.info("Fetching fresh tool definitions from MCP servers")
|
| 378 |
+
all_tools = []
|
| 379 |
+
|
| 380 |
+
# Get tools from all servers
|
| 381 |
+
server_tools = await mcp_manager.list_all_tools()
|
| 382 |
+
|
| 383 |
+
for server_name, tools in server_tools.items():
|
| 384 |
+
for tool in tools:
|
| 385 |
+
# Convert MCP tool to Anthropic function format
|
| 386 |
+
all_tools.append({
|
| 387 |
+
"name": f"{server_name}__{tool['name']}",
|
| 388 |
+
"description": tool.get('description', ''),
|
| 389 |
+
"input_schema": tool.get('inputSchema', {})
|
| 390 |
+
})
|
| 391 |
+
|
| 392 |
+
# Update cache
|
| 393 |
+
_cached_tools = all_tools
|
| 394 |
+
_tools_cache_time = time.time()
|
| 395 |
+
logger.info(f"Cached {len(all_tools)} tool definitions")
|
| 396 |
+
|
| 397 |
+
return all_tools
|
| 398 |
+
|
| 399 |
+
async def call_mcp_tool(tool_name: str, arguments: Dict[str, Any], max_retries: int = 3) -> str:
|
| 400 |
+
"""Execute an MCP tool call with caching, rate limiting, retry logic, and error handling"""
|
| 401 |
+
|
| 402 |
+
# Check cache first (no retries needed for cached results)
|
| 403 |
+
cached_result = tool_cache.get(tool_name, arguments)
|
| 404 |
+
if cached_result:
|
| 405 |
+
return cached_result
|
| 406 |
+
|
| 407 |
+
last_error = None
|
| 408 |
+
|
| 409 |
+
for attempt in range(max_retries):
|
| 410 |
+
try:
|
| 411 |
+
# Apply rate limiting
|
| 412 |
+
await rate_limiter.wait_if_needed(tool_name.split("__")[0])
|
| 413 |
+
|
| 414 |
+
# Parse tool name
|
| 415 |
+
if "__" not in tool_name:
|
| 416 |
+
logger.error(f"Invalid tool name format: {tool_name}")
|
| 417 |
+
return f"Error: Invalid tool name format: {tool_name}"
|
| 418 |
+
|
| 419 |
+
server_name, tool_method = tool_name.split("__", 1)
|
| 420 |
+
|
| 421 |
+
if attempt > 0:
|
| 422 |
+
logger.info(f"Retry {attempt}/{max_retries} for tool: {tool_method} on server: {server_name}")
|
| 423 |
+
else:
|
| 424 |
+
logger.info(f"Calling tool: {tool_method} on server: {server_name}")
|
| 425 |
+
|
| 426 |
+
# Call tool with timeout using custom client
|
| 427 |
+
result = await asyncio.wait_for(
|
| 428 |
+
mcp_manager.call_tool(server_name, tool_method, arguments),
|
| 429 |
+
timeout=90.0 # 90 second timeout for complex tool calls (BioRxiv searches can be slow)
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Result is already a string from custom client
|
| 433 |
+
final_result = result if result else "No content returned from tool"
|
| 434 |
+
|
| 435 |
+
# Cache the result
|
| 436 |
+
tool_cache.set(tool_name, arguments, final_result)
|
| 437 |
+
|
| 438 |
+
# Record successful tool call
|
| 439 |
+
health_monitor.record_tool_call(tool_name)
|
| 440 |
+
|
| 441 |
+
return final_result
|
| 442 |
+
|
| 443 |
+
except asyncio.TimeoutError as e:
|
| 444 |
+
last_error = e
|
| 445 |
+
logger.warning(f"Tool call timed out (attempt {attempt + 1}/{max_retries}): {tool_name}")
|
| 446 |
+
if attempt < max_retries - 1:
|
| 447 |
+
await asyncio.sleep(2 ** attempt) # Exponential backoff: 1s, 2s, 4s
|
| 448 |
+
continue
|
| 449 |
+
# Last attempt failed
|
| 450 |
+
timeout_error = TimeoutError(f"Tool timeout after {max_retries} attempts - the {server_name} server may be overloaded")
|
| 451 |
+
return format_error_message(timeout_error, context=f"Calling {tool_name}")
|
| 452 |
+
|
| 453 |
+
except ValueError as e:
|
| 454 |
+
logger.error(f"Invalid tool/server: {tool_name} - {e}")
|
| 455 |
+
return format_error_message(e, context=f"Invalid tool: {tool_name}")
|
| 456 |
+
|
| 457 |
+
except Exception as e:
|
| 458 |
+
last_error = e
|
| 459 |
+
logger.warning(f"Error calling tool {tool_name} (attempt {attempt + 1}/{max_retries}): {e}")
|
| 460 |
+
if attempt < max_retries - 1:
|
| 461 |
+
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
| 462 |
+
continue
|
| 463 |
+
# Last attempt failed
|
| 464 |
+
return format_error_message(e, context=f"Tool {tool_name} failed after {max_retries} attempts")
|
| 465 |
+
|
| 466 |
+
# Should not reach here, but handle just in case
|
| 467 |
+
if last_error:
|
| 468 |
+
return f"Tool failed after {max_retries} attempts: {str(last_error)[:200]}"
|
| 469 |
+
return "Unexpected error in tool execution"
|
| 470 |
+
|
| 471 |
+
def filter_internal_tags(text: str) -> str:
|
| 472 |
+
"""Remove all internal processing tags from the output."""
|
| 473 |
+
import re
|
| 474 |
+
|
| 475 |
+
# Remove internal tags and their content with single regex
|
| 476 |
+
text = re.sub(r'<(thinking|search_quality_reflection|search_quality_score)>.*?</\1>|<(thinking|search_quality_reflection|search_quality_score)>.*$', '', text, flags=re.DOTALL)
|
| 477 |
+
|
| 478 |
+
# Remove wrapper tags but keep content
|
| 479 |
+
text = re.sub(r'</?(result|answer)>', '', text)
|
| 480 |
+
|
| 481 |
+
# Fix phase formatting - ensure consistent formatting
|
| 482 |
+
# Add proper line breaks around phase headers
|
| 483 |
+
# First normalize any existing phase markers to be on their own line
|
| 484 |
+
phase_patterns = [
|
| 485 |
+
# Fix incorrect formats (missing asterisks) first
|
| 486 |
+
(r'(?<!\*)🎯\s*PLANNING:(?!\*)', r'**🎯 PLANNING:**'),
|
| 487 |
+
(r'(?<!\*)🔧\s*EXECUTING:(?!\*)', r'**🔧 EXECUTING:**'),
|
| 488 |
+
(r'(?<!\*)🤔\s*REFLECTING:(?!\*)', r'**🤔 REFLECTING:**'),
|
| 489 |
+
(r'(?<!\*)✅\s*SYNTHESIS:(?!\*)', r'**✅ SYNTHESIS:**'),
|
| 490 |
+
|
| 491 |
+
# Then ensure the markers are on new lines (if not already)
|
| 492 |
+
(r'(?<!\n)(\*\*🎯\s*PLANNING:\*\*)', r'\n\n\1'),
|
| 493 |
+
(r'(?<!\n)(\*\*🔧\s*EXECUTING:\*\*)', r'\n\n\1'),
|
| 494 |
+
(r'(?<!\n)(\*\*🤔\s*REFLECTING:\*\*)', r'\n\n\1'),
|
| 495 |
+
(r'(?<!\n)(\*\*✅\s*SYNTHESIS:\*\*)', r'\n\n\1'),
|
| 496 |
+
|
| 497 |
+
# Then add spacing after them
|
| 498 |
+
(r'(\*\*🎯\s*PLANNING:\*\*)', r'\1\n'),
|
| 499 |
+
(r'(\*\*🔧\s*EXECUTING:\*\*)', r'\1\n'),
|
| 500 |
+
(r'(\*\*🤔\s*REFLECTING:\*\*)', r'\1\n'),
|
| 501 |
+
(r'(\*\*✅\s*SYNTHESIS:\*\*)', r'\1\n'),
|
| 502 |
+
]
|
| 503 |
+
|
| 504 |
+
for pattern, replacement in phase_patterns:
|
| 505 |
+
text = re.sub(pattern, replacement, text)
|
| 506 |
+
|
| 507 |
+
# Clean up excessive whitespace while preserving intentional formatting
|
| 508 |
+
text = re.sub(r'[ \t]+', ' ', text) # Multiple spaces to single space
|
| 509 |
+
text = re.sub(r'\n{4,}', '\n\n\n', text) # Maximum 3 newlines
|
| 510 |
+
text = re.sub(r'^\n+', '', text) # Remove leading newlines
|
| 511 |
+
text = re.sub(r'\n+$', '\n', text) # Single trailing newline
|
| 512 |
+
|
| 513 |
+
return text.strip()
|
| 514 |
+
|
| 515 |
+
def is_complex_query(message: str) -> bool:
|
| 516 |
+
"""Detect complex queries that might need more iterations"""
|
| 517 |
+
complex_indicators = [
|
| 518 |
+
"genotyping", "genetic testing", "multiple", "comprehensive",
|
| 519 |
+
"all", "compare", "versus", "difference between", "systematic",
|
| 520 |
+
"gene-targeted", "gene targeted", "list the main", "what are all",
|
| 521 |
+
"complete overview", "detailed analysis", "in-depth"
|
| 522 |
+
]
|
| 523 |
+
return any(indicator in message.lower() for indicator in complex_indicators)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def validate_query(message: str) -> Tuple[bool, str]:
|
| 527 |
+
"""Validate and sanitize user input to prevent injection and abuse"""
|
| 528 |
+
# Check length
|
| 529 |
+
if not message or not message.strip():
|
| 530 |
+
return False, "Please enter a query"
|
| 531 |
+
|
| 532 |
+
if len(message) > 2000:
|
| 533 |
+
return False, "Query too long (maximum 2000 characters). Please shorten your question."
|
| 534 |
+
|
| 535 |
+
# Check for potential injection patterns
|
| 536 |
+
suspicious_patterns = [
|
| 537 |
+
r'<script', r'javascript:', r'onclick', r'onerror',
|
| 538 |
+
r'\bignore\s+previous\s+instructions\b',
|
| 539 |
+
r'\bsystem\s+prompt\b',
|
| 540 |
+
r'\bforget\s+everything\b',
|
| 541 |
+
r'\bdisregard\s+all\b'
|
| 542 |
+
]
|
| 543 |
+
|
| 544 |
+
for pattern in suspicious_patterns:
|
| 545 |
+
if re.search(pattern, message, re.IGNORECASE):
|
| 546 |
+
logger.warning(f"Suspicious pattern detected in query: {pattern}")
|
| 547 |
+
return False, "Invalid query format. Please rephrase your question."
|
| 548 |
+
|
| 549 |
+
# Check for excessive repetition (potential spam)
|
| 550 |
+
words = message.lower().split()
|
| 551 |
+
if len(words) > 10:
|
| 552 |
+
# Check if any word appears too frequently
|
| 553 |
+
word_freq = {}
|
| 554 |
+
for word in words:
|
| 555 |
+
word_freq[word] = word_freq.get(word, 0) + 1
|
| 556 |
+
|
| 557 |
+
max_freq = max(word_freq.values())
|
| 558 |
+
if max_freq > len(words) * 0.5: # If any word is more than 50% of the query
|
| 559 |
+
return False, "Query appears to contain excessive repetition. Please rephrase."
|
| 560 |
+
|
| 561 |
+
return True, ""
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
async def als_research_agent(message: str, history: Optional[List[Dict[str, Any]]]) -> AsyncGenerator[str, None]:
|
| 565 |
+
"""Main agent logic with streaming response and error handling"""
|
| 566 |
+
|
| 567 |
+
global last_response_was_research
|
| 568 |
+
|
| 569 |
+
start_time = time.time()
|
| 570 |
+
health_monitor.record_request()
|
| 571 |
+
|
| 572 |
+
try:
|
| 573 |
+
# Validate input first
|
| 574 |
+
valid, error_msg = validate_query(message)
|
| 575 |
+
if not valid:
|
| 576 |
+
yield f"⚠️ **Input Validation Error:** {error_msg}"
|
| 577 |
+
return
|
| 578 |
+
|
| 579 |
+
logger.info(f"Received valid query: {message[:100]}...") # Log first 100 chars
|
| 580 |
+
|
| 581 |
+
# Truncate history to prevent memory bloat
|
| 582 |
+
if history and len(history) > MAX_CONVERSATION_LENGTH:
|
| 583 |
+
logger.info(f"Truncating conversation history from {len(history)} to {MAX_CONVERSATION_LENGTH} messages")
|
| 584 |
+
history = history[-MAX_CONVERSATION_LENGTH:]
|
| 585 |
+
|
| 586 |
+
# System prompt
|
| 587 |
+
base_prompt = """You are ALSARA, an expert ALS (Amyotrophic Lateral Sclerosis) research assistant with agentic capabilities for planning, execution, and reflection.
|
| 588 |
+
|
| 589 |
+
CRITICAL CONTEXT: ALL queries should be interpreted in the context of ALS research unless explicitly stated otherwise.
|
| 590 |
+
|
| 591 |
+
MANDATORY SEARCH QUERY RULES:
|
| 592 |
+
1. ALWAYS include "ALS" or "amyotrophic lateral sclerosis" in EVERY search query
|
| 593 |
+
2. If the user's query doesn't mention ALS, ADD IT to your search terms
|
| 594 |
+
3. This prevents irrelevant results from other conditions
|
| 595 |
+
|
| 596 |
+
Examples:
|
| 597 |
+
- User: "genotyping for gene targeted treatments" → Search: "genotyping ALS gene targeted treatments"
|
| 598 |
+
- User: "psilocybin clinical trials" → Search: "psilocybin ALS clinical trials"
|
| 599 |
+
- User: "stem cell therapy" → Search: "stem cell therapy ALS"
|
| 600 |
+
- User: "gene therapy trials" → Search: "gene therapy ALS trials"
|
| 601 |
+
|
| 602 |
+
Your capabilities:
|
| 603 |
+
- Search PubMed for peer-reviewed research papers
|
| 604 |
+
- Find active clinical trials in the AACT database"""
|
| 605 |
+
|
| 606 |
+
# Add RAG capability only if enabled
|
| 607 |
+
enable_rag = os.getenv("ENABLE_RAG", "false").lower() == "true"
|
| 608 |
+
if enable_rag:
|
| 609 |
+
base_prompt += """
|
| 610 |
+
- **Semantic search using RAG**: Instantly search cached ALS research papers using AI-powered semantic matching"""
|
| 611 |
+
|
| 612 |
+
base_prompt += """
|
| 613 |
+
- Fetch and analyze web content
|
| 614 |
+
- Synthesize information from multiple sources
|
| 615 |
+
- Provide citations with PMIDs, DOIs, and NCT IDs
|
| 616 |
+
|
| 617 |
+
=== AGENTIC WORKFLOW (REQUIRED) ===
|
| 618 |
+
|
| 619 |
+
You MUST follow ALL FOUR phases for EVERY query - no exceptions:
|
| 620 |
+
|
| 621 |
+
1. **🎯 PLANNING PHASE** (MANDATORY - ALWAYS FIRST):
|
| 622 |
+
Before using any tools, you MUST explicitly outline your search strategy:"""
|
| 623 |
+
|
| 624 |
+
if enable_rag:
|
| 625 |
+
base_prompt += """
|
| 626 |
+
- FIRST check semantic cache using RAG for instant results from indexed papers"""
|
| 627 |
+
|
| 628 |
+
base_prompt += """
|
| 629 |
+
- State what databases you will search and in what order
|
| 630 |
+
- ALWAYS plan to search PubMed for peer-reviewed research
|
| 631 |
+
- For clinical questions, also include AACT trials database
|
| 632 |
+
- Identify key search terms and variations
|
| 633 |
+
- Explain your prioritization approach
|
| 634 |
+
- Format: MUST start on a NEW LINE with "**🎯 PLANNING:**" followed by your strategy
|
| 635 |
+
|
| 636 |
+
2. **🔧 EXECUTION PHASE** (MANDATORY - AFTER PLANNING):
|
| 637 |
+
- MUST mark this phase on a NEW LINE with "**🔧 EXECUTING:**"
|
| 638 |
+
- Execute your planned searches systematically"""
|
| 639 |
+
|
| 640 |
+
if enable_rag:
|
| 641 |
+
base_prompt += """
|
| 642 |
+
- START with semantic search using RAG for instant cached results"""
|
| 643 |
+
|
| 644 |
+
base_prompt += """
|
| 645 |
+
- MINIMUM requirement: Search PubMed for peer-reviewed literature
|
| 646 |
+
- For clinical questions, search AACT trials database
|
| 647 |
+
- Gather initial results from each source
|
| 648 |
+
- Show tool calls and results
|
| 649 |
+
- This phase is for INITIAL searches only (as planned)
|
| 650 |
+
|
| 651 |
+
3. **🤔 REFLECTION PHASE** (MANDATORY - AFTER EXECUTION):
|
| 652 |
+
After tool execution, you MUST ALWAYS reflect before synthesizing:
|
| 653 |
+
|
| 654 |
+
CRITICAL FORMAT REQUIREMENTS:
|
| 655 |
+
- MUST be EXACTLY: **🤔 REFLECTING:**
|
| 656 |
+
- MUST include the asterisks (**) for bold formatting
|
| 657 |
+
- MUST start on a NEW LINE (never inline with other text)
|
| 658 |
+
- WRONG: "🤔 REFLECTING:" (missing asterisks)
|
| 659 |
+
- WRONG: "search completed🤔 REFLECTING:" (inline, not on new line)
|
| 660 |
+
- CORRECT: New line, then **🤔 REFLECTING:**
|
| 661 |
+
|
| 662 |
+
Content requirements:
|
| 663 |
+
- Evaluate: "Do I have sufficient information to answer comprehensively?"
|
| 664 |
+
- Identify gaps: "What aspects of the query remain unaddressed?"
|
| 665 |
+
- Decide: "Should I refine my search or proceed to synthesis?"
|
| 666 |
+
|
| 667 |
+
CRITICAL: If you need more searches:
|
| 668 |
+
- DO NOT start a new PLANNING phase
|
| 669 |
+
- DO NOT write new phase markers
|
| 670 |
+
- Stay WITHIN the REFLECTION phase
|
| 671 |
+
- Simply continue searching and analyzing while in REFLECTING mode
|
| 672 |
+
- Additional searches are part of reflection, not a new workflow
|
| 673 |
+
|
| 674 |
+
- NEVER skip this phase - it ensures answer quality
|
| 675 |
+
|
| 676 |
+
4. **✅ SYNTHESIS PHASE** (MANDATORY - FINAL PHASE):
|
| 677 |
+
- MUST start on a NEW LINE with "**✅ SYNTHESIS:**"
|
| 678 |
+
- Provide comprehensive synthesis of all findings
|
| 679 |
+
- Include all citations with URLs
|
| 680 |
+
- Summarize key insights
|
| 681 |
+
- **CONFIDENCE SCORING**: Include confidence level for key claims:
|
| 682 |
+
• High confidence (🟢): Multiple peer-reviewed studies or systematic reviews
|
| 683 |
+
• Moderate confidence (🟡): Limited studies or preprints with consistent findings
|
| 684 |
+
• Low confidence (🔴): Single study, conflicting evidence, or theoretical basis
|
| 685 |
+
- This phase MUST appear in EVERY response
|
| 686 |
+
|
| 687 |
+
FORMATTING RULES:
|
| 688 |
+
- Each phase marker MUST appear on its own line
|
| 689 |
+
- Never put phase markers inline with other text
|
| 690 |
+
- Always use the exact format: **[emoji] PHASE_NAME:**
|
| 691 |
+
- MUST include asterisks for bold: **🤔 REFLECTING:** not just 🤔 REFLECTING:
|
| 692 |
+
- Each phase should appear EXACTLY ONCE per response - never repeat the workflow
|
| 693 |
+
|
| 694 |
+
CRITICAL WORKFLOW RULES:
|
| 695 |
+
- You MUST include ALL FOUR phases in your response
|
| 696 |
+
- Each phase appears EXACTLY ONCE (never repeat Planning→Executing→Reflecting→Synthesis)
|
| 697 |
+
- Missing any phase is unacceptable
|
| 698 |
+
- Duplicating phases is unacceptable
|
| 699 |
+
- The workflow is a SINGLE CYCLE:
|
| 700 |
+
1. PLANNING (once at start)
|
| 701 |
+
2. EXECUTING (initial searches)
|
| 702 |
+
3. REFLECTING (evaluate AND do additional searches if needed - all within this phase)
|
| 703 |
+
4. SYNTHESIS (final answer)
|
| 704 |
+
- NEVER restart the workflow - additional searches happen WITHIN reflection
|
| 705 |
+
|
| 706 |
+
CRITICAL SYNTHESIS RULES:
|
| 707 |
+
- You MUST ALWAYS end with a ✅ SYNTHESIS phase
|
| 708 |
+
- If searches fail, state "Despite search limitations..." and provide knowledge-based answer
|
| 709 |
+
- If you reach iteration limits, immediately provide synthesis
|
| 710 |
+
- NEVER end without synthesis - this is a MANDATORY requirement
|
| 711 |
+
- If uncertain, start synthesis with: "Based on available information..."
|
| 712 |
+
|
| 713 |
+
SYNTHESIS MUST INCLUDE:
|
| 714 |
+
1. Direct answer to the user's question
|
| 715 |
+
2. Key findings from successful searches (if any)
|
| 716 |
+
3. Citations with clickable URLs
|
| 717 |
+
4. If searches failed: explanation + knowledge-based answer
|
| 718 |
+
5. Suggested follow-up questions or alternative approaches
|
| 719 |
+
|
| 720 |
+
=== SELF-CORRECTION BEHAVIOR ===
|
| 721 |
+
|
| 722 |
+
If your searches return zero or insufficient results:
|
| 723 |
+
- Try broader search terms (remove qualifiers)
|
| 724 |
+
- Try alternative terminology or synonyms
|
| 725 |
+
- Search for related concepts
|
| 726 |
+
- Explicitly state what you tried and what you found
|
| 727 |
+
|
| 728 |
+
When answering:
|
| 729 |
+
1. Be concise in explanations while maintaining clarity
|
| 730 |
+
2. Focus on presenting search results efficiently
|
| 731 |
+
3. Always cite sources with specific identifiers AND URLs:
|
| 732 |
+
- PubMed: Include PMID and URL (https://pubmed.ncbi.nlm.nih.gov/PMID/)
|
| 733 |
+
- Preprints: Include DOI and URL (https://doi.org/DOI)
|
| 734 |
+
- Clinical Trials: Include NCT ID and URL (https://clinicaltrials.gov/study/NCTID)
|
| 735 |
+
4. Use numbered citations [1], [2] with a references section at the end
|
| 736 |
+
5. Prioritize recent research (2023-2025)
|
| 737 |
+
6. When discussing preprints, note they are NOT peer-reviewed
|
| 738 |
+
7. Explain complex concepts clearly
|
| 739 |
+
8. Acknowledge uncertainty when appropriate
|
| 740 |
+
9. Suggest related follow-up questions
|
| 741 |
+
|
| 742 |
+
CRITICAL CITATION RULES:
|
| 743 |
+
- ONLY cite papers, preprints, and trials that you have ACTUALLY found using the search tools
|
| 744 |
+
- NEVER make up or invent citations, PMIDs, DOIs, or NCT IDs
|
| 745 |
+
- NEVER cite papers from your training data unless you have verified them through search
|
| 746 |
+
- If you cannot find specific research on a topic, explicitly state "No studies found" rather than inventing citations
|
| 747 |
+
- Every citation must come from actual search results obtained through the available tools
|
| 748 |
+
- If asked about a topic you know from training but haven't searched, you MUST search first before citing
|
| 749 |
+
|
| 750 |
+
IMPORTANT: When referencing papers in your final answer, ALWAYS include clickable URLs alongside citations to make it easy for users to access the sources.
|
| 751 |
+
|
| 752 |
+
Available tools:
|
| 753 |
+
- pubmed__search_pubmed: Search peer-reviewed research literature
|
| 754 |
+
- pubmed__get_paper_details: Get full paper details from PubMed (USE SPARINGLY - only for most relevant papers)
|
| 755 |
+
# - biorxiv__search_preprints: (temporarily unavailable)
|
| 756 |
+
# - biorxiv__get_preprint_details: (temporarily unavailable)
|
| 757 |
+
- aact__search_aact_trials: Search clinical trials (PRIMARY - use this first)
|
| 758 |
+
- aact__get_aact_trial: Get specific trial details from AACT database
|
| 759 |
+
- trials_links__get_known_als_trials: Get curated list of important ALS trials (FALLBACK)
|
| 760 |
+
- trials_links__get_search_link: Generate direct ClinicalTrials.gov search URLs
|
| 761 |
+
- fetch__fetch_url: Retrieve web content
|
| 762 |
+
|
| 763 |
+
PERFORMANCE OPTIMIZATION:
|
| 764 |
+
- Search results already contain abstracts - use these for initial synthesis
|
| 765 |
+
- Only fetch full details for papers that are DIRECTLY relevant to the query
|
| 766 |
+
- Limit detail fetches to 5-7 most relevant items per database
|
| 767 |
+
- Prioritize based on: recency, relevance to query, impact/importance
|
| 768 |
+
|
| 769 |
+
Search strategy:
|
| 770 |
+
1. Search all relevant databases (PubMed, AACT clinical trials)
|
| 771 |
+
2. ALWAYS supplement with web fetching to:
|
| 772 |
+
- Find additional information not in databases
|
| 773 |
+
- Access sponsor/institution websites
|
| 774 |
+
- Get recent news and updates
|
| 775 |
+
- Retrieve full-text content when needed
|
| 776 |
+
- Verify and expand on database results
|
| 777 |
+
3. Synthesize all sources for comprehensive answers
|
| 778 |
+
|
| 779 |
+
For clinical trials - NEW ARCHITECTURE:
|
| 780 |
+
PRIMARY SOURCE - AACT Database:
|
| 781 |
+
- Use search_aact_trials FIRST - provides comprehensive clinical trials data from AACT database
|
| 782 |
+
- 559,000+ trials available with no rate limits
|
| 783 |
+
- Use uppercase status values: RECRUITING, ACTIVE_NOT_RECRUITING, NOT_YET_RECRUITING, COMPLETED
|
| 784 |
+
- For ALS searches, the condition "ALS" will automatically match related terms
|
| 785 |
+
|
| 786 |
+
FALLBACK - Links Server (when AACT unavailable):
|
| 787 |
+
- Use get_known_als_trials for curated list of 8 important ALS trials
|
| 788 |
+
- Use get_search_link to generate search URLs for clinical trials
|
| 789 |
+
- Use get_trial_link to generate direct links to specific trials
|
| 790 |
+
|
| 791 |
+
ADDITIONAL SOURCES:
|
| 792 |
+
- If specific NCT IDs are mentioned, can also use fetch__fetch_url with:
|
| 793 |
+
https://clinicaltrials.gov/study/{NCT_ID}
|
| 794 |
+
- Search sponsor websites, medical news, and university pages for updates
|
| 795 |
+
|
| 796 |
+
ARCHITECTURE FLOW:
|
| 797 |
+
User Query → AACT Database (Primary)
|
| 798 |
+
↓
|
| 799 |
+
If AACT unavailable
|
| 800 |
+
↓
|
| 801 |
+
Links Server (Fallback)
|
| 802 |
+
↓
|
| 803 |
+
Direct links to trial websites
|
| 804 |
+
|
| 805 |
+
Note: Direct API access is unavailable - using AACT database instead
|
| 806 |
+
"""
|
| 807 |
+
|
| 808 |
+
# Add enhanced instructions for Llama models to improve thoroughness
|
| 809 |
+
if client.is_using_llama_primary():
|
| 810 |
+
llama_enhancement = """
|
| 811 |
+
|
| 812 |
+
ENHANCED SEARCH REQUIREMENTS FOR COMPREHENSIVE RESULTS:
|
| 813 |
+
You MUST follow this structured approach for EVERY research query:
|
| 814 |
+
|
| 815 |
+
=== MANDATORY SEARCH PHASES ===
|
| 816 |
+
Phase 1 - Comprehensive Database Search (ALL databases REQUIRED):
|
| 817 |
+
□ Search PubMed with multiple keyword variations
|
| 818 |
+
□ Search AACT database for clinical trials
|
| 819 |
+
□ Use at least 3-5 different search queries per database
|
| 820 |
+
|
| 821 |
+
Phase 2 - Strategic Detail Fetching (BE SELECTIVE):
|
| 822 |
+
□ Get paper details for the TOP 5-7 most relevant PubMed results
|
| 823 |
+
□ Get trial details for the TOP 3-4 most relevant clinical trials
|
| 824 |
+
□ ONLY fetch details for papers that are DIRECTLY relevant to the query
|
| 825 |
+
□ Use search result abstracts to prioritize which papers need full details
|
| 826 |
+
|
| 827 |
+
Phase 3 - Synthesis Requirements:
|
| 828 |
+
□ Include ALL relevant papers found (not just top 3-5)
|
| 829 |
+
□ Organize by subtopic or treatment approach
|
| 830 |
+
□ Provide complete citations with URLs
|
| 831 |
+
|
| 832 |
+
MINIMUM SEARCH STANDARDS:
|
| 833 |
+
- For general queries: At least 10-15 total searches across all databases
|
| 834 |
+
- For specific treatments: At least 5-7 searches per database
|
| 835 |
+
- For comprehensive reviews: At least 15-20 total searches
|
| 836 |
+
- NEVER stop after finding just 2-3 results
|
| 837 |
+
|
| 838 |
+
EXAMPLE SEARCH PATTERN for "gene therapy ALS":
|
| 839 |
+
1. pubmed__search_pubmed: "gene therapy ALS"
|
| 840 |
+
2. pubmed__search_pubmed: "AAV ALS treatment"
|
| 841 |
+
3. pubmed__search_pubmed: "SOD1 gene therapy"
|
| 842 |
+
4. pubmed__search_pubmed: "C9orf72 gene therapy"
|
| 843 |
+
5. pubmed__search_pubmed: "viral vector ALS"
|
| 844 |
+
# 6. biorxiv__search_preprints: (temporarily unavailable)
|
| 845 |
+
# 7. biorxiv__search_preprints: (temporarily unavailable)
|
| 846 |
+
6. aact__search_aact_trials: condition="ALS", intervention="gene therapy"
|
| 847 |
+
7. aact__search_aact_trials: condition="ALS", intervention="AAV"
|
| 848 |
+
10. [Get details for ALL results found]
|
| 849 |
+
11. [Web fetch for recent developments]
|
| 850 |
+
|
| 851 |
+
CRITICAL: Thoroughness is MORE important than speed. Users expect comprehensive results."""
|
| 852 |
+
|
| 853 |
+
system_prompt = base_prompt + llama_enhancement
|
| 854 |
+
logger.info("Using enhanced prompting for Llama model to improve search thoroughness")
|
| 855 |
+
else:
|
| 856 |
+
# Use base prompt directly for Claude
|
| 857 |
+
system_prompt = base_prompt
|
| 858 |
+
|
| 859 |
+
# Import query classifier
|
| 860 |
+
from query_classifier import QueryClassifier
|
| 861 |
+
|
| 862 |
+
# Classify the query to determine processing mode
|
| 863 |
+
classification = QueryClassifier.classify_query(message)
|
| 864 |
+
processing_hint = QueryClassifier.get_processing_hint(classification)
|
| 865 |
+
logger.info(f"Query classification: {classification}")
|
| 866 |
+
|
| 867 |
+
# Check smart cache for similar queries first
|
| 868 |
+
cached_result = smart_cache.find_similar_cached(message)
|
| 869 |
+
if cached_result:
|
| 870 |
+
logger.info(f"Smart cache hit for query: {message[:50]}...")
|
| 871 |
+
yield "🎯 **Using cached result** (similar query found)\n\n"
|
| 872 |
+
yield cached_result
|
| 873 |
+
return
|
| 874 |
+
|
| 875 |
+
# Check if this is a high-frequency query with special config
|
| 876 |
+
high_freq_config = smart_cache.get_high_frequency_config(message)
|
| 877 |
+
if high_freq_config:
|
| 878 |
+
logger.info(f"High-frequency query detected with config: {high_freq_config}")
|
| 879 |
+
# Note: We could use optimized search terms or Claude here
|
| 880 |
+
# For now, just log it and continue with normal processing
|
| 881 |
+
|
| 882 |
+
# Get available tools
|
| 883 |
+
tools = await get_all_tools()
|
| 884 |
+
|
| 885 |
+
# Check if this is a simple query that doesn't need research
|
| 886 |
+
if not classification['requires_research']:
|
| 887 |
+
# Simple query - skip the full research workflow
|
| 888 |
+
logger.info(f"Simple query detected - using direct response mode: {classification['reason']}")
|
| 889 |
+
|
| 890 |
+
# Mark that this response won't use research workflow (disable voice button)
|
| 891 |
+
global last_response_was_research
|
| 892 |
+
last_response_was_research = False
|
| 893 |
+
|
| 894 |
+
# Use a simplified prompt for non-research queries
|
| 895 |
+
simple_prompt = """You are an AI assistant for ALS research questions.
|
| 896 |
+
For this query, provide a helpful, conversational response without using research tools.
|
| 897 |
+
Keep your response friendly and informative."""
|
| 898 |
+
|
| 899 |
+
# For simple queries, just make one API call without tools
|
| 900 |
+
messages = [
|
| 901 |
+
{"role": "system", "content": simple_prompt},
|
| 902 |
+
{"role": "user", "content": message}
|
| 903 |
+
]
|
| 904 |
+
|
| 905 |
+
# Display processing hint
|
| 906 |
+
yield f"{processing_hint}\n\n"
|
| 907 |
+
|
| 908 |
+
# Single API call for simple response (no tools)
|
| 909 |
+
async for response_text, tool_calls, provider_used in stream_with_retry(
|
| 910 |
+
client=client,
|
| 911 |
+
messages=messages,
|
| 912 |
+
tools=None, # No tools for simple queries
|
| 913 |
+
system_prompt=simple_prompt,
|
| 914 |
+
max_retries=2,
|
| 915 |
+
model=ANTHROPIC_MODEL,
|
| 916 |
+
max_tokens=2000, # Shorter responses for simple queries
|
| 917 |
+
stream_name="simple response"
|
| 918 |
+
):
|
| 919 |
+
yield response_text
|
| 920 |
+
|
| 921 |
+
# Return early - skip all the research phases
|
| 922 |
+
return
|
| 923 |
+
|
| 924 |
+
# Research query - use full workflow with tools
|
| 925 |
+
logger.info(f"Research query detected - using full workflow: {classification['reason']}")
|
| 926 |
+
|
| 927 |
+
# Mark that this response will use research workflow (enable voice button)
|
| 928 |
+
last_response_was_research = True
|
| 929 |
+
yield f"{processing_hint}\n\n"
|
| 930 |
+
|
| 931 |
+
# Build messages for research workflow
|
| 932 |
+
messages = [
|
| 933 |
+
{"role": "system", "content": system_prompt}
|
| 934 |
+
]
|
| 935 |
+
|
| 936 |
+
# Add history (remove Gradio metadata)
|
| 937 |
+
if history:
|
| 938 |
+
# Only keep 'role' and 'content' fields from messages
|
| 939 |
+
for msg in history:
|
| 940 |
+
if isinstance(msg, dict):
|
| 941 |
+
messages.append({
|
| 942 |
+
"role": msg.get("role"),
|
| 943 |
+
"content": msg.get("content")
|
| 944 |
+
})
|
| 945 |
+
else:
|
| 946 |
+
messages.append(msg)
|
| 947 |
+
|
| 948 |
+
# Add current message
|
| 949 |
+
messages.append({"role": "user", "content": message})
|
| 950 |
+
|
| 951 |
+
# Initial API call with streaming using helper function
|
| 952 |
+
full_response = ""
|
| 953 |
+
tool_calls = []
|
| 954 |
+
|
| 955 |
+
# Use the stream_with_retry helper to handle all retry logic
|
| 956 |
+
provider_used = "Anthropic Claude" # Track which provider
|
| 957 |
+
async for response_text, current_tool_calls, provider_used in stream_with_retry(
|
| 958 |
+
client=client,
|
| 959 |
+
messages=messages,
|
| 960 |
+
tools=tools,
|
| 961 |
+
system_prompt=system_prompt,
|
| 962 |
+
max_retries=2, # Increased from 0 to allow retries
|
| 963 |
+
model=ANTHROPIC_MODEL,
|
| 964 |
+
max_tokens=MAX_RESPONSE_TOKENS,
|
| 965 |
+
stream_name="initial API call"
|
| 966 |
+
):
|
| 967 |
+
full_response = response_text
|
| 968 |
+
tool_calls = current_tool_calls
|
| 969 |
+
# Apply single-pass filtering when yielding
|
| 970 |
+
# Optionally show provider info when using fallback
|
| 971 |
+
if provider_used != "Anthropic Claude" and response_text:
|
| 972 |
+
yield f"[Using {provider_used}]\n{filter_internal_tags(full_response)}"
|
| 973 |
+
else:
|
| 974 |
+
yield filter_internal_tags(full_response)
|
| 975 |
+
|
| 976 |
+
# Handle recursive tool calls (agent may need multiple searches)
|
| 977 |
+
tool_iteration = 0
|
| 978 |
+
|
| 979 |
+
# Adjust iteration limit based on query complexity
|
| 980 |
+
if is_complex_query(message):
|
| 981 |
+
max_tool_iterations = 5
|
| 982 |
+
logger.info("Complex query detected - allowing up to 5 iterations")
|
| 983 |
+
else:
|
| 984 |
+
max_tool_iterations = 3
|
| 985 |
+
logger.info("Standard query - allowing up to 3 iterations")
|
| 986 |
+
|
| 987 |
+
while tool_calls and tool_iteration < max_tool_iterations:
|
| 988 |
+
tool_iteration += 1
|
| 989 |
+
logger.info(f"Tool iteration {tool_iteration}: processing {len(tool_calls)} tool calls")
|
| 990 |
+
|
| 991 |
+
# No need to re-yield the planning phase - it was already shown
|
| 992 |
+
|
| 993 |
+
# Build assistant message using helper
|
| 994 |
+
assistant_content = build_assistant_message(
|
| 995 |
+
text_content=full_response,
|
| 996 |
+
tool_calls=tool_calls
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
messages.append({
|
| 1000 |
+
"role": "assistant",
|
| 1001 |
+
"content": assistant_content
|
| 1002 |
+
})
|
| 1003 |
+
|
| 1004 |
+
# Show working indicator for long searches
|
| 1005 |
+
num_tools = len(tool_calls)
|
| 1006 |
+
if num_tools > 0:
|
| 1007 |
+
working_text = f"\n⏳ **Searching {num_tools} database{'s' if num_tools > 1 else ''} in parallel...** "
|
| 1008 |
+
if num_tools > 2:
|
| 1009 |
+
working_text += f"(this typically takes 30-45 seconds)\n"
|
| 1010 |
+
elif num_tools > 1:
|
| 1011 |
+
working_text += f"(this typically takes 15-30 seconds)\n"
|
| 1012 |
+
else:
|
| 1013 |
+
working_text += f"\n"
|
| 1014 |
+
full_response += working_text
|
| 1015 |
+
yield filter_internal_tags(full_response) # Show working indicator immediately
|
| 1016 |
+
|
| 1017 |
+
# Execute tool calls in parallel for better performance
|
| 1018 |
+
from parallel_tool_execution import execute_tool_calls_parallel
|
| 1019 |
+
progress_text, tool_results_content = await execute_tool_calls_parallel(
|
| 1020 |
+
tool_calls=tool_calls,
|
| 1021 |
+
call_mcp_tool_func=call_mcp_tool
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
# Add progress text to full response and yield accumulated content
|
| 1025 |
+
full_response += progress_text
|
| 1026 |
+
if progress_text:
|
| 1027 |
+
yield filter_internal_tags(full_response) # Yield full accumulated response
|
| 1028 |
+
|
| 1029 |
+
# Add single user message with ALL tool results
|
| 1030 |
+
messages.append({
|
| 1031 |
+
"role": "user",
|
| 1032 |
+
"content": tool_results_content
|
| 1033 |
+
})
|
| 1034 |
+
|
| 1035 |
+
# Smart reflection: Only add reflection prompt if results seem incomplete
|
| 1036 |
+
if tool_iteration == 1:
|
| 1037 |
+
# First iteration - use normal workflow with reflection
|
| 1038 |
+
# Check confidence indicators in tool results
|
| 1039 |
+
results_text = str(tool_results_content).lower()
|
| 1040 |
+
|
| 1041 |
+
# Indicators of low confidence/incomplete results
|
| 1042 |
+
low_confidence_indicators = [
|
| 1043 |
+
'no results found', '0 results', 'no papers',
|
| 1044 |
+
'no trials', 'limited', 'insufficient', 'few results'
|
| 1045 |
+
]
|
| 1046 |
+
|
| 1047 |
+
# Indicators of high confidence/complete results
|
| 1048 |
+
high_confidence_indicators = [
|
| 1049 |
+
'recent study', 'multiple studies', 'clinical trial',
|
| 1050 |
+
'systematic review', 'meta-analysis', 'significant results'
|
| 1051 |
+
]
|
| 1052 |
+
|
| 1053 |
+
# Count confidence indicators
|
| 1054 |
+
low_conf_count = sum(1 for ind in low_confidence_indicators if ind in results_text)
|
| 1055 |
+
high_conf_count = sum(1 for ind in high_confidence_indicators if ind in results_text)
|
| 1056 |
+
|
| 1057 |
+
# Calculate total results found across all tools
|
| 1058 |
+
import re
|
| 1059 |
+
result_numbers = re.findall(r'(\d+)\s+(?:results?|papers?|studies|trials?)', results_text)
|
| 1060 |
+
total_results = sum(int(n) for n in result_numbers) if result_numbers else 0
|
| 1061 |
+
|
| 1062 |
+
# Decide if reflection is needed - more aggressive skipping for performance
|
| 1063 |
+
needs_reflection = (
|
| 1064 |
+
low_conf_count > 1 or # Only if multiple low-confidence indicators
|
| 1065 |
+
(high_conf_count == 0 and total_results < 10) or # No high confidence AND few results
|
| 1066 |
+
total_results < 3 # Almost no results at all
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
if needs_reflection:
|
| 1070 |
+
reflection_prompt = [
|
| 1071 |
+
{"type": "text", "text": "\n\n**SMART REFLECTION:** Based on the results so far, please evaluate:\n\n1. Do you have sufficient high-quality information to answer comprehensively?\n2. Are there important aspects that need more investigation?\n3. Would refining search terms or trying different databases help?\n\nIf confident with current information (found relevant studies/trials), proceed to synthesis with (**✅ ANSWER:**). Otherwise, use reflection markers (**🤔 REFLECTING:**) and search for missing information."}
|
| 1072 |
+
]
|
| 1073 |
+
messages.append({
|
| 1074 |
+
"role": "user",
|
| 1075 |
+
"content": reflection_prompt
|
| 1076 |
+
})
|
| 1077 |
+
logger.info(f"Smart reflection triggered (low_conf:{low_conf_count}, high_conf:{high_conf_count}, results:{total_results})")
|
| 1078 |
+
else:
|
| 1079 |
+
# High confidence - skip reflection and go straight to synthesis
|
| 1080 |
+
logger.info(f"Skipping reflection - high confidence (low_conf:{low_conf_count}, high_conf:{high_conf_count}, results:{total_results})")
|
| 1081 |
+
# Add a synthesis-only prompt
|
| 1082 |
+
synthesis_prompt = [
|
| 1083 |
+
{"type": "text", "text": "\n\n**HIGH CONFIDENCE RESULTS:** The search returned comprehensive information. Please proceed directly to synthesis with (**✅ SYNTHESIS:**) and provide a complete answer based on the findings."}
|
| 1084 |
+
]
|
| 1085 |
+
messages.append({
|
| 1086 |
+
"role": "user",
|
| 1087 |
+
"content": synthesis_prompt
|
| 1088 |
+
})
|
| 1089 |
+
else:
|
| 1090 |
+
# Subsequent iterations (tool_iteration > 1) - UPDATE existing synthesis without repeating workflow phases
|
| 1091 |
+
logger.info(f"Iteration {tool_iteration}: Updating synthesis with additional results")
|
| 1092 |
+
update_prompt = [
|
| 1093 |
+
{"type": "text", "text": "\n\n**ADDITIONAL RESULTS:** You have gathered more information. Please UPDATE your previous synthesis by integrating these new findings. Do NOT repeat the planning/executing/reflecting phases - just provide an updated synthesis that incorporates both the previous and new information. Continue directly with the updated content, no phase markers needed."}
|
| 1094 |
+
]
|
| 1095 |
+
messages.append({
|
| 1096 |
+
"role": "user",
|
| 1097 |
+
"content": update_prompt
|
| 1098 |
+
})
|
| 1099 |
+
|
| 1100 |
+
# Second API call with tool results (with retry logic)
|
| 1101 |
+
logger.info("Starting second streaming API call with tool results...")
|
| 1102 |
+
logger.info(f"Messages array has {len(messages)} messages")
|
| 1103 |
+
logger.info(f"Last 3 messages: {json.dumps([{'role': m.get('role'), 'content_type': type(m.get('content')).__name__, 'content_len': len(str(m.get('content')))} for m in messages[-3:]], indent=2)}")
|
| 1104 |
+
# Log the actual tool results content
|
| 1105 |
+
logger.info(f"Tool results content ({len(tool_results_content)} items): {json.dumps(tool_results_content[:1], indent=2) if tool_results_content else 'EMPTY'}") # Log first item only to avoid spam
|
| 1106 |
+
|
| 1107 |
+
# Second streaming call for synthesis
|
| 1108 |
+
synthesis_response = ""
|
| 1109 |
+
additional_tool_calls = []
|
| 1110 |
+
|
| 1111 |
+
# For subsequent iterations, use modified system prompt that doesn't require all phases
|
| 1112 |
+
iteration_system_prompt = system_prompt
|
| 1113 |
+
if tool_iteration > 1:
|
| 1114 |
+
iteration_system_prompt = """You are an AI assistant specializing in ALS (Amyotrophic Lateral Sclerosis) research.
|
| 1115 |
+
|
| 1116 |
+
You are continuing your research with additional results. Please integrate the new findings into an updated response.
|
| 1117 |
+
|
| 1118 |
+
IMPORTANT: Do NOT repeat the workflow phases (Planning/Executing/Reflecting/Synthesis) - you've already done those.
|
| 1119 |
+
Simply provide updated content that incorporates both previous and new information.
|
| 1120 |
+
Start your response directly with the updated information, no phase markers needed."""
|
| 1121 |
+
|
| 1122 |
+
# Limit tools on subsequent iterations to prevent endless loops
|
| 1123 |
+
available_tools = tools if tool_iteration == 1 else [] # No more tools after first iteration
|
| 1124 |
+
|
| 1125 |
+
async for response_text, current_tool_calls, provider_used in stream_with_retry(
|
| 1126 |
+
client=client,
|
| 1127 |
+
messages=messages,
|
| 1128 |
+
tools=available_tools,
|
| 1129 |
+
system_prompt=iteration_system_prompt,
|
| 1130 |
+
max_retries=2,
|
| 1131 |
+
model=ANTHROPIC_MODEL,
|
| 1132 |
+
max_tokens=MAX_RESPONSE_TOKENS,
|
| 1133 |
+
stream_name="synthesis API call"
|
| 1134 |
+
):
|
| 1135 |
+
synthesis_response = response_text
|
| 1136 |
+
additional_tool_calls = current_tool_calls
|
| 1137 |
+
|
| 1138 |
+
full_response += synthesis_response
|
| 1139 |
+
# Yield the full accumulated response including planning, execution, and synthesis
|
| 1140 |
+
yield filter_internal_tags(full_response)
|
| 1141 |
+
|
| 1142 |
+
# Check for additional tool calls
|
| 1143 |
+
if additional_tool_calls:
|
| 1144 |
+
logger.info(f"Found {len(additional_tool_calls)} recursive tool calls")
|
| 1145 |
+
|
| 1146 |
+
# Check if we're about to hit the iteration limit
|
| 1147 |
+
if tool_iteration >= (max_tool_iterations - 1): # Last iteration before limit
|
| 1148 |
+
# We're on the last allowed iteration
|
| 1149 |
+
logger.info(f"Approaching iteration limit ({max_tool_iterations}), wrapping up with current results")
|
| 1150 |
+
|
| 1151 |
+
# Don't execute more tools, instead trigger final synthesis
|
| 1152 |
+
# Add a user message to force final synthesis without tools
|
| 1153 |
+
messages.append({
|
| 1154 |
+
"role": "user",
|
| 1155 |
+
"content": [{"type": "text", "text": "Please provide a complete synthesis of all the information you've found so far. No more searches are available - summarize what you've discovered."}]
|
| 1156 |
+
})
|
| 1157 |
+
|
| 1158 |
+
# Make one final API call to synthesize all the results
|
| 1159 |
+
final_synthesis = ""
|
| 1160 |
+
async for response_text, _, provider_used in stream_with_retry(
|
| 1161 |
+
client=client,
|
| 1162 |
+
messages=messages,
|
| 1163 |
+
tools=[], # No tools for final synthesis
|
| 1164 |
+
system_prompt=system_prompt,
|
| 1165 |
+
max_retries=1,
|
| 1166 |
+
model=ANTHROPIC_MODEL,
|
| 1167 |
+
max_tokens=MAX_RESPONSE_TOKENS,
|
| 1168 |
+
stream_name="final synthesis"
|
| 1169 |
+
):
|
| 1170 |
+
final_synthesis = response_text
|
| 1171 |
+
|
| 1172 |
+
full_response += final_synthesis
|
| 1173 |
+
# Yield the full accumulated response
|
| 1174 |
+
yield filter_internal_tags(full_response)
|
| 1175 |
+
|
| 1176 |
+
# Clear tool_calls to exit the loop gracefully
|
| 1177 |
+
tool_calls = []
|
| 1178 |
+
else:
|
| 1179 |
+
# We have room for more iterations, proceed normally
|
| 1180 |
+
# Build assistant message for recursive calls
|
| 1181 |
+
assistant_content = build_assistant_message(
|
| 1182 |
+
text_content=synthesis_response,
|
| 1183 |
+
tool_calls=additional_tool_calls
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
messages.append({
|
| 1187 |
+
"role": "assistant",
|
| 1188 |
+
"content": assistant_content
|
| 1189 |
+
})
|
| 1190 |
+
|
| 1191 |
+
# Execute recursive tool calls
|
| 1192 |
+
progress_text, tool_results_content = await execute_tool_calls(
|
| 1193 |
+
tool_calls=additional_tool_calls,
|
| 1194 |
+
call_mcp_tool_func=call_mcp_tool
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
full_response += progress_text
|
| 1198 |
+
# Yield the full accumulated response
|
| 1199 |
+
if progress_text:
|
| 1200 |
+
yield filter_internal_tags(full_response)
|
| 1201 |
+
|
| 1202 |
+
# Add results and continue loop
|
| 1203 |
+
messages.append({
|
| 1204 |
+
"role": "user",
|
| 1205 |
+
"content": tool_results_content
|
| 1206 |
+
})
|
| 1207 |
+
|
| 1208 |
+
# Set tool_calls for next iteration
|
| 1209 |
+
tool_calls = additional_tool_calls
|
| 1210 |
+
else:
|
| 1211 |
+
# No more tool calls, exit loop
|
| 1212 |
+
tool_calls = []
|
| 1213 |
+
|
| 1214 |
+
if tool_iteration >= max_tool_iterations:
|
| 1215 |
+
logger.warning(f"Reached maximum tool iterations ({max_tool_iterations})")
|
| 1216 |
+
|
| 1217 |
+
# Force synthesis if we haven't provided one yet
|
| 1218 |
+
if tool_iteration > 0 and "✅ SYNTHESIS:" not in full_response:
|
| 1219 |
+
logger.warning(f"No synthesis found after {tool_iteration} iterations, forcing synthesis")
|
| 1220 |
+
|
| 1221 |
+
# Add a forced synthesis prompt
|
| 1222 |
+
synthesis_prompt_content = [{"type": "text", "text": "You MUST now provide a ✅ SYNTHESIS phase. Synthesize whatever information you've gathered, even if searches were limited. If you couldn't find specific research, provide knowledge-based answers with appropriate caveats."}]
|
| 1223 |
+
messages.append({
|
| 1224 |
+
"role": "user",
|
| 1225 |
+
"content": synthesis_prompt_content
|
| 1226 |
+
})
|
| 1227 |
+
|
| 1228 |
+
# Make final synthesis call without tools
|
| 1229 |
+
forced_synthesis = ""
|
| 1230 |
+
async for response_text, _, _ in stream_with_retry(
|
| 1231 |
+
client=client,
|
| 1232 |
+
messages=messages,
|
| 1233 |
+
tools=[], # No tools - just synthesize
|
| 1234 |
+
system_prompt=system_prompt,
|
| 1235 |
+
max_retries=1,
|
| 1236 |
+
model=ANTHROPIC_MODEL,
|
| 1237 |
+
max_tokens=MAX_RESPONSE_TOKENS,
|
| 1238 |
+
stream_name="forced synthesis"
|
| 1239 |
+
):
|
| 1240 |
+
forced_synthesis = response_text
|
| 1241 |
+
|
| 1242 |
+
full_response += "\n\n" + forced_synthesis
|
| 1243 |
+
# Yield the full accumulated response with forced synthesis
|
| 1244 |
+
yield filter_internal_tags(full_response)
|
| 1245 |
+
|
| 1246 |
+
# No final yield needed - response has already been yielded incrementally
|
| 1247 |
+
|
| 1248 |
+
# Record successful response time
|
| 1249 |
+
response_time = time.time() - start_time
|
| 1250 |
+
health_monitor.record_response_time(response_time)
|
| 1251 |
+
logger.info(f"Request completed in {response_time:.2f} seconds")
|
| 1252 |
+
|
| 1253 |
+
except Exception as e:
|
| 1254 |
+
logger.error(f"Error in als_research_agent: {e}", exc_info=True)
|
| 1255 |
+
health_monitor.record_error(str(e))
|
| 1256 |
+
error_message = format_error_message(e, context=f"Processing query: {message[:100]}...")
|
| 1257 |
+
yield error_message
|
| 1258 |
+
|
| 1259 |
+
# Gradio Interface
|
| 1260 |
+
async def main() -> None:
|
| 1261 |
+
"""Main function to setup and launch the Gradio interface"""
|
| 1262 |
+
global cleanup_task
|
| 1263 |
+
|
| 1264 |
+
try:
|
| 1265 |
+
# Setup MCP servers
|
| 1266 |
+
logger.info("Setting up MCP servers...")
|
| 1267 |
+
await setup_mcp_servers()
|
| 1268 |
+
logger.info("MCP servers initialized successfully")
|
| 1269 |
+
|
| 1270 |
+
# Start memory cleanup task
|
| 1271 |
+
cleanup_task = asyncio.create_task(cleanup_memory())
|
| 1272 |
+
logger.info("Memory cleanup task started")
|
| 1273 |
+
|
| 1274 |
+
except Exception as e:
|
| 1275 |
+
logger.error(f"Failed to initialize MCP servers: {e}", exc_info=True)
|
| 1276 |
+
raise
|
| 1277 |
+
|
| 1278 |
+
# Create Gradio interface with export button
|
| 1279 |
+
with gr.Blocks() as demo:
|
| 1280 |
+
gr.Markdown("# 🧬 ALSARA - ALS Agentic Research Assistant ")
|
| 1281 |
+
gr.Markdown("Ask questions about ALS research, treatments, and clinical trials. This agent searches PubMed, AACT clinical trials database, and other sources in real-time.")
|
| 1282 |
+
|
| 1283 |
+
# Show LLM configuration status using unified client
|
| 1284 |
+
llm_status = f"🤖 **LLM Provider:** {client.get_provider_display_name()}"
|
| 1285 |
+
gr.Markdown(llm_status)
|
| 1286 |
+
|
| 1287 |
+
with gr.Tabs():
|
| 1288 |
+
with gr.TabItem("Chat"):
|
| 1289 |
+
chatbot = gr.Chatbot(
|
| 1290 |
+
height=600,
|
| 1291 |
+
show_label=False,
|
| 1292 |
+
allow_tags=True, # Allow custom HTML tags from LLMs (Gradio 6 default)
|
| 1293 |
+
elem_classes="chatbot-container"
|
| 1294 |
+
)
|
| 1295 |
+
|
| 1296 |
+
with gr.TabItem("System Health"):
|
| 1297 |
+
gr.Markdown("## 📊 System Health Monitor")
|
| 1298 |
+
|
| 1299 |
+
def format_health_status():
|
| 1300 |
+
"""Format health status for display"""
|
| 1301 |
+
status = health_monitor.get_health_status()
|
| 1302 |
+
return f"""
|
| 1303 |
+
**Status:** {status['status'].upper()} {'✅' if status['status'] == 'healthy' else '⚠️'}
|
| 1304 |
+
|
| 1305 |
+
**Uptime:** {status['uptime_seconds'] / 3600:.1f} hours
|
| 1306 |
+
**Total Requests:** {status['request_count']}
|
| 1307 |
+
**Error Rate:** {status['error_rate']:.1%}
|
| 1308 |
+
**Avg Response Time:** {status['avg_response_time']:.2f}s
|
| 1309 |
+
|
| 1310 |
+
**Cache Status:**
|
| 1311 |
+
- Cache Size: {status['cache_size']} items
|
| 1312 |
+
- Rate Limiter: {status['rate_limit_status']}
|
| 1313 |
+
|
| 1314 |
+
**Most Used Tools:**
|
| 1315 |
+
{chr(10).join([f"- {tool}: {count} calls" for tool, count in status['most_used_tools'].items()])}
|
| 1316 |
+
|
| 1317 |
+
**Last Error:** {status['last_error']['error'] if status['last_error'] else 'None'}
|
| 1318 |
+
"""
|
| 1319 |
+
|
| 1320 |
+
health_display = gr.Markdown(format_health_status())
|
| 1321 |
+
refresh_btn = gr.Button("🔄 Refresh Health Status")
|
| 1322 |
+
refresh_btn.click(fn=format_health_status, outputs=health_display)
|
| 1323 |
+
|
| 1324 |
+
with gr.Row():
|
| 1325 |
+
with gr.Column(scale=6):
|
| 1326 |
+
msg = gr.Textbox(
|
| 1327 |
+
placeholder="Ask about ALS research, treatments, or clinical trials...",
|
| 1328 |
+
container=False,
|
| 1329 |
+
label="Type your question or use voice input"
|
| 1330 |
+
)
|
| 1331 |
+
with gr.Column(scale=1):
|
| 1332 |
+
audio_input = gr.Audio(
|
| 1333 |
+
sources=["microphone"],
|
| 1334 |
+
type="filepath",
|
| 1335 |
+
label="🎤 Voice Input"
|
| 1336 |
+
)
|
| 1337 |
+
export_btn = gr.DownloadButton("💾 Export", scale=1)
|
| 1338 |
+
|
| 1339 |
+
with gr.Row():
|
| 1340 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
| 1341 |
+
retry_btn = gr.Button("🔄 Retry")
|
| 1342 |
+
undo_btn = gr.Button("↩️ Undo")
|
| 1343 |
+
clear_btn = gr.Button("🗑️ Clear")
|
| 1344 |
+
speak_btn = gr.Button("🔊 Read Last Response", variant="secondary", interactive=False)
|
| 1345 |
+
|
| 1346 |
+
# Audio output component (initially hidden)
|
| 1347 |
+
with gr.Row(visible=False) as audio_row:
|
| 1348 |
+
audio_output = gr.Audio(
|
| 1349 |
+
label="🔊 Voice Output",
|
| 1350 |
+
type="filepath",
|
| 1351 |
+
autoplay=True,
|
| 1352 |
+
visible=True
|
| 1353 |
+
)
|
| 1354 |
+
|
| 1355 |
+
gr.Examples(
|
| 1356 |
+
examples=[
|
| 1357 |
+
"Psilocybin trials and use in therapy",
|
| 1358 |
+
"Role of Omega-3 and omega-6 fatty acids in ALS treatment",
|
| 1359 |
+
"List the main genes that should be tested for ALS gene therapy eligibility",
|
| 1360 |
+
"What are the latest SOD1-targeted therapies in recent preprints?",
|
| 1361 |
+
"Find recruiting clinical trials for bulbar-onset ALS",
|
| 1362 |
+
"Explain the role of TDP-43 in ALS pathology",
|
| 1363 |
+
"What is the current status of tofersen clinical trials?",
|
| 1364 |
+
"Are there any new combination therapies being studied?",
|
| 1365 |
+
"What's the latest research on ALS biomarkers from the past 60 days?",
|
| 1366 |
+
"Search PubMed for recent ALS gene therapy research"
|
| 1367 |
+
],
|
| 1368 |
+
inputs=msg
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
# Chat interface logic with improved error handling
|
| 1372 |
+
async def respond(message: str, history: Optional[List[Dict[str, str]]]) -> AsyncGenerator[List[Dict[str, str]], None]:
|
| 1373 |
+
history = history or []
|
| 1374 |
+
# Append user message
|
| 1375 |
+
history.append({"role": "user", "content": message})
|
| 1376 |
+
# Append empty assistant message
|
| 1377 |
+
history.append({"role": "assistant", "content": ""})
|
| 1378 |
+
|
| 1379 |
+
try:
|
| 1380 |
+
# Pass history without the new messages to als_research_agent
|
| 1381 |
+
async for response in als_research_agent(message, history[:-2]):
|
| 1382 |
+
# Update the last assistant message in place
|
| 1383 |
+
history[-1]['content'] = response
|
| 1384 |
+
yield history
|
| 1385 |
+
except Exception as e:
|
| 1386 |
+
logger.error(f"Error in respond: {e}", exc_info=True)
|
| 1387 |
+
error_msg = f"❌ Error: {str(e)}"
|
| 1388 |
+
history[-1]['content'] = error_msg
|
| 1389 |
+
yield history
|
| 1390 |
+
|
| 1391 |
+
def update_speak_button():
|
| 1392 |
+
"""Update the speak button state based on last_response_was_research"""
|
| 1393 |
+
global last_response_was_research
|
| 1394 |
+
return gr.update(interactive=last_response_was_research)
|
| 1395 |
+
|
| 1396 |
+
def undo_last(history: Optional[List[Dict[str, str]]]) -> Optional[List[Dict[str, str]]]:
|
| 1397 |
+
"""Remove the last message pair from history"""
|
| 1398 |
+
if history and len(history) >= 2:
|
| 1399 |
+
# Remove last user message and assistant response
|
| 1400 |
+
return history[:-2]
|
| 1401 |
+
return history
|
| 1402 |
+
|
| 1403 |
+
async def retry_last(history: Optional[List[Dict[str, str]]]) -> AsyncGenerator[List[Dict[str, str]], None]:
|
| 1404 |
+
"""Retry the last query with error handling"""
|
| 1405 |
+
if history and len(history) >= 2:
|
| 1406 |
+
# Get the last user message
|
| 1407 |
+
last_user_msg = history[-2]["content"] if history[-2]["role"] == "user" else None
|
| 1408 |
+
if last_user_msg:
|
| 1409 |
+
# Remove last assistant message, keep user message
|
| 1410 |
+
history = history[:-1]
|
| 1411 |
+
# Add new empty assistant message
|
| 1412 |
+
history.append({"role": "assistant", "content": ""})
|
| 1413 |
+
try:
|
| 1414 |
+
# Resubmit (pass history without the last user and assistant messages)
|
| 1415 |
+
async for response in als_research_agent(last_user_msg, history[:-2]):
|
| 1416 |
+
# Update the last assistant message in place
|
| 1417 |
+
history[-1]['content'] = response
|
| 1418 |
+
yield history
|
| 1419 |
+
except Exception as e:
|
| 1420 |
+
logger.error(f"Error in retry_last: {e}", exc_info=True)
|
| 1421 |
+
error_msg = f"❌ Error during retry: {str(e)}"
|
| 1422 |
+
history[-1]['content'] = error_msg
|
| 1423 |
+
yield history
|
| 1424 |
+
else:
|
| 1425 |
+
yield history
|
| 1426 |
+
else:
|
| 1427 |
+
yield history
|
| 1428 |
+
|
| 1429 |
+
async def process_voice_input(audio_file):
|
| 1430 |
+
"""Process voice input and convert to text"""
|
| 1431 |
+
try:
|
| 1432 |
+
if audio_file is None:
|
| 1433 |
+
return ""
|
| 1434 |
+
|
| 1435 |
+
# Try to use speech recognition if available
|
| 1436 |
+
try:
|
| 1437 |
+
import speech_recognition as sr
|
| 1438 |
+
recognizer = sr.Recognizer()
|
| 1439 |
+
|
| 1440 |
+
# Load audio file
|
| 1441 |
+
with sr.AudioFile(audio_file) as source:
|
| 1442 |
+
audio_data = recognizer.record(source)
|
| 1443 |
+
|
| 1444 |
+
# Use Google's free speech recognition
|
| 1445 |
+
try:
|
| 1446 |
+
text = recognizer.recognize_google(audio_data)
|
| 1447 |
+
logger.info(f"Voice input transcribed: {text[:50]}...")
|
| 1448 |
+
return text
|
| 1449 |
+
except sr.UnknownValueError:
|
| 1450 |
+
logger.warning("Could not understand audio")
|
| 1451 |
+
return ""
|
| 1452 |
+
except sr.RequestError as e:
|
| 1453 |
+
logger.error(f"Speech recognition service error: {e}")
|
| 1454 |
+
return ""
|
| 1455 |
+
|
| 1456 |
+
except ImportError:
|
| 1457 |
+
logger.warning("speech_recognition not available")
|
| 1458 |
+
return ""
|
| 1459 |
+
|
| 1460 |
+
except Exception as e:
|
| 1461 |
+
logger.error(f"Error processing voice input: {e}")
|
| 1462 |
+
return ""
|
| 1463 |
+
|
| 1464 |
+
async def speak_last_response(history: Optional[List[Dict[str, str]]]) -> Tuple[gr.update, gr.update]:
|
| 1465 |
+
"""Convert the last assistant response to speech using ElevenLabs"""
|
| 1466 |
+
try:
|
| 1467 |
+
# Check if the last response was from research workflow
|
| 1468 |
+
global last_response_was_research
|
| 1469 |
+
if not last_response_was_research:
|
| 1470 |
+
# This shouldn't happen since button is disabled, but handle it gracefully
|
| 1471 |
+
logger.info("Last response was not research-based, voice synthesis not available")
|
| 1472 |
+
return gr.update(visible=False), gr.update(value=None)
|
| 1473 |
+
|
| 1474 |
+
# Check ELEVENLABS_API_KEY
|
| 1475 |
+
api_key = os.getenv("ELEVENLABS_API_KEY")
|
| 1476 |
+
if not api_key:
|
| 1477 |
+
logger.warning("No ELEVENLABS_API_KEY configured")
|
| 1478 |
+
return gr.update(visible=True), gr.update(
|
| 1479 |
+
value=None,
|
| 1480 |
+
label="⚠️ Voice service unavailable - Please set ELEVENLABS_API_KEY"
|
| 1481 |
+
)
|
| 1482 |
+
|
| 1483 |
+
if not history or len(history) < 1:
|
| 1484 |
+
logger.warning("No history available for text-to-speech")
|
| 1485 |
+
return gr.update(visible=True), gr.update(
|
| 1486 |
+
value=None,
|
| 1487 |
+
label="⚠️ No conversation history to read"
|
| 1488 |
+
)
|
| 1489 |
+
|
| 1490 |
+
# Get the last assistant response
|
| 1491 |
+
last_response = None
|
| 1492 |
+
|
| 1493 |
+
# Detect and handle different history formats
|
| 1494 |
+
if isinstance(history, list) and len(history) > 0:
|
| 1495 |
+
# Check if history is a list of lists (Gradio chatbot format)
|
| 1496 |
+
if isinstance(history[0], list) and len(history[0]) == 2:
|
| 1497 |
+
# Format: [[user_msg, assistant_msg], ...]
|
| 1498 |
+
logger.info("Detected Gradio list-of-lists history format")
|
| 1499 |
+
for i, exchange in enumerate(reversed(history)):
|
| 1500 |
+
if len(exchange) == 2 and exchange[1]: # assistant message is second
|
| 1501 |
+
last_response = exchange[1]
|
| 1502 |
+
break
|
| 1503 |
+
elif isinstance(history[0], dict):
|
| 1504 |
+
# Format: [{"role": "user", "content": "..."}, ...]
|
| 1505 |
+
logger.info("Detected dict-based history format")
|
| 1506 |
+
for i, msg in enumerate(reversed(history)):
|
| 1507 |
+
if msg.get("role") == "assistant" and msg.get("content"):
|
| 1508 |
+
content = msg["content"]
|
| 1509 |
+
# CRITICAL FIX: Handle Claude API content blocks
|
| 1510 |
+
if isinstance(content, list):
|
| 1511 |
+
# Extract text from content blocks
|
| 1512 |
+
text_parts = []
|
| 1513 |
+
for block in content:
|
| 1514 |
+
if isinstance(block, dict):
|
| 1515 |
+
# Handle text block
|
| 1516 |
+
if block.get("type") == "text" and "text" in block:
|
| 1517 |
+
text_parts.append(block["text"])
|
| 1518 |
+
# Handle string content in dict
|
| 1519 |
+
elif "content" in block and isinstance(block["content"], str):
|
| 1520 |
+
text_parts.append(block["content"])
|
| 1521 |
+
elif isinstance(block, str):
|
| 1522 |
+
text_parts.append(block)
|
| 1523 |
+
last_response = "\n".join(text_parts)
|
| 1524 |
+
else:
|
| 1525 |
+
# Content is already a string
|
| 1526 |
+
last_response = content
|
| 1527 |
+
break
|
| 1528 |
+
elif isinstance(history[0], str):
|
| 1529 |
+
# Simple string list - take the last one
|
| 1530 |
+
logger.info("Detected simple string list history format")
|
| 1531 |
+
last_response = history[-1] if history else None
|
| 1532 |
+
else:
|
| 1533 |
+
# Unknown format - try to extract what we can
|
| 1534 |
+
logger.warning(f"Unknown history format: {type(history[0])}")
|
| 1535 |
+
# Try to convert to string as last resort
|
| 1536 |
+
try:
|
| 1537 |
+
last_response = str(history[-1]) if history else None
|
| 1538 |
+
except Exception as e:
|
| 1539 |
+
logger.error(f"Failed to extract last response: {e}")
|
| 1540 |
+
|
| 1541 |
+
if not last_response:
|
| 1542 |
+
logger.warning("No assistant response found in history")
|
| 1543 |
+
return gr.update(visible=True), gr.update(
|
| 1544 |
+
value=None,
|
| 1545 |
+
label="⚠️ No assistant response found to read"
|
| 1546 |
+
)
|
| 1547 |
+
|
| 1548 |
+
# Clean the response text (remove markdown, internal tags, etc.)
|
| 1549 |
+
# Convert to string if not already (safety check)
|
| 1550 |
+
last_response = str(last_response)
|
| 1551 |
+
|
| 1552 |
+
# IMPORTANT: Extract only the synthesis/main answer, skip references and "for more information"
|
| 1553 |
+
# Find where to cut off the response
|
| 1554 |
+
cutoff_patterns = [
|
| 1555 |
+
# Clear section headers with colons - most reliable indicators
|
| 1556 |
+
r'\n\s*(?:For (?:more|additional|further) (?:information|details|reading))\s*[::]',
|
| 1557 |
+
r'\n\s*(?:References?|Sources?|Citations?|Bibliography)\s*[::]',
|
| 1558 |
+
r'\n\s*(?:Additional (?:resources?|information|reading|materials?))\s*[::]',
|
| 1559 |
+
|
| 1560 |
+
# Markdown headers for reference sections (must be on their own line)
|
| 1561 |
+
r'\n\s*#{1,6}\s+(?:References?|Sources?|Citations?|Bibliography)\s*$',
|
| 1562 |
+
r'\n\s*#{1,6}\s+(?:For (?:more|additional|further) (?:information|details))\s*$',
|
| 1563 |
+
r'\n\s*#{1,6}\s+(?:Additional (?:Resources?|Information|Reading))\s*$',
|
| 1564 |
+
r'\n\s*#{1,6}\s+(?:Further Reading|Learn More)\s*$',
|
| 1565 |
+
|
| 1566 |
+
# Bold headers for reference sections (with newline after)
|
| 1567 |
+
r'\n\s*\*\*(?:References?|Sources?|Citations?)\*\*\s*[::]?\s*\n',
|
| 1568 |
+
r'\n\s*\*\*(?:For (?:more|additional) information)\*\*\s*[::]?\s*\n',
|
| 1569 |
+
|
| 1570 |
+
# Phrases that clearly introduce reference lists
|
| 1571 |
+
r'\n\s*(?:Here are|Below are|The following are)\s+(?:the |some |additional )?(?:references|sources|citations|papers cited|studies referenced)',
|
| 1572 |
+
r'\n\s*(?:References used|Sources consulted|Papers cited|Studies referenced)\s*[::]',
|
| 1573 |
+
r'\n\s*(?:Key|Recent|Selected|Relevant)\s+(?:references?|publications?|citations)\s*[::]',
|
| 1574 |
+
|
| 1575 |
+
# Clinical trials section headers with clear separators
|
| 1576 |
+
r'\n\s*(?:Clinical trials?|Studies|Research papers?)\s+(?:referenced|cited|mentioned|used)\s*[::]',
|
| 1577 |
+
r'\n\s*(?:AACT|ClinicalTrials\.gov)\s+(?:database entries?|trial IDs?|references?)\s*[::]',
|
| 1578 |
+
|
| 1579 |
+
# Web link sections
|
| 1580 |
+
r'\n\s*(?:Links?|URLs?|Websites?|Web resources?)\s*[::]',
|
| 1581 |
+
r'\n\s*(?:Visit|See|Check out)\s+(?:these|the following)\s+(?:links?|websites?|resources?)',
|
| 1582 |
+
r'\n\s*(?:Learn more|Read more|Find out more|Get more information)\s+(?:at|here|below)\s*[::]',
|
| 1583 |
+
|
| 1584 |
+
# Academic citation lists (only when preceded by double newline or clear separator)
|
| 1585 |
+
r'\n\n\s*\d+\.\s+[A-Z][a-z]+.*?et al\..*?(?:PMID|DOI|Journal)',
|
| 1586 |
+
r'\n\n\s*\[1\]\s+[A-Z][a-z]+.*?(?:et al\.|https?://)',
|
| 1587 |
+
|
| 1588 |
+
# Direct ID listings (clearly separate from main content)
|
| 1589 |
+
r'\n\s*(?:PMID|DOI|NCT)\s*[::]\s*\d+',
|
| 1590 |
+
r'\n\s*(?:Trial IDs?|Study IDs?)\s*[::]',
|
| 1591 |
+
|
| 1592 |
+
# Footer sections
|
| 1593 |
+
r'\n\s*(?:Note|Notes|Disclaimer|Important notice)\s*[::]',
|
| 1594 |
+
r'\n\s*(?:Data (?:source|from)|Database|Repository)\s*[::]',
|
| 1595 |
+
r'\n\s*(?:Retrieved from|Accessed via|Source database)\s*[::]',
|
| 1596 |
+
]
|
| 1597 |
+
|
| 1598 |
+
# FIRST: Extract ONLY the synthesis section (after ✅ SYNTHESIS:)
|
| 1599 |
+
# More robust pattern that handles various formatting
|
| 1600 |
+
synthesis_patterns = [
|
| 1601 |
+
r'✅\s*\*{0,2}SYNTHESIS\*{0,2}\s*:?\s*\n+(.*)', # Standard format with newline
|
| 1602 |
+
r'\*\*✅\s*SYNTHESIS:\*\*\s*(.*)', # Bold format
|
| 1603 |
+
r'✅\s*SYNTHESIS:\s*(.*)', # Simple format
|
| 1604 |
+
r'SYNTHESIS:\s*(.*)', # Fallback without emoji
|
| 1605 |
+
]
|
| 1606 |
+
|
| 1607 |
+
synthesis_text = None
|
| 1608 |
+
for pattern in synthesis_patterns:
|
| 1609 |
+
synthesis_match = re.search(pattern, last_response, re.IGNORECASE | re.DOTALL)
|
| 1610 |
+
if synthesis_match:
|
| 1611 |
+
synthesis_text = synthesis_match.group(1)
|
| 1612 |
+
logger.info(f"Extracted synthesis section using pattern: {pattern[:30]}...")
|
| 1613 |
+
break
|
| 1614 |
+
|
| 1615 |
+
if synthesis_text:
|
| 1616 |
+
logger.info("Extracted synthesis section for voice reading")
|
| 1617 |
+
else:
|
| 1618 |
+
# Fallback: if no synthesis marker found, use the whole response
|
| 1619 |
+
synthesis_text = last_response
|
| 1620 |
+
logger.info("No synthesis marker found, using full response")
|
| 1621 |
+
|
| 1622 |
+
# THEN: Remove references and footer sections
|
| 1623 |
+
for pattern in cutoff_patterns:
|
| 1624 |
+
match = re.search(pattern, synthesis_text, re.IGNORECASE | re.MULTILINE)
|
| 1625 |
+
if match:
|
| 1626 |
+
synthesis_text = synthesis_text[:match.start()]
|
| 1627 |
+
logger.info(f"Truncated response at pattern: {pattern[:50]}...")
|
| 1628 |
+
break
|
| 1629 |
+
|
| 1630 |
+
# Now clean the synthesis text
|
| 1631 |
+
clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', synthesis_text) # Remove bold
|
| 1632 |
+
clean_text = re.sub(r'\*(.*?)\*', r'\1', clean_text) # Remove italic
|
| 1633 |
+
clean_text = re.sub(r'#{1,6}\s*(.*?)\n', r'\1. ', clean_text) # Remove headers
|
| 1634 |
+
clean_text = re.sub(r'```.*?```', '', clean_text, flags=re.DOTALL) # Remove code blocks
|
| 1635 |
+
clean_text = re.sub(r'`(.*?)`', r'\1', clean_text) # Remove inline code
|
| 1636 |
+
clean_text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', clean_text) # Remove links
|
| 1637 |
+
clean_text = re.sub(r'<[^>]+>', '', clean_text) # Remove HTML tags
|
| 1638 |
+
clean_text = re.sub(r'\n{3,}', '\n\n', clean_text) # Reduce multiple newlines
|
| 1639 |
+
|
| 1640 |
+
# Strip leading/trailing whitespace
|
| 1641 |
+
clean_text = clean_text.strip()
|
| 1642 |
+
|
| 1643 |
+
# Ensure we have something to read
|
| 1644 |
+
if not clean_text or len(clean_text) < 10:
|
| 1645 |
+
logger.warning("Synthesis text too short after cleaning, using original")
|
| 1646 |
+
clean_text = last_response[:2500] # Fallback to first 2500 chars
|
| 1647 |
+
# Check if ElevenLabs server is available
|
| 1648 |
+
try:
|
| 1649 |
+
server_tools = await mcp_manager.list_all_tools()
|
| 1650 |
+
elevenlabs_available = any('elevenlabs' in tool for tool in server_tools.keys())
|
| 1651 |
+
if not elevenlabs_available:
|
| 1652 |
+
logger.error("ElevenLabs server not available in MCP tools")
|
| 1653 |
+
return gr.update(visible=True), gr.update(
|
| 1654 |
+
value=None,
|
| 1655 |
+
label="⚠️ Voice service not available - Please set ELEVENLABS_API_KEY"
|
| 1656 |
+
)
|
| 1657 |
+
except Exception as e:
|
| 1658 |
+
logger.error(f"Failed to check ElevenLabs availability: {e}", exc_info=True)
|
| 1659 |
+
return gr.update(visible=True), gr.update(
|
| 1660 |
+
value=None,
|
| 1661 |
+
label="⚠️ Voice service not available"
|
| 1662 |
+
)
|
| 1663 |
+
|
| 1664 |
+
# Remove phase markers from text
|
| 1665 |
+
clean_text = re.sub(r'\*\*[🎯🔧🤔✅].*?:\*\*', '', clean_text)
|
| 1666 |
+
# Call ElevenLabs text-to-speech through MCP
|
| 1667 |
+
logger.info(f"Calling ElevenLabs text-to-speech with {len(clean_text)} characters...")
|
| 1668 |
+
try:
|
| 1669 |
+
result = await call_mcp_tool(
|
| 1670 |
+
"elevenlabs__text_to_speech",
|
| 1671 |
+
{"text": clean_text, "speed": 0.95} # Slightly slower for clarity
|
| 1672 |
+
)
|
| 1673 |
+
except Exception as e:
|
| 1674 |
+
logger.error(f"MCP tool call failed: {e}", exc_info=True)
|
| 1675 |
+
raise
|
| 1676 |
+
|
| 1677 |
+
# Parse the result
|
| 1678 |
+
try:
|
| 1679 |
+
result_data = json.loads(result) if isinstance(result, str) else result
|
| 1680 |
+
# Check for API key error
|
| 1681 |
+
if "ELEVENLABS_API_KEY not configured" in str(result):
|
| 1682 |
+
logger.error("ElevenLabs API key not configured - found in result string")
|
| 1683 |
+
return gr.update(visible=True), gr.update(
|
| 1684 |
+
value=None,
|
| 1685 |
+
label="⚠️ Voice service unavailable - Please set ELEVENLABS_API_KEY environment variable"
|
| 1686 |
+
)
|
| 1687 |
+
|
| 1688 |
+
if result_data.get("status") == "success" and result_data.get("audio_base64"):
|
| 1689 |
+
# Save audio to temporary file
|
| 1690 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
|
| 1691 |
+
audio_data = base64.b64decode(result_data["audio_base64"])
|
| 1692 |
+
tmp_file.write(audio_data)
|
| 1693 |
+
audio_path = tmp_file.name
|
| 1694 |
+
|
| 1695 |
+
logger.info(f"Audio successfully generated and saved to: {audio_path}")
|
| 1696 |
+
return gr.update(visible=True), gr.update(
|
| 1697 |
+
value=audio_path,
|
| 1698 |
+
visible=True,
|
| 1699 |
+
label="🔊 Click to play voice output"
|
| 1700 |
+
)
|
| 1701 |
+
elif result_data.get("status") == "error":
|
| 1702 |
+
error_msg = result_data.get("message", "Unknown error")
|
| 1703 |
+
error_type = result_data.get("error", "Unknown")
|
| 1704 |
+
logger.error(f"ElevenLabs error - Type: {error_type}, Message: {error_msg}")
|
| 1705 |
+
return gr.update(visible=True), gr.update(
|
| 1706 |
+
value=None,
|
| 1707 |
+
label=f"⚠️ Voice service error: {error_msg}"
|
| 1708 |
+
)
|
| 1709 |
+
else:
|
| 1710 |
+
logger.error(f"Unexpected result structure")
|
| 1711 |
+
return gr.update(visible=True), gr.update(
|
| 1712 |
+
value=None,
|
| 1713 |
+
label="⚠️ Voice service returned no audio"
|
| 1714 |
+
)
|
| 1715 |
+
except json.JSONDecodeError as e:
|
| 1716 |
+
logger.error(f"JSON decode error: {e}")
|
| 1717 |
+
logger.error(f"Failed to parse ElevenLabs response, first 500 chars: {str(result)[:500]}")
|
| 1718 |
+
return gr.update(visible=True), gr.update(
|
| 1719 |
+
value=None,
|
| 1720 |
+
label="⚠️ Voice service response error"
|
| 1721 |
+
)
|
| 1722 |
+
except Exception as e:
|
| 1723 |
+
logger.error(f"Unexpected error in result parsing: {e}", exc_info=True)
|
| 1724 |
+
raise
|
| 1725 |
+
|
| 1726 |
+
except Exception as e:
|
| 1727 |
+
logger.error(f"Error in speak_last_response: {e}", exc_info=True)
|
| 1728 |
+
return gr.update(visible=True), gr.update(
|
| 1729 |
+
value=None,
|
| 1730 |
+
label=f"⚠️ Voice service error: {str(e)}"
|
| 1731 |
+
)
|
| 1732 |
+
|
| 1733 |
+
msg.submit(
|
| 1734 |
+
respond, [msg, chatbot], [chatbot],
|
| 1735 |
+
api_name="chat"
|
| 1736 |
+
).then(
|
| 1737 |
+
update_speak_button, None, [speak_btn]
|
| 1738 |
+
).then(
|
| 1739 |
+
lambda: "", None, [msg]
|
| 1740 |
+
)
|
| 1741 |
+
|
| 1742 |
+
# Add event handler for audio input
|
| 1743 |
+
audio_input.stop_recording(
|
| 1744 |
+
process_voice_input,
|
| 1745 |
+
inputs=[audio_input],
|
| 1746 |
+
outputs=[msg]
|
| 1747 |
+
).then(
|
| 1748 |
+
lambda: None,
|
| 1749 |
+
outputs=[audio_input] # Clear audio after processing
|
| 1750 |
+
)
|
| 1751 |
+
|
| 1752 |
+
submit_btn.click(
|
| 1753 |
+
respond, [msg, chatbot], [chatbot],
|
| 1754 |
+
api_name="chat_button"
|
| 1755 |
+
).then(
|
| 1756 |
+
update_speak_button, None, [speak_btn]
|
| 1757 |
+
).then(
|
| 1758 |
+
lambda: "", None, [msg]
|
| 1759 |
+
)
|
| 1760 |
+
|
| 1761 |
+
retry_btn.click(
|
| 1762 |
+
retry_last, [chatbot], [chatbot],
|
| 1763 |
+
api_name="retry"
|
| 1764 |
+
).then(
|
| 1765 |
+
update_speak_button, None, [speak_btn]
|
| 1766 |
+
)
|
| 1767 |
+
|
| 1768 |
+
undo_btn.click(
|
| 1769 |
+
undo_last, [chatbot], [chatbot],
|
| 1770 |
+
api_name="undo"
|
| 1771 |
+
)
|
| 1772 |
+
|
| 1773 |
+
clear_btn.click(
|
| 1774 |
+
lambda: None, None, chatbot,
|
| 1775 |
+
queue=False,
|
| 1776 |
+
api_name="clear"
|
| 1777 |
+
).then(
|
| 1778 |
+
lambda: gr.update(interactive=False), None, [speak_btn]
|
| 1779 |
+
)
|
| 1780 |
+
|
| 1781 |
+
export_btn.click(
|
| 1782 |
+
export_conversation, chatbot, export_btn,
|
| 1783 |
+
api_name="export"
|
| 1784 |
+
)
|
| 1785 |
+
|
| 1786 |
+
speak_btn.click(
|
| 1787 |
+
speak_last_response, [chatbot], [audio_row, audio_output],
|
| 1788 |
+
api_name="speak"
|
| 1789 |
+
)
|
| 1790 |
+
|
| 1791 |
+
# Enable queue for streaming to work
|
| 1792 |
+
demo.queue()
|
| 1793 |
+
|
| 1794 |
+
try:
|
| 1795 |
+
# Use environment variable for port, default to 7860 for HuggingFace
|
| 1796 |
+
port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
|
| 1797 |
+
demo.launch(
|
| 1798 |
+
server_name="0.0.0.0",
|
| 1799 |
+
server_port=port,
|
| 1800 |
+
share=False
|
| 1801 |
+
)
|
| 1802 |
+
except KeyboardInterrupt:
|
| 1803 |
+
logger.info("Received keyboard interrupt, shutting down...")
|
| 1804 |
+
except Exception as e:
|
| 1805 |
+
logger.error(f"Error during launch: {e}", exc_info=True)
|
| 1806 |
+
finally:
|
| 1807 |
+
# Cleanup
|
| 1808 |
+
logger.info("Cleaning up resources...")
|
| 1809 |
+
await cleanup_mcp_servers()
|
| 1810 |
+
|
| 1811 |
+
if __name__ == "__main__":
|
| 1812 |
+
try:
|
| 1813 |
+
asyncio.run(main())
|
| 1814 |
+
except KeyboardInterrupt:
|
| 1815 |
+
logger.info("Application terminated by user")
|
| 1816 |
+
except Exception as e:
|
| 1817 |
+
logger.error(f"Application error: {e}", exc_info=True)
|
| 1818 |
+
raise
|
| 1819 |
+
finally:
|
| 1820 |
+
# Cancel cleanup task if running
|
| 1821 |
+
if cleanup_task and not cleanup_task.done():
|
| 1822 |
+
cleanup_task.cancel()
|
| 1823 |
+
logger.info("Cancelled memory cleanup task")
|
| 1824 |
+
|
| 1825 |
+
# Cleanup unified LLM client
|
| 1826 |
+
if client is not None:
|
| 1827 |
+
try:
|
| 1828 |
+
asyncio.run(client.cleanup())
|
| 1829 |
+
logger.info("LLM client cleanup completed")
|
| 1830 |
+
except Exception as e:
|
| 1831 |
+
logger.warning(f"LLM client cleanup error: {e}")
|
| 1832 |
+
pass
|
custom_mcp_client.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom MCP client using direct subprocess communication.
|
| 3 |
+
This bypasses the buggy stdio_client from mcp.client.stdio.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import subprocess
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Dict, List, Optional
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MCPClient:
|
| 18 |
+
"""Custom MCP client using direct subprocess communication"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, server_script: str, server_name: str):
|
| 21 |
+
self.server_script = server_script
|
| 22 |
+
self.server_name = server_name
|
| 23 |
+
self.process: Optional[subprocess.Popen] = None
|
| 24 |
+
self.message_id = 0
|
| 25 |
+
self._initialized = False
|
| 26 |
+
self.script_path = server_script # Store for potential restart
|
| 27 |
+
|
| 28 |
+
async def start(self):
|
| 29 |
+
"""Start the MCP server subprocess"""
|
| 30 |
+
logger.info(f"Starting MCP server: {self.server_name}")
|
| 31 |
+
|
| 32 |
+
self.process = subprocess.Popen(
|
| 33 |
+
[sys.executable, self.server_script],
|
| 34 |
+
stdin=subprocess.PIPE,
|
| 35 |
+
stdout=subprocess.PIPE,
|
| 36 |
+
stderr=subprocess.PIPE,
|
| 37 |
+
text=True,
|
| 38 |
+
bufsize=1 # Line-buffered I/O to prevent 8KB truncation
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Initialize the session
|
| 42 |
+
await self._initialize()
|
| 43 |
+
logger.info(f"Successfully started MCP server: {self.server_name}")
|
| 44 |
+
|
| 45 |
+
async def _initialize(self):
|
| 46 |
+
"""Initialize the MCP session"""
|
| 47 |
+
init_message = {
|
| 48 |
+
"jsonrpc": "2.0",
|
| 49 |
+
"id": self._next_id(),
|
| 50 |
+
"method": "initialize",
|
| 51 |
+
"params": {
|
| 52 |
+
"protocolVersion": "2024-11-05",
|
| 53 |
+
"capabilities": {},
|
| 54 |
+
"clientInfo": {
|
| 55 |
+
"name": "als-research-agent",
|
| 56 |
+
"version": "1.0.0"
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
response = await self._send_request(init_message)
|
| 62 |
+
if "result" in response:
|
| 63 |
+
self._initialized = True
|
| 64 |
+
logger.info(f"Initialized {self.server_name}: {response['result'].get('serverInfo', {})}")
|
| 65 |
+
else:
|
| 66 |
+
raise Exception(f"Initialization failed: {response}")
|
| 67 |
+
|
| 68 |
+
def _next_id(self) -> int:
|
| 69 |
+
"""Get next message ID"""
|
| 70 |
+
self.message_id += 1
|
| 71 |
+
return self.message_id
|
| 72 |
+
|
| 73 |
+
async def _send_request(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
| 74 |
+
"""Send a JSON-RPC request and wait for response"""
|
| 75 |
+
if not self.process:
|
| 76 |
+
raise RuntimeError("Server not started")
|
| 77 |
+
|
| 78 |
+
# Check if process is still alive
|
| 79 |
+
if self.process.poll() is not None:
|
| 80 |
+
# Process has terminated
|
| 81 |
+
raise RuntimeError(f"Server {self.server_name} has terminated unexpectedly")
|
| 82 |
+
|
| 83 |
+
# Send request
|
| 84 |
+
request_json = json.dumps(message) + "\n"
|
| 85 |
+
self.process.stdin.write(request_json)
|
| 86 |
+
self.process.stdin.flush()
|
| 87 |
+
|
| 88 |
+
# Read response with timeout
|
| 89 |
+
try:
|
| 90 |
+
response_line = await asyncio.wait_for(
|
| 91 |
+
asyncio.to_thread(self.process.stdout.readline),
|
| 92 |
+
timeout=60.0 # Extended timeout for LlamaIndex/RAG server initialization
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if not response_line:
|
| 96 |
+
raise Exception("Server closed stdout")
|
| 97 |
+
|
| 98 |
+
return json.loads(response_line)
|
| 99 |
+
except asyncio.TimeoutError:
|
| 100 |
+
raise Exception("Request timed out")
|
| 101 |
+
|
| 102 |
+
async def list_tools(self) -> List[Dict[str, Any]]:
|
| 103 |
+
"""List available tools"""
|
| 104 |
+
if not self._initialized:
|
| 105 |
+
raise RuntimeError("Client not initialized")
|
| 106 |
+
|
| 107 |
+
message = {
|
| 108 |
+
"jsonrpc": "2.0",
|
| 109 |
+
"id": self._next_id(),
|
| 110 |
+
"method": "tools/list",
|
| 111 |
+
"params": {}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
response = await self._send_request(message)
|
| 115 |
+
if "result" in response:
|
| 116 |
+
return response["result"].get("tools", [])
|
| 117 |
+
else:
|
| 118 |
+
raise Exception(f"List tools failed: {response}")
|
| 119 |
+
|
| 120 |
+
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
| 121 |
+
"""Call a tool"""
|
| 122 |
+
if not self._initialized:
|
| 123 |
+
raise RuntimeError("Client not initialized")
|
| 124 |
+
|
| 125 |
+
message = {
|
| 126 |
+
"jsonrpc": "2.0",
|
| 127 |
+
"id": self._next_id(),
|
| 128 |
+
"method": "tools/call",
|
| 129 |
+
"params": {
|
| 130 |
+
"name": tool_name,
|
| 131 |
+
"arguments": arguments
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
response = await self._send_request(message)
|
| 136 |
+
if "result" in response:
|
| 137 |
+
# Extract result from response
|
| 138 |
+
result = response["result"]
|
| 139 |
+
|
| 140 |
+
# Handle different response formats
|
| 141 |
+
if isinstance(result, dict):
|
| 142 |
+
# New format with 'result' field
|
| 143 |
+
if "result" in result:
|
| 144 |
+
return result["result"]
|
| 145 |
+
# Content array format
|
| 146 |
+
elif "content" in result:
|
| 147 |
+
content = result["content"]
|
| 148 |
+
if isinstance(content, list) and len(content) > 0:
|
| 149 |
+
return content[0].get("text", str(content))
|
| 150 |
+
return str(content)
|
| 151 |
+
else:
|
| 152 |
+
return str(result)
|
| 153 |
+
else:
|
| 154 |
+
return str(result)
|
| 155 |
+
else:
|
| 156 |
+
error = response.get("error", {})
|
| 157 |
+
raise Exception(f"Tool call failed: {error.get('message', response)}")
|
| 158 |
+
|
| 159 |
+
async def close(self):
|
| 160 |
+
"""Close the MCP client and terminate server"""
|
| 161 |
+
if self.process:
|
| 162 |
+
logger.info(f"Closing MCP server: {self.server_name}")
|
| 163 |
+
self.process.terminate()
|
| 164 |
+
try:
|
| 165 |
+
self.process.wait(timeout=5)
|
| 166 |
+
except subprocess.TimeoutExpired:
|
| 167 |
+
self.process.kill()
|
| 168 |
+
self.process.wait()
|
| 169 |
+
self.process = None
|
| 170 |
+
self._initialized = False
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class MCPClientManager:
|
| 174 |
+
"""Manage multiple MCP clients"""
|
| 175 |
+
|
| 176 |
+
def __init__(self):
|
| 177 |
+
self.clients: Dict[str, MCPClient] = {}
|
| 178 |
+
|
| 179 |
+
async def add_server(self, name: str, script_path: str):
|
| 180 |
+
"""Add and start an MCP server"""
|
| 181 |
+
client = MCPClient(script_path, name)
|
| 182 |
+
await client.start()
|
| 183 |
+
self.clients[name] = client
|
| 184 |
+
logger.info(f"Added MCP server: {name}")
|
| 185 |
+
|
| 186 |
+
async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> str:
|
| 187 |
+
"""Call a tool on a specific server"""
|
| 188 |
+
if server_name not in self.clients:
|
| 189 |
+
raise ValueError(f"Server not found: {server_name}")
|
| 190 |
+
|
| 191 |
+
return await self.clients[server_name].call_tool(tool_name, arguments)
|
| 192 |
+
|
| 193 |
+
async def list_all_tools(self) -> Dict[str, List[Dict[str, Any]]]:
|
| 194 |
+
"""List tools from all servers, handling failures gracefully"""
|
| 195 |
+
all_tools = {}
|
| 196 |
+
failed_servers = []
|
| 197 |
+
|
| 198 |
+
for name, client in self.clients.items():
|
| 199 |
+
try:
|
| 200 |
+
tools = await client.list_tools()
|
| 201 |
+
for tool in tools:
|
| 202 |
+
tool['server'] = name # Add server info to each tool
|
| 203 |
+
all_tools[name] = tools
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Failed to list tools from server {name}: {e}")
|
| 206 |
+
failed_servers.append(name)
|
| 207 |
+
# Continue with other servers instead of failing entirely
|
| 208 |
+
all_tools[name] = []
|
| 209 |
+
|
| 210 |
+
if failed_servers:
|
| 211 |
+
logger.warning(f"Some servers failed to respond: {', '.join(failed_servers)}")
|
| 212 |
+
# Try to restart failed servers
|
| 213 |
+
for server_name in failed_servers:
|
| 214 |
+
try:
|
| 215 |
+
client = self.clients[server_name]
|
| 216 |
+
script_path = client.script_path if hasattr(client, 'script_path') else None
|
| 217 |
+
if script_path:
|
| 218 |
+
logger.info(f"Attempting to restart {server_name} server...")
|
| 219 |
+
await client.close()
|
| 220 |
+
# Re-add the server (which will restart it)
|
| 221 |
+
await self.add_server(server_name, script_path)
|
| 222 |
+
# Try listing tools again after restart
|
| 223 |
+
tools = await self.clients[server_name].list_tools()
|
| 224 |
+
for tool in tools:
|
| 225 |
+
tool['server'] = server_name
|
| 226 |
+
all_tools[server_name] = tools
|
| 227 |
+
logger.info(f"Successfully restarted {server_name} server")
|
| 228 |
+
except Exception as restart_error:
|
| 229 |
+
logger.error(f"Failed to restart {server_name}: {restart_error}")
|
| 230 |
+
# Remove the failed server from clients to prevent further errors
|
| 231 |
+
if server_name in self.clients:
|
| 232 |
+
del self.clients[server_name]
|
| 233 |
+
|
| 234 |
+
return all_tools
|
| 235 |
+
|
| 236 |
+
async def close_all(self):
|
| 237 |
+
"""Close all MCP clients"""
|
| 238 |
+
for client in self.clients.values():
|
| 239 |
+
await client.close()
|
| 240 |
+
self.clients.clear()
|
| 241 |
+
logger.info("All MCP servers closed")
|
llm_client.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Unified LLM Client - Single interface for all LLM providers
|
| 4 |
+
Handles Anthropic, SambaNova, and automatic fallback logic internally
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
import asyncio
|
| 10 |
+
import httpx
|
| 11 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple
|
| 12 |
+
from anthropic import AsyncAnthropic
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class UnifiedLLMClient:
|
| 20 |
+
"""
|
| 21 |
+
Unified client that abstracts all LLM provider logic.
|
| 22 |
+
Provides a single, clean interface to the application.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
"""Initialize the unified client with automatic provider selection"""
|
| 27 |
+
self.primary_client = None
|
| 28 |
+
self.fallback_router = None
|
| 29 |
+
self.provider_name = None
|
| 30 |
+
self.config = self._load_configuration()
|
| 31 |
+
self._initialize_providers()
|
| 32 |
+
|
| 33 |
+
def _load_configuration(self) -> Dict[str, Any]:
|
| 34 |
+
"""Load configuration from environment variables"""
|
| 35 |
+
return {
|
| 36 |
+
"anthropic_api_key": os.getenv("ANTHROPIC_API_KEY"),
|
| 37 |
+
"use_fallback": os.getenv("USE_FALLBACK_LLM", "false").lower() == "true",
|
| 38 |
+
"provider_preference": os.getenv("LLM_PROVIDER_PREFERENCE", "auto"),
|
| 39 |
+
"default_model": os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5-20250929"),
|
| 40 |
+
"max_retries": int(os.getenv("LLM_MAX_RETRIES", "2")),
|
| 41 |
+
"is_hf_space": os.getenv("SPACE_ID") is not None,
|
| 42 |
+
"enable_smart_routing": os.getenv("ENABLE_SMART_ROUTING", "false").lower() == "true"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def _initialize_providers(self):
|
| 46 |
+
"""Initialize LLM providers based on configuration"""
|
| 47 |
+
|
| 48 |
+
# Try to initialize Anthropic first
|
| 49 |
+
if self.config["anthropic_api_key"]:
|
| 50 |
+
try:
|
| 51 |
+
self.primary_client = AsyncAnthropic(api_key=self.config["anthropic_api_key"])
|
| 52 |
+
self.provider_name = "Anthropic Claude"
|
| 53 |
+
logger.info("Anthropic client initialized successfully")
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.warning(f"Failed to initialize Anthropic client: {e}")
|
| 56 |
+
self.primary_client = None
|
| 57 |
+
|
| 58 |
+
# Initialize fallback if needed
|
| 59 |
+
if self.config["use_fallback"] or not self.primary_client:
|
| 60 |
+
try:
|
| 61 |
+
from llm_providers import llm_router
|
| 62 |
+
self.fallback_router = llm_router
|
| 63 |
+
|
| 64 |
+
if not self.primary_client:
|
| 65 |
+
self.provider_name = "SambaNova Llama 3.3 70B"
|
| 66 |
+
logger.info("Using SambaNova as primary provider")
|
| 67 |
+
else:
|
| 68 |
+
logger.info("SambaNova fallback configured for automatic failover")
|
| 69 |
+
|
| 70 |
+
except ImportError:
|
| 71 |
+
logger.warning("Fallback LLM provider not available")
|
| 72 |
+
|
| 73 |
+
if not self.primary_client:
|
| 74 |
+
self._raise_configuration_error()
|
| 75 |
+
|
| 76 |
+
def _raise_configuration_error(self):
|
| 77 |
+
"""Raise appropriate error for missing configuration"""
|
| 78 |
+
if self.config["is_hf_space"]:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"🚨 No LLM provider configured!\n\n"
|
| 81 |
+
"Option 1: Add your Anthropic API key as a Space secret:\n"
|
| 82 |
+
"1. Go to your Space Settings\n"
|
| 83 |
+
"2. Add secret: ANTHROPIC_API_KEY = your_key\n\n"
|
| 84 |
+
"Option 2: Enable free SambaNova fallback:\n"
|
| 85 |
+
"Add secret: USE_FALLBACK_LLM = true"
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
"No LLM provider configured.\n\n"
|
| 90 |
+
"Option 1: Add to .env file:\n"
|
| 91 |
+
"ANTHROPIC_API_KEY=your_api_key_here\n\n"
|
| 92 |
+
"Option 2: Enable free SambaNova:\n"
|
| 93 |
+
"USE_FALLBACK_LLM=true"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
async def stream(
|
| 97 |
+
self,
|
| 98 |
+
messages: List[Dict],
|
| 99 |
+
tools: List[Dict] = None,
|
| 100 |
+
system_prompt: str = None,
|
| 101 |
+
model: str = None,
|
| 102 |
+
max_tokens: int = 8192,
|
| 103 |
+
temperature: float = 0.7
|
| 104 |
+
) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
|
| 105 |
+
"""
|
| 106 |
+
Stream responses from the LLM with automatic fallback.
|
| 107 |
+
|
| 108 |
+
This is the main interface - it handles all provider selection,
|
| 109 |
+
retries, and fallback logic internally.
|
| 110 |
+
|
| 111 |
+
Yields: (response_text, tool_calls, provider_used)
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
# Use default model if not specified
|
| 115 |
+
if model is None:
|
| 116 |
+
model = self.config["default_model"]
|
| 117 |
+
|
| 118 |
+
# Track which provider we're using
|
| 119 |
+
provider_used = self.provider_name
|
| 120 |
+
|
| 121 |
+
# Determine provider order based on preference
|
| 122 |
+
use_anthropic_first = True
|
| 123 |
+
if self.config["provider_preference"] == "cost_optimize" and self.fallback_router:
|
| 124 |
+
# With cost_optimize, prefer SambaNova first
|
| 125 |
+
use_anthropic_first = False
|
| 126 |
+
|
| 127 |
+
# Apply smart routing if enabled
|
| 128 |
+
if self.config.get("enable_smart_routing", False) and self.primary_client and self.fallback_router:
|
| 129 |
+
# Extract the last user message for analysis
|
| 130 |
+
last_message = ""
|
| 131 |
+
for msg in reversed(messages):
|
| 132 |
+
if msg.get("role") == "user":
|
| 133 |
+
if isinstance(msg.get("content"), str):
|
| 134 |
+
last_message = msg["content"]
|
| 135 |
+
elif isinstance(msg.get("content"), list):
|
| 136 |
+
# Extract text from content blocks
|
| 137 |
+
for block in msg["content"]:
|
| 138 |
+
if isinstance(block, dict) and block.get("type") == "text":
|
| 139 |
+
last_message = block.get("text", "")
|
| 140 |
+
break
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
if last_message:
|
| 144 |
+
# Classify the query
|
| 145 |
+
query_type = self.classify_query_complexity(
|
| 146 |
+
last_message,
|
| 147 |
+
len(tools) if tools else 0
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Override provider preference based on classification
|
| 151 |
+
if query_type == "simple":
|
| 152 |
+
if use_anthropic_first:
|
| 153 |
+
logger.info(f"Smart routing: Directing simple query to Llama for cost savings: '{last_message[:80]}...'")
|
| 154 |
+
use_anthropic_first = False
|
| 155 |
+
elif query_type == "complex":
|
| 156 |
+
if not use_anthropic_first:
|
| 157 |
+
logger.info(f"Smart routing: Directing complex query to Claude for better quality: '{last_message[:80]}...'")
|
| 158 |
+
use_anthropic_first = True
|
| 159 |
+
|
| 160 |
+
# Try first provider based on preference
|
| 161 |
+
if use_anthropic_first and self.primary_client:
|
| 162 |
+
try:
|
| 163 |
+
async for result in self._stream_anthropic(
|
| 164 |
+
messages, tools, system_prompt, model, max_tokens, temperature
|
| 165 |
+
):
|
| 166 |
+
yield result
|
| 167 |
+
return # Success, exit
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.warning(f"Primary provider failed: {e}")
|
| 170 |
+
|
| 171 |
+
# Fall through to fallback if available
|
| 172 |
+
if not self.fallback_router:
|
| 173 |
+
raise
|
| 174 |
+
|
| 175 |
+
# Try fallback provider
|
| 176 |
+
if self.fallback_router:
|
| 177 |
+
if not use_anthropic_first or not self.primary_client:
|
| 178 |
+
logger.info("Using SambaNova as primary provider (cost_optimize mode)" if not use_anthropic_first else "Using fallback LLM provider")
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
# Override provider preference to force SambaNova when smart routing decided to use it
|
| 182 |
+
effective_preference = "cost_optimize" if not use_anthropic_first else self.config["provider_preference"]
|
| 183 |
+
|
| 184 |
+
async for text, tool_calls, provider in self.fallback_router.stream_with_fallback(
|
| 185 |
+
messages=messages,
|
| 186 |
+
tools=tools or [],
|
| 187 |
+
system_prompt=system_prompt,
|
| 188 |
+
model=model,
|
| 189 |
+
max_tokens=max_tokens,
|
| 190 |
+
provider_preference=effective_preference
|
| 191 |
+
):
|
| 192 |
+
yield (text, tool_calls, provider)
|
| 193 |
+
|
| 194 |
+
# If we used SambaNova first successfully with cost_optimize, we're done
|
| 195 |
+
if not use_anthropic_first:
|
| 196 |
+
return
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
if not use_anthropic_first and self.primary_client:
|
| 200 |
+
# SambaNova failed in cost_optimize mode, try Anthropic
|
| 201 |
+
logger.warning(f"SambaNova failed in cost_optimize mode: {e}, falling back to Anthropic")
|
| 202 |
+
try:
|
| 203 |
+
async for result in self._stream_anthropic(
|
| 204 |
+
messages, tools, system_prompt, model, max_tokens, temperature
|
| 205 |
+
):
|
| 206 |
+
yield result
|
| 207 |
+
return # Success, exit
|
| 208 |
+
except Exception as anthropic_error:
|
| 209 |
+
logger.error(f"All LLM providers failed: SambaNova: {e}, Anthropic: {anthropic_error}")
|
| 210 |
+
raise RuntimeError("All LLM providers failed. Please check configuration.")
|
| 211 |
+
else:
|
| 212 |
+
logger.error(f"All LLM providers failed: {e}")
|
| 213 |
+
raise RuntimeError("All LLM providers failed. Please check configuration.")
|
| 214 |
+
else:
|
| 215 |
+
raise RuntimeError("No LLM providers available")
|
| 216 |
+
|
| 217 |
+
async def _stream_anthropic(
|
| 218 |
+
self,
|
| 219 |
+
messages: List[Dict],
|
| 220 |
+
tools: List[Dict],
|
| 221 |
+
system_prompt: str,
|
| 222 |
+
model: str,
|
| 223 |
+
max_tokens: int,
|
| 224 |
+
temperature: float
|
| 225 |
+
) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
|
| 226 |
+
"""Stream from Anthropic with retry logic"""
|
| 227 |
+
|
| 228 |
+
retry_delay = 1
|
| 229 |
+
last_error = None
|
| 230 |
+
|
| 231 |
+
# Skip system message if it's in messages array
|
| 232 |
+
api_messages = messages[1:] if messages and messages[0].get("role") == "system" else messages
|
| 233 |
+
|
| 234 |
+
# Use system prompt or extract from messages
|
| 235 |
+
if not system_prompt and messages and messages[0].get("role") == "system":
|
| 236 |
+
system_prompt = messages[0].get("content", "")
|
| 237 |
+
|
| 238 |
+
for attempt in range(self.config["max_retries"] + 1):
|
| 239 |
+
try:
|
| 240 |
+
logger.info(f"Streaming from Anthropic (attempt {attempt + 1})")
|
| 241 |
+
|
| 242 |
+
accumulated_text = ""
|
| 243 |
+
tool_calls = []
|
| 244 |
+
|
| 245 |
+
# Create the stream
|
| 246 |
+
stream_params = {
|
| 247 |
+
"model": model,
|
| 248 |
+
"max_tokens": max_tokens,
|
| 249 |
+
"messages": api_messages,
|
| 250 |
+
"temperature": temperature
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
if system_prompt:
|
| 254 |
+
stream_params["system"] = system_prompt
|
| 255 |
+
|
| 256 |
+
if tools:
|
| 257 |
+
stream_params["tools"] = tools
|
| 258 |
+
|
| 259 |
+
async with self.primary_client.messages.stream(**stream_params) as stream:
|
| 260 |
+
async for event in stream:
|
| 261 |
+
if event.type == "content_block_start":
|
| 262 |
+
if event.content_block.type == "tool_use":
|
| 263 |
+
tool_calls.append({
|
| 264 |
+
"id": event.content_block.id,
|
| 265 |
+
"name": event.content_block.name,
|
| 266 |
+
"input": {}
|
| 267 |
+
})
|
| 268 |
+
|
| 269 |
+
elif event.type == "content_block_delta":
|
| 270 |
+
if event.delta.type == "text_delta":
|
| 271 |
+
accumulated_text += event.delta.text
|
| 272 |
+
yield (accumulated_text, tool_calls, "Anthropic Claude")
|
| 273 |
+
|
| 274 |
+
# Get final message
|
| 275 |
+
final_message = await stream.get_final_message()
|
| 276 |
+
|
| 277 |
+
# Rebuild tool calls from final message
|
| 278 |
+
tool_calls.clear()
|
| 279 |
+
for block in final_message.content:
|
| 280 |
+
if block.type == "tool_use":
|
| 281 |
+
tool_calls.append({
|
| 282 |
+
"id": block.id,
|
| 283 |
+
"name": block.name,
|
| 284 |
+
"input": block.input
|
| 285 |
+
})
|
| 286 |
+
elif block.type == "text" and block.text:
|
| 287 |
+
if block.text not in accumulated_text:
|
| 288 |
+
accumulated_text += block.text
|
| 289 |
+
|
| 290 |
+
yield (accumulated_text, tool_calls, "Anthropic Claude")
|
| 291 |
+
return # Success
|
| 292 |
+
|
| 293 |
+
except (httpx.RemoteProtocolError, httpx.ReadError) as e:
|
| 294 |
+
last_error = e
|
| 295 |
+
logger.warning(f"Network error on attempt {attempt + 1}: {e}")
|
| 296 |
+
|
| 297 |
+
if attempt < self.config["max_retries"]:
|
| 298 |
+
await asyncio.sleep(retry_delay)
|
| 299 |
+
retry_delay *= 2
|
| 300 |
+
else:
|
| 301 |
+
raise
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
logger.error(f"Anthropic streaming error: {e}")
|
| 305 |
+
raise
|
| 306 |
+
|
| 307 |
+
def get_status(self) -> Dict[str, Any]:
|
| 308 |
+
"""Get current client status and configuration"""
|
| 309 |
+
return {
|
| 310 |
+
"primary_provider": "Anthropic" if self.primary_client else None,
|
| 311 |
+
"fallback_enabled": bool(self.fallback_router),
|
| 312 |
+
"current_provider": self.provider_name,
|
| 313 |
+
"provider_preference": self.config["provider_preference"],
|
| 314 |
+
"max_retries": self.config["max_retries"]
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
def is_using_llama_primary(self) -> bool:
|
| 318 |
+
"""Check if Llama/SambaNova is the primary provider"""
|
| 319 |
+
# Check if cost_optimize preference is set and fallback is available
|
| 320 |
+
if self.config.get("provider_preference") == "cost_optimize" and self.fallback_router:
|
| 321 |
+
return True
|
| 322 |
+
# Check if we have no Anthropic client and are using SambaNova
|
| 323 |
+
if not self.primary_client and self.fallback_router:
|
| 324 |
+
return True
|
| 325 |
+
return False
|
| 326 |
+
|
| 327 |
+
def classify_query_complexity(self, message: str, tools_count: int = 0) -> str:
|
| 328 |
+
"""
|
| 329 |
+
Classify query as 'simple' or 'complex' based on content analysis.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
message: The user's query text
|
| 333 |
+
tools_count: Number of tools available for this query
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
'simple' | 'complex' - The query classification
|
| 337 |
+
"""
|
| 338 |
+
message_lower = message.lower()
|
| 339 |
+
|
| 340 |
+
# Simple query indicators (good for Llama)
|
| 341 |
+
simple_patterns = [
|
| 342 |
+
"what is", "define", "when was", "who is", "list of",
|
| 343 |
+
"how many", "name the", "what does", "explain what",
|
| 344 |
+
"is there", "are there", "can you list", "tell me about",
|
| 345 |
+
"what are the symptoms", "side effects of", "list the",
|
| 346 |
+
"symptoms of", "treatment for", "causes of"
|
| 347 |
+
]
|
| 348 |
+
|
| 349 |
+
# Complex query indicators (better for Claude)
|
| 350 |
+
complex_patterns = [
|
| 351 |
+
"analyze", "compare", "evaluate", "synthesize", "comprehensive",
|
| 352 |
+
"all", "every", "detailed", "mechanism", "pathophysiology",
|
| 353 |
+
"genotyping", "gene therapy", "combination therapy",
|
| 354 |
+
"latest research", "recent studies", "cutting-edge",
|
| 355 |
+
"molecular", "genetic mutation", "therapeutic pipeline",
|
| 356 |
+
"clinical trial results", "meta-analysis", "systematic review",
|
| 357 |
+
# Enhanced trial-related patterns
|
| 358 |
+
"trials", "clinical trials", "studies", "clinical study",
|
| 359 |
+
"NCT", "recruiting", "enrollment", "study protocol",
|
| 360 |
+
"phase 1", "phase 2", "phase 3", "phase 4", "early phase",
|
| 361 |
+
"investigational", "experimental", "novel treatment",
|
| 362 |
+
"treatment pipeline", "research pipeline", "drug development"
|
| 363 |
+
]
|
| 364 |
+
|
| 365 |
+
# Count pattern matches
|
| 366 |
+
simple_score = sum(1 for pattern in simple_patterns if pattern in message_lower)
|
| 367 |
+
complex_score = sum(1 for pattern in complex_patterns if pattern in message_lower)
|
| 368 |
+
|
| 369 |
+
# Decision logic
|
| 370 |
+
if complex_score > 0:
|
| 371 |
+
# Any complex indicator suggests complex query
|
| 372 |
+
return "complex"
|
| 373 |
+
elif simple_score > 0 and len(message) < 150:
|
| 374 |
+
# Simple pattern and short query
|
| 375 |
+
return "simple"
|
| 376 |
+
elif len(message) > 300:
|
| 377 |
+
# Long queries are likely complex
|
| 378 |
+
return "complex"
|
| 379 |
+
elif tools_count > 8:
|
| 380 |
+
# Many tools suggest complex analysis needed
|
| 381 |
+
return "complex"
|
| 382 |
+
else:
|
| 383 |
+
# Default to complex for safety (better quality)
|
| 384 |
+
return "complex" if self.primary_client else "simple"
|
| 385 |
+
|
| 386 |
+
def get_provider_display_name(self) -> str:
|
| 387 |
+
"""Get a user-friendly provider status string"""
|
| 388 |
+
if self.primary_client and self.fallback_router:
|
| 389 |
+
# Both providers available
|
| 390 |
+
if self.config["provider_preference"] == "cost_optimize":
|
| 391 |
+
status = "SambaNova Llama 3.3 70B (primary, cost-optimized) with Anthropic Claude fallback"
|
| 392 |
+
elif self.config["provider_preference"] == "quality_first":
|
| 393 |
+
status = "Anthropic Claude (primary, quality-first) with SambaNova fallback"
|
| 394 |
+
else: # auto
|
| 395 |
+
status = "Anthropic Claude (with SambaNova fallback)"
|
| 396 |
+
elif self.primary_client:
|
| 397 |
+
status = "Anthropic Claude"
|
| 398 |
+
elif self.fallback_router:
|
| 399 |
+
status = f"SambaNova Llama 3.3 70B ({self.config['provider_preference']} mode)"
|
| 400 |
+
else:
|
| 401 |
+
status = "Not configured"
|
| 402 |
+
|
| 403 |
+
return status
|
| 404 |
+
|
| 405 |
+
async def cleanup(self):
|
| 406 |
+
"""Clean up resources"""
|
| 407 |
+
if self.fallback_router:
|
| 408 |
+
try:
|
| 409 |
+
await self.fallback_router.cleanup()
|
| 410 |
+
except:
|
| 411 |
+
pass
|
| 412 |
+
|
| 413 |
+
async def __aenter__(self):
|
| 414 |
+
"""Async context manager entry"""
|
| 415 |
+
return self
|
| 416 |
+
|
| 417 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 418 |
+
"""Async context manager exit"""
|
| 419 |
+
await self.cleanup()
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
# Global instance (optional - can be created per request instead)
|
| 423 |
+
_global_client: Optional[UnifiedLLMClient] = None
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def get_llm_client() -> UnifiedLLMClient:
|
| 427 |
+
"""Get or create the global LLM client instance"""
|
| 428 |
+
global _global_client
|
| 429 |
+
if _global_client is None:
|
| 430 |
+
_global_client = UnifiedLLMClient()
|
| 431 |
+
return _global_client
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
async def cleanup_global_client():
|
| 435 |
+
"""Clean up the global client instance"""
|
| 436 |
+
global _global_client
|
| 437 |
+
if _global_client:
|
| 438 |
+
await _global_client.cleanup()
|
| 439 |
+
_global_client = None
|
llm_providers.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Multi-LLM provider support with fallback logic.
|
| 4 |
+
Includes SambaNova free tier as primary fallback option.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import httpx
|
| 11 |
+
import asyncio
|
| 12 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple
|
| 13 |
+
from anthropic import AsyncAnthropic
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SambaNovaProvider:
|
| 22 |
+
"""
|
| 23 |
+
SambaNova Cloud provider - requires API key for access.
|
| 24 |
+
Get your API key at https://cloud.sambanova.ai/
|
| 25 |
+
Includes $5-30 free credits for new accounts.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
BASE_URL = "https://api.sambanova.ai/v1"
|
| 29 |
+
|
| 30 |
+
# Available models
|
| 31 |
+
MODELS = {
|
| 32 |
+
"llama-3.3-70b": "Meta-Llama-3.3-70B-Instruct", # Latest and best!
|
| 33 |
+
"llama-3.1-405b": "Meta-Llama-3.1-405B-Instruct",
|
| 34 |
+
"llama-3.1-70b": "Meta-Llama-3.1-70B-Instruct",
|
| 35 |
+
"llama-3.1-8b": "Meta-Llama-3.1-8B-Instruct",
|
| 36 |
+
"llama-3.2-11b": "Llama-3.2-11B-Vision-Instruct",
|
| 37 |
+
"llama-3.2-3b": "Llama-3.2-3B-Instruct",
|
| 38 |
+
"llama-3.2-1b": "Llama-3.2-1B-Instruct"
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 42 |
+
"""
|
| 43 |
+
Initialize SambaNova provider.
|
| 44 |
+
API key is REQUIRED - get yours at https://cloud.sambanova.ai/
|
| 45 |
+
"""
|
| 46 |
+
self.api_key = api_key or os.getenv("SAMBANOVA_API_KEY")
|
| 47 |
+
if not self.api_key:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"SAMBANOVA_API_KEY is required for SambaNova API access.\n"
|
| 50 |
+
"Get your API key at: https://cloud.sambanova.ai/\n"
|
| 51 |
+
"Then set it in your .env file: SAMBANOVA_API_KEY=your_key_here"
|
| 52 |
+
)
|
| 53 |
+
self.client = httpx.AsyncClient(timeout=60.0)
|
| 54 |
+
|
| 55 |
+
async def stream(
|
| 56 |
+
self,
|
| 57 |
+
messages: List[Dict],
|
| 58 |
+
system: str = None,
|
| 59 |
+
tools: List[Dict] = None,
|
| 60 |
+
model: str = "llama-3.1-70b",
|
| 61 |
+
max_tokens: int = 4096,
|
| 62 |
+
temperature: float = 0.7
|
| 63 |
+
) -> AsyncGenerator[Tuple[str, List[Dict]], None]:
|
| 64 |
+
"""
|
| 65 |
+
Stream responses from SambaNova API.
|
| 66 |
+
Compatible interface with Anthropic streaming.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
# Select the full model name
|
| 70 |
+
full_model = self.MODELS.get(model, self.MODELS["llama-3.1-70b"])
|
| 71 |
+
|
| 72 |
+
# Convert messages to OpenAI format (SambaNova uses OpenAI-compatible API)
|
| 73 |
+
formatted_messages = []
|
| 74 |
+
|
| 75 |
+
# Add system message if provided
|
| 76 |
+
if system:
|
| 77 |
+
formatted_messages.append({
|
| 78 |
+
"role": "system",
|
| 79 |
+
"content": system
|
| 80 |
+
})
|
| 81 |
+
|
| 82 |
+
# Convert Anthropic message format to OpenAI format
|
| 83 |
+
for msg in messages:
|
| 84 |
+
if msg["role"] == "user":
|
| 85 |
+
formatted_messages.append({
|
| 86 |
+
"role": "user",
|
| 87 |
+
"content": msg.get("content", "")
|
| 88 |
+
})
|
| 89 |
+
elif msg["role"] == "assistant":
|
| 90 |
+
# Handle assistant messages with potential tool calls
|
| 91 |
+
content = msg.get("content", "")
|
| 92 |
+
if isinstance(content, list):
|
| 93 |
+
# Extract text from content blocks
|
| 94 |
+
text_parts = []
|
| 95 |
+
for block in content:
|
| 96 |
+
if block.get("type") == "text":
|
| 97 |
+
text_parts.append(block.get("text", ""))
|
| 98 |
+
content = "\n".join(text_parts)
|
| 99 |
+
|
| 100 |
+
formatted_messages.append({
|
| 101 |
+
"role": "assistant",
|
| 102 |
+
"content": content
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
# Prepare request payload
|
| 106 |
+
payload = {
|
| 107 |
+
"model": full_model,
|
| 108 |
+
"messages": formatted_messages,
|
| 109 |
+
"max_tokens": max_tokens,
|
| 110 |
+
"temperature": temperature,
|
| 111 |
+
"stream": True,
|
| 112 |
+
"stream_options": {"include_usage": True}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
# Add tools if provided (for models that support it)
|
| 116 |
+
if tools and model in ["llama-3.3-70b", "llama-3.1-405b", "llama-3.1-70b"]:
|
| 117 |
+
# Convert Anthropic tool format to OpenAI format
|
| 118 |
+
openai_tools = []
|
| 119 |
+
for tool in tools:
|
| 120 |
+
openai_tools.append({
|
| 121 |
+
"type": "function",
|
| 122 |
+
"function": {
|
| 123 |
+
"name": tool["name"],
|
| 124 |
+
"description": tool.get("description", ""),
|
| 125 |
+
"parameters": tool.get("input_schema", {})
|
| 126 |
+
}
|
| 127 |
+
})
|
| 128 |
+
payload["tools"] = openai_tools
|
| 129 |
+
payload["tool_choice"] = "auto"
|
| 130 |
+
|
| 131 |
+
# Headers - API key is always required now
|
| 132 |
+
headers = {
|
| 133 |
+
"Content-Type": "application/json",
|
| 134 |
+
"Authorization": f"Bearer {self.api_key}"
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
# Make streaming request
|
| 139 |
+
accumulated_text = ""
|
| 140 |
+
tool_calls = []
|
| 141 |
+
|
| 142 |
+
async with self.client.stream(
|
| 143 |
+
"POST",
|
| 144 |
+
f"{self.BASE_URL}/chat/completions",
|
| 145 |
+
json=payload,
|
| 146 |
+
headers=headers
|
| 147 |
+
) as response:
|
| 148 |
+
response.raise_for_status()
|
| 149 |
+
|
| 150 |
+
async for line in response.aiter_lines():
|
| 151 |
+
if line.startswith("data: "):
|
| 152 |
+
data = line[6:] # Remove "data: " prefix
|
| 153 |
+
|
| 154 |
+
if data == "[DONE]":
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
chunk = json.loads(data)
|
| 159 |
+
|
| 160 |
+
# Handle usage-only chunks (sent at end of stream)
|
| 161 |
+
if "usage" in chunk and ("choices" not in chunk or len(chunk.get("choices", [])) == 0):
|
| 162 |
+
# This is a usage statistics chunk, skip it
|
| 163 |
+
logger.debug(f"Received usage chunk: {chunk.get('usage', {})}")
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
# Extract content from chunk
|
| 167 |
+
if "choices" in chunk and len(chunk["choices"]) > 0:
|
| 168 |
+
choice = chunk["choices"][0]
|
| 169 |
+
delta = choice.get("delta", {})
|
| 170 |
+
|
| 171 |
+
# Handle text content
|
| 172 |
+
if "content" in delta and delta["content"]:
|
| 173 |
+
accumulated_text += delta["content"]
|
| 174 |
+
yield (accumulated_text, tool_calls)
|
| 175 |
+
|
| 176 |
+
# Handle tool calls (if supported)
|
| 177 |
+
if "tool_calls" in delta:
|
| 178 |
+
for tc in delta["tool_calls"]:
|
| 179 |
+
# Convert OpenAI tool call format to Anthropic format
|
| 180 |
+
tool_calls.append({
|
| 181 |
+
"id": tc.get("id", f"tool_{len(tool_calls)}"),
|
| 182 |
+
"name": tc.get("function", {}).get("name", ""),
|
| 183 |
+
"input": json.loads(tc.get("function", {}).get("arguments", "{}"))
|
| 184 |
+
})
|
| 185 |
+
|
| 186 |
+
except json.JSONDecodeError:
|
| 187 |
+
logger.warning(f"Failed to parse SSE data: {data}")
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# Final yield with complete results
|
| 191 |
+
yield (accumulated_text, tool_calls)
|
| 192 |
+
|
| 193 |
+
except httpx.HTTPStatusError as e:
|
| 194 |
+
if e.response.status_code == 410:
|
| 195 |
+
logger.error("SambaNova API endpoint has been discontinued (410 GONE)")
|
| 196 |
+
raise RuntimeError(
|
| 197 |
+
"SambaNova API endpoint no longer exists. "
|
| 198 |
+
"Make sure you have a valid API key set in SAMBANOVA_API_KEY."
|
| 199 |
+
)
|
| 200 |
+
elif e.response.status_code == 401:
|
| 201 |
+
logger.error("SambaNova API authentication failed")
|
| 202 |
+
raise RuntimeError(
|
| 203 |
+
"SambaNova authentication failed. Please check your API key."
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
logger.error(f"SambaNova API error: {e}")
|
| 207 |
+
raise
|
| 208 |
+
except httpx.HTTPError as e:
|
| 209 |
+
logger.error(f"SambaNova API error: {e}")
|
| 210 |
+
raise
|
| 211 |
+
|
| 212 |
+
async def close(self):
|
| 213 |
+
"""Close the HTTP client"""
|
| 214 |
+
await self.client.aclose()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class LLMRouter:
|
| 218 |
+
"""
|
| 219 |
+
Routes LLM requests to appropriate providers with fallback logic.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
def __init__(self):
|
| 223 |
+
self.providers = {}
|
| 224 |
+
self._setup_providers()
|
| 225 |
+
|
| 226 |
+
def _setup_providers(self):
|
| 227 |
+
"""Initialize available providers"""
|
| 228 |
+
|
| 229 |
+
# Primary: Anthropic (if API key available)
|
| 230 |
+
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
| 231 |
+
if anthropic_key:
|
| 232 |
+
self.providers["anthropic"] = AsyncAnthropic(api_key=anthropic_key)
|
| 233 |
+
logger.info("Anthropic provider initialized")
|
| 234 |
+
|
| 235 |
+
# Fallback: SambaNova (always available, free!)
|
| 236 |
+
self.providers["sambanova"] = SambaNovaProvider()
|
| 237 |
+
logger.info("SambaNova provider initialized (free tier)")
|
| 238 |
+
|
| 239 |
+
async def stream_with_fallback(
|
| 240 |
+
self,
|
| 241 |
+
messages: List[Dict],
|
| 242 |
+
tools: List[Dict],
|
| 243 |
+
system_prompt: str,
|
| 244 |
+
model: str = None,
|
| 245 |
+
max_tokens: int = 4096,
|
| 246 |
+
provider_preference: str = "auto"
|
| 247 |
+
) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
|
| 248 |
+
"""
|
| 249 |
+
Stream from LLM with automatic fallback.
|
| 250 |
+
Returns (text, tool_calls, provider_used) tuples.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
# Determine provider order based on preference
|
| 254 |
+
if provider_preference == "cost_optimize":
|
| 255 |
+
# Prefer free SambaNova first
|
| 256 |
+
provider_order = ["sambanova", "anthropic"]
|
| 257 |
+
elif provider_preference == "quality_first":
|
| 258 |
+
# Prefer Anthropic first
|
| 259 |
+
provider_order = ["anthropic", "sambanova"]
|
| 260 |
+
else: # auto
|
| 261 |
+
# Use Anthropic if available, fall back to SambaNova
|
| 262 |
+
provider_order = ["anthropic", "sambanova"] if "anthropic" in self.providers else ["sambanova"]
|
| 263 |
+
|
| 264 |
+
last_error = None
|
| 265 |
+
|
| 266 |
+
for provider_name in provider_order:
|
| 267 |
+
if provider_name not in self.providers:
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
try:
|
| 271 |
+
logger.info(f"Attempting to use {provider_name} provider...")
|
| 272 |
+
|
| 273 |
+
if provider_name == "anthropic":
|
| 274 |
+
# Use existing Anthropic streaming
|
| 275 |
+
provider = self.providers["anthropic"]
|
| 276 |
+
|
| 277 |
+
# Stream from Anthropic
|
| 278 |
+
accumulated_text = ""
|
| 279 |
+
tool_calls = []
|
| 280 |
+
|
| 281 |
+
async with provider.messages.stream(
|
| 282 |
+
model=model or "claude-sonnet-4-5-20250929",
|
| 283 |
+
max_tokens=max_tokens,
|
| 284 |
+
messages=messages,
|
| 285 |
+
system=system_prompt,
|
| 286 |
+
tools=tools
|
| 287 |
+
) as stream:
|
| 288 |
+
async for event in stream:
|
| 289 |
+
if event.type == "content_block_start":
|
| 290 |
+
if event.content_block.type == "tool_use":
|
| 291 |
+
tool_calls.append({
|
| 292 |
+
"id": event.content_block.id,
|
| 293 |
+
"name": event.content_block.name,
|
| 294 |
+
"input": {}
|
| 295 |
+
})
|
| 296 |
+
|
| 297 |
+
elif event.type == "content_block_delta":
|
| 298 |
+
if event.delta.type == "text_delta":
|
| 299 |
+
accumulated_text += event.delta.text
|
| 300 |
+
yield (accumulated_text, tool_calls, "Anthropic Claude")
|
| 301 |
+
|
| 302 |
+
# Get final message
|
| 303 |
+
final_message = await stream.get_final_message()
|
| 304 |
+
|
| 305 |
+
# Rebuild tool calls from final message
|
| 306 |
+
tool_calls.clear()
|
| 307 |
+
for block in final_message.content:
|
| 308 |
+
if block.type == "tool_use":
|
| 309 |
+
tool_calls.append({
|
| 310 |
+
"id": block.id,
|
| 311 |
+
"name": block.name,
|
| 312 |
+
"input": block.input
|
| 313 |
+
})
|
| 314 |
+
|
| 315 |
+
yield (accumulated_text, tool_calls, "Anthropic Claude")
|
| 316 |
+
return # Success!
|
| 317 |
+
|
| 318 |
+
elif provider_name == "sambanova":
|
| 319 |
+
# Use SambaNova streaming
|
| 320 |
+
provider = self.providers["sambanova"]
|
| 321 |
+
|
| 322 |
+
# Determine which Llama model to use
|
| 323 |
+
if max_tokens > 8192:
|
| 324 |
+
samba_model = "llama-3.1-405b" # Largest model for complex tasks
|
| 325 |
+
else:
|
| 326 |
+
# Default to Llama 3.3 70B - newest and best for most tasks
|
| 327 |
+
samba_model = "llama-3.3-70b"
|
| 328 |
+
|
| 329 |
+
async for text, tool_calls in provider.stream(
|
| 330 |
+
messages=messages,
|
| 331 |
+
system=system_prompt,
|
| 332 |
+
tools=tools,
|
| 333 |
+
model=samba_model,
|
| 334 |
+
max_tokens=max_tokens
|
| 335 |
+
):
|
| 336 |
+
yield (text, tool_calls, f"SambaNova {samba_model}")
|
| 337 |
+
|
| 338 |
+
return # Success!
|
| 339 |
+
|
| 340 |
+
except Exception as e:
|
| 341 |
+
logger.warning(f"Provider {provider_name} failed: {e}")
|
| 342 |
+
last_error = e
|
| 343 |
+
continue
|
| 344 |
+
|
| 345 |
+
# All providers failed
|
| 346 |
+
error_msg = f"All LLM providers failed. Last error: {last_error}"
|
| 347 |
+
logger.error(error_msg)
|
| 348 |
+
raise Exception(error_msg)
|
| 349 |
+
|
| 350 |
+
async def cleanup(self):
|
| 351 |
+
"""Clean up provider resources"""
|
| 352 |
+
if "sambanova" in self.providers:
|
| 353 |
+
await self.providers["sambanova"].close()
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# Global router instance
|
| 357 |
+
llm_router = LLMRouter()
|
parallel_tool_execution.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Parallel tool execution optimization for ALS Research Agent
|
| 4 |
+
This module replaces sequential tool execution with parallel execution
|
| 5 |
+
to reduce response time by ~60-70% for multi-tool queries.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
from typing import List, Dict, Tuple, Any
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def execute_single_tool(
|
| 16 |
+
tool_call: Dict,
|
| 17 |
+
call_mcp_tool_func,
|
| 18 |
+
index: int
|
| 19 |
+
) -> Tuple[int, str, Dict]:
|
| 20 |
+
"""
|
| 21 |
+
Execute a single tool call asynchronously.
|
| 22 |
+
Returns (index, progress_text, result_dict) to maintain order.
|
| 23 |
+
"""
|
| 24 |
+
tool_name = tool_call["name"]
|
| 25 |
+
tool_args = tool_call["input"]
|
| 26 |
+
|
| 27 |
+
# Show search info in progress text
|
| 28 |
+
tool_display = tool_name.replace('__', ' → ')
|
| 29 |
+
search_info = ""
|
| 30 |
+
if "query" in tool_args:
|
| 31 |
+
search_info = f" `{tool_args['query'][:50]}{'...' if len(tool_args['query']) > 50 else ''}`"
|
| 32 |
+
elif "condition" in tool_args:
|
| 33 |
+
search_info = f" `{tool_args['condition'][:50]}{'...' if len(tool_args['condition']) > 50 else ''}`"
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
# Call MCP tool
|
| 37 |
+
start_time = asyncio.get_event_loop().time()
|
| 38 |
+
tool_result = await call_mcp_tool_func(tool_name, tool_args)
|
| 39 |
+
elapsed = asyncio.get_event_loop().time() - start_time
|
| 40 |
+
|
| 41 |
+
logger.info(f"Tool {tool_name} completed in {elapsed:.2f}s")
|
| 42 |
+
|
| 43 |
+
# Check for zero results to provide clear indicators
|
| 44 |
+
has_results = True
|
| 45 |
+
results_count = 0
|
| 46 |
+
|
| 47 |
+
if isinstance(tool_result, str):
|
| 48 |
+
result_lower = tool_result.lower()
|
| 49 |
+
|
| 50 |
+
# Check for specific result counts
|
| 51 |
+
import re
|
| 52 |
+
count_matches = re.findall(r'found (\d+) (?:papers?|trials?|preprints?|results?)', result_lower)
|
| 53 |
+
if count_matches:
|
| 54 |
+
results_count = int(count_matches[0])
|
| 55 |
+
|
| 56 |
+
# Check for no results
|
| 57 |
+
if any(phrase in result_lower for phrase in [
|
| 58 |
+
"no results found", "0 results", "no papers found",
|
| 59 |
+
"no trials found", "no preprints found", "not found",
|
| 60 |
+
"zero results", "no matches"
|
| 61 |
+
]) or results_count == 0:
|
| 62 |
+
has_results = False
|
| 63 |
+
|
| 64 |
+
# Create clear success/failure indicator
|
| 65 |
+
if has_results:
|
| 66 |
+
if results_count > 0:
|
| 67 |
+
progress_text = f"\n✅ **Found {results_count} results:** {tool_display}{search_info}"
|
| 68 |
+
else:
|
| 69 |
+
progress_text = f"\n✅ **Success:** {tool_display}{search_info}"
|
| 70 |
+
else:
|
| 71 |
+
progress_text = f"\n⚠️ **No results:** {tool_display}{search_info} - will try alternatives"
|
| 72 |
+
|
| 73 |
+
# Add timing for long operations
|
| 74 |
+
if elapsed > 5:
|
| 75 |
+
progress_text += f" (took {elapsed:.1f}s)"
|
| 76 |
+
|
| 77 |
+
# Check for zero results to enable self-correction
|
| 78 |
+
if not has_results:
|
| 79 |
+
# Add self-correction hint to the result
|
| 80 |
+
tool_result += "\n\n**SELF-CORRECTION HINT:** No results found with this query. Consider:\n"
|
| 81 |
+
tool_result += "1. Broadening search terms (remove qualifiers)\n"
|
| 82 |
+
tool_result += "2. Using alternative terminology or synonyms\n"
|
| 83 |
+
tool_result += "3. Searching related concepts\n"
|
| 84 |
+
tool_result += "4. Checking for typos in search terms"
|
| 85 |
+
|
| 86 |
+
result_dict = {
|
| 87 |
+
"type": "tool_result",
|
| 88 |
+
"tool_use_id": tool_call["id"],
|
| 89 |
+
"content": tool_result
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
return index, progress_text, result_dict
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Error executing tool {tool_name}: {e}")
|
| 96 |
+
|
| 97 |
+
# Clear failure indicator for errors
|
| 98 |
+
progress_text = f"\n❌ **Failed:** {tool_display}{search_info} - {str(e)[:50]}"
|
| 99 |
+
|
| 100 |
+
error_result = {
|
| 101 |
+
"type": "tool_result",
|
| 102 |
+
"tool_use_id": tool_call["id"],
|
| 103 |
+
"content": f"Error executing tool: {str(e)}"
|
| 104 |
+
}
|
| 105 |
+
return index, progress_text, error_result
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
async def execute_tool_calls_parallel(
|
| 109 |
+
tool_calls: List[Dict],
|
| 110 |
+
call_mcp_tool_func
|
| 111 |
+
) -> Tuple[str, List[Dict]]:
|
| 112 |
+
"""
|
| 113 |
+
Execute tool calls in parallel and collect results.
|
| 114 |
+
Maintains the original order of tool calls in results.
|
| 115 |
+
|
| 116 |
+
Returns: (progress_text, tool_results_content)
|
| 117 |
+
"""
|
| 118 |
+
if not tool_calls:
|
| 119 |
+
return "", []
|
| 120 |
+
|
| 121 |
+
# Track execution time for progress reporting
|
| 122 |
+
start_time = asyncio.get_event_loop().time()
|
| 123 |
+
|
| 124 |
+
# Log parallel execution
|
| 125 |
+
logger.info(f"Executing {len(tool_calls)} tools in parallel")
|
| 126 |
+
|
| 127 |
+
# Create tasks for parallel execution
|
| 128 |
+
tasks = [
|
| 129 |
+
execute_single_tool(tool_call, call_mcp_tool_func, i)
|
| 130 |
+
for i, tool_call in enumerate(tool_calls)
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
# Execute all tasks in parallel
|
| 134 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 135 |
+
|
| 136 |
+
# Sort results by index to maintain original order
|
| 137 |
+
sorted_results = sorted(
|
| 138 |
+
[r for r in results if not isinstance(r, Exception)],
|
| 139 |
+
key=lambda x: x[0]
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Combine results with progress summary
|
| 143 |
+
completed_count = len(sorted_results)
|
| 144 |
+
total_count = len(tool_calls)
|
| 145 |
+
|
| 146 |
+
# Create progress summary with timing info
|
| 147 |
+
elapsed_time = asyncio.get_event_loop().time() - start_time
|
| 148 |
+
|
| 149 |
+
if elapsed_time > 5:
|
| 150 |
+
timing_info = f" in {elapsed_time:.1f}s"
|
| 151 |
+
else:
|
| 152 |
+
timing_info = ""
|
| 153 |
+
|
| 154 |
+
progress_text = f"\n📊 **Search Progress:** Completed {completed_count}/{total_count} searches{timing_info}\n"
|
| 155 |
+
|
| 156 |
+
tool_results_content = []
|
| 157 |
+
|
| 158 |
+
for index, prog_text, result_dict in sorted_results:
|
| 159 |
+
progress_text += prog_text
|
| 160 |
+
tool_results_content.append(result_dict)
|
| 161 |
+
|
| 162 |
+
# Handle any exceptions
|
| 163 |
+
for i, result in enumerate(results):
|
| 164 |
+
if isinstance(result, Exception):
|
| 165 |
+
logger.error(f"Task {i} failed with exception: {result}")
|
| 166 |
+
# Add error result for failed tasks
|
| 167 |
+
if i < len(tool_calls):
|
| 168 |
+
tool_results_content.insert(i, {
|
| 169 |
+
"type": "tool_result",
|
| 170 |
+
"tool_use_id": tool_calls[i]["id"],
|
| 171 |
+
"content": f"Tool execution failed: {str(result)}"
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
return progress_text, tool_results_content
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# Backward compatibility wrapper
|
| 178 |
+
async def execute_tool_calls_optimized(
|
| 179 |
+
tool_calls: List[Dict],
|
| 180 |
+
call_mcp_tool_func,
|
| 181 |
+
parallel: bool = True
|
| 182 |
+
) -> Tuple[str, List[Dict]]:
|
| 183 |
+
"""
|
| 184 |
+
Execute tool calls with optional parallel execution.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
tool_calls: List of tool calls to execute
|
| 188 |
+
call_mcp_tool_func: Function to call MCP tools
|
| 189 |
+
parallel: If True, execute tools in parallel; if False, execute sequentially
|
| 190 |
+
|
| 191 |
+
Returns: (progress_text, tool_results_content)
|
| 192 |
+
"""
|
| 193 |
+
if parallel and len(tool_calls) > 1:
|
| 194 |
+
# Use parallel execution for multiple tools
|
| 195 |
+
return await execute_tool_calls_parallel(tool_calls, call_mcp_tool_func)
|
| 196 |
+
else:
|
| 197 |
+
# Fall back to sequential execution (import from original)
|
| 198 |
+
from refactored_helpers import execute_tool_calls
|
| 199 |
+
return await execute_tool_calls(tool_calls, call_mcp_tool_func)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def estimate_time_savings(num_tools: int, avg_tool_time: float = 3.5) -> Dict[str, float]:
|
| 203 |
+
"""
|
| 204 |
+
Estimate time savings from parallel execution.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
num_tools: Number of tools to execute
|
| 208 |
+
avg_tool_time: Average time per tool in seconds
|
| 209 |
+
|
| 210 |
+
Returns: Dictionary with timing estimates
|
| 211 |
+
"""
|
| 212 |
+
sequential_time = num_tools * avg_tool_time
|
| 213 |
+
# Parallel time is roughly the time of the slowest tool plus overhead
|
| 214 |
+
parallel_time = avg_tool_time + 0.5 # 0.5s overhead for coordination
|
| 215 |
+
|
| 216 |
+
savings = sequential_time - parallel_time
|
| 217 |
+
savings_percent = (savings / sequential_time) * 100 if sequential_time > 0 else 0
|
| 218 |
+
|
| 219 |
+
return {
|
| 220 |
+
"sequential_time": sequential_time,
|
| 221 |
+
"parallel_time": parallel_time,
|
| 222 |
+
"time_saved": savings,
|
| 223 |
+
"savings_percent": savings_percent
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Test the optimization
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
# Test time savings estimation
|
| 230 |
+
for n in [2, 3, 4, 5]:
|
| 231 |
+
estimates = estimate_time_savings(n)
|
| 232 |
+
print(f"\n{n} tools:")
|
| 233 |
+
print(f" Sequential: {estimates['sequential_time']:.1f}s")
|
| 234 |
+
print(f" Parallel: {estimates['parallel_time']:.1f}s")
|
| 235 |
+
print(f" Savings: {estimates['time_saved']:.1f}s ({estimates['savings_percent']:.0f}%)")
|
query_classifier.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Query Classification Module for ALS Research Agent
|
| 4 |
+
Determines whether a query requires full research workflow or simple response
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
from typing import Dict, Tuple, List
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QueryClassifier:
|
| 15 |
+
"""Classify queries as research-required or simple questions"""
|
| 16 |
+
|
| 17 |
+
# Keywords that indicate ALS research is needed
|
| 18 |
+
RESEARCH_KEYWORDS = [
|
| 19 |
+
# Disease-specific terms
|
| 20 |
+
'als', 'amyotrophic lateral sclerosis', 'motor neuron disease',
|
| 21 |
+
'mnd', 'lou gehrig', 'ftd', 'frontotemporal dementia',
|
| 22 |
+
|
| 23 |
+
# Medical research terms
|
| 24 |
+
'clinical trial', 'treatment', 'therapy', 'drug', 'medication',
|
| 25 |
+
'gene therapy', 'stem cell', 'biomarker', 'diagnosis',
|
| 26 |
+
'prognosis', 'survival', 'progression', 'symptom',
|
| 27 |
+
'cure', 'breakthrough', 'research', 'study', 'paper',
|
| 28 |
+
'latest', 'recent', 'new findings', 'advances',
|
| 29 |
+
|
| 30 |
+
# Specific ALS-related
|
| 31 |
+
'riluzole', 'edaravone', 'radicava', 'relyvrio', 'qalsody',
|
| 32 |
+
'tofersen', 'sod1', 'c9orf72', 'tdp-43', 'fus',
|
| 33 |
+
|
| 34 |
+
# Research actions
|
| 35 |
+
'find studies', 'search papers', 'what research',
|
| 36 |
+
'clinical evidence', 'scientific literature'
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# Keywords that indicate simple/general questions
|
| 40 |
+
SIMPLE_KEYWORDS = [
|
| 41 |
+
'hello', 'hi', 'hey', 'thanks', 'thank you',
|
| 42 |
+
'how are you', "what's your name", 'who are you',
|
| 43 |
+
'what can you do', 'help', 'test', 'testing',
|
| 44 |
+
'explain', 'define', 'what is', 'what are',
|
| 45 |
+
'how does', 'why', 'when', 'where', 'who'
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
# Exclusion patterns for non-research queries
|
| 49 |
+
NON_RESEARCH_PATTERNS = [
|
| 50 |
+
r'^(hi|hello|hey|thanks|thank you)',
|
| 51 |
+
r'^test\s',
|
| 52 |
+
r'^how (are you|do you)',
|
| 53 |
+
r'^what (is|are) (the|a|an)\s+\w+$', # Simple definitions
|
| 54 |
+
r'^(explain|define)\s+\w+$', # Simple explanations
|
| 55 |
+
r'^\w{1,3}$', # Very short queries
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def classify_query(cls, query: str) -> Dict[str, any]:
|
| 60 |
+
"""
|
| 61 |
+
Classify a query and determine processing strategy.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Dict with:
|
| 65 |
+
- requires_research: bool - Whether to use full research workflow
|
| 66 |
+
- confidence: float - Confidence in classification (0-1)
|
| 67 |
+
- reason: str - Explanation of classification
|
| 68 |
+
- suggested_mode: str - 'research' or 'simple'
|
| 69 |
+
"""
|
| 70 |
+
query_lower = query.lower().strip()
|
| 71 |
+
|
| 72 |
+
# Check for very short or empty queries
|
| 73 |
+
if len(query_lower) < 5:
|
| 74 |
+
return {
|
| 75 |
+
'requires_research': False,
|
| 76 |
+
'confidence': 0.9,
|
| 77 |
+
'reason': 'Query too short for research',
|
| 78 |
+
'suggested_mode': 'simple'
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Check exclusion patterns first
|
| 82 |
+
for pattern in cls.NON_RESEARCH_PATTERNS:
|
| 83 |
+
if re.match(pattern, query_lower):
|
| 84 |
+
return {
|
| 85 |
+
'requires_research': False,
|
| 86 |
+
'confidence': 0.85,
|
| 87 |
+
'reason': 'Matches non-research pattern',
|
| 88 |
+
'suggested_mode': 'simple'
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Count research keywords
|
| 92 |
+
research_score = sum(
|
| 93 |
+
1 for keyword in cls.RESEARCH_KEYWORDS
|
| 94 |
+
if keyword in query_lower
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Count simple keywords
|
| 98 |
+
simple_score = sum(
|
| 99 |
+
1 for keyword in cls.SIMPLE_KEYWORDS
|
| 100 |
+
if keyword in query_lower
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Check for question complexity
|
| 104 |
+
has_multiple_questions = query.count('?') > 1
|
| 105 |
+
has_complex_structure = len(query.split()) > 15
|
| 106 |
+
mentions_comparison = any(word in query_lower for word in
|
| 107 |
+
['compare', 'versus', 'vs', 'difference between'])
|
| 108 |
+
|
| 109 |
+
# Decision logic - Conservative approach for ALS research agent
|
| 110 |
+
|
| 111 |
+
# FIRST: Check if this is truly just a greeting/thanks (only these skip research)
|
| 112 |
+
greeting_only = query_lower in ['hi', 'hello', 'hey', 'thanks', 'thank you', 'bye', 'goodbye', 'test']
|
| 113 |
+
if greeting_only and research_score == 0:
|
| 114 |
+
return {
|
| 115 |
+
'requires_research': False,
|
| 116 |
+
'confidence': 0.95,
|
| 117 |
+
'reason': 'Pure greeting or acknowledgment',
|
| 118 |
+
'suggested_mode': 'simple'
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# SECOND: If ANY research keyword is present, use research mode
|
| 122 |
+
# This includes "ALS", "treatment", "therapy", etc.
|
| 123 |
+
if research_score >= 1:
|
| 124 |
+
return {
|
| 125 |
+
'requires_research': True,
|
| 126 |
+
'confidence': min(0.95, 0.7 + research_score * 0.1),
|
| 127 |
+
'reason': f'Contains research-related terms ({research_score} keywords)',
|
| 128 |
+
'suggested_mode': 'research'
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
# THIRD: Check for questions about the agent itself
|
| 132 |
+
about_agent = any(phrase in query_lower for phrase in [
|
| 133 |
+
'who are you', 'what can you do', 'how do you work',
|
| 134 |
+
'what are you', 'your capabilities'
|
| 135 |
+
])
|
| 136 |
+
if about_agent:
|
| 137 |
+
return {
|
| 138 |
+
'requires_research': False,
|
| 139 |
+
'confidence': 0.85,
|
| 140 |
+
'reason': 'Question about the agent itself',
|
| 141 |
+
'suggested_mode': 'simple'
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
# DEFAULT: For an ALS research agent, when in doubt, use research mode
|
| 145 |
+
# This is safer than potentially missing important medical queries
|
| 146 |
+
return {
|
| 147 |
+
'requires_research': True,
|
| 148 |
+
'confidence': 0.6,
|
| 149 |
+
'reason': 'Default to research mode for potential medical queries',
|
| 150 |
+
'suggested_mode': 'research'
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def should_use_tools(cls, query: str) -> bool:
|
| 155 |
+
"""Quick check if query needs research tools"""
|
| 156 |
+
classification = cls.classify_query(query)
|
| 157 |
+
return classification['requires_research'] and classification['confidence'] > 0.65
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def get_processing_hint(cls, classification: Dict) -> str:
|
| 161 |
+
"""Get a hint for how to process the query"""
|
| 162 |
+
if classification['requires_research']:
|
| 163 |
+
return "🔬 Using full research workflow ..."
|
| 164 |
+
else:
|
| 165 |
+
return "💬 Providing direct response without research tools"
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def test_classifier():
|
| 169 |
+
"""Test the classifier with example queries"""
|
| 170 |
+
test_queries = [
|
| 171 |
+
# Should require research
|
| 172 |
+
"What are the latest gene therapy trials for ALS?",
|
| 173 |
+
"Compare riluzole and edaravone effectiveness",
|
| 174 |
+
"Find recent studies on SOD1 mutations",
|
| 175 |
+
"What breakthroughs in ALS treatment happened in 2024?",
|
| 176 |
+
"Are there any promising stem cell therapies for motor neuron disease?",
|
| 177 |
+
|
| 178 |
+
# Should NOT require research
|
| 179 |
+
"Hello, how are you?",
|
| 180 |
+
"What is your name?",
|
| 181 |
+
"Test",
|
| 182 |
+
"Thanks for your help",
|
| 183 |
+
"Explain what a database is",
|
| 184 |
+
"What time is it?",
|
| 185 |
+
"How do I use this app?",
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
print("Query Classification Test Results")
|
| 189 |
+
print("=" * 60)
|
| 190 |
+
|
| 191 |
+
for query in test_queries:
|
| 192 |
+
result = QueryClassifier.classify_query(query)
|
| 193 |
+
print(f"\nQuery: {query[:50]}...")
|
| 194 |
+
print(f"Requires Research: {result['requires_research']}")
|
| 195 |
+
print(f"Confidence: {result['confidence']:.2f}")
|
| 196 |
+
print(f"Reason: {result['reason']}")
|
| 197 |
+
print(f"Mode: {result['suggested_mode']}")
|
| 198 |
+
print("-" * 40)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
test_classifier()
|
refactored_helpers.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Helper functions to consolidate duplicate code in als_agent_app.py
|
| 4 |
+
Refactored to improve efficiency and reduce redundancy
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
import httpx
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple
|
| 12 |
+
from llm_client import UnifiedLLMClient
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
async def stream_with_retry(
|
| 18 |
+
client,
|
| 19 |
+
messages: List[Dict],
|
| 20 |
+
tools: List[Dict],
|
| 21 |
+
system_prompt: str,
|
| 22 |
+
max_retries: int = 2,
|
| 23 |
+
model: str = None,
|
| 24 |
+
max_tokens: int = 8192,
|
| 25 |
+
stream_name: str = "API call",
|
| 26 |
+
temperature: float = 0.7
|
| 27 |
+
) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
|
| 28 |
+
"""
|
| 29 |
+
Simplified wrapper that delegates to UnifiedLLMClient.
|
| 30 |
+
|
| 31 |
+
The client parameter can be:
|
| 32 |
+
- An Anthropic client (for backward compatibility)
|
| 33 |
+
- A UnifiedLLMClient instance
|
| 34 |
+
- None (will create a UnifiedLLMClient)
|
| 35 |
+
|
| 36 |
+
Yields: (response_text, tool_calls, provider_used) tuples
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# If client is None or is an Anthropic client, use UnifiedLLMClient
|
| 40 |
+
if client is None or not hasattr(client, 'stream'):
|
| 41 |
+
# Create or get a UnifiedLLMClient instance
|
| 42 |
+
llm_client = UnifiedLLMClient()
|
| 43 |
+
|
| 44 |
+
logger.info(f"Using {llm_client.get_provider_display_name()} for {stream_name}")
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
# Use the unified client's stream method
|
| 48 |
+
async for text, tool_calls, provider in llm_client.stream(
|
| 49 |
+
messages=messages,
|
| 50 |
+
tools=tools,
|
| 51 |
+
system_prompt=system_prompt,
|
| 52 |
+
model=model,
|
| 53 |
+
max_tokens=max_tokens,
|
| 54 |
+
temperature=temperature
|
| 55 |
+
):
|
| 56 |
+
yield (text, tool_calls, provider)
|
| 57 |
+
finally:
|
| 58 |
+
# Clean up if we created the client
|
| 59 |
+
await llm_client.cleanup()
|
| 60 |
+
|
| 61 |
+
else:
|
| 62 |
+
# Client is already a UnifiedLLMClient
|
| 63 |
+
logger.info(f"Using provided {client.get_provider_display_name()} for {stream_name}")
|
| 64 |
+
|
| 65 |
+
async for text, tool_calls, provider in client.stream(
|
| 66 |
+
messages=messages,
|
| 67 |
+
tools=tools,
|
| 68 |
+
system_prompt=system_prompt,
|
| 69 |
+
model=model,
|
| 70 |
+
max_tokens=max_tokens,
|
| 71 |
+
temperature=temperature
|
| 72 |
+
):
|
| 73 |
+
yield (text, tool_calls, provider)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
async def execute_tool_calls(
|
| 77 |
+
tool_calls: List[Dict],
|
| 78 |
+
call_mcp_tool_func
|
| 79 |
+
) -> Tuple[str, List[Dict]]:
|
| 80 |
+
"""
|
| 81 |
+
Execute tool calls and collect results.
|
| 82 |
+
Consolidates duplicate tool execution logic.
|
| 83 |
+
Now includes self-correction hints for zero results.
|
| 84 |
+
|
| 85 |
+
Returns: (progress_text, tool_results_content)
|
| 86 |
+
"""
|
| 87 |
+
progress_text = ""
|
| 88 |
+
tool_results_content = []
|
| 89 |
+
zero_result_tools = []
|
| 90 |
+
|
| 91 |
+
for tool_call in tool_calls:
|
| 92 |
+
tool_name = tool_call["name"]
|
| 93 |
+
tool_args = tool_call["input"]
|
| 94 |
+
|
| 95 |
+
# Single, clean execution marker with search info
|
| 96 |
+
tool_display = tool_name.replace('__', ' → ')
|
| 97 |
+
|
| 98 |
+
# Show key search parameters
|
| 99 |
+
search_info = ""
|
| 100 |
+
if "query" in tool_args:
|
| 101 |
+
search_info = f" `{tool_args['query'][:50]}{'...' if len(tool_args['query']) > 50 else ''}`"
|
| 102 |
+
elif "condition" in tool_args:
|
| 103 |
+
search_info = f" `{tool_args['condition'][:50]}{'...' if len(tool_args['condition']) > 50 else ''}`"
|
| 104 |
+
|
| 105 |
+
progress_text += f"\n🔧 **Searching:** {tool_display}{search_info}\n"
|
| 106 |
+
|
| 107 |
+
# Call MCP tool
|
| 108 |
+
tool_result = await call_mcp_tool_func(tool_name, tool_args)
|
| 109 |
+
|
| 110 |
+
# Check for zero results to enable self-correction
|
| 111 |
+
if isinstance(tool_result, str):
|
| 112 |
+
result_lower = tool_result.lower()
|
| 113 |
+
if any(phrase in result_lower for phrase in [
|
| 114 |
+
"no results found", "0 results", "no papers found",
|
| 115 |
+
"no trials found", "no preprints found", "not found",
|
| 116 |
+
"zero results", "no matches"
|
| 117 |
+
]):
|
| 118 |
+
zero_result_tools.append((tool_name, tool_args))
|
| 119 |
+
|
| 120 |
+
# Add self-correction hint to the result
|
| 121 |
+
tool_result += "\n\n**SELF-CORRECTION HINT:** No results found with this query. Consider:\n"
|
| 122 |
+
tool_result += "1. Broadening search terms (remove qualifiers)\n"
|
| 123 |
+
tool_result += "2. Using alternative terminology or synonyms\n"
|
| 124 |
+
tool_result += "3. Searching related concepts\n"
|
| 125 |
+
tool_result += "4. Checking for typos in search terms"
|
| 126 |
+
|
| 127 |
+
# Add to results array
|
| 128 |
+
tool_results_content.append({
|
| 129 |
+
"type": "tool_result",
|
| 130 |
+
"tool_use_id": tool_call["id"],
|
| 131 |
+
"content": tool_result
|
| 132 |
+
})
|
| 133 |
+
|
| 134 |
+
return progress_text, tool_results_content
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def build_assistant_message(
|
| 138 |
+
text_content: str,
|
| 139 |
+
tool_calls: List[Dict],
|
| 140 |
+
strip_markers: List[str] = None
|
| 141 |
+
) -> List[Dict]:
|
| 142 |
+
"""
|
| 143 |
+
Build assistant message content with text and tool uses.
|
| 144 |
+
Consolidates duplicate message building logic.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
text_content: Text content to include
|
| 148 |
+
tool_calls: List of tool calls to include
|
| 149 |
+
strip_markers: Optional list of text markers to strip from content
|
| 150 |
+
|
| 151 |
+
Returns: List of content blocks for assistant message
|
| 152 |
+
"""
|
| 153 |
+
assistant_content = []
|
| 154 |
+
|
| 155 |
+
# Process text content
|
| 156 |
+
if text_content and text_content.strip():
|
| 157 |
+
processed_text = text_content
|
| 158 |
+
|
| 159 |
+
# Strip any specified markers
|
| 160 |
+
if strip_markers:
|
| 161 |
+
for marker in strip_markers:
|
| 162 |
+
processed_text = processed_text.replace(marker, "")
|
| 163 |
+
|
| 164 |
+
processed_text = processed_text.strip()
|
| 165 |
+
|
| 166 |
+
if processed_text:
|
| 167 |
+
assistant_content.append({
|
| 168 |
+
"type": "text",
|
| 169 |
+
"text": processed_text
|
| 170 |
+
})
|
| 171 |
+
|
| 172 |
+
# Add tool uses
|
| 173 |
+
for tc in tool_calls:
|
| 174 |
+
assistant_content.append({
|
| 175 |
+
"type": "tool_use",
|
| 176 |
+
"id": tc["id"],
|
| 177 |
+
"name": tc["name"],
|
| 178 |
+
"input": tc["input"]
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
return assistant_content
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def should_continue_iterations(
|
| 185 |
+
iteration_count: int,
|
| 186 |
+
max_iterations: int,
|
| 187 |
+
tool_calls: List[Dict]
|
| 188 |
+
) -> bool:
|
| 189 |
+
"""
|
| 190 |
+
Check if tool iterations should continue.
|
| 191 |
+
Centralizes iteration control logic.
|
| 192 |
+
"""
|
| 193 |
+
if not tool_calls:
|
| 194 |
+
return False
|
| 195 |
+
|
| 196 |
+
if iteration_count >= max_iterations:
|
| 197 |
+
logger.warning(f"Reached maximum tool iterations ({max_iterations})")
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
return True
|
requirements.txt
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Dependencies
|
| 2 |
+
# REQUIRES PYTHON 3.10+ (Recommended: Python 3.12)
|
| 3 |
+
# Compatible with Gradio 5.x and 6.x
|
| 4 |
+
gradio>=6.0.0
|
| 5 |
+
anthropic>=0.34.0
|
| 6 |
+
mcp>=1.21.2 # FastMCP API required
|
| 7 |
+
httpx>=0.25.0
|
| 8 |
+
|
| 9 |
+
# Web Scraping
|
| 10 |
+
beautifulsoup4>=4.12.0
|
| 11 |
+
lxml>=4.9.0
|
| 12 |
+
|
| 13 |
+
# Database Access (for AACT clinical trials database)
|
| 14 |
+
psycopg2-binary>=2.9.0
|
| 15 |
+
asyncpg>=0.29.0 # For async PostgreSQL with connection pooling
|
| 16 |
+
|
| 17 |
+
# Configuration
|
| 18 |
+
python-dotenv>=1.0.0
|
| 19 |
+
|
| 20 |
+
# Testing
|
| 21 |
+
pytest>=7.4.0
|
| 22 |
+
pytest-asyncio>=0.21.0
|
| 23 |
+
pytest-cov>=4.1.0
|
| 24 |
+
pytest-mock>=3.12.0
|
| 25 |
+
|
| 26 |
+
# RAG and Research Memory (LlamaIndex)
|
| 27 |
+
llama-index-core>=0.11.0
|
| 28 |
+
llama-index-vector-stores-chroma>=0.2.0
|
| 29 |
+
llama-index-embeddings-huggingface>=0.3.0
|
| 30 |
+
chromadb>=0.5.0
|
| 31 |
+
sentence-transformers>=3.0.0
|
| 32 |
+
transformers>=4.30.0
|
| 33 |
+
|
| 34 |
+
# Development
|
| 35 |
+
black>=23.0.0
|
| 36 |
+
flake8>=6.1.0
|
| 37 |
+
mypy>=1.7.0
|
servers/aact_server.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AACT Database MCP Server with Connection Pooling
|
| 4 |
+
Provides access to ClinicalTrials.gov data through the AACT PostgreSQL database.
|
| 5 |
+
|
| 6 |
+
AACT (Aggregate Analysis of ClinicalTrials.gov) is maintained by Duke University
|
| 7 |
+
and FDA, providing complete ClinicalTrials.gov data updated daily.
|
| 8 |
+
|
| 9 |
+
Database access: aact-db.ctti-clinicaltrials.org
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import asyncio
|
| 17 |
+
from typing import Optional, Dict, Any, List
|
| 18 |
+
from datetime import datetime, timedelta
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# Add parent directory to path for shared imports
|
| 22 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 23 |
+
from shared.config import config
|
| 24 |
+
|
| 25 |
+
# Load .env file if it exists
|
| 26 |
+
try:
|
| 27 |
+
from dotenv import load_dotenv
|
| 28 |
+
env_path = Path(__file__).parent.parent / '.env'
|
| 29 |
+
if env_path.exists():
|
| 30 |
+
load_dotenv(env_path)
|
| 31 |
+
logging.info(f"Loaded .env file from {env_path}")
|
| 32 |
+
except ImportError:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
# Try both asyncpg (preferred) and psycopg2 (fallback)
|
| 36 |
+
try:
|
| 37 |
+
import asyncpg
|
| 38 |
+
ASYNCPG_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
ASYNCPG_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
import psycopg2
|
| 44 |
+
from psycopg2.extras import RealDictCursor
|
| 45 |
+
POSTGRES_AVAILABLE = True
|
| 46 |
+
except ImportError:
|
| 47 |
+
POSTGRES_AVAILABLE = False
|
| 48 |
+
|
| 49 |
+
from mcp.server.fastmcp import FastMCP
|
| 50 |
+
|
| 51 |
+
# Setup logging
|
| 52 |
+
logging.basicConfig(level=logging.INFO)
|
| 53 |
+
logger = logging.getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
# Initialize MCP server
|
| 56 |
+
mcp = FastMCP("aact-database")
|
| 57 |
+
|
| 58 |
+
# Database configuration
|
| 59 |
+
AACT_HOST = os.getenv("AACT_HOST", "aact-db.ctti-clinicaltrials.org")
|
| 60 |
+
AACT_PORT = os.getenv("AACT_PORT", "5432")
|
| 61 |
+
AACT_DB = os.getenv("AACT_DB", "aact")
|
| 62 |
+
AACT_USER = os.getenv("AACT_USER", "aact")
|
| 63 |
+
AACT_PASSWORD = os.getenv("AACT_PASSWORD", "")
|
| 64 |
+
|
| 65 |
+
# Global connection pool (initialized once)
|
| 66 |
+
_connection_pool: Optional[asyncpg.Pool] = None
|
| 67 |
+
|
| 68 |
+
async def get_connection_pool() -> asyncpg.Pool:
|
| 69 |
+
"""Get or create the global connection pool"""
|
| 70 |
+
global _connection_pool
|
| 71 |
+
|
| 72 |
+
if _connection_pool is None or _connection_pool._closed:
|
| 73 |
+
logger.info("Creating new database connection pool...")
|
| 74 |
+
|
| 75 |
+
# Build connection URL
|
| 76 |
+
if AACT_PASSWORD:
|
| 77 |
+
dsn = f"postgresql://{AACT_USER}:{AACT_PASSWORD}@{AACT_HOST}:{AACT_PORT}/{AACT_DB}"
|
| 78 |
+
else:
|
| 79 |
+
dsn = f"postgresql://{AACT_USER}@{AACT_HOST}:{AACT_PORT}/{AACT_DB}"
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
_connection_pool = await asyncpg.create_pool(
|
| 83 |
+
dsn=dsn,
|
| 84 |
+
min_size=2, # Minimum connections in pool
|
| 85 |
+
max_size=10, # Maximum connections in pool
|
| 86 |
+
max_queries=50000, # Max queries per connection before recycling
|
| 87 |
+
max_inactive_connection_lifetime=300, # Close idle connections after 5 min
|
| 88 |
+
command_timeout=60.0, # Query timeout
|
| 89 |
+
statement_cache_size=20, # Cache prepared statements
|
| 90 |
+
)
|
| 91 |
+
logger.info("Connection pool created successfully")
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.error(f"Failed to create connection pool: {e}")
|
| 94 |
+
raise
|
| 95 |
+
|
| 96 |
+
return _connection_pool
|
| 97 |
+
|
| 98 |
+
async def execute_query_pooled(query: str, params: tuple = ()) -> List[Dict]:
|
| 99 |
+
"""Execute query using connection pool (asyncpg)"""
|
| 100 |
+
pool = await get_connection_pool()
|
| 101 |
+
|
| 102 |
+
async with pool.acquire() as conn:
|
| 103 |
+
# Convert rows to dicts
|
| 104 |
+
rows = await conn.fetch(query, *params)
|
| 105 |
+
return [dict(row) for row in rows]
|
| 106 |
+
|
| 107 |
+
def execute_query_sync(query: str, params: tuple = ()) -> List[Dict]:
|
| 108 |
+
"""Fallback: Execute query synchronously (psycopg2)"""
|
| 109 |
+
conn = None
|
| 110 |
+
cursor = None
|
| 111 |
+
try:
|
| 112 |
+
# Build connection string
|
| 113 |
+
conn_params = {
|
| 114 |
+
"host": AACT_HOST,
|
| 115 |
+
"port": AACT_PORT,
|
| 116 |
+
"dbname": AACT_DB,
|
| 117 |
+
"user": AACT_USER,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
if AACT_PASSWORD:
|
| 121 |
+
conn_params["password"] = AACT_PASSWORD
|
| 122 |
+
|
| 123 |
+
conn = psycopg2.connect(**conn_params)
|
| 124 |
+
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
| 125 |
+
cursor.execute(query, params)
|
| 126 |
+
results = cursor.fetchall()
|
| 127 |
+
|
| 128 |
+
return [dict(row) for row in results]
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Database query failed: {e}")
|
| 132 |
+
raise
|
| 133 |
+
finally:
|
| 134 |
+
if cursor:
|
| 135 |
+
cursor.close()
|
| 136 |
+
if conn:
|
| 137 |
+
conn.close()
|
| 138 |
+
|
| 139 |
+
async def execute_query(query: str, params: tuple = ()) -> List[Dict]:
|
| 140 |
+
"""Execute query using best available method"""
|
| 141 |
+
if ASYNCPG_AVAILABLE:
|
| 142 |
+
# Prefer asyncpg with connection pooling
|
| 143 |
+
return await execute_query_pooled(query, params)
|
| 144 |
+
elif POSTGRES_AVAILABLE:
|
| 145 |
+
# Fallback to synchronous psycopg2
|
| 146 |
+
return await asyncio.to_thread(execute_query_sync, query, params)
|
| 147 |
+
else:
|
| 148 |
+
raise RuntimeError("No PostgreSQL driver available (install asyncpg or psycopg2)")
|
| 149 |
+
|
| 150 |
+
@mcp.tool()
|
| 151 |
+
async def search_als_trials(
|
| 152 |
+
status: Optional[str] = "RECRUITING",
|
| 153 |
+
phase: Optional[str] = None,
|
| 154 |
+
intervention: Optional[str] = None,
|
| 155 |
+
location: Optional[str] = None,
|
| 156 |
+
max_results: int = 20
|
| 157 |
+
) -> str:
|
| 158 |
+
"""Search for ALS clinical trials in the AACT database.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
status: Trial status (RECRUITING, ENROLLING_BY_INVITATION, ACTIVE_NOT_RECRUITING, COMPLETED)
|
| 162 |
+
phase: Trial phase (PHASE_1, PHASE_2, PHASE_3, PHASE_4, EARLY_PHASE_1)
|
| 163 |
+
intervention: Type of intervention to search for
|
| 164 |
+
location: Country or region
|
| 165 |
+
max_results: Maximum number of results to return
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
if not (ASYNCPG_AVAILABLE or POSTGRES_AVAILABLE):
|
| 169 |
+
return json.dumps({
|
| 170 |
+
"error": "Database not available",
|
| 171 |
+
"message": "PostgreSQL driver not installed. Install asyncpg or psycopg2-binary."
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
logger.info(f"🔎 AACT Search: status={status}, phase={phase}, intervention={intervention}, location={location}")
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
# Build the query with proper filters
|
| 178 |
+
base_query = """
|
| 179 |
+
SELECT DISTINCT
|
| 180 |
+
s.nct_id,
|
| 181 |
+
s.brief_title,
|
| 182 |
+
s.overall_status,
|
| 183 |
+
s.phase,
|
| 184 |
+
s.enrollment,
|
| 185 |
+
s.start_date,
|
| 186 |
+
s.completion_date,
|
| 187 |
+
s.study_type,
|
| 188 |
+
s.official_title,
|
| 189 |
+
d.name as sponsor,
|
| 190 |
+
STRING_AGG(DISTINCT i.name, ', ') as interventions,
|
| 191 |
+
STRING_AGG(DISTINCT c.name, ', ') as conditions,
|
| 192 |
+
COUNT(DISTINCT f.id) as num_locations
|
| 193 |
+
FROM studies s
|
| 194 |
+
LEFT JOIN sponsors sp ON s.nct_id = sp.nct_id AND sp.lead_or_collaborator = 'lead'
|
| 195 |
+
LEFT JOIN responsible_parties d ON sp.nct_id = d.nct_id
|
| 196 |
+
LEFT JOIN interventions i ON s.nct_id = i.nct_id
|
| 197 |
+
LEFT JOIN conditions c ON s.nct_id = c.nct_id
|
| 198 |
+
LEFT JOIN facilities f ON s.nct_id = f.nct_id
|
| 199 |
+
WHERE (
|
| 200 |
+
LOWER(c.name) LIKE '%amyotrophic lateral sclerosis%' OR
|
| 201 |
+
LOWER(c.name) LIKE '%als %' OR
|
| 202 |
+
LOWER(c.name) LIKE '% als' OR
|
| 203 |
+
LOWER(c.name) LIKE '%motor neuron disease%' OR
|
| 204 |
+
LOWER(c.name) LIKE '%lou gehrig%' OR
|
| 205 |
+
LOWER(s.brief_title) LIKE '%amyotrophic lateral sclerosis%' OR
|
| 206 |
+
LOWER(s.brief_title) LIKE '%als %' OR
|
| 207 |
+
LOWER(s.brief_title) LIKE '% als' OR
|
| 208 |
+
LOWER(s.official_title) LIKE '%amyotrophic lateral sclerosis%' OR
|
| 209 |
+
LOWER(s.official_title) LIKE '%als %' OR
|
| 210 |
+
LOWER(s.official_title) LIKE '% als'
|
| 211 |
+
)
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
# Apply filters
|
| 215 |
+
conditions = []
|
| 216 |
+
params = []
|
| 217 |
+
param_count = 1
|
| 218 |
+
|
| 219 |
+
if status:
|
| 220 |
+
conditions.append(f"UPPER(s.overall_status) = ${param_count}")
|
| 221 |
+
params.append(status.upper())
|
| 222 |
+
param_count += 1
|
| 223 |
+
|
| 224 |
+
if phase:
|
| 225 |
+
phase_map = {
|
| 226 |
+
'PHASE_1': 'Phase 1',
|
| 227 |
+
'PHASE_2': 'Phase 2',
|
| 228 |
+
'PHASE_3': 'Phase 3',
|
| 229 |
+
'PHASE_4': 'Phase 4',
|
| 230 |
+
'EARLY_PHASE_1': 'Early Phase 1'
|
| 231 |
+
}
|
| 232 |
+
mapped_phase = phase_map.get(phase.upper(), phase)
|
| 233 |
+
conditions.append(f"s.phase = ${param_count}")
|
| 234 |
+
params.append(mapped_phase)
|
| 235 |
+
param_count += 1
|
| 236 |
+
|
| 237 |
+
if intervention:
|
| 238 |
+
conditions.append(f"LOWER(i.name) LIKE ${param_count}")
|
| 239 |
+
params.append(f"%{intervention.lower()}%")
|
| 240 |
+
param_count += 1
|
| 241 |
+
|
| 242 |
+
if location:
|
| 243 |
+
base_query = base_query.replace("LEFT JOIN facilities f", "INNER JOIN facilities f")
|
| 244 |
+
conditions.append(f"(LOWER(f.country) LIKE ${param_count} OR LOWER(f.state) LIKE ${param_count})")
|
| 245 |
+
params.append(f"%{location.lower()}%")
|
| 246 |
+
param_count += 1
|
| 247 |
+
|
| 248 |
+
# Add conditions to query
|
| 249 |
+
if conditions:
|
| 250 |
+
base_query += " AND " + " AND ".join(conditions)
|
| 251 |
+
|
| 252 |
+
# Add GROUP BY and ORDER BY
|
| 253 |
+
base_query += """
|
| 254 |
+
GROUP BY s.nct_id, s.brief_title, s.overall_status, s.phase,
|
| 255 |
+
s.enrollment, s.start_date, s.completion_date,
|
| 256 |
+
s.study_type, s.official_title, d.name
|
| 257 |
+
ORDER BY
|
| 258 |
+
CASE s.overall_status
|
| 259 |
+
WHEN 'Recruiting' THEN 1
|
| 260 |
+
WHEN 'Enrolling by invitation' THEN 2
|
| 261 |
+
WHEN 'Active, not recruiting' THEN 3
|
| 262 |
+
WHEN 'Not yet recruiting' THEN 4
|
| 263 |
+
ELSE 5
|
| 264 |
+
END,
|
| 265 |
+
s.start_date DESC NULLS LAST
|
| 266 |
+
LIMIT ${param_count}
|
| 267 |
+
"""
|
| 268 |
+
params.append(max_results)
|
| 269 |
+
|
| 270 |
+
# Execute query
|
| 271 |
+
logger.debug(f"📊 Executing query with {len(params)} parameters")
|
| 272 |
+
results = await execute_query(base_query, tuple(params))
|
| 273 |
+
|
| 274 |
+
logger.info(f"✅ AACT Results: Found {len(results) if results else 0} trials")
|
| 275 |
+
|
| 276 |
+
if not results:
|
| 277 |
+
return json.dumps({
|
| 278 |
+
"message": "No ALS trials found matching your criteria",
|
| 279 |
+
"total": 0,
|
| 280 |
+
"trials": []
|
| 281 |
+
})
|
| 282 |
+
|
| 283 |
+
# Format results
|
| 284 |
+
trials = []
|
| 285 |
+
for row in results:
|
| 286 |
+
trial = {
|
| 287 |
+
"nct_id": row['nct_id'],
|
| 288 |
+
"title": row['brief_title'],
|
| 289 |
+
"status": row['overall_status'],
|
| 290 |
+
"phase": row['phase'],
|
| 291 |
+
"enrollment": row['enrollment'],
|
| 292 |
+
"sponsor": row['sponsor'],
|
| 293 |
+
"interventions": row['interventions'],
|
| 294 |
+
"conditions": row['conditions'],
|
| 295 |
+
"locations_count": row['num_locations'],
|
| 296 |
+
"start_date": str(row['start_date']) if row['start_date'] else None,
|
| 297 |
+
"completion_date": str(row['completion_date']) if row['completion_date'] else None,
|
| 298 |
+
"url": f"https://clinicaltrials.gov/study/{row['nct_id']}"
|
| 299 |
+
}
|
| 300 |
+
trials.append(trial)
|
| 301 |
+
|
| 302 |
+
return json.dumps({
|
| 303 |
+
"message": f"Found {len(trials)} ALS clinical trials",
|
| 304 |
+
"total": len(trials),
|
| 305 |
+
"trials": trials
|
| 306 |
+
}, indent=2)
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
logger.error(f"❌ AACT Database query failed: {e}")
|
| 310 |
+
logger.error(f" Query type: search_als_trials")
|
| 311 |
+
logger.error(f" Parameters: status={status}, phase={phase}, intervention={intervention}")
|
| 312 |
+
return json.dumps({
|
| 313 |
+
"error": "Database query failed",
|
| 314 |
+
"message": str(e)
|
| 315 |
+
})
|
| 316 |
+
|
| 317 |
+
@mcp.tool()
|
| 318 |
+
async def get_trial_details(nct_id: str) -> str:
|
| 319 |
+
"""Get detailed information about a specific clinical trial.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
nct_id: The NCT ID of the trial (e.g., 'NCT04856982')
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
if not (ASYNCPG_AVAILABLE or POSTGRES_AVAILABLE):
|
| 326 |
+
return json.dumps({
|
| 327 |
+
"error": "Database not available",
|
| 328 |
+
"message": "PostgreSQL driver not installed."
|
| 329 |
+
})
|
| 330 |
+
|
| 331 |
+
try:
|
| 332 |
+
# Main trial information
|
| 333 |
+
main_query = """
|
| 334 |
+
SELECT
|
| 335 |
+
s.nct_id,
|
| 336 |
+
s.brief_title,
|
| 337 |
+
s.official_title,
|
| 338 |
+
s.overall_status,
|
| 339 |
+
s.phase,
|
| 340 |
+
s.study_type,
|
| 341 |
+
s.enrollment,
|
| 342 |
+
s.start_date,
|
| 343 |
+
s.primary_completion_date,
|
| 344 |
+
s.completion_date,
|
| 345 |
+
s.first_posted_date,
|
| 346 |
+
s.last_update_posted_date,
|
| 347 |
+
s.why_stopped,
|
| 348 |
+
b.description as brief_summary,
|
| 349 |
+
dd.description as detailed_description,
|
| 350 |
+
e.criteria as eligibility_criteria,
|
| 351 |
+
e.gender,
|
| 352 |
+
e.minimum_age,
|
| 353 |
+
e.maximum_age,
|
| 354 |
+
e.healthy_volunteers,
|
| 355 |
+
rp.name as sponsor,
|
| 356 |
+
rp.responsible_party_type
|
| 357 |
+
FROM studies s
|
| 358 |
+
LEFT JOIN brief_summaries b ON s.nct_id = b.nct_id
|
| 359 |
+
LEFT JOIN detailed_descriptions dd ON s.nct_id = dd.nct_id
|
| 360 |
+
LEFT JOIN eligibilities e ON s.nct_id = e.nct_id
|
| 361 |
+
LEFT JOIN responsible_parties rp ON s.nct_id = rp.nct_id
|
| 362 |
+
WHERE s.nct_id = $1
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
results = await execute_query(main_query, (nct_id,))
|
| 366 |
+
|
| 367 |
+
if not results:
|
| 368 |
+
return json.dumps({
|
| 369 |
+
"error": "Trial not found",
|
| 370 |
+
"message": f"No trial found with NCT ID: {nct_id}"
|
| 371 |
+
})
|
| 372 |
+
|
| 373 |
+
trial_info = results[0]
|
| 374 |
+
|
| 375 |
+
# Get outcomes
|
| 376 |
+
outcomes_query = """
|
| 377 |
+
SELECT outcome_type, measure, time_frame, description
|
| 378 |
+
FROM outcomes
|
| 379 |
+
WHERE nct_id = $1
|
| 380 |
+
ORDER BY outcome_type, id
|
| 381 |
+
LIMIT 20
|
| 382 |
+
"""
|
| 383 |
+
outcomes = await execute_query(outcomes_query, (nct_id,))
|
| 384 |
+
|
| 385 |
+
# Get interventions
|
| 386 |
+
interventions_query = """
|
| 387 |
+
SELECT intervention_type, name, description
|
| 388 |
+
FROM interventions
|
| 389 |
+
WHERE nct_id = $1
|
| 390 |
+
"""
|
| 391 |
+
interventions = await execute_query(interventions_query, (nct_id,))
|
| 392 |
+
|
| 393 |
+
# Get locations
|
| 394 |
+
locations_query = """
|
| 395 |
+
SELECT name, city, state, country, status
|
| 396 |
+
FROM facilities
|
| 397 |
+
WHERE nct_id = $1
|
| 398 |
+
LIMIT 50
|
| 399 |
+
"""
|
| 400 |
+
locations = await execute_query(locations_query, (nct_id,))
|
| 401 |
+
|
| 402 |
+
# Format the response
|
| 403 |
+
return json.dumps({
|
| 404 |
+
"nct_id": trial_info['nct_id'],
|
| 405 |
+
"title": trial_info['brief_title'],
|
| 406 |
+
"official_title": trial_info['official_title'],
|
| 407 |
+
"status": trial_info['overall_status'],
|
| 408 |
+
"phase": trial_info['phase'],
|
| 409 |
+
"study_type": trial_info['study_type'],
|
| 410 |
+
"enrollment": trial_info['enrollment'],
|
| 411 |
+
"sponsor": trial_info['sponsor'],
|
| 412 |
+
"dates": {
|
| 413 |
+
"start": str(trial_info['start_date']) if trial_info['start_date'] else None,
|
| 414 |
+
"primary_completion": str(trial_info['primary_completion_date']) if trial_info['primary_completion_date'] else None,
|
| 415 |
+
"completion": str(trial_info['completion_date']) if trial_info['completion_date'] else None,
|
| 416 |
+
"first_posted": str(trial_info['first_posted_date']) if trial_info['first_posted_date'] else None,
|
| 417 |
+
"last_updated": str(trial_info['last_update_posted_date']) if trial_info['last_update_posted_date'] else None
|
| 418 |
+
},
|
| 419 |
+
"summary": trial_info['brief_summary'],
|
| 420 |
+
"detailed_description": trial_info['detailed_description'],
|
| 421 |
+
"eligibility": {
|
| 422 |
+
"criteria": trial_info['eligibility_criteria'],
|
| 423 |
+
"gender": trial_info['gender'],
|
| 424 |
+
"age_range": f"{trial_info['minimum_age'] or 'N/A'} - {trial_info['maximum_age'] or 'N/A'}",
|
| 425 |
+
"healthy_volunteers": trial_info['healthy_volunteers']
|
| 426 |
+
},
|
| 427 |
+
"outcomes": [
|
| 428 |
+
{
|
| 429 |
+
"type": o['outcome_type'],
|
| 430 |
+
"measure": o['measure'],
|
| 431 |
+
"time_frame": o['time_frame'],
|
| 432 |
+
"description": o['description']
|
| 433 |
+
} for o in outcomes
|
| 434 |
+
],
|
| 435 |
+
"interventions": [
|
| 436 |
+
{
|
| 437 |
+
"type": i['intervention_type'],
|
| 438 |
+
"name": i['name'],
|
| 439 |
+
"description": i['description']
|
| 440 |
+
} for i in interventions
|
| 441 |
+
],
|
| 442 |
+
"locations": [
|
| 443 |
+
{
|
| 444 |
+
"name": l['name'],
|
| 445 |
+
"city": l['city'],
|
| 446 |
+
"state": l['state'],
|
| 447 |
+
"country": l['country'],
|
| 448 |
+
"status": l['status']
|
| 449 |
+
} for l in locations
|
| 450 |
+
],
|
| 451 |
+
"url": f"https://clinicaltrials.gov/study/{nct_id}"
|
| 452 |
+
}, indent=2)
|
| 453 |
+
|
| 454 |
+
except Exception as e:
|
| 455 |
+
logger.error(f"Failed to get trial details: {e}")
|
| 456 |
+
return json.dumps({
|
| 457 |
+
"error": "Database query failed",
|
| 458 |
+
"message": str(e)
|
| 459 |
+
})
|
| 460 |
+
|
| 461 |
+
# Cleanup on shutdown
|
| 462 |
+
# Note: FastMCP doesn't have a built-in shutdown handler
|
| 463 |
+
# The connection pool will be closed when the process ends
|
| 464 |
+
# async def cleanup():
|
| 465 |
+
# """Close the connection pool on shutdown"""
|
| 466 |
+
# global _connection_pool
|
| 467 |
+
# if _connection_pool:
|
| 468 |
+
# await _connection_pool.close()
|
| 469 |
+
# logger.info("Connection pool closed")
|
| 470 |
+
|
| 471 |
+
if __name__ == "__main__":
|
| 472 |
+
mcp.run()
|
servers/biorxiv_server.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# biorxiv_server_fixed.py
|
| 2 |
+
from mcp.server.fastmcp import FastMCP
|
| 3 |
+
import httpx
|
| 4 |
+
import logging
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
# Add parent directory to path for shared imports
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 12 |
+
|
| 13 |
+
from shared import (
|
| 14 |
+
config,
|
| 15 |
+
RateLimiter,
|
| 16 |
+
format_authors,
|
| 17 |
+
ErrorFormatter,
|
| 18 |
+
truncate_text
|
| 19 |
+
)
|
| 20 |
+
from shared.http_client import get_http_client, CustomHTTPClient
|
| 21 |
+
|
| 22 |
+
# Configure logging with DEBUG for detailed troubleshooting
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.DEBUG,
|
| 25 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 26 |
+
)
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
mcp = FastMCP("biorxiv-server")
|
| 30 |
+
|
| 31 |
+
# Rate limiting using shared utility
|
| 32 |
+
rate_limiter = RateLimiter(config.rate_limits.biorxiv_delay)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def preprocess_query(query: str) -> tuple[list[str], list[str]]:
|
| 36 |
+
"""Preprocess query into search terms and handle synonyms.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
tuple of (primary_terms, all_search_terms)
|
| 40 |
+
"""
|
| 41 |
+
# Convert to lowercase for matching
|
| 42 |
+
query_lower = query.lower()
|
| 43 |
+
|
| 44 |
+
# Common ALS-related synonyms and variations
|
| 45 |
+
synonyms = {
|
| 46 |
+
'als': ['amyotrophic lateral sclerosis', 'motor neuron disease', 'motor neurone disease', 'lou gehrig'],
|
| 47 |
+
'amyotrophic lateral sclerosis': ['als', 'motor neuron disease'],
|
| 48 |
+
'mnd': ['motor neuron disease', 'motor neurone disease', 'als'],
|
| 49 |
+
'sod1': ['superoxide dismutase 1', 'cu/zn superoxide dismutase'],
|
| 50 |
+
'tdp-43': ['tdp43', 'tardbp', 'tar dna binding protein'],
|
| 51 |
+
'c9orf72': ['c9', 'chromosome 9 open reading frame 72'],
|
| 52 |
+
'fus': ['fused in sarcoma', 'tls'],
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# Split query into individual terms (handle multiple spaces and special chars)
|
| 56 |
+
# Keep hyphenated words together (like TDP-43)
|
| 57 |
+
terms = re.split(r'\s+', query_lower.strip())
|
| 58 |
+
|
| 59 |
+
# Build comprehensive search term list
|
| 60 |
+
all_terms = []
|
| 61 |
+
primary_terms = []
|
| 62 |
+
|
| 63 |
+
for term in terms:
|
| 64 |
+
# Skip very short terms unless they're known abbreviations
|
| 65 |
+
if len(term) < 3 and term not in ['als', 'mnd', 'fus', 'c9']:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
primary_terms.append(term)
|
| 69 |
+
all_terms.append(term)
|
| 70 |
+
|
| 71 |
+
# Add synonyms if they exist
|
| 72 |
+
if term in synonyms:
|
| 73 |
+
all_terms.extend(synonyms[term])
|
| 74 |
+
|
| 75 |
+
# Remove duplicates while preserving order
|
| 76 |
+
seen = set()
|
| 77 |
+
all_terms = [t for t in all_terms if not (t in seen or seen.add(t))]
|
| 78 |
+
primary_terms = [t for t in primary_terms if not (t in seen or seen.add(t))]
|
| 79 |
+
|
| 80 |
+
return primary_terms, all_terms
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def matches_query(paper: dict, primary_terms: list[str], all_terms: list[str], require_all: bool = False) -> bool:
|
| 84 |
+
"""Check if a paper matches the search query.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
paper: Paper dictionary from bioRxiv API
|
| 88 |
+
primary_terms: Main search terms from user query
|
| 89 |
+
all_terms: All search terms including synonyms
|
| 90 |
+
require_all: If True, require ALL primary terms. If False, require ANY term.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
True if paper matches search criteria
|
| 94 |
+
"""
|
| 95 |
+
# Get searchable text
|
| 96 |
+
title = paper.get("title", "").lower()
|
| 97 |
+
abstract = paper.get("abstract", "").lower()
|
| 98 |
+
searchable_text = f" {title} {abstract} " # Add spaces for boundary matching
|
| 99 |
+
|
| 100 |
+
# DEBUG: Log paper being checked
|
| 101 |
+
paper_doi = paper.get("doi", "unknown")
|
| 102 |
+
logger.debug(f"🔍 Checking paper: {title[:60]}... (DOI: {paper_doi})")
|
| 103 |
+
|
| 104 |
+
if not searchable_text.strip():
|
| 105 |
+
logger.debug(f" ❌ Rejected: No title/abstract")
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
# For ALS specifically, need to be careful about word boundaries
|
| 109 |
+
has_any_match = False
|
| 110 |
+
matched_term = None
|
| 111 |
+
for term in all_terms:
|
| 112 |
+
# For short terms like "ALS", require word boundaries
|
| 113 |
+
if len(term) <= 3:
|
| 114 |
+
# Check for word boundary match
|
| 115 |
+
pattern = r'\b' + re.escape(term) + r'\b'
|
| 116 |
+
if re.search(pattern, searchable_text, re.IGNORECASE):
|
| 117 |
+
has_any_match = True
|
| 118 |
+
matched_term = term
|
| 119 |
+
break
|
| 120 |
+
else:
|
| 121 |
+
# For longer terms, can be more lenient
|
| 122 |
+
if term.lower() in searchable_text:
|
| 123 |
+
has_any_match = True
|
| 124 |
+
matched_term = term
|
| 125 |
+
break
|
| 126 |
+
|
| 127 |
+
if not has_any_match:
|
| 128 |
+
logger.debug(f" ❌ Rejected: No term match. Terms searched: {all_terms[:3]}...")
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
logger.debug(f" ✅ Matched on term: '{matched_term}'")
|
| 132 |
+
|
| 133 |
+
# If we only need any match, we're done
|
| 134 |
+
if not require_all:
|
| 135 |
+
return True
|
| 136 |
+
|
| 137 |
+
# For require_all, check that all primary terms are present
|
| 138 |
+
# Allow for word boundaries to avoid partial matches
|
| 139 |
+
for term in primary_terms:
|
| 140 |
+
# Create pattern that matches the term as a whole word or part of hyphenated word
|
| 141 |
+
# This handles cases like "TDP-43" or "SOD1"
|
| 142 |
+
pattern = r'\b' + re.escape(term) + r'(?:\b|[-])'
|
| 143 |
+
if not re.search(pattern, searchable_text, re.IGNORECASE):
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@mcp.tool()
|
| 150 |
+
async def search_preprints(
|
| 151 |
+
query: str,
|
| 152 |
+
server: str = "both",
|
| 153 |
+
max_results: int = 10,
|
| 154 |
+
days_back: int = 365
|
| 155 |
+
) -> str:
|
| 156 |
+
"""Search bioRxiv and medRxiv for ALS preprints. Returns recent preprints before peer review.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
query: Search query (e.g., 'ALS TDP-43')
|
| 160 |
+
server: Which server to search - one of: biorxiv, medrxiv, both (default: both)
|
| 161 |
+
max_results: Maximum number of results (default: 10)
|
| 162 |
+
days_back: Number of days to look back (default: 365 - about 1 year)
|
| 163 |
+
"""
|
| 164 |
+
try:
|
| 165 |
+
logger.info(f"🔎 Searching bioRxiv/medRxiv for: '{query}'")
|
| 166 |
+
logger.info(f" Parameters: server={server}, max_results={max_results}, days_back={days_back}")
|
| 167 |
+
|
| 168 |
+
# Preprocess query for better matching
|
| 169 |
+
primary_terms, all_terms = preprocess_query(query)
|
| 170 |
+
logger.info(f"📝 Search terms: primary={primary_terms}, all={all_terms}")
|
| 171 |
+
|
| 172 |
+
# Calculate date range
|
| 173 |
+
end_date = datetime.now()
|
| 174 |
+
start_date = end_date - timedelta(days=days_back)
|
| 175 |
+
|
| 176 |
+
# Format dates for API (YYYY-MM-DD)
|
| 177 |
+
start_date_str = start_date.strftime("%Y-%m-%d")
|
| 178 |
+
end_date_str = end_date.strftime("%Y-%m-%d")
|
| 179 |
+
logger.info(f"📅 Date range: {start_date_str} to {end_date_str}")
|
| 180 |
+
|
| 181 |
+
# bioRxiv/medRxiv API endpoint
|
| 182 |
+
base_url = "https://api.biorxiv.org/details"
|
| 183 |
+
|
| 184 |
+
all_results = []
|
| 185 |
+
servers_to_search = []
|
| 186 |
+
|
| 187 |
+
if server in ["biorxiv", "both"]:
|
| 188 |
+
servers_to_search.append("biorxiv")
|
| 189 |
+
if server in ["medrxiv", "both"]:
|
| 190 |
+
servers_to_search.append("medrxiv")
|
| 191 |
+
|
| 192 |
+
# Use a custom HTTP client with proper timeout for bioRxiv
|
| 193 |
+
# Don't use shared client as it may have conflicting timeout settings
|
| 194 |
+
async with CustomHTTPClient(timeout=15.0) as client:
|
| 195 |
+
for srv in servers_to_search:
|
| 196 |
+
try:
|
| 197 |
+
cursor = 0
|
| 198 |
+
found_in_server = []
|
| 199 |
+
max_iterations = 1 # Only check first page (100 papers) for much faster response
|
| 200 |
+
iteration = 0
|
| 201 |
+
|
| 202 |
+
while iteration < max_iterations:
|
| 203 |
+
# Rate limiting
|
| 204 |
+
await rate_limiter.wait()
|
| 205 |
+
|
| 206 |
+
# Search by date range with cursor for pagination
|
| 207 |
+
url = f"{base_url}/{srv}/{start_date_str}/{end_date_str}/{cursor}"
|
| 208 |
+
|
| 209 |
+
logger.info(f"🌐 Querying {srv} API (page {iteration+1}, cursor={cursor})")
|
| 210 |
+
logger.info(f" URL: {url}")
|
| 211 |
+
response = await client.get(url)
|
| 212 |
+
response.raise_for_status()
|
| 213 |
+
data = response.json()
|
| 214 |
+
|
| 215 |
+
# Extract collection
|
| 216 |
+
collection = data.get("collection", [])
|
| 217 |
+
|
| 218 |
+
if not collection:
|
| 219 |
+
logger.info(f"📭 No more results from {srv}")
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
logger.info(f"📦 Fetched {len(collection)} papers from API")
|
| 223 |
+
|
| 224 |
+
# Show first few papers for debugging
|
| 225 |
+
if iteration == 0 and collection:
|
| 226 |
+
logger.info(" Sample papers from API:")
|
| 227 |
+
for i, paper in enumerate(collection[:3]):
|
| 228 |
+
logger.info(f" {i+1}. {paper.get('title', 'No title')[:60]}...")
|
| 229 |
+
|
| 230 |
+
# Filter papers using improved matching
|
| 231 |
+
# Start with lenient matching (ANY term)
|
| 232 |
+
logger.debug(f"🔍 Starting to filter {len(collection)} papers...")
|
| 233 |
+
filtered = [
|
| 234 |
+
paper for paper in collection
|
| 235 |
+
if matches_query(paper, primary_terms, all_terms, require_all=False)
|
| 236 |
+
]
|
| 237 |
+
|
| 238 |
+
logger.info(f"✅ Filtered results: {len(filtered)}/{len(collection)} papers matched")
|
| 239 |
+
|
| 240 |
+
if len(filtered) > 0:
|
| 241 |
+
logger.info(" Matched papers:")
|
| 242 |
+
for i, paper in enumerate(filtered[:3]):
|
| 243 |
+
logger.info(f" {i+1}. {paper.get('title', 'No title')[:60]}...")
|
| 244 |
+
|
| 245 |
+
found_in_server.extend(filtered)
|
| 246 |
+
logger.info(f"📊 Running total for {srv}: {len(found_in_server)} papers")
|
| 247 |
+
|
| 248 |
+
# Check if we have enough results
|
| 249 |
+
if len(found_in_server) >= max_results:
|
| 250 |
+
logger.info(f"Reached max_results limit ({max_results})")
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
# Continue searching if we haven't found enough
|
| 254 |
+
if len(found_in_server) < 5 and iteration < max_iterations - 1:
|
| 255 |
+
# Keep searching for more results
|
| 256 |
+
pass
|
| 257 |
+
elif len(found_in_server) > 0 and iteration >= 3:
|
| 258 |
+
# Found some results after reasonable search
|
| 259 |
+
logger.info(f"Found {len(found_in_server)} results after {iteration+1} pages")
|
| 260 |
+
break
|
| 261 |
+
|
| 262 |
+
# Check for more pages
|
| 263 |
+
messages = data.get("messages", [])
|
| 264 |
+
|
| 265 |
+
# The API returns "cursor" in messages for next page
|
| 266 |
+
has_more = False
|
| 267 |
+
for msg in messages:
|
| 268 |
+
if "cursor=" in str(msg):
|
| 269 |
+
try:
|
| 270 |
+
cursor_str = str(msg).split("cursor=")[1].split()[0]
|
| 271 |
+
next_cursor = int(cursor_str)
|
| 272 |
+
if next_cursor > cursor:
|
| 273 |
+
cursor = next_cursor
|
| 274 |
+
has_more = True
|
| 275 |
+
break
|
| 276 |
+
except:
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
# Alternative: increment by collection size
|
| 280 |
+
if not has_more:
|
| 281 |
+
if len(collection) >= 100:
|
| 282 |
+
cursor += len(collection)
|
| 283 |
+
else:
|
| 284 |
+
# Less than full page means we've reached the end
|
| 285 |
+
break
|
| 286 |
+
|
| 287 |
+
iteration += 1
|
| 288 |
+
|
| 289 |
+
all_results.extend(found_in_server[:max_results])
|
| 290 |
+
logger.info(f"🏁 Total results from {srv}: {len(found_in_server)} papers found")
|
| 291 |
+
|
| 292 |
+
except httpx.HTTPStatusError as e:
|
| 293 |
+
logger.warning(f"Error searching {srv}: {e}")
|
| 294 |
+
continue
|
| 295 |
+
except Exception as e:
|
| 296 |
+
logger.warning(f"Unexpected error searching {srv}: {e}")
|
| 297 |
+
continue
|
| 298 |
+
|
| 299 |
+
# If no results with lenient matching, provide helpful message
|
| 300 |
+
if not all_results:
|
| 301 |
+
logger.warning(f"⚠️ No preprints found for query: {query}")
|
| 302 |
+
|
| 303 |
+
# Provide suggestions for improving search
|
| 304 |
+
suggestions = []
|
| 305 |
+
if len(primary_terms) > 3:
|
| 306 |
+
suggestions.append("Try using fewer search terms")
|
| 307 |
+
if not any(term in ['als', 'amyotrophic lateral sclerosis', 'motor neuron'] for term in all_terms):
|
| 308 |
+
suggestions.append("Add 'ALS' or 'motor neuron disease' to your search")
|
| 309 |
+
if days_back < 365:
|
| 310 |
+
suggestions.append(f"Expand the time range beyond {days_back} days")
|
| 311 |
+
|
| 312 |
+
suggestion_text = ""
|
| 313 |
+
if suggestions:
|
| 314 |
+
suggestion_text = "\n\nSuggestions:\n" + "\n".join(f"- {s}" for s in suggestions)
|
| 315 |
+
|
| 316 |
+
return f"No preprints found for query: '{query}' in the last {days_back} days{suggestion_text}"
|
| 317 |
+
|
| 318 |
+
# Sort by date (most recent first)
|
| 319 |
+
all_results.sort(key=lambda x: x.get("date", ""), reverse=True)
|
| 320 |
+
|
| 321 |
+
# Limit results
|
| 322 |
+
all_results = all_results[:max_results]
|
| 323 |
+
|
| 324 |
+
logger.info(f"🎯 FINAL RESULTS: Returning {len(all_results)} preprints for '{query}'")
|
| 325 |
+
if all_results:
|
| 326 |
+
logger.info(" Top results:")
|
| 327 |
+
for i, paper in enumerate(all_results[:3], 1):
|
| 328 |
+
logger.info(f" {i}. {paper.get('title', 'No title')[:60]}...")
|
| 329 |
+
logger.info(f" DOI: {paper.get('doi', 'unknown')}, Date: {paper.get('date', 'unknown')}")
|
| 330 |
+
|
| 331 |
+
# Format results
|
| 332 |
+
result = f"Found {len(all_results)} preprints for query: '{query}'\n\n"
|
| 333 |
+
|
| 334 |
+
for i, paper in enumerate(all_results, 1):
|
| 335 |
+
title = paper.get("title", "No title")
|
| 336 |
+
doi = paper.get("doi", "Unknown")
|
| 337 |
+
date = paper.get("date", "Unknown")
|
| 338 |
+
authors = paper.get("authors", "Unknown authors")
|
| 339 |
+
authors_str = format_authors(authors, max_authors=3)
|
| 340 |
+
|
| 341 |
+
abstract = paper.get("abstract", "No abstract available")
|
| 342 |
+
category = paper.get("category", "")
|
| 343 |
+
server_name = "bioRxiv" if "biorxiv" in doi else "medRxiv"
|
| 344 |
+
|
| 345 |
+
result += f"{i}. **{title}**\n"
|
| 346 |
+
result += f" DOI: {doi} | {server_name} | Posted: {date}\n"
|
| 347 |
+
result += f" Authors: {authors_str}\n"
|
| 348 |
+
if category:
|
| 349 |
+
result += f" Category: {category}\n"
|
| 350 |
+
result += f" Abstract: {truncate_text(abstract, max_chars=300, suffix='')}\n"
|
| 351 |
+
result += f" URL: https://doi.org/{doi}\n\n"
|
| 352 |
+
|
| 353 |
+
logger.info(f"Successfully retrieved {len(all_results)} preprints")
|
| 354 |
+
return result
|
| 355 |
+
|
| 356 |
+
except httpx.TimeoutException:
|
| 357 |
+
logger.error("bioRxiv/medRxiv API request timed out")
|
| 358 |
+
return "Error: bioRxiv/medRxiv API request timed out. Please try again."
|
| 359 |
+
except httpx.HTTPStatusError as e:
|
| 360 |
+
logger.error(f"bioRxiv/medRxiv API error: {e}")
|
| 361 |
+
return f"Error: bioRxiv/medRxiv API returned status code {e.response.status_code}"
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"Unexpected error in search_preprints: {e}")
|
| 364 |
+
return f"Error searching preprints: {str(e)}"
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
@mcp.tool()
|
| 368 |
+
async def get_preprint_details(doi: str) -> str:
|
| 369 |
+
"""Get full details for a specific bioRxiv/medRxiv preprint by DOI.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
doi: The DOI of the preprint (e.g., '10.1101/2024.01.01.123456')
|
| 373 |
+
"""
|
| 374 |
+
try:
|
| 375 |
+
logger.info(f"Getting details for DOI: {doi}")
|
| 376 |
+
|
| 377 |
+
# Ensure DOI is properly formatted
|
| 378 |
+
if not doi.startswith("10.1101/"):
|
| 379 |
+
doi = f"10.1101/{doi}"
|
| 380 |
+
|
| 381 |
+
# Determine server from DOI
|
| 382 |
+
# bioRxiv DOIs typically have format: 10.1101/YYYY.MM.DD.NNNNNN
|
| 383 |
+
# medRxiv DOIs are similar but the content determines the server
|
| 384 |
+
|
| 385 |
+
# Use shared HTTP client for connection pooling
|
| 386 |
+
client = get_http_client(timeout=30.0)
|
| 387 |
+
# Try the DOI endpoint
|
| 388 |
+
url = f"https://api.biorxiv.org/details/{doi}"
|
| 389 |
+
|
| 390 |
+
response = await client.get(url)
|
| 391 |
+
|
| 392 |
+
if response.status_code == 404:
|
| 393 |
+
# Try with both servers
|
| 394 |
+
for server in ["biorxiv", "medrxiv"]:
|
| 395 |
+
url = f"https://api.biorxiv.org/details/{server}/{doi}"
|
| 396 |
+
response = await client.get(url)
|
| 397 |
+
if response.status_code == 200:
|
| 398 |
+
break
|
| 399 |
+
else:
|
| 400 |
+
return f"Preprint with DOI {doi} not found"
|
| 401 |
+
|
| 402 |
+
response.raise_for_status()
|
| 403 |
+
data = response.json()
|
| 404 |
+
|
| 405 |
+
collection = data.get("collection", [])
|
| 406 |
+
if not collection:
|
| 407 |
+
return f"No details found for DOI: {doi}"
|
| 408 |
+
|
| 409 |
+
# Get the first (and should be only) paper
|
| 410 |
+
paper = collection[0]
|
| 411 |
+
|
| 412 |
+
title = paper.get("title", "No title")
|
| 413 |
+
date = paper.get("date", "Unknown")
|
| 414 |
+
authors = paper.get("authors", "Unknown authors")
|
| 415 |
+
abstract = paper.get("abstract", "No abstract available")
|
| 416 |
+
category = paper.get("category", "")
|
| 417 |
+
server_name = paper.get("server", "Unknown")
|
| 418 |
+
|
| 419 |
+
result = f"**{title}**\n\n"
|
| 420 |
+
result += f"**DOI:** {doi}\n"
|
| 421 |
+
result += f"**Server:** {server_name}\n"
|
| 422 |
+
result += f"**Posted:** {date}\n"
|
| 423 |
+
if category:
|
| 424 |
+
result += f"**Category:** {category}\n"
|
| 425 |
+
result += f"**Authors:** {authors}\n\n"
|
| 426 |
+
result += f"**Abstract:**\n{abstract}\n\n"
|
| 427 |
+
result += f"**Full Text URL:** https://doi.org/{doi}\n"
|
| 428 |
+
|
| 429 |
+
return result
|
| 430 |
+
|
| 431 |
+
except httpx.HTTPStatusError as e:
|
| 432 |
+
logger.error(f"Error fetching preprint details: {e}")
|
| 433 |
+
return f"Error fetching preprint details: HTTP {e.response.status_code}"
|
| 434 |
+
except Exception as e:
|
| 435 |
+
logger.error(f"Unexpected error getting preprint details: {e}")
|
| 436 |
+
return f"Error getting preprint details: {str(e)}"
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if __name__ == "__main__":
|
| 440 |
+
mcp.run(transport="stdio")
|
servers/clinicaltrials_links.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simplified ClinicalTrials.gov Link Generator
|
| 4 |
+
Provides direct links and known trials as fallback when AACT is unavailable
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from mcp.server.fastmcp import FastMCP
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from urllib.parse import quote_plus
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
mcp = FastMCP("clinicaltrials-links")
|
| 17 |
+
|
| 18 |
+
# Known important ALS trials (updated periodically)
|
| 19 |
+
KNOWN_ALS_TRIALS = {
|
| 20 |
+
"NCT05112094": {
|
| 21 |
+
"title": "Tofersen (ATLAS)",
|
| 22 |
+
"description": "SOD1-targeted antisense therapy for SOD1-ALS",
|
| 23 |
+
"status": "Active",
|
| 24 |
+
"sponsor": "Biogen"
|
| 25 |
+
},
|
| 26 |
+
"NCT04856982": {
|
| 27 |
+
"title": "HEALEY ALS Platform Trial",
|
| 28 |
+
"description": "Multiple drugs tested simultaneously",
|
| 29 |
+
"status": "Recruiting",
|
| 30 |
+
"sponsor": "Massachusetts General Hospital"
|
| 31 |
+
},
|
| 32 |
+
"NCT04768972": {
|
| 33 |
+
"title": "Ravulizumab",
|
| 34 |
+
"description": "Complement C5 inhibition",
|
| 35 |
+
"status": "Active",
|
| 36 |
+
"sponsor": "Alexion"
|
| 37 |
+
},
|
| 38 |
+
"NCT05370950": {
|
| 39 |
+
"title": "Pridopidine",
|
| 40 |
+
"description": "Sigma-1 receptor agonist",
|
| 41 |
+
"status": "Recruiting",
|
| 42 |
+
"sponsor": "Prilenia"
|
| 43 |
+
},
|
| 44 |
+
"NCT04632225": {
|
| 45 |
+
"title": "NurOwn",
|
| 46 |
+
"description": "MSC-NTF cells (mesenchymal stem cells)",
|
| 47 |
+
"status": "Active",
|
| 48 |
+
"sponsor": "BrainStorm Cell"
|
| 49 |
+
},
|
| 50 |
+
"NCT07204977": {
|
| 51 |
+
"title": "Acamprosate",
|
| 52 |
+
"description": "C9orf72 hexanucleotide repeat expansion treatment",
|
| 53 |
+
"status": "Recruiting",
|
| 54 |
+
"sponsor": "Mayo Clinic"
|
| 55 |
+
},
|
| 56 |
+
"NCT07161999": {
|
| 57 |
+
"title": "COYA 302",
|
| 58 |
+
"description": "Regulatory T-cell therapy",
|
| 59 |
+
"status": "Recruiting",
|
| 60 |
+
"sponsor": "Coya Therapeutics"
|
| 61 |
+
},
|
| 62 |
+
"NCT07023835": {
|
| 63 |
+
"title": "Usnoflast",
|
| 64 |
+
"description": "Anti-inflammatory for ALS",
|
| 65 |
+
"status": "Recruiting",
|
| 66 |
+
"sponsor": "Seelos Therapeutics"
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@mcp.tool()
|
| 72 |
+
async def get_trial_link(nct_id: str) -> str:
|
| 73 |
+
"""Generate direct link to a ClinicalTrials.gov trial page.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
nct_id: NCT identifier (e.g., 'NCT05112094')
|
| 77 |
+
"""
|
| 78 |
+
nct_id = nct_id.upper()
|
| 79 |
+
url = f"https://clinicaltrials.gov/study/{nct_id}"
|
| 80 |
+
|
| 81 |
+
result = f"**Direct link to trial {nct_id}:**\n{url}\n\n"
|
| 82 |
+
|
| 83 |
+
# Add info if it's a known trial
|
| 84 |
+
if nct_id in KNOWN_ALS_TRIALS:
|
| 85 |
+
trial = KNOWN_ALS_TRIALS[nct_id]
|
| 86 |
+
result += f"**{trial['title']}**\n"
|
| 87 |
+
result += f"Description: {trial['description']}\n"
|
| 88 |
+
result += f"Status: {trial['status']}\n"
|
| 89 |
+
result += f"Sponsor: {trial['sponsor']}\n"
|
| 90 |
+
|
| 91 |
+
return result
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@mcp.tool()
|
| 95 |
+
async def get_search_link(
|
| 96 |
+
condition: str = "ALS",
|
| 97 |
+
status: Optional[str] = None,
|
| 98 |
+
intervention: Optional[str] = None,
|
| 99 |
+
location: Optional[str] = None
|
| 100 |
+
) -> str:
|
| 101 |
+
"""Generate direct search link for ClinicalTrials.gov.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
condition: Medical condition (default: ALS)
|
| 105 |
+
status: Trial status (recruiting, active, completed)
|
| 106 |
+
intervention: Treatment/drug name
|
| 107 |
+
location: Country or city
|
| 108 |
+
"""
|
| 109 |
+
base_url = "https://clinicaltrials.gov/search"
|
| 110 |
+
params = []
|
| 111 |
+
|
| 112 |
+
# Add condition
|
| 113 |
+
params.append(f"cond={quote_plus(condition)}")
|
| 114 |
+
|
| 115 |
+
# Map status to ClinicalTrials.gov format
|
| 116 |
+
if status:
|
| 117 |
+
status_lower = status.lower()
|
| 118 |
+
if "recruit" in status_lower:
|
| 119 |
+
params.append("recrs=a") # Recruiting
|
| 120 |
+
elif "active" in status_lower:
|
| 121 |
+
params.append("recrs=d") # Active, not recruiting
|
| 122 |
+
elif "complet" in status_lower:
|
| 123 |
+
params.append("recrs=e") # Completed
|
| 124 |
+
|
| 125 |
+
# Add intervention
|
| 126 |
+
if intervention:
|
| 127 |
+
params.append(f"intr={quote_plus(intervention)}")
|
| 128 |
+
|
| 129 |
+
# Add location
|
| 130 |
+
if location:
|
| 131 |
+
params.append(f"locn={quote_plus(location)}")
|
| 132 |
+
|
| 133 |
+
# Build URL
|
| 134 |
+
search_url = f"{base_url}?{'&'.join(params)}"
|
| 135 |
+
|
| 136 |
+
result = f"**Direct search on ClinicalTrials.gov:**\n\n"
|
| 137 |
+
result += f"Search parameters:\n"
|
| 138 |
+
result += f"- Condition: {condition}\n"
|
| 139 |
+
if status:
|
| 140 |
+
result += f"- Status: {status}\n"
|
| 141 |
+
if intervention:
|
| 142 |
+
result += f"- Intervention: {intervention}\n"
|
| 143 |
+
if location:
|
| 144 |
+
result += f"- Location: {location}\n"
|
| 145 |
+
result += f"\n🔗 **Search URL:** {search_url}\n"
|
| 146 |
+
result += f"\nClick the link above to see results on ClinicalTrials.gov"
|
| 147 |
+
|
| 148 |
+
return result
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@mcp.tool()
|
| 152 |
+
async def get_known_als_trials(
|
| 153 |
+
status_filter: Optional[str] = None
|
| 154 |
+
) -> str:
|
| 155 |
+
"""Get list of known important ALS trials.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
status_filter: Filter by status (recruiting, active, all)
|
| 159 |
+
"""
|
| 160 |
+
result = "**Important ALS Clinical Trials:**\n\n"
|
| 161 |
+
|
| 162 |
+
if not KNOWN_ALS_TRIALS:
|
| 163 |
+
return "No known trials available in offline database."
|
| 164 |
+
|
| 165 |
+
count = 0
|
| 166 |
+
for nct_id, trial in KNOWN_ALS_TRIALS.items():
|
| 167 |
+
# Apply status filter if provided
|
| 168 |
+
if status_filter:
|
| 169 |
+
filter_lower = status_filter.lower()
|
| 170 |
+
trial_status = trial['status'].lower()
|
| 171 |
+
|
| 172 |
+
if filter_lower == "recruiting" and "recruit" not in trial_status:
|
| 173 |
+
continue
|
| 174 |
+
elif filter_lower == "active" and "active" not in trial_status:
|
| 175 |
+
continue
|
| 176 |
+
elif filter_lower == "completed" and "complet" not in trial_status:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
count += 1
|
| 180 |
+
result += f"{count}. **{trial['title']}** ({nct_id})\n"
|
| 181 |
+
result += f" {trial['description']}\n"
|
| 182 |
+
result += f" Status: {trial['status']} | Sponsor: {trial['sponsor']}\n"
|
| 183 |
+
result += f" 🔗 https://clinicaltrials.gov/study/{nct_id}\n\n"
|
| 184 |
+
|
| 185 |
+
if count == 0:
|
| 186 |
+
result += f"No trials found with status filter: {status_filter}\n"
|
| 187 |
+
else:
|
| 188 |
+
result += f"\n📌 *This is a curated list. For comprehensive search, use AACT database server.*"
|
| 189 |
+
|
| 190 |
+
return result
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@mcp.tool()
|
| 194 |
+
async def get_trial_resources() -> str:
|
| 195 |
+
"""Get helpful resources for finding clinical trials."""
|
| 196 |
+
|
| 197 |
+
resources = """**Clinical Trials Resources for ALS:**
|
| 198 |
+
|
| 199 |
+
**Official Databases:**
|
| 200 |
+
1. **ClinicalTrials.gov**: https://clinicaltrials.gov/search?cond=ALS
|
| 201 |
+
- Official US trials registry
|
| 202 |
+
- Most comprehensive for US trials
|
| 203 |
+
|
| 204 |
+
2. **WHO ICTRP**: https://trialsearch.who.int/
|
| 205 |
+
- International trials from all countries
|
| 206 |
+
- Includes non-US trials
|
| 207 |
+
|
| 208 |
+
3. **EU Clinical Trials Register**: https://www.clinicaltrialsregister.eu/
|
| 209 |
+
- European trials database
|
| 210 |
+
|
| 211 |
+
**ALS-Specific Resources:**
|
| 212 |
+
1. **Northeast ALS Consortium (NEALS)**: https://www.neals.org/
|
| 213 |
+
- Network of ALS clinical trial sites
|
| 214 |
+
- Trial matching service
|
| 215 |
+
|
| 216 |
+
2. **ALS Therapy Development Institute**: https://www.als.net/clinical-trials/
|
| 217 |
+
- Independent ALS research organization
|
| 218 |
+
- Trial tracker and updates
|
| 219 |
+
|
| 220 |
+
3. **I AM ALS Registry**: https://iamals.org/get-help/clinical-trials/
|
| 221 |
+
- Patient-focused trial information
|
| 222 |
+
- Trial matching assistance
|
| 223 |
+
|
| 224 |
+
**Major ALS Clinical Centers:**
|
| 225 |
+
- Massachusetts General Hospital (Healey Center)
|
| 226 |
+
- Johns Hopkins ALS Clinic
|
| 227 |
+
- Mayo Clinic ALS Center
|
| 228 |
+
- Cleveland Clinic Lou Ruvo Center
|
| 229 |
+
- UCSF ALS Center
|
| 230 |
+
|
| 231 |
+
**Tips for Finding Trials:**
|
| 232 |
+
1. Use condition terms: "ALS", "Amyotrophic Lateral Sclerosis", "Motor Neuron Disease"
|
| 233 |
+
2. Check recruiting AND not-yet-recruiting trials
|
| 234 |
+
3. Consider trials at different phases (1, 2, 3)
|
| 235 |
+
4. Look for platform trials testing multiple drugs
|
| 236 |
+
5. Contact trial coordinators directly for eligibility
|
| 237 |
+
|
| 238 |
+
**Note:** For programmatic access to trial data, use the AACT database server which provides complete ClinicalTrials.gov data without API restrictions.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
return resources
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
mcp.run(transport="stdio")
|
servers/elevenlabs_server.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ElevenLabs MCP Server for Voice Capabilities
|
| 4 |
+
Provides text-to-speech and speech-to-text for ALS Research Agent
|
| 5 |
+
|
| 6 |
+
This server enables voice accessibility features crucial for ALS patients
|
| 7 |
+
who may have limited mobility but retain cognitive function.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from mcp.server.fastmcp import FastMCP
|
| 11 |
+
import httpx
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import base64
|
| 15 |
+
import json
|
| 16 |
+
from typing import Optional, Dict, Any
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
# Add parent directory to path for shared imports
|
| 21 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 22 |
+
|
| 23 |
+
from shared import config
|
| 24 |
+
from shared.http_client import get_http_client
|
| 25 |
+
|
| 26 |
+
# Configure logging
|
| 27 |
+
logging.basicConfig(level=logging.INFO)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# Initialize MCP server
|
| 31 |
+
mcp = FastMCP("elevenlabs-voice")
|
| 32 |
+
|
| 33 |
+
# ElevenLabs API configuration
|
| 34 |
+
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY")
|
| 35 |
+
ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
|
| 36 |
+
|
| 37 |
+
# Default voice settings optimized for clarity (important for ALS patients)
|
| 38 |
+
DEFAULT_VOICE_ID = os.getenv("ELEVENLABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM") # Rachel voice (clear and calm)
|
| 39 |
+
DEFAULT_MODEL = "eleven_turbo_v2_5" # Turbo v2.5 - Fastest model available (40% faster than v2)
|
| 40 |
+
|
| 41 |
+
# Voice settings for accessibility
|
| 42 |
+
VOICE_SETTINGS = {
|
| 43 |
+
"stability": 0.5, # Balanced for speed and clarity (turbo model)
|
| 44 |
+
"similarity_boost": 0.5, # Balanced setting for faster processing
|
| 45 |
+
"style": 0.0, # Neutral style for clarity
|
| 46 |
+
"use_speaker_boost": True # Enhanced clarity
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@mcp.tool()
|
| 51 |
+
async def text_to_speech(
|
| 52 |
+
text: str,
|
| 53 |
+
voice_id: Optional[str] = None,
|
| 54 |
+
output_format: str = "mp3_44100_128",
|
| 55 |
+
speed: float = 1.0
|
| 56 |
+
) -> str:
|
| 57 |
+
"""Convert text to speech optimized for ALS patients.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
text: Text to convert to speech (research findings, paper summaries, etc.)
|
| 61 |
+
voice_id: ElevenLabs voice ID (defaults to clear, calm voice)
|
| 62 |
+
output_format: Audio format (mp3_44100_128, mp3_44100_192, pcm_16000, etc.)
|
| 63 |
+
speed: Speech rate (0.5-2.0, default 1.0 - can be slower for clarity)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Base64 encoded audio data and metadata
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
if not ELEVENLABS_API_KEY:
|
| 70 |
+
return json.dumps({
|
| 71 |
+
"status": "error",
|
| 72 |
+
"error": "ELEVENLABS_API_KEY not configured",
|
| 73 |
+
"message": "Please set your ElevenLabs API key in .env file"
|
| 74 |
+
}, indent=2)
|
| 75 |
+
|
| 76 |
+
# Limit text length to avoid ElevenLabs API timeouts
|
| 77 |
+
# Testing shows 2500 chars is safe, 5000 chars times out
|
| 78 |
+
max_length = 2500
|
| 79 |
+
if len(text) > max_length:
|
| 80 |
+
logger.warning(f"Text truncated from {len(text)} to {max_length} characters to avoid timeout")
|
| 81 |
+
# Try to truncate at a sentence boundary
|
| 82 |
+
truncated = text[:max_length]
|
| 83 |
+
last_period = truncated.rfind('.')
|
| 84 |
+
last_newline = truncated.rfind('\n')
|
| 85 |
+
# Use the latest sentence/paragraph boundary
|
| 86 |
+
boundary = max(last_period, last_newline)
|
| 87 |
+
if boundary > max_length - 500: # If there's a boundary in the last 500 chars
|
| 88 |
+
text = truncated[:boundary + 1]
|
| 89 |
+
else:
|
| 90 |
+
text = truncated + "..."
|
| 91 |
+
|
| 92 |
+
voice_id = voice_id or DEFAULT_VOICE_ID
|
| 93 |
+
|
| 94 |
+
# Prepare the request
|
| 95 |
+
url = f"{ELEVENLABS_API_BASE}/text-to-speech/{voice_id}"
|
| 96 |
+
|
| 97 |
+
headers = {
|
| 98 |
+
"xi-api-key": ELEVENLABS_API_KEY,
|
| 99 |
+
"Content-Type": "application/json"
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Adjust voice settings for speed
|
| 103 |
+
adjusted_settings = VOICE_SETTINGS.copy()
|
| 104 |
+
if speed < 1.0:
|
| 105 |
+
# Slower speech - increase stability for clarity
|
| 106 |
+
adjusted_settings["stability"] = min(1.0, adjusted_settings["stability"] + 0.1)
|
| 107 |
+
|
| 108 |
+
payload = {
|
| 109 |
+
"text": text,
|
| 110 |
+
"model_id": DEFAULT_MODEL,
|
| 111 |
+
"voice_settings": adjusted_settings
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
logger.info(f"Converting text to speech: {len(text)} characters")
|
| 115 |
+
|
| 116 |
+
# Set timeout based on text length (with 2500 char limit, 45s should be enough)
|
| 117 |
+
timeout = 45.0
|
| 118 |
+
logger.info(f"Using timeout of {timeout} seconds")
|
| 119 |
+
|
| 120 |
+
# Use shared HTTP client for connection pooling
|
| 121 |
+
client = get_http_client(timeout=timeout)
|
| 122 |
+
response = await client.post(url, json=payload, headers=headers)
|
| 123 |
+
response.raise_for_status()
|
| 124 |
+
|
| 125 |
+
# Get the audio data
|
| 126 |
+
audio_data = response.content
|
| 127 |
+
|
| 128 |
+
# Encode to base64 for transmission
|
| 129 |
+
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
|
| 130 |
+
|
| 131 |
+
# Return structured response
|
| 132 |
+
result = {
|
| 133 |
+
"status": "success",
|
| 134 |
+
"audio_base64": audio_base64,
|
| 135 |
+
"format": output_format,
|
| 136 |
+
"duration_estimate": len(text) / 150 * 60, # Rough estimate: 150 words/min
|
| 137 |
+
"text_length": len(text),
|
| 138 |
+
"voice_id": voice_id,
|
| 139 |
+
"message": "Audio generated successfully. Use the audio_base64 field to play the audio."
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
logger.info(f"Successfully generated {len(audio_data)} bytes of audio")
|
| 143 |
+
return json.dumps(result, indent=2)
|
| 144 |
+
|
| 145 |
+
except httpx.HTTPStatusError as e:
|
| 146 |
+
logger.error(f"ElevenLabs API error: {e}")
|
| 147 |
+
if e.response.status_code == 401:
|
| 148 |
+
return json.dumps({
|
| 149 |
+
"status": "error",
|
| 150 |
+
"error": "Authentication failed",
|
| 151 |
+
"message": "Check your ELEVENLABS_API_KEY"
|
| 152 |
+
}, indent=2)
|
| 153 |
+
elif e.response.status_code == 429:
|
| 154 |
+
return json.dumps({
|
| 155 |
+
"status": "error",
|
| 156 |
+
"error": "Rate limit exceeded",
|
| 157 |
+
"message": "Please wait before trying again"
|
| 158 |
+
}, indent=2)
|
| 159 |
+
else:
|
| 160 |
+
return json.dumps({
|
| 161 |
+
"status": "error",
|
| 162 |
+
"error": f"API error: {e.response.status_code}",
|
| 163 |
+
"message": str(e)
|
| 164 |
+
}, indent=2)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Unexpected error in text_to_speech: {e}")
|
| 168 |
+
return json.dumps({
|
| 169 |
+
"status": "error",
|
| 170 |
+
"error": "Text-to-speech error",
|
| 171 |
+
"message": str(e)
|
| 172 |
+
}, indent=2)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@mcp.tool()
|
| 176 |
+
async def create_audio_summary(
|
| 177 |
+
content: str,
|
| 178 |
+
summary_type: str = "research",
|
| 179 |
+
max_duration: int = 60
|
| 180 |
+
) -> str:
|
| 181 |
+
"""Create an audio summary of research content optimized for listening.
|
| 182 |
+
|
| 183 |
+
This tool reformats technical content into a more listenable format
|
| 184 |
+
before converting to speech - important for complex medical research.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
content: Research content to summarize (paper abstract, findings, etc.)
|
| 188 |
+
summary_type: Type of summary - "research", "clinical", "patient-friendly"
|
| 189 |
+
max_duration: Target duration in seconds (affects summary length)
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Audio summary with both text and audio versions
|
| 193 |
+
"""
|
| 194 |
+
try:
|
| 195 |
+
# Calculate target word count (assuming 150 words per minute)
|
| 196 |
+
target_words = int((max_duration / 60) * 150)
|
| 197 |
+
|
| 198 |
+
# Process content based on summary type
|
| 199 |
+
if summary_type == "patient-friendly":
|
| 200 |
+
# Simplify medical jargon for patients/families
|
| 201 |
+
processed_text = _simplify_medical_content(content, target_words)
|
| 202 |
+
elif summary_type == "clinical":
|
| 203 |
+
# Focus on clinical relevance
|
| 204 |
+
processed_text = _extract_clinical_relevance(content, target_words)
|
| 205 |
+
else: # research
|
| 206 |
+
# Standard research summary
|
| 207 |
+
processed_text = _create_research_summary(content, target_words)
|
| 208 |
+
|
| 209 |
+
# Add intro for context
|
| 210 |
+
intro = "Here's your audio research summary: "
|
| 211 |
+
final_text = intro + processed_text
|
| 212 |
+
|
| 213 |
+
# Convert to speech
|
| 214 |
+
tts_result = await text_to_speech(
|
| 215 |
+
text=final_text,
|
| 216 |
+
speed=0.95 # Slightly slower for complex content
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Parse the TTS result
|
| 220 |
+
tts_data = json.loads(tts_result)
|
| 221 |
+
|
| 222 |
+
if tts_data.get("status") != "success":
|
| 223 |
+
return tts_result # Return error from TTS
|
| 224 |
+
|
| 225 |
+
# Return enhanced result
|
| 226 |
+
result = {
|
| 227 |
+
"status": "success",
|
| 228 |
+
"audio_base64": tts_data["audio_base64"],
|
| 229 |
+
"text_summary": processed_text,
|
| 230 |
+
"summary_type": summary_type,
|
| 231 |
+
"word_count": len(processed_text.split()),
|
| 232 |
+
"estimated_duration": tts_data["duration_estimate"],
|
| 233 |
+
"format": tts_data["format"],
|
| 234 |
+
"message": f"Audio summary created: {summary_type} format, ~{int(tts_data['duration_estimate'])} seconds"
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
return json.dumps(result, indent=2)
|
| 238 |
+
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.error(f"Error creating audio summary: {e}")
|
| 241 |
+
return json.dumps({
|
| 242 |
+
"status": "error",
|
| 243 |
+
"error": "Summary creation error",
|
| 244 |
+
"message": str(e)
|
| 245 |
+
}, indent=2)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@mcp.tool()
|
| 249 |
+
async def list_voices() -> str:
|
| 250 |
+
"""List available voices optimized for medical/research content.
|
| 251 |
+
|
| 252 |
+
Returns voices suitable for clear pronunciation of medical terminology.
|
| 253 |
+
"""
|
| 254 |
+
try:
|
| 255 |
+
if not ELEVENLABS_API_KEY:
|
| 256 |
+
return json.dumps({
|
| 257 |
+
"status": "error",
|
| 258 |
+
"error": "ELEVENLABS_API_KEY not configured",
|
| 259 |
+
"message": "Please set your ElevenLabs API key in .env file"
|
| 260 |
+
}, indent=2)
|
| 261 |
+
|
| 262 |
+
url = f"{ELEVENLABS_API_BASE}/voices"
|
| 263 |
+
headers = {"xi-api-key": ELEVENLABS_API_KEY}
|
| 264 |
+
|
| 265 |
+
# Use shared HTTP client for connection pooling
|
| 266 |
+
client = get_http_client(timeout=10.0)
|
| 267 |
+
response = await client.get(url, headers=headers)
|
| 268 |
+
response.raise_for_status()
|
| 269 |
+
|
| 270 |
+
data = response.json()
|
| 271 |
+
voices = data.get("voices", [])
|
| 272 |
+
|
| 273 |
+
# Filter and rank voices for medical content
|
| 274 |
+
recommended_voices = []
|
| 275 |
+
for voice in voices:
|
| 276 |
+
# Prefer clear, professional voices
|
| 277 |
+
labels = voice.get("labels", {})
|
| 278 |
+
if any(label in ["clear", "professional", "narration"] for label in labels.values()):
|
| 279 |
+
recommended_voices.append({
|
| 280 |
+
"voice_id": voice["voice_id"],
|
| 281 |
+
"name": voice["name"],
|
| 282 |
+
"preview_url": voice.get("preview_url"),
|
| 283 |
+
"description": voice.get("description", ""),
|
| 284 |
+
"recommended_for": "medical_content"
|
| 285 |
+
})
|
| 286 |
+
|
| 287 |
+
# Add all other voices
|
| 288 |
+
other_voices = []
|
| 289 |
+
for voice in voices:
|
| 290 |
+
if voice["voice_id"] not in [v["voice_id"] for v in recommended_voices]:
|
| 291 |
+
other_voices.append({
|
| 292 |
+
"voice_id": voice["voice_id"],
|
| 293 |
+
"name": voice["name"],
|
| 294 |
+
"preview_url": voice.get("preview_url"),
|
| 295 |
+
"description": voice.get("description", "")
|
| 296 |
+
})
|
| 297 |
+
|
| 298 |
+
result = {
|
| 299 |
+
"status": "success",
|
| 300 |
+
"recommended_voices": recommended_voices[:5], # Top 5 recommended
|
| 301 |
+
"other_voices": other_voices[:10], # Limit for clarity
|
| 302 |
+
"total_voices": len(voices),
|
| 303 |
+
"message": "Recommended voices are optimized for clear medical terminology pronunciation"
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
return json.dumps(result, indent=2)
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
logger.error(f"Error listing voices: {e}")
|
| 310 |
+
return json.dumps({
|
| 311 |
+
"status": "error",
|
| 312 |
+
"error": "Failed to list voices",
|
| 313 |
+
"message": str(e)
|
| 314 |
+
}, indent=2)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@mcp.tool()
|
| 318 |
+
async def pronunciation_guide(
|
| 319 |
+
medical_terms: list[str],
|
| 320 |
+
include_audio: bool = True
|
| 321 |
+
) -> str:
|
| 322 |
+
"""Generate pronunciation guide for medical terms.
|
| 323 |
+
|
| 324 |
+
Critical for ALS patients/caregivers learning about complex terminology.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
medical_terms: List of medical terms to pronounce
|
| 328 |
+
include_audio: Whether to include audio pronunciation
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
Pronunciation guide with optional audio
|
| 332 |
+
"""
|
| 333 |
+
try:
|
| 334 |
+
results = []
|
| 335 |
+
|
| 336 |
+
for term in medical_terms[:10]: # Limit to prevent long processing
|
| 337 |
+
# Create phonetic breakdown
|
| 338 |
+
phonetic = _get_phonetic_spelling(term)
|
| 339 |
+
|
| 340 |
+
# Create pronunciation text
|
| 341 |
+
pronunciation_text = f"{term}. {phonetic}. {term}."
|
| 342 |
+
|
| 343 |
+
result_entry = {
|
| 344 |
+
"term": term,
|
| 345 |
+
"phonetic": phonetic
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
if include_audio:
|
| 349 |
+
# Generate audio
|
| 350 |
+
tts_result = await text_to_speech(
|
| 351 |
+
text=pronunciation_text,
|
| 352 |
+
speed=0.8 # Slower for clarity
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
tts_data = json.loads(tts_result)
|
| 356 |
+
if tts_data.get("status") == "success":
|
| 357 |
+
result_entry["audio_base64"] = tts_data["audio_base64"]
|
| 358 |
+
|
| 359 |
+
results.append(result_entry)
|
| 360 |
+
|
| 361 |
+
return json.dumps({
|
| 362 |
+
"status": "success",
|
| 363 |
+
"pronunciations": results,
|
| 364 |
+
"message": f"Generated pronunciation guide for {len(results)} terms"
|
| 365 |
+
}, indent=2)
|
| 366 |
+
|
| 367 |
+
except Exception as e:
|
| 368 |
+
logger.error(f"Error creating pronunciation guide: {e}")
|
| 369 |
+
return json.dumps({
|
| 370 |
+
"status": "error",
|
| 371 |
+
"error": "Pronunciation guide error",
|
| 372 |
+
"message": str(e)
|
| 373 |
+
}, indent=2)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# Helper functions for content processing
|
| 377 |
+
|
| 378 |
+
def _simplify_medical_content(content: str, target_words: int) -> str:
|
| 379 |
+
"""Simplify medical content for patient understanding."""
|
| 380 |
+
# This would ideally use NLP, but for now, basic simplification
|
| 381 |
+
|
| 382 |
+
# First, strip references for cleaner audio
|
| 383 |
+
content = _strip_references(content)
|
| 384 |
+
|
| 385 |
+
# Common medical term replacements
|
| 386 |
+
replacements = {
|
| 387 |
+
"amyotrophic lateral sclerosis": "ALS or Lou Gehrig's disease",
|
| 388 |
+
"motor neurons": "nerve cells that control muscles",
|
| 389 |
+
"neurodegeneration": "nerve cell damage",
|
| 390 |
+
"pathogenesis": "disease development",
|
| 391 |
+
"etiology": "cause",
|
| 392 |
+
"prognosis": "expected outcome",
|
| 393 |
+
"therapeutic": "treatment",
|
| 394 |
+
"pharmacological": "drug-based",
|
| 395 |
+
"intervention": "treatment",
|
| 396 |
+
"mortality": "death rate",
|
| 397 |
+
"morbidity": "illness rate"
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
simplified = content.lower()
|
| 401 |
+
for term, replacement in replacements.items():
|
| 402 |
+
simplified = simplified.replace(term, replacement)
|
| 403 |
+
|
| 404 |
+
# Truncate to target length
|
| 405 |
+
words = simplified.split()
|
| 406 |
+
if len(words) > target_words:
|
| 407 |
+
words = words[:target_words]
|
| 408 |
+
simplified = " ".join(words) + "..."
|
| 409 |
+
|
| 410 |
+
return simplified.capitalize()
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _extract_clinical_relevance(content: str, target_words: int) -> str:
|
| 414 |
+
"""Extract clinically relevant information."""
|
| 415 |
+
# Focus on treatment, outcomes, and practical implications
|
| 416 |
+
|
| 417 |
+
# First, strip references for cleaner audio
|
| 418 |
+
content = _strip_references(content)
|
| 419 |
+
|
| 420 |
+
# Look for key clinical phrases
|
| 421 |
+
clinical_markers = [
|
| 422 |
+
"treatment", "therapy", "outcome", "survival", "progression",
|
| 423 |
+
"clinical trial", "efficacy", "safety", "adverse", "benefit",
|
| 424 |
+
"patient", "dose", "administration"
|
| 425 |
+
]
|
| 426 |
+
|
| 427 |
+
sentences = content.split(". ")
|
| 428 |
+
relevant_sentences = []
|
| 429 |
+
|
| 430 |
+
for sentence in sentences:
|
| 431 |
+
if any(marker in sentence.lower() for marker in clinical_markers):
|
| 432 |
+
relevant_sentences.append(sentence)
|
| 433 |
+
|
| 434 |
+
result = ". ".join(relevant_sentences)
|
| 435 |
+
|
| 436 |
+
# Truncate to target length
|
| 437 |
+
words = result.split()
|
| 438 |
+
if len(words) > target_words:
|
| 439 |
+
words = words[:target_words]
|
| 440 |
+
result = " ".join(words) + "..."
|
| 441 |
+
|
| 442 |
+
return result
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _create_research_summary(content: str, target_words: int) -> str:
|
| 446 |
+
"""Create a research-focused summary."""
|
| 447 |
+
# Extract key findings and implications
|
| 448 |
+
|
| 449 |
+
# First, strip references section if present
|
| 450 |
+
content = _strip_references(content)
|
| 451 |
+
|
| 452 |
+
# Simply truncate for now (could be enhanced with NLP)
|
| 453 |
+
words = content.split()
|
| 454 |
+
if len(words) > target_words:
|
| 455 |
+
words = words[:target_words]
|
| 456 |
+
content = " ".join(words) + "..."
|
| 457 |
+
|
| 458 |
+
return content
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def _strip_references(content: str) -> str:
|
| 462 |
+
"""Remove references section and citations from content for audio reading."""
|
| 463 |
+
import re
|
| 464 |
+
|
| 465 |
+
# Extract only synthesis content if it's marked
|
| 466 |
+
synthesis_match = re.search(r'✅\s*SYNTHESIS:?\s*(.*?)(?=##?\s*References|##?\s*Bibliography|$)',
|
| 467 |
+
content, flags=re.DOTALL | re.IGNORECASE)
|
| 468 |
+
if synthesis_match:
|
| 469 |
+
content = synthesis_match.group(1)
|
| 470 |
+
|
| 471 |
+
# Remove References section (multiple possible formats)
|
| 472 |
+
patterns_to_remove = [
|
| 473 |
+
r'##?\s*References.*$', # ## References or # References to end
|
| 474 |
+
r'##?\s*Bibliography.*$', # Bibliography section
|
| 475 |
+
r'##?\s*Citations.*$', # Citations section
|
| 476 |
+
r'##?\s*Works Cited.*$', # Works Cited section
|
| 477 |
+
r'##?\s*Key References.*$', # Key References section
|
| 478 |
+
]
|
| 479 |
+
|
| 480 |
+
for pattern in patterns_to_remove:
|
| 481 |
+
content = re.sub(pattern, '', content, flags=re.DOTALL | re.IGNORECASE)
|
| 482 |
+
|
| 483 |
+
# Remove inline citations like [1], [2,3], [PMID: 12345678]
|
| 484 |
+
content = re.sub(r'\[[\d,\s]+\]', '', content) # [1], [2,3], etc.
|
| 485 |
+
content = re.sub(r'\[PMID:\s*\d+\]', '', content) # [PMID: 12345678]
|
| 486 |
+
content = re.sub(r'\[NCT\d+\]', '', content) # [NCT12345678]
|
| 487 |
+
|
| 488 |
+
# Remove URLs for cleaner audio
|
| 489 |
+
content = re.sub(r'https?://[^\s\)]+', '', content)
|
| 490 |
+
content = re.sub(r'www\.[^\s\)]+', '', content)
|
| 491 |
+
|
| 492 |
+
# Remove PMID/DOI/NCT references
|
| 493 |
+
content = re.sub(r'PMID:\s*\d+', '', content)
|
| 494 |
+
content = re.sub(r'DOI:\s*[^\s]+', '', content)
|
| 495 |
+
content = re.sub(r'NCT\d{8}', '', content)
|
| 496 |
+
|
| 497 |
+
# Remove markdown formatting that sounds awkward in audio
|
| 498 |
+
content = re.sub(r'\*\*(.*?)\*\*', r'\1', content) # Remove bold
|
| 499 |
+
content = re.sub(r'\*(.*?)\*', r'\1', content) # Remove italic
|
| 500 |
+
content = re.sub(r'`(.*?)`', r'\1', content) # Remove inline code
|
| 501 |
+
content = re.sub(r'#{1,6}\s*', '', content) # Remove headers
|
| 502 |
+
content = re.sub(r'^[-*+]\s+', '', content, flags=re.MULTILINE) # Remove bullet points
|
| 503 |
+
content = re.sub(r'^\d+\.\s+', '', content, flags=re.MULTILINE) # Remove numbered lists
|
| 504 |
+
|
| 505 |
+
# Replace markdown links with just the text
|
| 506 |
+
content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content)
|
| 507 |
+
|
| 508 |
+
# Clean up extra whitespace
|
| 509 |
+
content = re.sub(r'\s+', ' ', content)
|
| 510 |
+
content = re.sub(r'\n{3,}', '\n\n', content)
|
| 511 |
+
|
| 512 |
+
return content.strip()
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _get_phonetic_spelling(term: str) -> str:
|
| 516 |
+
"""Generate phonetic spelling for medical terms."""
|
| 517 |
+
# Basic phonetic rules for medical terms
|
| 518 |
+
# This could be enhanced with a medical pronunciation dictionary
|
| 519 |
+
|
| 520 |
+
phonetic_map = {
|
| 521 |
+
"amyotrophic": "AM-ee-oh-TROH-fik",
|
| 522 |
+
"lateral": "LAT-er-al",
|
| 523 |
+
"sclerosis": "skleh-ROH-sis",
|
| 524 |
+
"tdp-43": "T-D-P forty-three",
|
| 525 |
+
"riluzole": "RIL-you-zole",
|
| 526 |
+
"edaravone": "ed-AR-a-vone",
|
| 527 |
+
"tofersen": "TOE-fer-sen",
|
| 528 |
+
"neurofilament": "NUR-oh-FIL-a-ment",
|
| 529 |
+
"astrocyte": "AS-tro-site",
|
| 530 |
+
"oligodendrocyte": "oh-li-go-DEN-dro-site"
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
term_lower = term.lower()
|
| 534 |
+
if term_lower in phonetic_map:
|
| 535 |
+
return phonetic_map[term_lower]
|
| 536 |
+
|
| 537 |
+
# Basic syllable breakdown for unknown terms
|
| 538 |
+
# This is very simplified and could be improved
|
| 539 |
+
syllables = []
|
| 540 |
+
current = ""
|
| 541 |
+
for char in term:
|
| 542 |
+
if char in "aeiouAEIOU" and current:
|
| 543 |
+
syllables.append(current + char)
|
| 544 |
+
current = ""
|
| 545 |
+
else:
|
| 546 |
+
current += char
|
| 547 |
+
if current:
|
| 548 |
+
syllables.append(current)
|
| 549 |
+
|
| 550 |
+
return "-".join(syllables).upper()
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
if __name__ == "__main__":
|
| 554 |
+
# Check for API key
|
| 555 |
+
if not ELEVENLABS_API_KEY:
|
| 556 |
+
logger.warning("ELEVENLABS_API_KEY not set in environment")
|
| 557 |
+
logger.warning("Voice features will be limited without API key")
|
| 558 |
+
logger.info("Get your API key at: https://elevenlabs.io")
|
| 559 |
+
|
| 560 |
+
# Run the MCP server
|
| 561 |
+
mcp.run(transport="stdio")
|
servers/fetch_server.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# fetch_server.py
|
| 2 |
+
from mcp.server.fastmcp import FastMCP
|
| 3 |
+
import httpx
|
| 4 |
+
from bs4 import BeautifulSoup
|
| 5 |
+
from urllib.parse import urlparse
|
| 6 |
+
import logging
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add parent directory to path for shared imports
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 12 |
+
|
| 13 |
+
from shared import (
|
| 14 |
+
config,
|
| 15 |
+
clean_whitespace,
|
| 16 |
+
truncate_text
|
| 17 |
+
)
|
| 18 |
+
from shared.http_client import get_http_client
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
mcp = FastMCP("fetch-server")
|
| 25 |
+
|
| 26 |
+
def validate_url(url: str) -> tuple[bool, str]:
|
| 27 |
+
"""Validate URL for security concerns. Returns (is_valid, error_message)"""
|
| 28 |
+
try:
|
| 29 |
+
parsed = urlparse(url)
|
| 30 |
+
|
| 31 |
+
# Check scheme using shared config
|
| 32 |
+
if parsed.scheme not in config.security.allowed_schemes:
|
| 33 |
+
return False, f"Invalid URL scheme. Only {', '.join(config.security.allowed_schemes)} are allowed."
|
| 34 |
+
|
| 35 |
+
# Check for blocked hosts (SSRF protection)
|
| 36 |
+
hostname = parsed.hostname
|
| 37 |
+
if not hostname:
|
| 38 |
+
return False, "Invalid URL: no hostname found."
|
| 39 |
+
|
| 40 |
+
# Use shared security config for SSRF checks
|
| 41 |
+
if config.security.is_private_ip(hostname):
|
| 42 |
+
return False, "Access to localhost/private IPs is not allowed."
|
| 43 |
+
|
| 44 |
+
return True, ""
|
| 45 |
+
|
| 46 |
+
except Exception as e:
|
| 47 |
+
return False, f"Invalid URL: {str(e)}"
|
| 48 |
+
|
| 49 |
+
def parse_clinical_trial_page(soup: BeautifulSoup, url: str) -> str:
|
| 50 |
+
"""Parse ClinicalTrials.gov trial detail page for structured data."""
|
| 51 |
+
# Check if this is a ClinicalTrials.gov page
|
| 52 |
+
if "clinicaltrials.gov" not in url.lower():
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
# Extract NCT ID from URL
|
| 56 |
+
import re
|
| 57 |
+
nct_match = re.search(r'NCT\d{8}', url)
|
| 58 |
+
nct_id = nct_match.group() if nct_match else "Unknown"
|
| 59 |
+
|
| 60 |
+
# Try to extract key trial information
|
| 61 |
+
trial_info = []
|
| 62 |
+
trial_info.append(f"**NCT ID:** {nct_id}")
|
| 63 |
+
trial_info.append(f"**URL:** {url}")
|
| 64 |
+
|
| 65 |
+
# Look for title
|
| 66 |
+
title = soup.find('h1')
|
| 67 |
+
if title:
|
| 68 |
+
trial_info.append(f"**Title:** {title.get_text(strip=True)}")
|
| 69 |
+
|
| 70 |
+
# Look for status (various patterns)
|
| 71 |
+
status_patterns = [
|
| 72 |
+
soup.find('span', string=re.compile(r'Recruiting|Active|Completed|Enrolling', re.I)),
|
| 73 |
+
soup.find('div', string=re.compile(r'Recruitment Status', re.I))
|
| 74 |
+
]
|
| 75 |
+
for pattern in status_patterns:
|
| 76 |
+
if pattern:
|
| 77 |
+
status_text = pattern.get_text(strip=True) if hasattr(pattern, 'get_text') else str(pattern)
|
| 78 |
+
trial_info.append(f"**Status:** {status_text}")
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
# Look for study description
|
| 82 |
+
desc_section = soup.find('div', {'class': re.compile('description', re.I)})
|
| 83 |
+
if desc_section:
|
| 84 |
+
desc_text = desc_section.get_text(strip=True)[:500]
|
| 85 |
+
trial_info.append(f"**Description:** {desc_text}...")
|
| 86 |
+
|
| 87 |
+
# Look for conditions
|
| 88 |
+
conditions = soup.find_all(string=re.compile(r'Condition', re.I))
|
| 89 |
+
if conditions:
|
| 90 |
+
for cond in conditions[:1]: # Just first mention
|
| 91 |
+
parent = cond.parent
|
| 92 |
+
if parent:
|
| 93 |
+
trial_info.append(f"**Condition:** {parent.get_text(strip=True)[:200]}")
|
| 94 |
+
break
|
| 95 |
+
|
| 96 |
+
# Look for interventions
|
| 97 |
+
interventions = soup.find_all(string=re.compile(r'Intervention', re.I))
|
| 98 |
+
if interventions:
|
| 99 |
+
for inter in interventions[:1]: # Just first mention
|
| 100 |
+
parent = inter.parent
|
| 101 |
+
if parent:
|
| 102 |
+
trial_info.append(f"**Intervention:** {parent.get_text(strip=True)[:200]}")
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
# Look for sponsor
|
| 106 |
+
sponsor = soup.find(string=re.compile(r'Sponsor', re.I))
|
| 107 |
+
if sponsor and sponsor.parent:
|
| 108 |
+
trial_info.append(f"**Sponsor:** {sponsor.parent.get_text(strip=True)[:100]}")
|
| 109 |
+
|
| 110 |
+
# Locations/Sites
|
| 111 |
+
locations = soup.find_all(string=re.compile(r'Location|Site', re.I))
|
| 112 |
+
if locations:
|
| 113 |
+
location_texts = []
|
| 114 |
+
for loc in locations[:3]: # First 3 locations
|
| 115 |
+
if loc.parent:
|
| 116 |
+
location_texts.append(loc.parent.get_text(strip=True)[:50])
|
| 117 |
+
if location_texts:
|
| 118 |
+
trial_info.append(f"**Locations:** {', '.join(location_texts)}")
|
| 119 |
+
|
| 120 |
+
if len(trial_info) > 2: # If we found meaningful data
|
| 121 |
+
return "\n\n".join(trial_info) + "\n\n**Note:** This is extracted from the trial webpage. Some details may be incomplete due to page structure variations."
|
| 122 |
+
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
@mcp.tool()
|
| 126 |
+
async def fetch_url(url: str, extract_text_only: bool = True) -> str:
|
| 127 |
+
"""Fetch content from a URL (paper abstract page, news article, etc.).
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
url: URL to fetch
|
| 131 |
+
extract_text_only: Extract only main text content (default: True)
|
| 132 |
+
"""
|
| 133 |
+
try:
|
| 134 |
+
logger.info(f"Fetching URL: {url}")
|
| 135 |
+
|
| 136 |
+
# Validate URL
|
| 137 |
+
is_valid, error_msg = validate_url(url)
|
| 138 |
+
if not is_valid:
|
| 139 |
+
logger.warning(f"URL validation failed: {error_msg}")
|
| 140 |
+
return f"Error: {error_msg}"
|
| 141 |
+
|
| 142 |
+
# Use shared HTTP client for connection pooling
|
| 143 |
+
client = get_http_client(timeout=config.api.timeout)
|
| 144 |
+
response = await client.get(url, headers={
|
| 145 |
+
"User-Agent": config.api.user_agent
|
| 146 |
+
})
|
| 147 |
+
response.raise_for_status()
|
| 148 |
+
|
| 149 |
+
# Check content size using shared config
|
| 150 |
+
content_length = response.headers.get('content-length')
|
| 151 |
+
if content_length and int(content_length) > config.content_limits.max_content_size:
|
| 152 |
+
logger.warning(f"Content too large: {content_length} bytes")
|
| 153 |
+
return f"Error: Content size ({content_length} bytes) exceeds maximum allowed size of {config.content_limits.max_content_size} bytes"
|
| 154 |
+
|
| 155 |
+
# Check actual content size
|
| 156 |
+
if len(response.content) > config.content_limits.max_content_size:
|
| 157 |
+
logger.warning(f"Content too large: {len(response.content)} bytes")
|
| 158 |
+
return f"Error: Content size exceeds maximum allowed size of {config.content_limits.max_content_size} bytes"
|
| 159 |
+
|
| 160 |
+
if extract_text_only:
|
| 161 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
| 162 |
+
|
| 163 |
+
# Check if this is a clinical trial page and try enhanced parsing
|
| 164 |
+
trial_data = parse_clinical_trial_page(soup, url)
|
| 165 |
+
if trial_data:
|
| 166 |
+
logger.info(f"Successfully parsed clinical trial page: {url}")
|
| 167 |
+
return trial_data
|
| 168 |
+
|
| 169 |
+
# Otherwise, do standard text extraction
|
| 170 |
+
# Remove script and style elements
|
| 171 |
+
for script in soup(["script", "style", "meta", "link"]):
|
| 172 |
+
script.decompose()
|
| 173 |
+
|
| 174 |
+
# Get text
|
| 175 |
+
text = soup.get_text()
|
| 176 |
+
|
| 177 |
+
# Clean up whitespace using shared utility
|
| 178 |
+
text = clean_whitespace(text)
|
| 179 |
+
|
| 180 |
+
# Limit to reasonable size for LLM context using shared utility
|
| 181 |
+
text = truncate_text(text, max_chars=config.content_limits.max_text_chars)
|
| 182 |
+
|
| 183 |
+
logger.info(f"Successfully fetched and extracted text from {url}")
|
| 184 |
+
return text
|
| 185 |
+
else:
|
| 186 |
+
# Return raw HTML, but still limit size using shared utility
|
| 187 |
+
html = truncate_text(response.text, max_chars=config.content_limits.max_text_chars)
|
| 188 |
+
|
| 189 |
+
logger.info(f"Successfully fetched raw HTML from {url}")
|
| 190 |
+
return html
|
| 191 |
+
|
| 192 |
+
except httpx.TimeoutException:
|
| 193 |
+
logger.error(f"Request to {url} timed out")
|
| 194 |
+
return f"Error: Request timed out after {config.api.timeout} seconds"
|
| 195 |
+
except httpx.HTTPStatusError as e:
|
| 196 |
+
logger.error(f"HTTP error fetching {url}: {e}")
|
| 197 |
+
return f"Error: HTTP {e.response.status_code} - {e.response.reason_phrase}"
|
| 198 |
+
except httpx.RequestError as e:
|
| 199 |
+
logger.error(f"Request error fetching {url}: {e}")
|
| 200 |
+
return f"Error: Failed to fetch URL - {str(e)}"
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"Unexpected error fetching {url}: {e}")
|
| 203 |
+
return f"Error: {str(e)}"
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
mcp.run(transport="stdio")
|
servers/llamaindex_server.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LlamaIndex MCP Server for Research Memory and RAG
|
| 4 |
+
Provides persistent memory and semantic search capabilities for ALS Research Agent
|
| 5 |
+
|
| 6 |
+
This server enables the agent to remember all research it encounters, build
|
| 7 |
+
knowledge over time, and discover connections between papers.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from mcp.server.fastmcp import FastMCP
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
import hashlib
|
| 15 |
+
from typing import Optional, List, Dict, Any
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import sys
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
import asyncio
|
| 20 |
+
|
| 21 |
+
# Add parent directory to path for shared imports
|
| 22 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 23 |
+
|
| 24 |
+
from shared import config
|
| 25 |
+
|
| 26 |
+
# Configure logging
|
| 27 |
+
logging.basicConfig(level=logging.INFO)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# Initialize MCP server
|
| 31 |
+
mcp = FastMCP("llamaindex-rag")
|
| 32 |
+
|
| 33 |
+
# Import LlamaIndex components (will be installed)
|
| 34 |
+
try:
|
| 35 |
+
from llama_index.core import (
|
| 36 |
+
VectorStoreIndex,
|
| 37 |
+
Document,
|
| 38 |
+
StorageContext,
|
| 39 |
+
Settings,
|
| 40 |
+
load_index_from_storage
|
| 41 |
+
)
|
| 42 |
+
from llama_index.core.node_parser import SentenceSplitter
|
| 43 |
+
from llama_index.vector_stores.chroma import ChromaVectorStore
|
| 44 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 45 |
+
import chromadb
|
| 46 |
+
LLAMAINDEX_AVAILABLE = True
|
| 47 |
+
except ImportError:
|
| 48 |
+
LLAMAINDEX_AVAILABLE = False
|
| 49 |
+
logger.warning("LlamaIndex not installed. Install with: pip install llama-index chromadb sentence-transformers")
|
| 50 |
+
|
| 51 |
+
# Configuration
|
| 52 |
+
CHROMA_DB_PATH = os.getenv("CHROMA_DB_PATH", "./chroma_db")
|
| 53 |
+
EMBED_MODEL = os.getenv("LLAMAINDEX_EMBED_MODEL", "dmis-lab/biobert-base-cased-v1.2")
|
| 54 |
+
CHUNK_SIZE = int(os.getenv("LLAMAINDEX_CHUNK_SIZE", "1024"))
|
| 55 |
+
CHUNK_OVERLAP = int(os.getenv("LLAMAINDEX_CHUNK_OVERLAP", "200"))
|
| 56 |
+
|
| 57 |
+
# Global index storage
|
| 58 |
+
research_index = None
|
| 59 |
+
chroma_client = None
|
| 60 |
+
collection = None
|
| 61 |
+
papers_metadata = {} # Store paper metadata separately
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ResearchMemoryManager:
|
| 65 |
+
"""Manages persistent research memory using LlamaIndex and ChromaDB"""
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
self.index = None
|
| 69 |
+
self.chroma_client = None
|
| 70 |
+
self.collection = None
|
| 71 |
+
self.metadata_path = Path(CHROMA_DB_PATH) / "metadata.json"
|
| 72 |
+
|
| 73 |
+
if LLAMAINDEX_AVAILABLE:
|
| 74 |
+
self._initialize_index()
|
| 75 |
+
|
| 76 |
+
def _initialize_index(self):
|
| 77 |
+
"""Initialize or load existing index from ChromaDB"""
|
| 78 |
+
try:
|
| 79 |
+
# Create directory if it doesn't exist
|
| 80 |
+
Path(CHROMA_DB_PATH).mkdir(parents=True, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
# Initialize ChromaDB client
|
| 83 |
+
self.chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
| 84 |
+
|
| 85 |
+
# Get or create collection
|
| 86 |
+
try:
|
| 87 |
+
self.collection = self.chroma_client.get_collection("als_research")
|
| 88 |
+
logger.info(f"Loaded existing ChromaDB collection with {self.collection.count()} papers")
|
| 89 |
+
except:
|
| 90 |
+
self.collection = self.chroma_client.create_collection("als_research")
|
| 91 |
+
logger.info("Created new ChromaDB collection")
|
| 92 |
+
|
| 93 |
+
# Initialize embedding model - prefer biomedical models
|
| 94 |
+
try:
|
| 95 |
+
embed_model = HuggingFaceEmbedding(
|
| 96 |
+
model_name=EMBED_MODEL,
|
| 97 |
+
cache_folder="./embed_cache"
|
| 98 |
+
)
|
| 99 |
+
logger.info(f"Using embedding model: {EMBED_MODEL}")
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.warning(f"Failed to load {EMBED_MODEL}, falling back to default")
|
| 102 |
+
embed_model = HuggingFaceEmbedding(
|
| 103 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 104 |
+
cache_folder="./embed_cache"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Configure settings
|
| 108 |
+
Settings.embed_model = embed_model
|
| 109 |
+
Settings.chunk_size = CHUNK_SIZE
|
| 110 |
+
Settings.chunk_overlap = CHUNK_OVERLAP
|
| 111 |
+
|
| 112 |
+
# Initialize vector store
|
| 113 |
+
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
| 114 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
| 115 |
+
|
| 116 |
+
# Create or load index
|
| 117 |
+
if self.collection.count() > 0:
|
| 118 |
+
# Load existing index
|
| 119 |
+
self.index = VectorStoreIndex.from_vector_store(
|
| 120 |
+
vector_store,
|
| 121 |
+
storage_context=storage_context
|
| 122 |
+
)
|
| 123 |
+
logger.info("Loaded existing vector index")
|
| 124 |
+
else:
|
| 125 |
+
# Create new index
|
| 126 |
+
self.index = VectorStoreIndex(
|
| 127 |
+
[],
|
| 128 |
+
storage_context=storage_context
|
| 129 |
+
)
|
| 130 |
+
logger.info("Created new vector index")
|
| 131 |
+
|
| 132 |
+
# Load metadata
|
| 133 |
+
self._load_metadata()
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Failed to initialize index: {e}")
|
| 137 |
+
self.index = None
|
| 138 |
+
|
| 139 |
+
def _load_metadata(self):
|
| 140 |
+
"""Load paper metadata from disk"""
|
| 141 |
+
global papers_metadata
|
| 142 |
+
if self.metadata_path.exists():
|
| 143 |
+
try:
|
| 144 |
+
with open(self.metadata_path, 'r') as f:
|
| 145 |
+
papers_metadata = json.load(f)
|
| 146 |
+
logger.info(f"Loaded metadata for {len(papers_metadata)} papers")
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Failed to load metadata: {e}")
|
| 149 |
+
papers_metadata = {}
|
| 150 |
+
else:
|
| 151 |
+
papers_metadata = {}
|
| 152 |
+
|
| 153 |
+
def _save_metadata(self):
|
| 154 |
+
"""Save paper metadata to disk"""
|
| 155 |
+
try:
|
| 156 |
+
with open(self.metadata_path, 'w') as f:
|
| 157 |
+
json.dump(papers_metadata, f, indent=2, default=str)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Failed to save metadata: {e}")
|
| 160 |
+
|
| 161 |
+
def generate_paper_id(self, title: str, doi: Optional[str] = None) -> str:
|
| 162 |
+
"""Generate unique ID for a paper"""
|
| 163 |
+
if doi:
|
| 164 |
+
return hashlib.md5(doi.encode()).hexdigest()
|
| 165 |
+
return hashlib.md5(title.lower().encode()).hexdigest()
|
| 166 |
+
|
| 167 |
+
async def index_paper(
|
| 168 |
+
self,
|
| 169 |
+
title: str,
|
| 170 |
+
abstract: str,
|
| 171 |
+
authors: List[str],
|
| 172 |
+
doi: Optional[str] = None,
|
| 173 |
+
journal: Optional[str] = None,
|
| 174 |
+
year: Optional[int] = None,
|
| 175 |
+
findings: Optional[str] = None,
|
| 176 |
+
url: Optional[str] = None,
|
| 177 |
+
paper_type: str = "research"
|
| 178 |
+
) -> Dict[str, Any]:
|
| 179 |
+
"""Index a research paper with metadata"""
|
| 180 |
+
|
| 181 |
+
if not self.index:
|
| 182 |
+
return {"status": "error", "message": "Index not initialized"}
|
| 183 |
+
|
| 184 |
+
# Generate unique ID
|
| 185 |
+
paper_id = self.generate_paper_id(title, doi)
|
| 186 |
+
|
| 187 |
+
# Check if already indexed
|
| 188 |
+
if paper_id in papers_metadata:
|
| 189 |
+
return {
|
| 190 |
+
"status": "already_indexed",
|
| 191 |
+
"paper_id": paper_id,
|
| 192 |
+
"title": title,
|
| 193 |
+
"message": "Paper already in research memory"
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# Prepare document text
|
| 197 |
+
doc_text = f"Title: {title}\n\n"
|
| 198 |
+
doc_text += f"Authors: {', '.join(authors)}\n\n"
|
| 199 |
+
|
| 200 |
+
if journal:
|
| 201 |
+
doc_text += f"Journal: {journal}\n"
|
| 202 |
+
if year:
|
| 203 |
+
doc_text += f"Year: {year}\n\n"
|
| 204 |
+
|
| 205 |
+
doc_text += f"Abstract: {abstract}\n\n"
|
| 206 |
+
|
| 207 |
+
if findings:
|
| 208 |
+
doc_text += f"Key Findings: {findings}\n\n"
|
| 209 |
+
|
| 210 |
+
# Create document with metadata (ChromaDB only accepts strings, not lists)
|
| 211 |
+
metadata = {
|
| 212 |
+
"paper_id": paper_id,
|
| 213 |
+
"title": title,
|
| 214 |
+
"authors": ", ".join(authors) if authors else "", # Convert list to string
|
| 215 |
+
"doi": doi,
|
| 216 |
+
"journal": journal,
|
| 217 |
+
"year": year,
|
| 218 |
+
"url": url,
|
| 219 |
+
"paper_type": paper_type,
|
| 220 |
+
"indexed_at": datetime.now().isoformat()
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
document = Document(
|
| 224 |
+
text=doc_text,
|
| 225 |
+
metadata=metadata
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
# Add to index
|
| 230 |
+
self.index.insert(document)
|
| 231 |
+
|
| 232 |
+
# Store metadata
|
| 233 |
+
papers_metadata[paper_id] = metadata
|
| 234 |
+
self._save_metadata()
|
| 235 |
+
|
| 236 |
+
logger.info(f"Indexed paper: {title}")
|
| 237 |
+
|
| 238 |
+
return {
|
| 239 |
+
"status": "success",
|
| 240 |
+
"paper_id": paper_id,
|
| 241 |
+
"title": title,
|
| 242 |
+
"message": f"Successfully indexed paper into research memory"
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
logger.error(f"Failed to index paper: {e}")
|
| 247 |
+
return {
|
| 248 |
+
"status": "error",
|
| 249 |
+
"message": f"Failed to index paper: {str(e)}"
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
async def search_similar(
|
| 253 |
+
self,
|
| 254 |
+
query: str,
|
| 255 |
+
top_k: int = 5,
|
| 256 |
+
include_scores: bool = True
|
| 257 |
+
) -> List[Dict[str, Any]]:
|
| 258 |
+
"""Search for similar research in memory"""
|
| 259 |
+
|
| 260 |
+
if not self.index:
|
| 261 |
+
return []
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
# Use retriever for direct vector search (no LLM needed)
|
| 265 |
+
retriever = self.index.as_retriever(
|
| 266 |
+
similarity_top_k=top_k
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Search using retriever
|
| 270 |
+
nodes = retriever.retrieve(query)
|
| 271 |
+
|
| 272 |
+
results = []
|
| 273 |
+
for node in nodes:
|
| 274 |
+
result = {
|
| 275 |
+
"text": node.text[:500] + "..." if len(node.text) > 500 else node.text,
|
| 276 |
+
"metadata": node.metadata,
|
| 277 |
+
"score": node.score if include_scores else None
|
| 278 |
+
}
|
| 279 |
+
results.append(result)
|
| 280 |
+
|
| 281 |
+
return results
|
| 282 |
+
|
| 283 |
+
except Exception as e:
|
| 284 |
+
logger.error(f"Search failed: {e}")
|
| 285 |
+
return []
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# Global manager - will be initialized on first use
|
| 289 |
+
memory_manager = None
|
| 290 |
+
_initialization_lock = asyncio.Lock() # Prevent race conditions during initialization
|
| 291 |
+
_initialization_started = False
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
async def ensure_initialized():
|
| 295 |
+
"""Ensure the memory manager is initialized (lazy initialization)."""
|
| 296 |
+
global memory_manager, _initialization_started
|
| 297 |
+
|
| 298 |
+
# Quick check without lock
|
| 299 |
+
if memory_manager is not None:
|
| 300 |
+
return True
|
| 301 |
+
|
| 302 |
+
# Thread-safe initialization
|
| 303 |
+
async with _initialization_lock:
|
| 304 |
+
# Double-check after acquiring lock
|
| 305 |
+
if memory_manager is not None:
|
| 306 |
+
return True
|
| 307 |
+
|
| 308 |
+
if not LLAMAINDEX_AVAILABLE:
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
if _initialization_started:
|
| 312 |
+
# Another thread is initializing, wait for it
|
| 313 |
+
while memory_manager is None and _initialization_started:
|
| 314 |
+
await asyncio.sleep(0.1)
|
| 315 |
+
return memory_manager is not None
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
_initialization_started = True
|
| 319 |
+
logger.info("🔄 Initializing LlamaIndex RAG system (this may take 20-30 seconds)...")
|
| 320 |
+
logger.info(" Loading BioBERT embedding model...")
|
| 321 |
+
|
| 322 |
+
# Initialize the memory manager
|
| 323 |
+
memory_manager = ResearchMemoryManager()
|
| 324 |
+
|
| 325 |
+
logger.info("✅ LlamaIndex RAG system initialized successfully")
|
| 326 |
+
return True
|
| 327 |
+
|
| 328 |
+
except Exception as e:
|
| 329 |
+
logger.error(f"❌ Failed to initialize LlamaIndex: {e}")
|
| 330 |
+
_initialization_started = False
|
| 331 |
+
return False
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@mcp.tool()
|
| 335 |
+
async def index_paper(
|
| 336 |
+
title: str,
|
| 337 |
+
abstract: str,
|
| 338 |
+
authors: str,
|
| 339 |
+
doi: Optional[str] = None,
|
| 340 |
+
journal: Optional[str] = None,
|
| 341 |
+
year: Optional[int] = None,
|
| 342 |
+
findings: Optional[str] = None,
|
| 343 |
+
url: Optional[str] = None
|
| 344 |
+
) -> str:
|
| 345 |
+
"""Index a research paper into persistent memory for future retrieval.
|
| 346 |
+
|
| 347 |
+
The agent's research memory persists across sessions, building knowledge over time.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
title: Paper title
|
| 351 |
+
abstract: Paper abstract or summary
|
| 352 |
+
authors: Comma-separated list of authors
|
| 353 |
+
doi: Digital Object Identifier (optional)
|
| 354 |
+
journal: Journal or preprint server name (optional)
|
| 355 |
+
year: Publication year (optional)
|
| 356 |
+
findings: Key findings or implications (optional)
|
| 357 |
+
url: URL to paper (optional)
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
Status of indexing operation
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
if not LLAMAINDEX_AVAILABLE:
|
| 364 |
+
return json.dumps({
|
| 365 |
+
"status": "error",
|
| 366 |
+
"error": "LlamaIndex not installed",
|
| 367 |
+
"message": "Install with: pip install llama-index chromadb sentence-transformers"
|
| 368 |
+
}, indent=2)
|
| 369 |
+
|
| 370 |
+
# Lazy initialization on first use
|
| 371 |
+
if not await ensure_initialized():
|
| 372 |
+
return json.dumps({
|
| 373 |
+
"status": "error",
|
| 374 |
+
"error": "Memory manager initialization failed",
|
| 375 |
+
"message": "Check LlamaIndex configuration and dependencies"
|
| 376 |
+
}, indent=2)
|
| 377 |
+
|
| 378 |
+
try:
|
| 379 |
+
# Parse authors
|
| 380 |
+
authors_list = [a.strip() for a in authors.split(",")]
|
| 381 |
+
|
| 382 |
+
result = await memory_manager.index_paper(
|
| 383 |
+
title=title,
|
| 384 |
+
abstract=abstract,
|
| 385 |
+
authors=authors_list,
|
| 386 |
+
doi=doi,
|
| 387 |
+
journal=journal,
|
| 388 |
+
year=year,
|
| 389 |
+
findings=findings,
|
| 390 |
+
url=url
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if result["status"] == "success":
|
| 394 |
+
return json.dumps({
|
| 395 |
+
"status": "success",
|
| 396 |
+
"paper_id": result["paper_id"],
|
| 397 |
+
"title": result["title"],
|
| 398 |
+
"message": f"✅ Indexed into research memory. Total papers: {len(papers_metadata)}",
|
| 399 |
+
"total_papers_indexed": len(papers_metadata)
|
| 400 |
+
}, indent=2)
|
| 401 |
+
|
| 402 |
+
elif result["status"] == "already_indexed":
|
| 403 |
+
return json.dumps({
|
| 404 |
+
"status": "already_indexed",
|
| 405 |
+
"paper_id": result["paper_id"],
|
| 406 |
+
"title": result["title"],
|
| 407 |
+
"message": "ℹ️ Paper already in research memory",
|
| 408 |
+
"total_papers_indexed": len(papers_metadata)
|
| 409 |
+
}, indent=2)
|
| 410 |
+
|
| 411 |
+
else:
|
| 412 |
+
return json.dumps({"status": "error", "error": "Indexing failed", "message": result.get("message", "Unknown error")}, indent=2)
|
| 413 |
+
|
| 414 |
+
except Exception as e:
|
| 415 |
+
logger.error(f"Error indexing paper: {e}")
|
| 416 |
+
return json.dumps({"status": "error", "error": "Indexing error", "message": str(e)}, indent=2)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
@mcp.tool()
|
| 420 |
+
async def semantic_search(
|
| 421 |
+
query: str,
|
| 422 |
+
max_results: int = 5
|
| 423 |
+
) -> str:
|
| 424 |
+
"""Search research memory using semantic similarity.
|
| 425 |
+
|
| 426 |
+
Finds papers similar to your query across all indexed research,
|
| 427 |
+
even if they don't contain exact keywords.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
query: Search query (can be a question, topic, or paper abstract)
|
| 431 |
+
max_results: Maximum number of results to return (default: 5)
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
Similar papers from research memory
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
if not LLAMAINDEX_AVAILABLE:
|
| 438 |
+
return json.dumps({
|
| 439 |
+
"status": "error",
|
| 440 |
+
"error": "LlamaIndex not installed",
|
| 441 |
+
"message": "Install with: pip install llama-index chromadb sentence-transformers"
|
| 442 |
+
}, indent=2)
|
| 443 |
+
|
| 444 |
+
# Lazy initialization on first use
|
| 445 |
+
if not await ensure_initialized():
|
| 446 |
+
return json.dumps({
|
| 447 |
+
"status": "error",
|
| 448 |
+
"error": "Memory manager initialization failed",
|
| 449 |
+
"message": "Check LlamaIndex configuration and dependencies"
|
| 450 |
+
}, indent=2)
|
| 451 |
+
|
| 452 |
+
if not memory_manager.index:
|
| 453 |
+
return json.dumps({
|
| 454 |
+
"status": "error",
|
| 455 |
+
"error": "No research memory available",
|
| 456 |
+
"message": "No papers have been indexed yet"
|
| 457 |
+
}, indent=2)
|
| 458 |
+
|
| 459 |
+
try:
|
| 460 |
+
results = await memory_manager.search_similar(
|
| 461 |
+
query=query,
|
| 462 |
+
top_k=max_results
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
if not results:
|
| 466 |
+
return json.dumps({
|
| 467 |
+
"status": "no_results",
|
| 468 |
+
"query": query,
|
| 469 |
+
"message": "No similar research found in memory"
|
| 470 |
+
}, indent=2)
|
| 471 |
+
|
| 472 |
+
# Format results
|
| 473 |
+
formatted_results = []
|
| 474 |
+
for i, result in enumerate(results, 1):
|
| 475 |
+
metadata = result["metadata"]
|
| 476 |
+
formatted_results.append({
|
| 477 |
+
"rank": i,
|
| 478 |
+
"title": metadata.get("title", "Unknown"),
|
| 479 |
+
"authors": metadata.get("authors", []),
|
| 480 |
+
"year": metadata.get("year"),
|
| 481 |
+
"journal": metadata.get("journal"),
|
| 482 |
+
"doi": metadata.get("doi"),
|
| 483 |
+
"url": metadata.get("url"),
|
| 484 |
+
"similarity_score": round(result["score"], 3) if result["score"] else None,
|
| 485 |
+
"excerpt": result["text"][:300] + "..."
|
| 486 |
+
})
|
| 487 |
+
|
| 488 |
+
return json.dumps({
|
| 489 |
+
"status": "success",
|
| 490 |
+
"query": query,
|
| 491 |
+
"num_results": len(formatted_results),
|
| 492 |
+
"results": formatted_results,
|
| 493 |
+
"message": f"Found {len(formatted_results)} similar papers in research memory"
|
| 494 |
+
}, indent=2)
|
| 495 |
+
|
| 496 |
+
except Exception as e:
|
| 497 |
+
logger.error(f"Search error: {e}")
|
| 498 |
+
return json.dumps({"status": "error", "error": "Search failed", "message": str(e)}, indent=2)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
@mcp.tool()
|
| 502 |
+
async def get_research_connections(
|
| 503 |
+
paper_title: str,
|
| 504 |
+
connection_type: str = "similar",
|
| 505 |
+
max_connections: int = 5
|
| 506 |
+
) -> str:
|
| 507 |
+
"""Discover connections between research papers in memory.
|
| 508 |
+
|
| 509 |
+
Finds related papers that might share themes, methods, or findings.
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
paper_title: Title of paper to find connections for
|
| 513 |
+
connection_type: Type of connections - "similar", "citations", "authors"
|
| 514 |
+
max_connections: Maximum connections to return
|
| 515 |
+
|
| 516 |
+
Returns:
|
| 517 |
+
Connected papers with relationship descriptions
|
| 518 |
+
"""
|
| 519 |
+
|
| 520 |
+
if not LLAMAINDEX_AVAILABLE:
|
| 521 |
+
return json.dumps({
|
| 522 |
+
"status": "error",
|
| 523 |
+
"error": "LlamaIndex not installed",
|
| 524 |
+
"message": "Install with: pip install llama-index chromadb sentence-transformers"
|
| 525 |
+
}, indent=2)
|
| 526 |
+
|
| 527 |
+
# Lazy initialization on first use
|
| 528 |
+
if not await ensure_initialized():
|
| 529 |
+
return json.dumps({
|
| 530 |
+
"status": "error",
|
| 531 |
+
"error": "Memory manager initialization failed",
|
| 532 |
+
"message": "Check LlamaIndex configuration and dependencies"
|
| 533 |
+
}, indent=2)
|
| 534 |
+
|
| 535 |
+
try:
|
| 536 |
+
# For now, we'll use similarity search
|
| 537 |
+
# Future: implement citation networks, co-authorship graphs
|
| 538 |
+
|
| 539 |
+
if connection_type == "similar":
|
| 540 |
+
# Search for papers similar to this title
|
| 541 |
+
results = await memory_manager.search_similar(
|
| 542 |
+
query=paper_title,
|
| 543 |
+
top_k=max_connections + 1 # +1 because it might include itself
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
# Filter out the paper itself
|
| 547 |
+
filtered_results = []
|
| 548 |
+
for result in results:
|
| 549 |
+
if result["metadata"].get("title", "").lower() != paper_title.lower():
|
| 550 |
+
filtered_results.append(result)
|
| 551 |
+
|
| 552 |
+
if not filtered_results:
|
| 553 |
+
return json.dumps({
|
| 554 |
+
"status": "no_connections",
|
| 555 |
+
"paper": paper_title,
|
| 556 |
+
"message": "No connections found in research memory"
|
| 557 |
+
}, indent=2)
|
| 558 |
+
|
| 559 |
+
connections = []
|
| 560 |
+
for result in filtered_results[:max_connections]:
|
| 561 |
+
metadata = result["metadata"]
|
| 562 |
+
connections.append({
|
| 563 |
+
"title": metadata.get("title", "Unknown"),
|
| 564 |
+
"authors": metadata.get("authors", []),
|
| 565 |
+
"year": metadata.get("year"),
|
| 566 |
+
"connection_strength": round(result["score"], 3) if result["score"] else None,
|
| 567 |
+
"connection_type": "semantic_similarity",
|
| 568 |
+
"url": metadata.get("url")
|
| 569 |
+
})
|
| 570 |
+
|
| 571 |
+
return json.dumps({
|
| 572 |
+
"status": "success",
|
| 573 |
+
"paper": paper_title,
|
| 574 |
+
"connection_type": connection_type,
|
| 575 |
+
"num_connections": len(connections),
|
| 576 |
+
"connections": connections,
|
| 577 |
+
"message": f"Found {len(connections)} connected papers"
|
| 578 |
+
}, indent=2)
|
| 579 |
+
|
| 580 |
+
else:
|
| 581 |
+
return json.dumps({
|
| 582 |
+
"status": "not_implemented",
|
| 583 |
+
"message": f"Connection type '{connection_type}' not yet implemented. Use 'similar' for now."
|
| 584 |
+
}, indent=2)
|
| 585 |
+
|
| 586 |
+
except Exception as e:
|
| 587 |
+
logger.error(f"Error finding connections: {e}")
|
| 588 |
+
return json.dumps({"status": "error", "error": "Connection search failed", "message": str(e)}, indent=2)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
@mcp.tool()
|
| 592 |
+
async def list_indexed_papers(
|
| 593 |
+
limit: int = 20,
|
| 594 |
+
sort_by: str = "date"
|
| 595 |
+
) -> str:
|
| 596 |
+
"""List papers currently in research memory.
|
| 597 |
+
|
| 598 |
+
Shows what research the agent has learned from previously.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
limit: Maximum papers to list (default: 20)
|
| 602 |
+
sort_by: Sort order - "date" (indexed date) or "year" (publication year)
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
List of indexed papers with metadata
|
| 606 |
+
"""
|
| 607 |
+
|
| 608 |
+
if not papers_metadata:
|
| 609 |
+
return json.dumps({
|
| 610 |
+
"status": "empty",
|
| 611 |
+
"message": "No papers indexed yet. Research memory is empty.",
|
| 612 |
+
"total_papers": 0
|
| 613 |
+
}, indent=2)
|
| 614 |
+
|
| 615 |
+
try:
|
| 616 |
+
# Get papers list
|
| 617 |
+
papers_list = []
|
| 618 |
+
for paper_id, metadata in papers_metadata.items():
|
| 619 |
+
# Convert authors string back to list
|
| 620 |
+
authors_str = metadata.get("authors", "")
|
| 621 |
+
authors_list = authors_str.split(", ") if authors_str else []
|
| 622 |
+
|
| 623 |
+
papers_list.append({
|
| 624 |
+
"paper_id": paper_id,
|
| 625 |
+
"title": metadata.get("title", "Unknown"),
|
| 626 |
+
"authors": authors_list,
|
| 627 |
+
"year": metadata.get("year"),
|
| 628 |
+
"journal": metadata.get("journal"),
|
| 629 |
+
"doi": metadata.get("doi"),
|
| 630 |
+
"indexed_at": metadata.get("indexed_at"),
|
| 631 |
+
"url": metadata.get("url")
|
| 632 |
+
})
|
| 633 |
+
|
| 634 |
+
# Sort
|
| 635 |
+
if sort_by == "date":
|
| 636 |
+
papers_list.sort(key=lambda x: x.get("indexed_at", ""), reverse=True)
|
| 637 |
+
elif sort_by == "year":
|
| 638 |
+
papers_list.sort(key=lambda x: x.get("year", 0), reverse=True)
|
| 639 |
+
|
| 640 |
+
# Limit
|
| 641 |
+
papers_list = papers_list[:limit]
|
| 642 |
+
|
| 643 |
+
return json.dumps({
|
| 644 |
+
"status": "success",
|
| 645 |
+
"total_papers": len(papers_metadata),
|
| 646 |
+
"showing": len(papers_list),
|
| 647 |
+
"sort_by": sort_by,
|
| 648 |
+
"papers": papers_list,
|
| 649 |
+
"message": f"Research memory contains {len(papers_metadata)} papers"
|
| 650 |
+
}, indent=2)
|
| 651 |
+
|
| 652 |
+
except Exception as e:
|
| 653 |
+
logger.error(f"Error listing papers: {e}")
|
| 654 |
+
return json.dumps({"status": "error", "error": "Failed to list papers", "message": str(e)}, indent=2)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
@mcp.tool()
|
| 658 |
+
async def clear_research_memory(
|
| 659 |
+
confirm: bool = False
|
| 660 |
+
) -> str:
|
| 661 |
+
"""Clear all papers from research memory.
|
| 662 |
+
|
| 663 |
+
⚠️ This will permanently delete all indexed research!
|
| 664 |
+
|
| 665 |
+
Args:
|
| 666 |
+
confirm: Must be True to actually clear memory
|
| 667 |
+
|
| 668 |
+
Returns:
|
| 669 |
+
Confirmation of memory clearing
|
| 670 |
+
"""
|
| 671 |
+
global papers_metadata
|
| 672 |
+
|
| 673 |
+
if not confirm:
|
| 674 |
+
return json.dumps({
|
| 675 |
+
"status": "confirmation_required",
|
| 676 |
+
"message": "⚠️ This will delete all research memory. Set confirm=True to proceed.",
|
| 677 |
+
"current_papers": len(papers_metadata)
|
| 678 |
+
}, indent=2)
|
| 679 |
+
|
| 680 |
+
try:
|
| 681 |
+
# Check if memory manager needs initialization
|
| 682 |
+
# Only initialize if we have papers to clear
|
| 683 |
+
if papers_metadata and not memory_manager:
|
| 684 |
+
await ensure_initialized()
|
| 685 |
+
|
| 686 |
+
# Clear ChromaDB collection
|
| 687 |
+
if memory_manager and memory_manager.collection:
|
| 688 |
+
# Delete and recreate collection
|
| 689 |
+
memory_manager.chroma_client.delete_collection("als_research")
|
| 690 |
+
memory_manager.collection = memory_manager.chroma_client.create_collection("als_research")
|
| 691 |
+
|
| 692 |
+
# Reinitialize index
|
| 693 |
+
memory_manager._initialize_index()
|
| 694 |
+
|
| 695 |
+
# Clear metadata
|
| 696 |
+
num_papers = len(papers_metadata)
|
| 697 |
+
papers_metadata = {}
|
| 698 |
+
|
| 699 |
+
# Save empty metadata
|
| 700 |
+
if memory_manager:
|
| 701 |
+
memory_manager._save_metadata()
|
| 702 |
+
|
| 703 |
+
logger.info(f"Cleared research memory: {num_papers} papers removed")
|
| 704 |
+
|
| 705 |
+
return json.dumps({
|
| 706 |
+
"status": "success",
|
| 707 |
+
"message": f"✅ Research memory cleared. Removed {num_papers} papers.",
|
| 708 |
+
"papers_removed": num_papers
|
| 709 |
+
}, indent=2)
|
| 710 |
+
|
| 711 |
+
except Exception as e:
|
| 712 |
+
logger.error(f"Error clearing memory: {e}")
|
| 713 |
+
return json.dumps({"status": "error", "error": "Failed to clear memory", "message": str(e)}, indent=2)
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
if __name__ == "__main__":
|
| 717 |
+
# Check for required packages
|
| 718 |
+
if not LLAMAINDEX_AVAILABLE:
|
| 719 |
+
logger.error("LlamaIndex dependencies not installed!")
|
| 720 |
+
logger.info("Install with: pip install llama-index-core llama-index-vector-stores-chroma")
|
| 721 |
+
logger.info(" pip install chromadb sentence-transformers transformers")
|
| 722 |
+
else:
|
| 723 |
+
logger.info(f"LlamaIndex RAG server starting...")
|
| 724 |
+
logger.info(f"ChromaDB path: {CHROMA_DB_PATH}")
|
| 725 |
+
logger.info(f"Embedding model: {EMBED_MODEL}")
|
| 726 |
+
logger.info(f"Papers in memory: {len(papers_metadata)}")
|
| 727 |
+
|
| 728 |
+
# Run the MCP server
|
| 729 |
+
mcp.run(transport="stdio")
|
servers/pubmed_server.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pubmed_server.py
|
| 2 |
+
from mcp.server.fastmcp import FastMCP
|
| 3 |
+
import httpx
|
| 4 |
+
import logging
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Add parent directory to path for shared imports
|
| 9 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 10 |
+
|
| 11 |
+
from shared import (
|
| 12 |
+
config,
|
| 13 |
+
RateLimiter,
|
| 14 |
+
format_authors,
|
| 15 |
+
ErrorFormatter,
|
| 16 |
+
truncate_text
|
| 17 |
+
)
|
| 18 |
+
from shared.http_client import get_http_client
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# Create FastMCP server
|
| 25 |
+
mcp = FastMCP("pubmed-server")
|
| 26 |
+
|
| 27 |
+
# Rate limiting using shared utility
|
| 28 |
+
rate_limiter = RateLimiter(config.rate_limits.pubmed_delay)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@mcp.tool()
|
| 32 |
+
async def search_pubmed(
|
| 33 |
+
query: str,
|
| 34 |
+
max_results: int = 10,
|
| 35 |
+
sort: str = "relevance"
|
| 36 |
+
) -> str:
|
| 37 |
+
"""Search PubMed for ALS research papers. Returns titles, abstracts, PMIDs, and publication dates.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
query: Search query (e.g., 'ALS SOD1 therapy')
|
| 41 |
+
max_results: Maximum number of results (default: 10)
|
| 42 |
+
sort: Sort order - 'relevance' or 'date' (default: 'relevance')
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
logger.info(f"Searching PubMed for: {query}")
|
| 46 |
+
|
| 47 |
+
# PubMed E-utilities API (no auth required)
|
| 48 |
+
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
| 49 |
+
|
| 50 |
+
# Rate limiting
|
| 51 |
+
await rate_limiter.wait()
|
| 52 |
+
|
| 53 |
+
# Step 1: Search for PMIDs
|
| 54 |
+
search_params = {
|
| 55 |
+
"db": "pubmed",
|
| 56 |
+
"term": query,
|
| 57 |
+
"retmax": max_results,
|
| 58 |
+
"retmode": "json",
|
| 59 |
+
"sort": sort
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Use shared HTTP client for connection pooling
|
| 63 |
+
client = get_http_client(timeout=config.api.timeout)
|
| 64 |
+
|
| 65 |
+
# Get PMIDs
|
| 66 |
+
search_resp = await client.get(f"{base_url}/esearch.fcgi", params=search_params)
|
| 67 |
+
search_resp.raise_for_status()
|
| 68 |
+
search_data = search_resp.json()
|
| 69 |
+
pmids = search_data.get("esearchresult", {}).get("idlist", [])
|
| 70 |
+
|
| 71 |
+
if not pmids:
|
| 72 |
+
logger.info(f"No results found for query: {query}")
|
| 73 |
+
return ErrorFormatter.no_results(query)
|
| 74 |
+
|
| 75 |
+
# Rate limiting
|
| 76 |
+
await rate_limiter.wait()
|
| 77 |
+
|
| 78 |
+
# Step 2: Fetch details for PMIDs
|
| 79 |
+
fetch_params = {
|
| 80 |
+
"db": "pubmed",
|
| 81 |
+
"id": ",".join(pmids),
|
| 82 |
+
"retmode": "xml"
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
fetch_resp = await client.get(f"{base_url}/efetch.fcgi", params=fetch_params)
|
| 86 |
+
fetch_resp.raise_for_status()
|
| 87 |
+
|
| 88 |
+
# Parse XML and extract key info
|
| 89 |
+
papers = parse_pubmed_xml(fetch_resp.text)
|
| 90 |
+
|
| 91 |
+
result = f"Found {len(papers)} papers for query: '{query}'\n\n"
|
| 92 |
+
for i, paper in enumerate(papers, 1):
|
| 93 |
+
result += f"{i}. **{paper['title']}**\n"
|
| 94 |
+
result += f" PMID: {paper['pmid']} | Published: {paper['date']}\n"
|
| 95 |
+
result += f" Authors: {paper['authors']}\n"
|
| 96 |
+
result += f" URL: https://pubmed.ncbi.nlm.nih.gov/{paper['pmid']}/\n"
|
| 97 |
+
result += f" Abstract: {truncate_text(paper['abstract'], max_chars=300, suffix='')}...\n\n"
|
| 98 |
+
|
| 99 |
+
logger.info(f"Successfully retrieved {len(papers)} papers")
|
| 100 |
+
return result
|
| 101 |
+
|
| 102 |
+
except httpx.TimeoutException:
|
| 103 |
+
logger.error("PubMed API request timed out")
|
| 104 |
+
return "Error: PubMed API request timed out. Please try again."
|
| 105 |
+
except httpx.HTTPStatusError as e:
|
| 106 |
+
logger.error(f"PubMed API error: {e}")
|
| 107 |
+
return f"Error: PubMed API returned status {e.response.status_code}"
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Unexpected error in search_pubmed: {e}")
|
| 110 |
+
return f"Error: {str(e)}"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@mcp.tool()
|
| 114 |
+
async def get_paper_details(pmid: str) -> str:
|
| 115 |
+
"""Get full details for a specific PubMed paper by PMID.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
pmid: PubMed ID
|
| 119 |
+
"""
|
| 120 |
+
try:
|
| 121 |
+
logger.info(f"Fetching details for PMID: {pmid}")
|
| 122 |
+
|
| 123 |
+
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
| 124 |
+
|
| 125 |
+
# Rate limiting
|
| 126 |
+
await rate_limiter.wait()
|
| 127 |
+
|
| 128 |
+
fetch_params = {
|
| 129 |
+
"db": "pubmed",
|
| 130 |
+
"id": pmid,
|
| 131 |
+
"retmode": "xml"
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
# Use shared HTTP client for connection pooling
|
| 135 |
+
client = get_http_client(timeout=config.api.timeout)
|
| 136 |
+
fetch_resp = await client.get(f"{base_url}/efetch.fcgi", params=fetch_params)
|
| 137 |
+
fetch_resp.raise_for_status()
|
| 138 |
+
|
| 139 |
+
papers = parse_pubmed_xml(fetch_resp.text)
|
| 140 |
+
|
| 141 |
+
if not papers:
|
| 142 |
+
return ErrorFormatter.not_found("paper", pmid)
|
| 143 |
+
|
| 144 |
+
paper = papers[0]
|
| 145 |
+
|
| 146 |
+
# Format detailed response
|
| 147 |
+
result = f"**{paper['title']}**\n\n"
|
| 148 |
+
result += f"**PMID:** {paper['pmid']}\n"
|
| 149 |
+
result += f"**Published:** {paper['date']}\n"
|
| 150 |
+
result += f"**Authors:** {paper['authors']}\n\n"
|
| 151 |
+
result += f"**Abstract:**\n{paper['abstract']}\n\n"
|
| 152 |
+
result += f"**Journal:** {paper.get('journal', 'N/A')}\n"
|
| 153 |
+
result += f"**DOI:** {paper.get('doi', 'N/A')}\n"
|
| 154 |
+
result += f"**PubMed URL:** https://pubmed.ncbi.nlm.nih.gov/{pmid}/\n"
|
| 155 |
+
|
| 156 |
+
logger.info(f"Successfully retrieved details for PMID: {pmid}")
|
| 157 |
+
return result
|
| 158 |
+
|
| 159 |
+
except httpx.TimeoutException:
|
| 160 |
+
logger.error("PubMed API request timed out")
|
| 161 |
+
return "Error: PubMed API request timed out. Please try again."
|
| 162 |
+
except httpx.HTTPStatusError as e:
|
| 163 |
+
logger.error(f"PubMed API error: {e}")
|
| 164 |
+
return f"Error: PubMed API returned status {e.response.status_code}"
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Unexpected error in get_paper_details: {e}")
|
| 167 |
+
return f"Error: {str(e)}"
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def parse_pubmed_xml(xml_text: str) -> list[dict]:
|
| 171 |
+
"""Parse PubMed XML response into structured data with error handling"""
|
| 172 |
+
import xml.etree.ElementTree as ET
|
| 173 |
+
|
| 174 |
+
papers = []
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
root = ET.fromstring(xml_text)
|
| 178 |
+
except ET.ParseError as e:
|
| 179 |
+
logger.error(f"XML parsing error: {e}")
|
| 180 |
+
return papers
|
| 181 |
+
|
| 182 |
+
for article in root.findall(".//PubmedArticle"):
|
| 183 |
+
try:
|
| 184 |
+
# Extract title
|
| 185 |
+
title_elem = article.find(".//ArticleTitle")
|
| 186 |
+
title = "".join(title_elem.itertext()) if title_elem is not None else "No title"
|
| 187 |
+
|
| 188 |
+
# Extract abstract (may have multiple AbstractText elements)
|
| 189 |
+
abstract_parts = []
|
| 190 |
+
for abstract_elem in article.findall(".//AbstractText"):
|
| 191 |
+
if abstract_elem is not None and abstract_elem.text:
|
| 192 |
+
label = abstract_elem.get("Label", "")
|
| 193 |
+
text = "".join(abstract_elem.itertext())
|
| 194 |
+
if label:
|
| 195 |
+
abstract_parts.append(f"{label}: {text}")
|
| 196 |
+
else:
|
| 197 |
+
abstract_parts.append(text)
|
| 198 |
+
abstract = " ".join(abstract_parts) if abstract_parts else "No abstract available"
|
| 199 |
+
|
| 200 |
+
# Extract PMID
|
| 201 |
+
pmid_elem = article.find(".//PMID")
|
| 202 |
+
pmid = pmid_elem.text if pmid_elem is not None else "Unknown"
|
| 203 |
+
|
| 204 |
+
# Extract date - correct path in MedlineCitation
|
| 205 |
+
pub_date = article.find(".//MedlineCitation/Article/Journal/JournalIssue/PubDate")
|
| 206 |
+
if pub_date is not None:
|
| 207 |
+
year_elem = pub_date.find("Year")
|
| 208 |
+
month_elem = pub_date.find("Month")
|
| 209 |
+
year = year_elem.text if year_elem is not None else "Unknown"
|
| 210 |
+
month = month_elem.text if month_elem is not None else ""
|
| 211 |
+
date_str = f"{month} {year}" if month else year
|
| 212 |
+
else:
|
| 213 |
+
# Try alternative date location
|
| 214 |
+
date_completed = article.find(".//DateCompleted")
|
| 215 |
+
if date_completed is not None:
|
| 216 |
+
year_elem = date_completed.find("Year")
|
| 217 |
+
year = year_elem.text if year_elem is not None else "Unknown"
|
| 218 |
+
date_str = year
|
| 219 |
+
else:
|
| 220 |
+
date_str = "Unknown"
|
| 221 |
+
|
| 222 |
+
# Extract authors
|
| 223 |
+
authors = []
|
| 224 |
+
for author in article.findall(".//Author"):
|
| 225 |
+
last = author.find("LastName")
|
| 226 |
+
first = author.find("ForeName")
|
| 227 |
+
collective = author.find("CollectiveName")
|
| 228 |
+
|
| 229 |
+
if collective is not None and collective.text:
|
| 230 |
+
authors.append(collective.text)
|
| 231 |
+
elif last is not None and first is not None:
|
| 232 |
+
authors.append(f"{first.text} {last.text}")
|
| 233 |
+
elif last is not None:
|
| 234 |
+
authors.append(last.text)
|
| 235 |
+
|
| 236 |
+
# Format authors using shared utility
|
| 237 |
+
authors_str = format_authors("; ".join(authors), max_authors=3) if authors else "Unknown authors"
|
| 238 |
+
|
| 239 |
+
# Extract journal name
|
| 240 |
+
journal_elem = article.find(".//Journal/Title")
|
| 241 |
+
journal = journal_elem.text if journal_elem is not None else "Unknown"
|
| 242 |
+
|
| 243 |
+
# Extract DOI
|
| 244 |
+
doi = None
|
| 245 |
+
for article_id in article.findall(".//ArticleId"):
|
| 246 |
+
if article_id.get("IdType") == "doi":
|
| 247 |
+
doi = article_id.text
|
| 248 |
+
break
|
| 249 |
+
|
| 250 |
+
papers.append({
|
| 251 |
+
"title": title,
|
| 252 |
+
"abstract": abstract,
|
| 253 |
+
"pmid": pmid,
|
| 254 |
+
"date": date_str,
|
| 255 |
+
"authors": authors_str,
|
| 256 |
+
"journal": journal,
|
| 257 |
+
"doi": doi or "N/A"
|
| 258 |
+
})
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.warning(f"Error parsing article: {e}")
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
return papers
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
# Run with stdio transport
|
| 269 |
+
mcp.run(transport="stdio")
|
shared/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# shared/__init__.py
|
| 2 |
+
"""Shared utilities and configuration for ALS Research Agent"""
|
| 3 |
+
|
| 4 |
+
from .config import config, AppConfig, APIConfig, RateLimitConfig, ContentLimits, SecurityConfig
|
| 5 |
+
from .utils import (
|
| 6 |
+
RateLimiter,
|
| 7 |
+
safe_api_call,
|
| 8 |
+
truncate_text,
|
| 9 |
+
format_authors,
|
| 10 |
+
clean_whitespace,
|
| 11 |
+
ErrorFormatter,
|
| 12 |
+
create_citation
|
| 13 |
+
)
|
| 14 |
+
from .cache import SimpleCache
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
# Configuration
|
| 18 |
+
'config',
|
| 19 |
+
'AppConfig',
|
| 20 |
+
'APIConfig',
|
| 21 |
+
'RateLimitConfig',
|
| 22 |
+
'ContentLimits',
|
| 23 |
+
'SecurityConfig',
|
| 24 |
+
# Utilities
|
| 25 |
+
'RateLimiter',
|
| 26 |
+
'safe_api_call',
|
| 27 |
+
'truncate_text',
|
| 28 |
+
'format_authors',
|
| 29 |
+
'clean_whitespace',
|
| 30 |
+
'ErrorFormatter',
|
| 31 |
+
'create_citation',
|
| 32 |
+
# Cache
|
| 33 |
+
'SimpleCache',
|
| 34 |
+
]
|
shared/cache.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# shared/cache.py
|
| 2 |
+
"""Simple in-memory cache for API responses"""
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
+
from typing import Optional, Any
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SimpleCache:
|
| 14 |
+
"""Simple TTL-based in-memory cache with size limits"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, ttl: int = 3600, max_size: int = 100):
|
| 17 |
+
"""
|
| 18 |
+
Initialize cache with TTL and size limits
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
ttl: Time to live in seconds (default: 1 hour)
|
| 22 |
+
max_size: Maximum number of cached entries (default: 100)
|
| 23 |
+
"""
|
| 24 |
+
self.cache = {}
|
| 25 |
+
self.ttl = ttl
|
| 26 |
+
self.max_size = max_size
|
| 27 |
+
|
| 28 |
+
def _make_key(self, tool_name: str, arguments: dict) -> str:
|
| 29 |
+
"""Create cache key from tool name and arguments"""
|
| 30 |
+
# Sort dict for consistent hashing
|
| 31 |
+
args_str = json.dumps(arguments, sort_keys=True)
|
| 32 |
+
key_str = f"{tool_name}:{args_str}"
|
| 33 |
+
return hashlib.md5(key_str.encode()).hexdigest()
|
| 34 |
+
|
| 35 |
+
def get(self, tool_name: str, arguments: dict) -> Optional[str]:
|
| 36 |
+
"""Get cached result if available and not expired"""
|
| 37 |
+
key = self._make_key(tool_name, arguments)
|
| 38 |
+
|
| 39 |
+
if key in self.cache:
|
| 40 |
+
result, timestamp = self.cache[key]
|
| 41 |
+
|
| 42 |
+
# Check if expired
|
| 43 |
+
if time.time() - timestamp < self.ttl:
|
| 44 |
+
logger.info(f"Cache HIT for {tool_name}")
|
| 45 |
+
return result
|
| 46 |
+
else:
|
| 47 |
+
# Remove expired entry
|
| 48 |
+
del self.cache[key]
|
| 49 |
+
logger.info(f"Cache EXPIRED for {tool_name}")
|
| 50 |
+
|
| 51 |
+
logger.info(f"Cache MISS for {tool_name}")
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
def set(self, tool_name: str, arguments: dict, result: str) -> None:
|
| 55 |
+
"""Store result in cache with LRU eviction if at capacity"""
|
| 56 |
+
key = self._make_key(tool_name, arguments)
|
| 57 |
+
|
| 58 |
+
# Check if we need to evict an entry
|
| 59 |
+
if len(self.cache) >= self.max_size and key not in self.cache:
|
| 60 |
+
# Find and remove oldest entry (LRU based on timestamp)
|
| 61 |
+
if self.cache: # Safety check
|
| 62 |
+
oldest_key = min(self.cache.keys(),
|
| 63 |
+
key=lambda k: self.cache[k][1])
|
| 64 |
+
del self.cache[oldest_key]
|
| 65 |
+
logger.debug(f"Evicted oldest cache entry to maintain size limit")
|
| 66 |
+
|
| 67 |
+
self.cache[key] = (result, time.time())
|
| 68 |
+
logger.debug(f"Cached result for {tool_name} (cache size: {len(self.cache)}/{self.max_size})")
|
| 69 |
+
|
| 70 |
+
def clear(self) -> None:
|
| 71 |
+
"""Clear all cache entries"""
|
| 72 |
+
self.cache.clear()
|
| 73 |
+
logger.info("Cache cleared")
|
| 74 |
+
|
| 75 |
+
def size(self) -> int:
|
| 76 |
+
"""Get number of cached items"""
|
| 77 |
+
return len(self.cache)
|
| 78 |
+
|
| 79 |
+
def cleanup_expired(self) -> int:
|
| 80 |
+
"""Remove all expired entries and return count of removed items"""
|
| 81 |
+
expired_keys = []
|
| 82 |
+
current_time = time.time()
|
| 83 |
+
|
| 84 |
+
for key, (result, timestamp) in self.cache.items():
|
| 85 |
+
if current_time - timestamp >= self.ttl:
|
| 86 |
+
expired_keys.append(key)
|
| 87 |
+
|
| 88 |
+
for key in expired_keys:
|
| 89 |
+
del self.cache[key]
|
| 90 |
+
|
| 91 |
+
if expired_keys:
|
| 92 |
+
logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")
|
| 93 |
+
|
| 94 |
+
return len(expired_keys)
|
shared/config.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# shared/config.py
|
| 2 |
+
"""Shared configuration for MCP servers"""
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class APIConfig:
|
| 11 |
+
"""Configuration for API calls"""
|
| 12 |
+
timeout: float = 15.0 # Reduced from 30s - PubMed typically responds in <1s
|
| 13 |
+
max_retries: int = 3
|
| 14 |
+
user_agent: str = "Mozilla/5.0 (compatible; ALS-Research-Bot/1.0)"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class RateLimitConfig:
|
| 19 |
+
"""Rate limiting configuration for different APIs"""
|
| 20 |
+
# PubMed: 3 req/sec without key, 10 req/sec with key
|
| 21 |
+
pubmed_delay: float = 0.34 # ~3 requests per second
|
| 22 |
+
|
| 23 |
+
# ClinicalTrials.gov: conservative limit (API limit is ~50 req/min)
|
| 24 |
+
clinicaltrials_delay: float = 1.5 # ~40 requests per minute (safe margin)
|
| 25 |
+
|
| 26 |
+
# bioRxiv/medRxiv: be respectful
|
| 27 |
+
biorxiv_delay: float = 1.0 # 1 request per second
|
| 28 |
+
|
| 29 |
+
# General web fetching
|
| 30 |
+
fetch_delay: float = 0.5
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class ContentLimits:
|
| 35 |
+
"""Content size and length limits"""
|
| 36 |
+
# Maximum content size for downloads (10MB)
|
| 37 |
+
max_content_size: int = 10 * 1024 * 1024
|
| 38 |
+
|
| 39 |
+
# Maximum characters for LLM context
|
| 40 |
+
max_text_chars: int = 8000
|
| 41 |
+
|
| 42 |
+
# Maximum abstract preview length
|
| 43 |
+
max_abstract_preview: int = 300
|
| 44 |
+
|
| 45 |
+
# Maximum description preview length
|
| 46 |
+
max_description_preview: int = 500
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class SecurityConfig:
|
| 51 |
+
"""Security-related configuration"""
|
| 52 |
+
allowed_schemes: list[str] = None
|
| 53 |
+
blocked_hosts: list[str] = None
|
| 54 |
+
|
| 55 |
+
def __post_init__(self):
|
| 56 |
+
if self.allowed_schemes is None:
|
| 57 |
+
self.allowed_schemes = ['http', 'https']
|
| 58 |
+
|
| 59 |
+
if self.blocked_hosts is None:
|
| 60 |
+
self.blocked_hosts = [
|
| 61 |
+
'localhost',
|
| 62 |
+
'127.0.0.1',
|
| 63 |
+
'0.0.0.0',
|
| 64 |
+
'[::1]'
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
def is_private_ip(self, hostname: str) -> bool:
|
| 68 |
+
"""Check if hostname is a private IP"""
|
| 69 |
+
hostname_lower = hostname.lower()
|
| 70 |
+
|
| 71 |
+
# Check exact matches
|
| 72 |
+
if hostname_lower in self.blocked_hosts:
|
| 73 |
+
return True
|
| 74 |
+
|
| 75 |
+
# Check private IP ranges
|
| 76 |
+
if hostname_lower.startswith(('192.168.', '10.')):
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
# Check 172.16-31 range
|
| 80 |
+
if hostname_lower.startswith('172.'):
|
| 81 |
+
try:
|
| 82 |
+
second_octet = int(hostname.split('.')[1])
|
| 83 |
+
if 16 <= second_octet <= 31:
|
| 84 |
+
return True
|
| 85 |
+
except (ValueError, IndexError):
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class AppConfig:
|
| 93 |
+
"""Application-wide configuration"""
|
| 94 |
+
# API configurations
|
| 95 |
+
api: APIConfig = None
|
| 96 |
+
rate_limits: RateLimitConfig = None
|
| 97 |
+
content_limits: ContentLimits = None
|
| 98 |
+
security: SecurityConfig = None
|
| 99 |
+
|
| 100 |
+
# Environment variables
|
| 101 |
+
anthropic_api_key: Optional[str] = None
|
| 102 |
+
anthropic_model: str = "claude-sonnet-4-5-20250929"
|
| 103 |
+
gradio_port: int = 7860
|
| 104 |
+
log_level: str = "INFO"
|
| 105 |
+
|
| 106 |
+
# PubMed email (optional, increases rate limit)
|
| 107 |
+
pubmed_email: Optional[str] = None
|
| 108 |
+
|
| 109 |
+
def __post_init__(self):
|
| 110 |
+
# Initialize sub-configs
|
| 111 |
+
if self.api is None:
|
| 112 |
+
self.api = APIConfig()
|
| 113 |
+
if self.rate_limits is None:
|
| 114 |
+
self.rate_limits = RateLimitConfig()
|
| 115 |
+
if self.content_limits is None:
|
| 116 |
+
self.content_limits = ContentLimits()
|
| 117 |
+
if self.security is None:
|
| 118 |
+
self.security = SecurityConfig()
|
| 119 |
+
|
| 120 |
+
# Load from environment
|
| 121 |
+
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", self.anthropic_api_key)
|
| 122 |
+
self.anthropic_model = os.getenv("ANTHROPIC_MODEL", self.anthropic_model)
|
| 123 |
+
self.gradio_port = int(os.getenv("GRADIO_SERVER_PORT", self.gradio_port))
|
| 124 |
+
self.log_level = os.getenv("LOG_LEVEL", self.log_level)
|
| 125 |
+
self.pubmed_email = os.getenv("PUBMED_EMAIL", self.pubmed_email)
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def from_env(cls) -> 'AppConfig':
|
| 129 |
+
"""Create configuration from environment variables"""
|
| 130 |
+
return cls()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# Global configuration instance
|
| 134 |
+
config = AppConfig.from_env()
|
shared/http_client.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Shared HTTP client with connection pooling for better performance.
|
| 4 |
+
All MCP servers should use this instead of creating new clients for each request.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import httpx
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
# Global HTTP client with connection pooling
|
| 11 |
+
# This maintains persistent connections to servers for faster subsequent requests
|
| 12 |
+
_http_client: Optional[httpx.AsyncClient] = None
|
| 13 |
+
|
| 14 |
+
def get_http_client(timeout: float = 30.0) -> httpx.AsyncClient:
|
| 15 |
+
"""
|
| 16 |
+
Get the shared HTTP client with connection pooling.
|
| 17 |
+
|
| 18 |
+
NOTE: For different timeout values, use CustomHTTPClient context manager
|
| 19 |
+
instead to avoid conflicts between servers.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
timeout: Request timeout in seconds (default 30)
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Shared httpx.AsyncClient instance
|
| 26 |
+
"""
|
| 27 |
+
global _http_client
|
| 28 |
+
|
| 29 |
+
if _http_client is None or _http_client.is_closed:
|
| 30 |
+
_http_client = httpx.AsyncClient(
|
| 31 |
+
timeout=httpx.Timeout(timeout),
|
| 32 |
+
limits=httpx.Limits(
|
| 33 |
+
max_connections=100, # Maximum number of connections
|
| 34 |
+
max_keepalive_connections=20, # Keep 20 connections alive for reuse
|
| 35 |
+
keepalive_expiry=300 # Keep connections alive for 5 minutes
|
| 36 |
+
),
|
| 37 |
+
# Follow redirects by default
|
| 38 |
+
follow_redirects=True
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return _http_client
|
| 42 |
+
|
| 43 |
+
async def close_http_client():
|
| 44 |
+
"""Close the shared HTTP client (call on shutdown)."""
|
| 45 |
+
global _http_client
|
| 46 |
+
if _http_client and not _http_client.is_closed:
|
| 47 |
+
await _http_client.aclose()
|
| 48 |
+
_http_client = None
|
| 49 |
+
|
| 50 |
+
# Context manager for temporary clients with custom settings
|
| 51 |
+
class CustomHTTPClient:
|
| 52 |
+
"""Context manager for creating temporary HTTP clients with custom settings."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, timeout: float = 30.0, **kwargs):
|
| 55 |
+
self.timeout = timeout
|
| 56 |
+
self.kwargs = kwargs
|
| 57 |
+
self.client = None
|
| 58 |
+
|
| 59 |
+
async def __aenter__(self):
|
| 60 |
+
self.client = httpx.AsyncClient(
|
| 61 |
+
timeout=httpx.Timeout(self.timeout),
|
| 62 |
+
**self.kwargs
|
| 63 |
+
)
|
| 64 |
+
return self.client
|
| 65 |
+
|
| 66 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 67 |
+
if self.client:
|
| 68 |
+
await self.client.aclose()
|
shared/utils.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# shared/utils.py
|
| 2 |
+
"""Shared utilities for MCP servers"""
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
import time
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Optional, Callable, Any
|
| 8 |
+
from mcp.types import TextContent
|
| 9 |
+
import httpx
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RateLimiter:
|
| 15 |
+
"""Rate limiter for API calls"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, delay: float):
|
| 18 |
+
"""
|
| 19 |
+
Initialize rate limiter
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
delay: Minimum delay between requests in seconds
|
| 23 |
+
"""
|
| 24 |
+
self.delay = delay
|
| 25 |
+
self.last_request_time: Optional[float] = None
|
| 26 |
+
|
| 27 |
+
async def wait(self) -> None:
|
| 28 |
+
"""Wait if necessary to respect rate limit"""
|
| 29 |
+
if self.last_request_time is not None:
|
| 30 |
+
elapsed = time.time() - self.last_request_time
|
| 31 |
+
if elapsed < self.delay:
|
| 32 |
+
await asyncio.sleep(self.delay - elapsed)
|
| 33 |
+
self.last_request_time = time.time()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def safe_api_call(
|
| 37 |
+
func: Callable,
|
| 38 |
+
*args: Any,
|
| 39 |
+
timeout: float = 30.0,
|
| 40 |
+
error_prefix: str = "API",
|
| 41 |
+
**kwargs: Any
|
| 42 |
+
) -> list[TextContent]:
|
| 43 |
+
"""
|
| 44 |
+
Safely execute an API call with comprehensive error handling
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
func: Async function to call
|
| 48 |
+
*args: Positional arguments for func
|
| 49 |
+
timeout: Timeout in seconds
|
| 50 |
+
error_prefix: Prefix for error messages
|
| 51 |
+
**kwargs: Keyword arguments for func
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
list[TextContent]: Result or error message
|
| 55 |
+
"""
|
| 56 |
+
try:
|
| 57 |
+
return await asyncio.wait_for(func(*args, **kwargs), timeout=timeout)
|
| 58 |
+
|
| 59 |
+
except asyncio.TimeoutError:
|
| 60 |
+
logger.error(f"{error_prefix} request timed out after {timeout}s")
|
| 61 |
+
return [TextContent(
|
| 62 |
+
type="text",
|
| 63 |
+
text=f"Error: {error_prefix} request timed out after {timeout} seconds. Please try again."
|
| 64 |
+
)]
|
| 65 |
+
|
| 66 |
+
except httpx.TimeoutException:
|
| 67 |
+
logger.error(f"{error_prefix} request timed out")
|
| 68 |
+
return [TextContent(
|
| 69 |
+
type="text",
|
| 70 |
+
text=f"Error: {error_prefix} request timed out. Please try again."
|
| 71 |
+
)]
|
| 72 |
+
|
| 73 |
+
except httpx.HTTPStatusError as e:
|
| 74 |
+
logger.error(f"{error_prefix} error: HTTP {e.response.status_code}")
|
| 75 |
+
return [TextContent(
|
| 76 |
+
type="text",
|
| 77 |
+
text=f"Error: {error_prefix} returned status {e.response.status_code}"
|
| 78 |
+
)]
|
| 79 |
+
|
| 80 |
+
except httpx.RequestError as e:
|
| 81 |
+
logger.error(f"{error_prefix} request error: {e}")
|
| 82 |
+
return [TextContent(
|
| 83 |
+
type="text",
|
| 84 |
+
text=f"Error: Failed to connect to {error_prefix}. Please check your connection."
|
| 85 |
+
)]
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Unexpected error in {error_prefix}: {e}", exc_info=True)
|
| 89 |
+
return [TextContent(
|
| 90 |
+
type="text",
|
| 91 |
+
text=f"Error: {str(e)}"
|
| 92 |
+
)]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def truncate_text(text: str, max_chars: int = 8000, suffix: str = "...") -> str:
|
| 96 |
+
"""
|
| 97 |
+
Truncate text to maximum length with suffix
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
text: Text to truncate
|
| 101 |
+
max_chars: Maximum character count
|
| 102 |
+
suffix: Suffix to add when truncated
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Truncated text
|
| 106 |
+
"""
|
| 107 |
+
if len(text) <= max_chars:
|
| 108 |
+
return text
|
| 109 |
+
|
| 110 |
+
return text[:max_chars] + f"\n\n[Content truncated at {max_chars} characters]{suffix}"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def format_authors(authors: str, max_authors: int = 3) -> str:
|
| 114 |
+
"""
|
| 115 |
+
Format author list with et al. if needed
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
authors: Semicolon-separated author list
|
| 119 |
+
max_authors: Maximum authors to show
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Formatted author string
|
| 123 |
+
"""
|
| 124 |
+
if not authors or authors == "Unknown":
|
| 125 |
+
return "Unknown authors"
|
| 126 |
+
|
| 127 |
+
author_list = [a.strip() for a in authors.split(";")]
|
| 128 |
+
|
| 129 |
+
if len(author_list) <= max_authors:
|
| 130 |
+
return ", ".join(author_list)
|
| 131 |
+
|
| 132 |
+
return ", ".join(author_list[:max_authors]) + " et al."
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def clean_whitespace(text: str) -> str:
|
| 136 |
+
"""
|
| 137 |
+
Clean up excessive whitespace in text
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
text: Text to clean
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Cleaned text
|
| 144 |
+
"""
|
| 145 |
+
lines = (line.strip() for line in text.splitlines())
|
| 146 |
+
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
| 147 |
+
return '\n'.join(chunk for chunk in chunks if chunk)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ErrorFormatter:
|
| 151 |
+
"""Consistent error message formatting"""
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def not_found(resource_type: str, identifier: str) -> str:
|
| 155 |
+
"""Format not found error"""
|
| 156 |
+
return f"No {resource_type} found with identifier: {identifier}"
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def no_results(query: str, time_period: str = "") -> str:
|
| 160 |
+
"""Format no results error"""
|
| 161 |
+
time_str = f" {time_period}" if time_period else ""
|
| 162 |
+
return f"No results found for query: {query}{time_str}"
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def validation_error(field: str, issue: str) -> str:
|
| 166 |
+
"""Format validation error"""
|
| 167 |
+
return f"Validation error: {field} - {issue}"
|
| 168 |
+
|
| 169 |
+
@staticmethod
|
| 170 |
+
def api_error(service: str, status_code: int) -> str:
|
| 171 |
+
"""Format API error"""
|
| 172 |
+
return f"Error: {service} API returned status {status_code}"
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def create_citation(
|
| 176 |
+
identifier: str,
|
| 177 |
+
identifier_type: str,
|
| 178 |
+
url: Optional[str] = None
|
| 179 |
+
) -> str:
|
| 180 |
+
"""
|
| 181 |
+
Create a formatted citation string
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
identifier: Citation identifier (PMID, DOI, NCT ID)
|
| 185 |
+
identifier_type: Type of identifier
|
| 186 |
+
url: Optional URL
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Formatted citation
|
| 190 |
+
"""
|
| 191 |
+
citation = f"{identifier_type}: {identifier}"
|
| 192 |
+
if url:
|
| 193 |
+
citation += f" | URL: {url}"
|
| 194 |
+
return citation
|
smart_cache.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Smart Cache System for ALS Research Agent
|
| 4 |
+
Features:
|
| 5 |
+
- Query normalization to match similar queries
|
| 6 |
+
- Cache pre-warming with common queries
|
| 7 |
+
- High-frequency question optimization
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import hashlib
|
| 12 |
+
import re
|
| 13 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 14 |
+
from datetime import datetime, timedelta
|
| 15 |
+
import asyncio
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SmartCache:
|
| 22 |
+
"""Advanced caching system with query normalization and pre-warming"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, cache_dir: str = ".cache", ttl_hours: int = 24):
|
| 25 |
+
"""
|
| 26 |
+
Initialize smart cache system.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
cache_dir: Directory for cache storage
|
| 30 |
+
ttl_hours: Time-to-live for cached entries in hours
|
| 31 |
+
"""
|
| 32 |
+
self.cache_dir = cache_dir
|
| 33 |
+
self.ttl = timedelta(hours=ttl_hours)
|
| 34 |
+
self.cache = {} # In-memory cache
|
| 35 |
+
self.normalized_cache = {} # Maps normalized queries to original cache keys
|
| 36 |
+
self.high_frequency_queries = {} # User-specified common queries
|
| 37 |
+
self.query_stats = {} # Track query frequency
|
| 38 |
+
|
| 39 |
+
# Ensure cache directory exists
|
| 40 |
+
import os
|
| 41 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
# Load persistent cache on init
|
| 44 |
+
self.load_cache()
|
| 45 |
+
|
| 46 |
+
def normalize_query(self, query: str) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Normalize query for better cache matching.
|
| 49 |
+
|
| 50 |
+
Handles variations like:
|
| 51 |
+
- "ALS gene therapy" vs "gene therapy ALS"
|
| 52 |
+
- "What are the latest trials" vs "what are latest trials"
|
| 53 |
+
- Different word orders, case, punctuation
|
| 54 |
+
"""
|
| 55 |
+
# Convert to lowercase
|
| 56 |
+
normalized = query.lower().strip()
|
| 57 |
+
|
| 58 |
+
# Remove common question words that don't affect meaning
|
| 59 |
+
question_words = [
|
| 60 |
+
'what', 'how', 'when', 'where', 'why', 'who', 'which',
|
| 61 |
+
'are', 'is', 'the', 'a', 'an', 'there', 'can', 'could',
|
| 62 |
+
'would', 'should', 'do', 'does', 'did', 'have', 'has', 'had'
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
# Remove punctuation
|
| 66 |
+
normalized = re.sub(r'[^\w\s]', ' ', normalized)
|
| 67 |
+
|
| 68 |
+
# Split into words and remove question words
|
| 69 |
+
words = normalized.split()
|
| 70 |
+
content_words = [w for w in words if w not in question_words]
|
| 71 |
+
|
| 72 |
+
# Sort words alphabetically for consistent ordering
|
| 73 |
+
# This makes "ALS gene therapy" match "gene therapy ALS"
|
| 74 |
+
content_words.sort()
|
| 75 |
+
|
| 76 |
+
# Join back together
|
| 77 |
+
normalized = ' '.join(content_words)
|
| 78 |
+
|
| 79 |
+
# Remove extra whitespace
|
| 80 |
+
normalized = ' '.join(normalized.split())
|
| 81 |
+
|
| 82 |
+
return normalized
|
| 83 |
+
|
| 84 |
+
def generate_cache_key(self, query: str, include_normalization: bool = True) -> str:
|
| 85 |
+
"""
|
| 86 |
+
Generate a cache key for a query.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
query: The original query
|
| 90 |
+
include_normalization: Whether to also store normalized version
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Hash-based cache key
|
| 94 |
+
"""
|
| 95 |
+
# Generate hash of original query
|
| 96 |
+
original_hash = hashlib.sha256(query.encode()).hexdigest()[:16]
|
| 97 |
+
|
| 98 |
+
if include_normalization:
|
| 99 |
+
# Also store mapping from normalized query to this cache key
|
| 100 |
+
normalized = self.normalize_query(query)
|
| 101 |
+
normalized_hash = hashlib.sha256(normalized.encode()).hexdigest()[:16]
|
| 102 |
+
|
| 103 |
+
# Store mapping for future lookups
|
| 104 |
+
if normalized_hash not in self.normalized_cache:
|
| 105 |
+
self.normalized_cache[normalized_hash] = []
|
| 106 |
+
if original_hash not in self.normalized_cache[normalized_hash]:
|
| 107 |
+
self.normalized_cache[normalized_hash].append(original_hash)
|
| 108 |
+
|
| 109 |
+
return original_hash
|
| 110 |
+
|
| 111 |
+
def find_similar_cached(self, query: str) -> Optional[Dict[str, Any]]:
|
| 112 |
+
"""
|
| 113 |
+
Find cached results for similar queries.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
query: The query to search for
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Cached result if found, None otherwise
|
| 120 |
+
"""
|
| 121 |
+
# First try exact match
|
| 122 |
+
exact_key = self.generate_cache_key(query, include_normalization=False)
|
| 123 |
+
if exact_key in self.cache:
|
| 124 |
+
entry = self.cache[exact_key]
|
| 125 |
+
if self._is_valid(entry):
|
| 126 |
+
logger.info(f"Cache hit (exact): {query[:50]}...")
|
| 127 |
+
self._update_stats(query)
|
| 128 |
+
return entry['result']
|
| 129 |
+
|
| 130 |
+
# Try normalized match
|
| 131 |
+
normalized = self.normalize_query(query)
|
| 132 |
+
normalized_key = hashlib.sha256(normalized.encode()).hexdigest()[:16]
|
| 133 |
+
|
| 134 |
+
if normalized_key in self.normalized_cache:
|
| 135 |
+
# Check all original queries that normalize to this
|
| 136 |
+
for original_key in self.normalized_cache[normalized_key]:
|
| 137 |
+
if original_key in self.cache:
|
| 138 |
+
entry = self.cache[original_key]
|
| 139 |
+
if self._is_valid(entry):
|
| 140 |
+
logger.info(f"Cache hit (normalized): {query[:50]}...")
|
| 141 |
+
self._update_stats(query)
|
| 142 |
+
return entry['result']
|
| 143 |
+
|
| 144 |
+
logger.info(f"Cache miss: {query[:50]}...")
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
def store(self, query: str, result: Any, metadata: Optional[Dict] = None):
|
| 148 |
+
"""
|
| 149 |
+
Store a query result in cache.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
query: The original query
|
| 153 |
+
result: The result to cache
|
| 154 |
+
metadata: Optional metadata about the result
|
| 155 |
+
"""
|
| 156 |
+
cache_key = self.generate_cache_key(query, include_normalization=True)
|
| 157 |
+
|
| 158 |
+
entry = {
|
| 159 |
+
'query': query,
|
| 160 |
+
'result': result,
|
| 161 |
+
'timestamp': datetime.now().isoformat(),
|
| 162 |
+
'metadata': metadata or {},
|
| 163 |
+
'access_count': 0
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
self.cache[cache_key] = entry
|
| 167 |
+
self._update_stats(query)
|
| 168 |
+
|
| 169 |
+
# Persist to disk asynchronously (non-blocking)
|
| 170 |
+
asyncio.create_task(self._save_cache_async())
|
| 171 |
+
|
| 172 |
+
logger.info(f"Cached result for: {query[:50]}...")
|
| 173 |
+
|
| 174 |
+
def _is_valid(self, entry: Dict) -> bool:
|
| 175 |
+
"""Check if a cache entry is still valid (not expired)"""
|
| 176 |
+
try:
|
| 177 |
+
timestamp = datetime.fromisoformat(entry['timestamp'])
|
| 178 |
+
age = datetime.now() - timestamp
|
| 179 |
+
return age < self.ttl
|
| 180 |
+
except:
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
def _update_stats(self, query: str):
|
| 184 |
+
"""Update query frequency statistics"""
|
| 185 |
+
normalized = self.normalize_query(query)
|
| 186 |
+
if normalized not in self.query_stats:
|
| 187 |
+
self.query_stats[normalized] = {'count': 0, 'last_access': None}
|
| 188 |
+
|
| 189 |
+
self.query_stats[normalized]['count'] += 1
|
| 190 |
+
self.query_stats[normalized]['last_access'] = datetime.now().isoformat()
|
| 191 |
+
|
| 192 |
+
async def pre_warm_cache(self, queries: List[Dict[str, Any]],
|
| 193 |
+
search_func=None, llm_func=None):
|
| 194 |
+
"""
|
| 195 |
+
Pre-warm cache with common queries.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
queries: List of dicts with 'query', 'search_terms', 'use_claude' keys
|
| 199 |
+
search_func: Async function to perform searches
|
| 200 |
+
llm_func: Async function to call Claude for high-priority queries
|
| 201 |
+
"""
|
| 202 |
+
logger.info(f"Pre-warming cache with {len(queries)} queries...")
|
| 203 |
+
|
| 204 |
+
for query_config in queries:
|
| 205 |
+
query = query_config['query']
|
| 206 |
+
|
| 207 |
+
# Check if already cached
|
| 208 |
+
if self.find_similar_cached(query):
|
| 209 |
+
logger.info(f"Already cached: {query}")
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
# Use optimized search terms if provided
|
| 214 |
+
search_terms = query_config.get('search_terms', query)
|
| 215 |
+
use_claude = query_config.get('use_claude', False)
|
| 216 |
+
|
| 217 |
+
if search_func:
|
| 218 |
+
# Perform search with optimized terms
|
| 219 |
+
logger.info(f"Pre-warming: {query}")
|
| 220 |
+
|
| 221 |
+
if use_claude and llm_func:
|
| 222 |
+
# Use Claude for high-priority queries
|
| 223 |
+
result = await llm_func(search_terms)
|
| 224 |
+
else:
|
| 225 |
+
# Use standard search
|
| 226 |
+
result = await search_func(search_terms)
|
| 227 |
+
|
| 228 |
+
# Cache the result
|
| 229 |
+
self.store(query, result, {
|
| 230 |
+
'pre_warmed': True,
|
| 231 |
+
'optimized_terms': search_terms,
|
| 232 |
+
'used_claude': use_claude
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
# Small delay to avoid overwhelming APIs
|
| 236 |
+
await asyncio.sleep(1)
|
| 237 |
+
|
| 238 |
+
except Exception as e:
|
| 239 |
+
logger.error(f"Failed to pre-warm cache for '{query}': {e}")
|
| 240 |
+
|
| 241 |
+
def add_high_frequency_query(self, query: str, config: Dict[str, Any]):
|
| 242 |
+
"""
|
| 243 |
+
Add a high-frequency query configuration.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
query: The query pattern
|
| 247 |
+
config: Configuration dict with search_terms, use_claude, etc.
|
| 248 |
+
"""
|
| 249 |
+
normalized = self.normalize_query(query)
|
| 250 |
+
self.high_frequency_queries[normalized] = {
|
| 251 |
+
'original': query,
|
| 252 |
+
'config': config,
|
| 253 |
+
'added': datetime.now().isoformat()
|
| 254 |
+
}
|
| 255 |
+
logger.info(f"Added high-frequency query: {query}")
|
| 256 |
+
|
| 257 |
+
def get_high_frequency_config(self, query: str) -> Optional[Dict[str, Any]]:
|
| 258 |
+
"""
|
| 259 |
+
Get configuration for a high-frequency query if it matches.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
query: The query to check
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Configuration dict if this is a high-frequency query
|
| 266 |
+
"""
|
| 267 |
+
normalized = self.normalize_query(query)
|
| 268 |
+
if normalized in self.high_frequency_queries:
|
| 269 |
+
return self.high_frequency_queries[normalized]['config']
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
| 273 |
+
"""Get cache statistics"""
|
| 274 |
+
valid_entries = sum(1 for entry in self.cache.values() if self._is_valid(entry))
|
| 275 |
+
total_entries = len(self.cache)
|
| 276 |
+
|
| 277 |
+
# Get top queries
|
| 278 |
+
top_queries = sorted(
|
| 279 |
+
self.query_stats.items(),
|
| 280 |
+
key=lambda x: x[1]['count'],
|
| 281 |
+
reverse=True
|
| 282 |
+
)[:10]
|
| 283 |
+
|
| 284 |
+
return {
|
| 285 |
+
'total_entries': total_entries,
|
| 286 |
+
'valid_entries': valid_entries,
|
| 287 |
+
'expired_entries': total_entries - valid_entries,
|
| 288 |
+
'normalized_groups': len(self.normalized_cache),
|
| 289 |
+
'high_frequency_queries': len(self.high_frequency_queries),
|
| 290 |
+
'top_queries': [
|
| 291 |
+
{'query': q, 'count': stats['count']}
|
| 292 |
+
for q, stats in top_queries
|
| 293 |
+
]
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
def clear_expired(self):
|
| 297 |
+
"""Remove expired entries from cache"""
|
| 298 |
+
expired_keys = [
|
| 299 |
+
key for key, entry in self.cache.items()
|
| 300 |
+
if not self._is_valid(entry)
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
for key in expired_keys:
|
| 304 |
+
del self.cache[key]
|
| 305 |
+
|
| 306 |
+
if expired_keys:
|
| 307 |
+
logger.info(f"Cleared {len(expired_keys)} expired cache entries")
|
| 308 |
+
self.save_cache()
|
| 309 |
+
|
| 310 |
+
def save_cache(self):
|
| 311 |
+
"""Persist cache to disk"""
|
| 312 |
+
cache_file = f"{self.cache_dir}/smart_cache.json"
|
| 313 |
+
try:
|
| 314 |
+
with open(cache_file, 'w') as f:
|
| 315 |
+
json.dump({
|
| 316 |
+
'cache': self.cache,
|
| 317 |
+
'normalized_cache': self.normalized_cache,
|
| 318 |
+
'high_frequency_queries': self.high_frequency_queries,
|
| 319 |
+
'query_stats': self.query_stats
|
| 320 |
+
}, f, indent=2)
|
| 321 |
+
logger.debug(f"Cache saved to {cache_file}")
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.error(f"Failed to save cache: {e}")
|
| 324 |
+
|
| 325 |
+
async def _save_cache_async(self):
|
| 326 |
+
"""Async version of save_cache that doesn't block"""
|
| 327 |
+
try:
|
| 328 |
+
await asyncio.to_thread(self.save_cache)
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.error(f"Failed to save cache asynchronously: {e}")
|
| 331 |
+
|
| 332 |
+
def load_cache(self):
|
| 333 |
+
"""Load cache from disk"""
|
| 334 |
+
cache_file = f"{self.cache_dir}/smart_cache.json"
|
| 335 |
+
try:
|
| 336 |
+
with open(cache_file, 'r') as f:
|
| 337 |
+
data = json.load(f)
|
| 338 |
+
self.cache = data.get('cache', {})
|
| 339 |
+
self.normalized_cache = data.get('normalized_cache', {})
|
| 340 |
+
self.high_frequency_queries = data.get('high_frequency_queries', {})
|
| 341 |
+
self.query_stats = data.get('query_stats', {})
|
| 342 |
+
|
| 343 |
+
# Clear expired entries on load
|
| 344 |
+
self.clear_expired()
|
| 345 |
+
|
| 346 |
+
logger.info(f"Loaded cache with {len(self.cache)} entries")
|
| 347 |
+
except FileNotFoundError:
|
| 348 |
+
logger.info("No existing cache file found")
|
| 349 |
+
except Exception as e:
|
| 350 |
+
logger.error(f"Failed to load cache: {e}")
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# Configuration for common ALS queries to pre-warm
|
| 354 |
+
DEFAULT_PREWARM_QUERIES = [
|
| 355 |
+
{
|
| 356 |
+
'query': 'What are the latest ALS treatments?',
|
| 357 |
+
'search_terms': 'ALS treatment therapy 2024 riluzole edaravone',
|
| 358 |
+
'use_claude': True # High-frequency, use Claude for best results
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
'query': 'Gene therapy for ALS',
|
| 362 |
+
'search_terms': 'ALS gene therapy SOD1 C9orf72 clinical trial',
|
| 363 |
+
'use_claude': True
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
'query': 'ALS clinical trials',
|
| 367 |
+
'search_terms': 'ALS clinical trials recruiting phase 2 phase 3',
|
| 368 |
+
'use_claude': False
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
'query': 'What causes ALS?',
|
| 372 |
+
'search_terms': 'ALS etiology pathogenesis genetic environmental factors',
|
| 373 |
+
'use_claude': True
|
| 374 |
+
},
|
| 375 |
+
{
|
| 376 |
+
'query': 'ALS symptoms and diagnosis',
|
| 377 |
+
'search_terms': 'ALS symptoms diagnosis EMG criteria El Escorial',
|
| 378 |
+
'use_claude': False
|
| 379 |
+
},
|
| 380 |
+
{
|
| 381 |
+
'query': 'Stem cell therapy for ALS',
|
| 382 |
+
'search_terms': 'ALS stem cell therapy mesenchymal clinical trial',
|
| 383 |
+
'use_claude': False
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
'query': 'ALS prognosis and life expectancy',
|
| 387 |
+
'search_terms': 'ALS prognosis survival life expectancy factors',
|
| 388 |
+
'use_claude': True
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
'query': 'New ALS drugs',
|
| 392 |
+
'search_terms': 'ALS new drugs FDA approved pipeline 2024',
|
| 393 |
+
'use_claude': False
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
'query': 'ALS biomarkers',
|
| 397 |
+
'search_terms': 'ALS biomarkers neurofilament TDP-43 diagnostic prognostic',
|
| 398 |
+
'use_claude': False
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
'query': 'Is there a cure for ALS?',
|
| 402 |
+
'search_terms': 'ALS cure breakthrough research treatment advances',
|
| 403 |
+
'use_claude': True
|
| 404 |
+
}
|
| 405 |
+
]
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def test_smart_cache():
|
| 409 |
+
"""Test the smart cache functionality"""
|
| 410 |
+
print("Testing Smart Cache System")
|
| 411 |
+
print("=" * 60)
|
| 412 |
+
|
| 413 |
+
cache = SmartCache()
|
| 414 |
+
|
| 415 |
+
# Test query normalization
|
| 416 |
+
test_queries = [
|
| 417 |
+
("What are the latest ALS gene therapy trials?", "ALS gene therapy trials"),
|
| 418 |
+
("gene therapy ALS", "ALS gene therapy"),
|
| 419 |
+
("What is ALS?", "ALS"),
|
| 420 |
+
("HOW does riluzole work for ALS?", "ALS riluzole work"),
|
| 421 |
+
]
|
| 422 |
+
|
| 423 |
+
print("\n1. Query Normalization Tests:")
|
| 424 |
+
for original, expected_words in test_queries:
|
| 425 |
+
normalized = cache.normalize_query(original)
|
| 426 |
+
print(f" Original: {original}")
|
| 427 |
+
print(f" Normalized: {normalized}")
|
| 428 |
+
print(f" Expected words present: {all(w in normalized for w in expected_words.lower().split())}")
|
| 429 |
+
print()
|
| 430 |
+
|
| 431 |
+
# Test similar query matching
|
| 432 |
+
print("\n2. Similar Query Matching:")
|
| 433 |
+
cache.store("What are the latest ALS treatments?", {"result": "Treatment data"})
|
| 434 |
+
|
| 435 |
+
similar_queries = [
|
| 436 |
+
"latest ALS treatments",
|
| 437 |
+
"ALS latest treatments",
|
| 438 |
+
"What are latest treatments for ALS?",
|
| 439 |
+
"treatments ALS latest"
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
for query in similar_queries:
|
| 443 |
+
result = cache.find_similar_cached(query)
|
| 444 |
+
print(f" Query: {query}")
|
| 445 |
+
print(f" Found: {result is not None}")
|
| 446 |
+
|
| 447 |
+
# Test cache statistics
|
| 448 |
+
print("\n3. Cache Statistics:")
|
| 449 |
+
stats = cache.get_cache_stats()
|
| 450 |
+
print(f" Total entries: {stats['total_entries']}")
|
| 451 |
+
print(f" Valid entries: {stats['valid_entries']}")
|
| 452 |
+
print(f" Normalized groups: {stats['normalized_groups']}")
|
| 453 |
+
|
| 454 |
+
print("\n✅ Smart cache tests completed!")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
if __name__ == "__main__":
|
| 458 |
+
test_smart_cache()
|