axegameon commited on
Commit
3e435ad
·
verified ·
1 Parent(s): 6260c2f

Upload ALSARA app files (#1)

Browse files

- Upload ALSARA app files (2076784d2a441a06677136e33fa32aa2e9bacb72)

.env.example ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ALS Research Agent Environment Configuration
2
+
3
+ # Anthropic API Key (Required)
4
+ ANTHROPIC_API_KEY=your_anthropic_api_key_here
5
+
6
+ # Optional: Specify which Anthropic model to use
7
+ # Available models: claude-sonnet-4-5-20250929, claude-opus-3-20240229, claude-3-haiku-20240307
8
+ # Default: claude-sonnet-4-5-20250929
9
+ ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
10
+
11
+ # Optional: Gradio server configuration
12
+ GRADIO_SERVER_PORT=7860 # Default port for Gradio UI
13
+
14
+ # ElevenLabs Configuration (Optional - for voice capabilities)
15
+ ELEVENLABS_API_KEY=your_elevenlabs_api_key_here
16
+ ELEVENLABS_VOICE_ID=21m00Tcm4TlvDq8ikWAM # Rachel voice (clear for medical terms)
17
+
18
+ # LlamaIndex RAG Configuration (Optional - for research memory)
19
+ CHROMA_DB_PATH=./chroma_db # Path to persist vector database
20
+ LLAMAINDEX_EMBED_MODEL=dmis-lab/biobert-base-cased-v1.2 # Biomedical embedding model
21
+ LLAMAINDEX_CHUNK_SIZE=1024 # Text chunk size for indexing
22
+ LLAMAINDEX_CHUNK_OVERLAP=200 # Overlap between chunks
23
+
24
+ # Optional: Show agent thinking process in UI
25
+ SHOW_THINKING=false
26
+
27
+ # Optional: LLM provider preference
28
+ # Options: quality_optimize (best model), cost_optimize (cheaper model), auto (default)
29
+ LLM_PROVIDER_PREFERENCE=auto
30
+
31
+ # Research API Configuration (Optional)
32
+ # Configure these if you want to limit API usage
33
+ RATE_LIMIT_PUBMED_DELAY=1.0 # Delay between PubMed requests (seconds)
34
+ RATE_LIMIT_BIORXIV_DELAY=1.0 # Delay between bioRxiv requests (seconds)
35
+
36
+ # Optional: Max concurrent searches
37
+ MAX_CONCURRENT_SEARCHES=3
README.md CHANGED
@@ -1,14 +1,474 @@
1
  ---
2
- title: ALSARA
3
- emoji: 🔥
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.0.1
8
- app_file: app.py
9
- pinned: false
10
  license: mit
11
- short_description: ALSARA is an agentic research assistant for ALS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ALSARA - ALS Agentic Research Agent
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: "6.0.0"
8
+ app_file: als_agent_app.py
 
9
  license: mit
10
+ short_description: AI research assistant for ALS research & trials
11
+ pinned: false
12
+ sponsors: Sambanova, Anthropic, ElevenLabs, LlamaIndex
13
+ tags:
14
+ - mcp-in-action-track-consumer
15
+ ---
16
+
17
+
18
+ # ALSARA - ALS Agentic Research Agent
19
+
20
+ ALSARA (ALS Agentic Research Assistant) is an AI-powered research tool that intelligently orchestrates multiple biomedical databases to answer complex questions about ALS (Amyotrophic Lateral Sclerosis) research, treatments, and clinical trials in real-time.
21
+
22
+ Built with a 4-phase agentic workflow (Planning → Executing → Reflecting → Synthesis), ALSARA searches PubMed, 559,000+ clinical trials via AACT database, and provides voice accessibility for ALS patients - delivering comprehensive research in 5-10 seconds.
23
+
24
+ Built with Model Context Protocol (MCP), Gradio 6.x, and Anthropic Claude.
25
+
26
+ ## Key Features
27
+
28
+ ### Core Capabilities
29
+ - **4-Phase Agentic Workflow**: Intelligent planning, parallel execution, reflection with gap-filling, and comprehensive synthesis
30
+ - **Real-time Literature Search**: Query millions of PubMed peer-reviewed papers
31
+ - **Clinical Trial Discovery**: Access 559,000+ trials from AACT PostgreSQL database (primary) with ClinicalTrials.gov fallback
32
+ - **Voice Accessibility**: Text-to-speech using ElevenLabs for ALS patients with limited mobility
33
+ - **Smart Caching**: Query normalization with 24-hour TTL for instant similar query responses
34
+ - **Parallel Tool Execution**: 70% faster responses by running all searches simultaneously
35
+
36
+ ### Advanced Features
37
+ - **Multi-Provider LLM Support**: Claude primary with SambaNova Llama 3.3 70B fallback
38
+ - **Query Classification**: Smart routing between simple answers and complex research
39
+ - **Rate Limiting**: 30 requests/minute per user with exponential backoff
40
+ - **Memory Management**: Automatic conversation truncation and garbage collection
41
+ - **Health Monitoring**: Uptime tracking, error rates, and tool usage statistics
42
+ - **Citation Tracking**: All responses include PMIDs, DOIs, NCT IDs, and source references
43
+ - **Web Scraping**: Fetch full-text articles with SSRF protection
44
+ - **Export Conversations**: Download chat history as markdown files
45
+
46
+ ## Architecture
47
+
48
+ The system uses a sophisticated multi-layer architecture:
49
+
50
+ ### 1. User Interface Layer
51
+ - **Gradio 6.x** web application with chat interface
52
+ - Real-time streaming responses
53
+ - Voice output controls
54
+ - Export and retry functionality
55
+
56
+ ### 2. Agentic Orchestration Layer
57
+ **4-Phase Workflow:**
58
+ 1. **PLANNING**: Agent strategizes which databases to query
59
+ 2. **EXECUTING**: Parallel searches across all data sources
60
+ 3. **REFLECTING**: Evaluates results, identifies gaps, runs additional searches
61
+ 4. **SYNTHESIS**: Comprehensive answer with citations and confidence scoring
62
+
63
+ ### 3. LLM Provider Layer
64
+ - **Primary**: Anthropic Claude (claude-sonnet-4-5-20250929)
65
+ - **Fallback**: SambaNova Llama 3.3 70B (free alternative)
66
+ - Smart routing based on query complexity
67
+
68
+ ### 4. MCP Server Layer
69
+ Each server runs as a separate subprocess with JSON-RPC communication:
70
+
71
+ - **aact-server**: Primary clinical trials database (559,000+ trials)
72
+ - **pubmed-server**: PubMed literature search
73
+ - **fetch-server**: Web scraping with security hardening
74
+ - **elevenlabs-server**: Voice synthesis for accessibility
75
+ - **clinicaltrials_links**: Fallback trial links when AACT unavailable
76
+ - **llamaindex-server**: RAG/semantic search (optional)
77
+
78
+ **Technical Note:** Uses custom MCP client (`custom_mcp_client.py`) to bypass SDK bugs with proper async/await handling, line-buffered I/O, and automatic retry logic.
79
+
80
+ ## Available Tools
81
+
82
+ The agent has access to specialized tools across 6 MCP servers:
83
+
84
+ ### AACT Clinical Trials Database Tools (PRIMARY)
85
+
86
+ #### 1. `aact__search_aact_trials`
87
+ Search 559,000+ clinical trials from the AACT PostgreSQL database.
88
+
89
+ **Parameters:**
90
+ - `condition` (string, optional): Medical condition (default: "ALS")
91
+ - `status` (string, optional): Trial status - "recruiting", "active", "completed", "all"
92
+ - `intervention` (string, optional): Treatment/drug name
93
+ - `sponsor` (string, optional): Trial sponsor organization
94
+ - `phase` (string, optional): Trial phase (1, 2, 3, 4)
95
+ - `max_results` (integer, optional): Maximum results (default: 10)
96
+
97
+ **Returns:** Comprehensive trial data with NCT IDs, titles, status, phases, enrollment, and locations.
98
+
99
+ #### 2. `aact__get_aact_trial`
100
+ Get complete details for a specific clinical trial.
101
+
102
+ **Parameters:**
103
+ - `nct_id` (string, required): ClinicalTrials.gov NCT ID
104
+
105
+ **Returns:** Full trial information including eligibility, outcomes, interventions, and contacts.
106
+
107
  ---
108
 
109
+ ### PubMed Literature Tools
110
+
111
+ #### 3. `pubmed__search_pubmed`
112
+ Search PubMed for peer-reviewed research papers.
113
+
114
+ **Parameters:**
115
+ - `query` (string, required): Search query (e.g., "ALS SOD1 therapy")
116
+ - `max_results` (integer, optional): Maximum results (default: 10)
117
+ - `sort` (string, optional): Sort by "relevance" or "date"
118
+
119
+ **Returns:** Papers with titles, abstracts, PMIDs, authors, and publication dates.
120
+
121
+ #### 4. `pubmed__get_paper_details`
122
+ Get complete details for a specific PubMed paper.
123
+
124
+ **Parameters:**
125
+ - `pmid` (string, required): PubMed ID
126
+
127
+ **Returns:** Full paper information including abstract, journal, DOI, and PubMed URL.
128
+
129
+ ---
130
+
131
+ ### Web Fetching Tools
132
+
133
+ #### 5. `fetch__fetch_url`
134
+ Fetch and extract content from web URLs with security hardening.
135
+
136
+ **Parameters:**
137
+ - `url` (string, required): URL to fetch
138
+ - `extract_text_only` (boolean, optional): Extract only text content (default: true)
139
+
140
+ **Returns:** Extracted webpage content with SSRF protection.
141
+
142
+ ---
143
+
144
+ ### Voice Accessibility Tools
145
+
146
+ #### 6. `elevenlabs__text_to_speech`
147
+ Convert research findings to audio for accessibility.
148
+
149
+ **Parameters:**
150
+ - `text` (string, required): Text to convert (max 2500 chars)
151
+ - `voice_id` (string, optional): Voice selection (default: Rachel - medical-friendly)
152
+ - `speed` (number, optional): Speech speed (0.5-2.0)
153
+
154
+ **Returns:** Audio stream for playback.
155
+
156
+ ---
157
+
158
+ ### Fallback Tools
159
+
160
+ #### 7. `clinicaltrials_links__get_known_als_trials`
161
+ Returns curated list of important ALS trials when AACT is unavailable.
162
+
163
+ #### 8. `clinicaltrials_links__get_search_link`
164
+ Generates direct ClinicalTrials.gov search URLs.
165
+
166
+ ---
167
+
168
+ ### Tool Usage Notes
169
+
170
+ - **Rate Limiting**: All tools respect API rate limits (PubMed: 3 req/sec)
171
+ - **Caching**: Results cached for 24 hours with smart query normalization
172
+ - **Connection Pooling**: AACT uses async PostgreSQL with 2-10 connections
173
+ - **Timeout Protection**: 90-second timeout with automatic retry
174
+ - **Security**: SSRF protection, input validation, content size limits
175
+
176
+ ## Quick Start
177
+
178
+ ### Prerequisites
179
+
180
+ - Python 3.10+ (3.12 recommended)
181
+ - Anthropic API key
182
+ - Git
183
+
184
+ ### Installation
185
+
186
+ 1. Clone the repository
187
+
188
+ ```bash
189
+ git clone https://github.com/yourusername/als-research-agent.git
190
+ cd als-research-agent
191
+ ```
192
+
193
+ 2. Create virtual environment
194
+
195
+ ```bash
196
+ python3.12 -m venv venv
197
+ source venv/bin/activate # On Windows: venv\Scripts\activate
198
+ ```
199
+
200
+ 3. Install dependencies
201
+
202
+ ```bash
203
+ pip install -r requirements.txt
204
+ ```
205
+
206
+ 4. Set up environment variables
207
+
208
+ Create a `.env` file:
209
+
210
+ ```bash
211
+ # Required
212
+ ANTHROPIC_API_KEY=sk-ant-xxx
213
+
214
+ # Recommended
215
+ ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
216
+ ELEVENLABS_API_KEY=xxx
217
+ ELEVENLABS_VOICE_ID=21m00Tcm4TlvDq8ikWAM # Rachel voice
218
+
219
+ # Optional Features
220
+ ENABLE_RAG=false # Enable semantic search (requires setup)
221
+ USE_FALLBACK_LLM=true # Enable free SambaNova fallback
222
+ DISABLE_CACHE=false # Disable smart caching
223
+
224
+ # Configuration
225
+ GRADIO_SERVER_PORT=7860
226
+ MAX_CONCURRENT_SEARCHES=3
227
+ RATE_LIMIT_PUBMED_DELAY=1.0
228
+ ```
229
+
230
+ 5. Run the application
231
+
232
+ ```bash
233
+ python als_agent_app.py
234
+ ```
235
+
236
+ or
237
+
238
+ ```bash
239
+ ./venv/bin/python3.12 als_agent_app.py 2>&1
240
+ ```
241
+
242
+ The app will launch at http://localhost:7860
243
+
244
+ ## Project Structure
245
+
246
+ ```
247
+ als-research-agent/
248
+ ├── README.md
249
+ ├── requirements.txt
250
+ ├── .env.example
251
+ ├── als_agent_app.py # Main Gradio application (1835 lines)
252
+ ├── custom_mcp_client.py # Custom MCP client implementation
253
+ ├── llm_client.py # Multi-provider LLM abstraction
254
+ ├── query_classifier.py # Research vs simple query detection
255
+ ├── smart_cache.py # Query normalization and caching
256
+ ├── refactored_helpers.py # Streaming and tool execution
257
+ ├── parallel_tool_execution.py # Concurrent search management
258
+ ├── servers/
259
+ │ ├── aact_server.py # AACT clinical trials database (PRIMARY)
260
+ │ ├── pubmed_server.py # PubMed literature search
261
+ │ ├── fetch_server.py # Web scraping with security
262
+ │ ├── elevenlabs_server.py # Voice synthesis
263
+ │ ├── clinicaltrials_links.py # Fallback trial links
264
+ │ └── llamaindex_server.py # RAG/semantic search (optional)
265
+ ├── shared/
266
+ │ ├── __init__.py
267
+ │ ├── config.py # Centralized configuration
268
+ │ ├── cache.py # TTL-based caching
269
+ │ └── utils.py # Rate limiting and formatting
270
+ └── tests/
271
+ ├── test_pubmed_server.py
272
+ ├── test_aact_server.py
273
+ ├── test_fetch_server.py
274
+ ├── test_elevenlabs.py
275
+ ├── test_integration.py
276
+ ├── test_llm_client.py
277
+ ├── test_performance.py
278
+ └── test_workflow_*.py
279
+ ```
280
+
281
+ ## Usage Examples
282
+
283
+ ### Example Queries
284
+
285
+ **Complex Research Questions:**
286
+ - "What are the latest gene therapy trials for SOD1 mutations with recent biomarker data?"
287
+ - "Compare antisense oligonucleotide therapies in Phase 2 or 3 trials"
288
+ - "Find recent PubMed papers on ALS protein aggregation from Japanese researchers"
289
+
290
+ **Clinical Trial Discovery:**
291
+ - "Active trials in Germany for bulbar onset ALS"
292
+ - "Recruiting trials for ALS patients under 40 with slow progression"
293
+ - "Phase 3 trials sponsored by Biogen or Ionis"
294
+
295
+ **Treatment Information:**
296
+ - "Compare efficacy of riluzole, edaravone, and AMX0035"
297
+ - "What combination therapies showed promise in 2024?"
298
+ - "Latest developments in stem cell therapy for ALS"
299
+
300
+ **Accessibility Features:**
301
+ - Click the voice icon to hear research summaries
302
+ - Adjustable speech speed for comfort
303
+ - Medical-friendly voice optimized for clarity
304
+
305
+ ## Performance Characteristics
306
+
307
+ - **Typical Response Time**: 5-10 seconds for complex queries
308
+ - **Parallel Speedup**: 70% faster than sequential searching
309
+ - **Cache Hit Time**: <100ms for similar queries (24-hour TTL)
310
+ - **Concurrent Handling**: 4 requests in ~8 seconds
311
+ - **Tool Call Timeout**: 90 seconds with automatic retry
312
+ - **Memory Limit**: 50 messages per conversation (~8-50KB per message)
313
+
314
+ ## Development
315
+
316
+ ### Running Tests
317
+
318
+ ```bash
319
+ # All tests
320
+ pytest tests/ -v
321
+
322
+ # Unit tests only
323
+ pytest tests/ -m "not integration"
324
+
325
+ # With coverage
326
+ pytest --cov=servers --cov-report=html
327
+
328
+ # Quick tests
329
+ ./run_quick_tests.sh
330
+ ```
331
+
332
+ ### Adding New MCP Servers
333
+
334
+ 1. Create new server file in `servers/`
335
+ 2. Use FastMCP API to implement tools:
336
+
337
+ ```python
338
+ from mcp.server.fastmcp import FastMCP
339
+
340
+ mcp = FastMCP("my-server")
341
+
342
+ @mcp.tool()
343
+ async def my_tool(param: str) -> str:
344
+ """Tool description"""
345
+ return f"Result: {param}"
346
+
347
+ if __name__ == "__main__":
348
+ mcp.run(transport="stdio")
349
+ ```
350
+
351
+ 3. Add server to `als_agent_app.py` in `setup_mcp_servers()`
352
+ 4. Write tests in `tests/`
353
+
354
+ ## Deployment
355
+
356
+ ### Hugging Face Spaces
357
+
358
+ 1. Create a Gradio Space
359
+ 2. Push your code
360
+ 3. Add secrets:
361
+ - `ANTHROPIC_API_KEY` (required)
362
+ - `ELEVENLABS_API_KEY` (for voice features)
363
+
364
+ ### Docker
365
+
366
+ ```bash
367
+ docker build -t als-research-agent .
368
+ docker run -p 7860:7860 \
369
+ -e ANTHROPIC_API_KEY=your_key \
370
+ -e ELEVENLABS_API_KEY=your_key \
371
+ als-research-agent
372
+ ```
373
+
374
+ ### Cloud Deployment (Azure/AWS/GCP)
375
+
376
+ The application is containerized and ready for deployment on any cloud platform supporting Docker containers. See deployment guides for specific platforms.
377
+
378
+ ## Troubleshooting
379
+
380
+ **MCP server not responding**
381
+ - Check Python path and virtual environment activation
382
+ - Verify all dependencies installed: `pip install -r requirements.txt`
383
+
384
+ **Rate limit exceeded**
385
+ - Add delays between requests
386
+ - Check Anthropic API quota
387
+ - Use `USE_FALLBACK_LLM=true` for free alternative
388
+
389
+ **Voice synthesis not working**
390
+ - Verify `ELEVENLABS_API_KEY` is set
391
+ - Check API quota at ElevenLabs dashboard
392
+ - Text may be too long (max 2500 chars)
393
+
394
+ **AACT database connection issues**
395
+ - Database may be under maintenance (Sunday 7 AM ET)
396
+ - Fallback to `clinicaltrials_links` server activates automatically
397
+
398
+ **Cache not working**
399
+ - Check `DISABLE_CACHE` is not set to true
400
+ - Verify `.cache/` directory has write permissions
401
+
402
+ ## Resources
403
+
404
+ ### ALS Research Organizations
405
+ - ALS Association: https://www.als.org/
406
+ - ALS Therapy Development Institute: https://www.als.net/
407
+ - Answer ALS Data Portal: https://dataportal.answerals.org/
408
+ - International Alliance of ALS/MND Associations: https://www.als-mnd.org/
409
+
410
+ ### Data Sources
411
+ - PubMed E-utilities: https://www.ncbi.nlm.nih.gov/books/NBK25501/
412
+ - AACT Database: https://aact.ctti-clinicaltrials.org/
413
+ - ClinicalTrials.gov: https://clinicaltrials.gov/
414
+
415
+ ### Technologies
416
+ - Model Context Protocol: https://modelcontextprotocol.io/
417
+ - Gradio Documentation: https://www.gradio.app/docs/
418
+ - Anthropic Claude: https://www.anthropic.com/
419
+ - ElevenLabs API: https://elevenlabs.io/
420
+
421
+ ## Security & Privacy
422
+
423
+ - **No Patient Data Storage**: Conversations are not permanently stored
424
+ - **SSRF Protection**: Blocks access to private IPs and localhost
425
+ - **Input Validation**: Injection pattern detection and length limits
426
+ - **Rate Limiting**: Per-user request throttling
427
+ - **API Key Security**: All keys stored as environment variables
428
+
429
+ ## License
430
+
431
+ MIT License - See LICENSE file for details
432
+
433
+ ## Contributing
434
+
435
+ 1. Fork the repository
436
+ 2. Create a feature branch (`git checkout -b feature/amazing-feature`)
437
+ 3. Write tests for your changes
438
+ 4. Ensure all tests pass (`pytest`)
439
+ 5. Commit your changes (`git commit -m 'Add amazing feature'`)
440
+ 6. Push to the branch (`git push origin feature/amazing-feature`)
441
+ 7. Open a Pull Request
442
+
443
+ ## Future Enhancements
444
+
445
+ ### In Development
446
+ - **NCBI Gene Database**: Gene information and mutations
447
+ - **OMIM Integration**: Genetic disorder phenotypes
448
+ - **Protein Data Bank**: 3D protein structures
449
+ - **AlphaFold Database**: AI-predicted protein structures
450
+
451
+ ### Planned Features
452
+ - **Voice Input**: Speech recognition for queries
453
+ - **Patient Trial Matching**: Personalized eligibility assessment
454
+ - **Research Trend Analysis**: Track emerging themes
455
+ - **Alert System**: Notifications for new trials/papers
456
+ - **Enhanced Export**: BibTeX, CSV, PDF formats
457
+ - **Multi-language Support**: Global accessibility
458
+ - **Drug Repurposing Module**: Identify potential ALS treatments
459
+ - **arXiv Integration**: Computational biology papers
460
+
461
+ ## Acknowledgments
462
+
463
+ Built for the global ALS research community to accelerate the path to a cure.
464
+
465
+ Special thanks to:
466
+ - The MCP team for the Model Context Protocol
467
+ - Anthropic for Claude AI
468
+ - The open-source community for invaluable contributions
469
+
470
+ ---
471
+
472
+ **ALSARA - Accelerating ALS research, one query at a time.**
473
+
474
+ For questions, issues, or contributions, please open an issue on GitHub.
als_agent_app.py ADDED
@@ -0,0 +1,1832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # als_agent_app.py
2
+ import gradio as gr
3
+ import asyncio
4
+ import json
5
+ import os
6
+ import logging
7
+ from pathlib import Path
8
+ from datetime import datetime, timedelta
9
+ import sys
10
+ import time
11
+ from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator, Union
12
+ from collections import defaultdict
13
+ from dotenv import load_dotenv
14
+ import httpx
15
+ import base64
16
+ import tempfile
17
+ import re
18
+
19
+ # Load environment variables from .env file
20
+ load_dotenv()
21
+
22
+ # Add current directory to path for shared imports
23
+ sys.path.insert(0, str(Path(__file__).parent))
24
+ from shared import SimpleCache
25
+ from custom_mcp_client import MCPClientManager
26
+ from llm_client import UnifiedLLMClient
27
+ from smart_cache import SmartCache, DEFAULT_PREWARM_QUERIES
28
+
29
+ # Helper function imports for refactored code
30
+ from refactored_helpers import (
31
+ stream_with_retry,
32
+ execute_tool_calls,
33
+ build_assistant_message
34
+ )
35
+
36
+ # Configure logging
37
+ logging.basicConfig(
38
+ level=logging.INFO,
39
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
40
+ handlers=[
41
+ logging.StreamHandler(),
42
+ logging.FileHandler('app.log', mode='a', encoding='utf-8')
43
+ ]
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Rate Limiter Class
48
+ class RateLimiter:
49
+ """Rate limiter to prevent API overload"""
50
+
51
+ def __init__(self, max_requests_per_minute: int = 30):
52
+ self.max_requests_per_minute = max_requests_per_minute
53
+ self.request_times = defaultdict(list)
54
+
55
+ async def check_rate_limit(self, key: str = "default") -> bool:
56
+ """Check if request is within rate limit"""
57
+ now = datetime.now()
58
+ minute_ago = now - timedelta(minutes=1)
59
+
60
+ # Clean old requests
61
+ self.request_times[key] = [
62
+ t for t in self.request_times[key]
63
+ if t > minute_ago
64
+ ]
65
+
66
+ # Check if under limit
67
+ if len(self.request_times[key]) >= self.max_requests_per_minute:
68
+ return False
69
+
70
+ # Record this request
71
+ self.request_times[key].append(now)
72
+ return True
73
+
74
+ async def wait_if_needed(self, key: str = "default"):
75
+ """Wait if rate limit exceeded"""
76
+ while not await self.check_rate_limit(key):
77
+ await asyncio.sleep(2) # Wait 2 seconds before retry
78
+
79
+ # Initialize rate limiter
80
+ rate_limiter = RateLimiter(max_requests_per_minute=30)
81
+
82
+ # Memory management settings
83
+ MAX_CONVERSATION_LENGTH = 50 # Maximum messages to keep in history
84
+ MEMORY_CLEANUP_INTERVAL = 300 # Cleanup every 5 minutes
85
+
86
+ async def cleanup_memory():
87
+ """Periodic memory cleanup task"""
88
+ while True:
89
+ try:
90
+ # Clean up expired cache entries
91
+ tool_cache.cleanup_expired()
92
+ smart_cache.cleanup() if smart_cache else None
93
+
94
+ # Force garbage collection for large cleanups
95
+ import gc
96
+ collected = gc.collect()
97
+ if collected > 0:
98
+ logger.debug(f"Memory cleanup: collected {collected} objects")
99
+
100
+ except Exception as e:
101
+ logger.error(f"Error during memory cleanup: {e}")
102
+
103
+ await asyncio.sleep(MEMORY_CLEANUP_INTERVAL)
104
+
105
+ # Start memory cleanup task
106
+ cleanup_task = None
107
+
108
+ # Track whether last response used research workflow (for voice button)
109
+ last_response_was_research = False
110
+
111
+ # Health monitoring
112
+ class HealthMonitor:
113
+ """Monitor system health and performance"""
114
+
115
+ def __init__(self):
116
+ self.start_time = datetime.now()
117
+ self.request_count = 0
118
+ self.error_count = 0
119
+ self.tool_call_count = defaultdict(int)
120
+ self.response_times = []
121
+ self.last_error = None
122
+
123
+ def record_request(self):
124
+ self.request_count += 1
125
+
126
+ def record_error(self, error: str):
127
+ self.error_count += 1
128
+ self.last_error = {"time": datetime.now(), "error": str(error)[:500]}
129
+
130
+ def record_tool_call(self, tool_name: str):
131
+ self.tool_call_count[tool_name] += 1
132
+
133
+ def record_response_time(self, duration: float):
134
+ self.response_times.append(duration)
135
+ # Keep only last 100 response times to avoid memory buildup
136
+ if len(self.response_times) > 100:
137
+ self.response_times = self.response_times[-100:]
138
+
139
+ def get_health_status(self) -> Dict[str, Any]:
140
+ """Get current health status"""
141
+ uptime = (datetime.now() - self.start_time).total_seconds()
142
+ avg_response_time = sum(self.response_times) / len(self.response_times) if self.response_times else 0
143
+
144
+ return {
145
+ "status": "healthy" if self.error_count < 10 else "degraded",
146
+ "uptime_seconds": uptime,
147
+ "request_count": self.request_count,
148
+ "error_count": self.error_count,
149
+ "error_rate": self.error_count / max(1, self.request_count),
150
+ "avg_response_time": avg_response_time,
151
+ "cache_size": tool_cache.size(),
152
+ "rate_limit_status": f"{len(rate_limiter.request_times)} active keys",
153
+ "most_used_tools": dict(sorted(self.tool_call_count.items(), key=lambda x: x[1], reverse=True)[:5]),
154
+ "last_error": self.last_error
155
+ }
156
+
157
+ # Initialize health monitor
158
+ health_monitor = HealthMonitor()
159
+
160
+ # Error message formatter
161
+ def format_error_message(error: Exception, context: str = "") -> str:
162
+ """Format error messages with helpful suggestions"""
163
+
164
+ error_str = str(error)
165
+ error_type = type(error).__name__
166
+
167
+ # Common error patterns and suggestions
168
+ if "timeout" in error_str.lower():
169
+ suggestion = """
170
+ **Suggestions:**
171
+ - Try simplifying your search query
172
+ - Break complex questions into smaller parts
173
+ - Check your internet connection
174
+ - The service may be temporarily overloaded - try again in a moment
175
+ """
176
+ elif "rate limit" in error_str.lower():
177
+ suggestion = """
178
+ **Suggestions:**
179
+ - Wait a moment before trying again
180
+ - Reduce the number of simultaneous searches
181
+ - Consider using cached results when available
182
+ """
183
+ elif "connection" in error_str.lower() or "network" in error_str.lower():
184
+ suggestion = """
185
+ **Suggestions:**
186
+ - Check your internet connection
187
+ - The external service may be temporarily unavailable
188
+ - Try again in a few moments
189
+ """
190
+ elif "invalid" in error_str.lower() or "validation" in error_str.lower():
191
+ suggestion = """
192
+ **Suggestions:**
193
+ - Check your query for special characters or formatting issues
194
+ - Ensure your question is clear and well-formed
195
+ - Avoid using HTML or script tags in your query
196
+ """
197
+ elif "memory" in error_str.lower() or "resource" in error_str.lower():
198
+ suggestion = """
199
+ **Suggestions:**
200
+ - The system may be under heavy load
201
+ - Try a simpler query
202
+ - Clear your browser cache and refresh the page
203
+ """
204
+ else:
205
+ suggestion = """
206
+ **Suggestions:**
207
+ - Try rephrasing your question
208
+ - Break complex queries into simpler parts
209
+ - If the error persists, please report it to support
210
+ """
211
+
212
+ formatted = f"""
213
+ ❌ **Error Encountered**
214
+
215
+ **Type:** {error_type}
216
+ **Details:** {error_str[:500]}
217
+ {f"**Context:** {context}" if context else ""}
218
+
219
+ {suggestion}
220
+
221
+ **Need Help?**
222
+ - Try the example queries in the sidebar
223
+ - Check the System Health tab for service status
224
+ - Report persistent issues on GitHub
225
+ """
226
+
227
+ return formatted.strip()
228
+
229
+ # Initialize the unified LLM client
230
+ # All provider logic is now handled inside UnifiedLLMClient
231
+ client = None # Initialize to None for proper cleanup handling
232
+ try:
233
+ client = UnifiedLLMClient()
234
+ logger.info(f"LLM client initialized: {client.get_provider_display_name()}")
235
+ except ValueError as e:
236
+ # Re-raise configuration errors with clear instructions
237
+ logger.error(f"LLM configuration error: {e}")
238
+ raise
239
+
240
+ # Global MCP client manager
241
+ mcp_manager = MCPClientManager()
242
+
243
+ # Internal thinking tags are always filtered for cleaner output
244
+
245
+ # Model configuration
246
+ # Use Claude 3.5 Sonnet with correct model ID that works with the API key
247
+ ANTHROPIC_MODEL = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5-20250929")
248
+ logger.info(f"Using model: {ANTHROPIC_MODEL}")
249
+
250
+ # Configuration for max tokens in responses
251
+ # Set MAX_RESPONSE_TOKENS in .env to control response length
252
+ # Claude 3.5 Sonnet supports up to 8192 tokens
253
+ MAX_RESPONSE_TOKENS = min(int(os.getenv("MAX_RESPONSE_TOKENS", "8192")), 8192)
254
+ logger.info(f"Max response tokens set to: {MAX_RESPONSE_TOKENS}")
255
+
256
+ # Global smart cache (24 hour TTL for research queries)
257
+ smart_cache = SmartCache(cache_dir=".cache", ttl_hours=24)
258
+
259
+ # Keep tool cache for MCP tool results
260
+ tool_cache = SimpleCache(ttl=3600)
261
+
262
+ # Cache for tool definitions to avoid repeated fetching
263
+ _cached_tools = None
264
+ _tools_cache_time = None
265
+ TOOLS_CACHE_TTL = 86400 # 24 hour cache for tool definitions (tools rarely change)
266
+
267
+ async def setup_mcp_servers() -> MCPClientManager:
268
+ """Initialize all MCP servers using custom client"""
269
+ logger.info("Setting up MCP servers...")
270
+
271
+ # Get the directory where this script is located
272
+ script_dir = Path(__file__).parent.resolve()
273
+ servers_dir = script_dir / "servers"
274
+
275
+ logger.info(f"Script directory: {script_dir}")
276
+ logger.info(f"Servers directory: {servers_dir}")
277
+
278
+ # Verify servers directory exists
279
+ if not servers_dir.exists():
280
+ logger.error(f"Servers directory not found: {servers_dir}")
281
+ raise FileNotFoundError(f"Servers directory not found: {servers_dir}")
282
+
283
+ # Add all servers to manager
284
+ servers = {
285
+ "pubmed": servers_dir / "pubmed_server.py",
286
+ "aact": servers_dir / "aact_server.py", # PRIMARY: AACT database for comprehensive clinical trials data
287
+ "trials_links": servers_dir / "clinicaltrials_links.py", # FALLBACK: Direct links and known ALS trials
288
+ "fetch": servers_dir / "fetch_server.py",
289
+ "elevenlabs": servers_dir / "elevenlabs_server.py", # Voice capabilities for accessibility
290
+ }
291
+
292
+ # bioRxiv temporarily disabled - commenting out to hide from users
293
+ # enable_biorxiv = os.getenv("ENABLE_BIORXIV", "true").lower() == "true"
294
+ # if enable_biorxiv:
295
+ # servers["biorxiv"] = servers_dir / "biorxiv_server.py"
296
+ # else:
297
+ # logger.info("⚠️ bioRxiv/medRxiv disabled for faster searches (set ENABLE_BIORXIV=true to enable)")
298
+
299
+ # Conditionally add LlamaIndex RAG based on environment variable
300
+ enable_rag = os.getenv("ENABLE_RAG", "false").lower() == "true"
301
+ if enable_rag:
302
+ logger.info("📚 RAG/LlamaIndex enabled (will add ~10s to startup for semantic search)")
303
+ servers["llamaindex"] = servers_dir / "llamaindex_server.py"
304
+ else:
305
+ logger.info("🚀 RAG/LlamaIndex disabled for faster startup (set ENABLE_RAG=true to enable)")
306
+
307
+ # Parallelize server initialization for faster startup
308
+ async def init_server(name: str, script_path: Path):
309
+ try:
310
+ await mcp_manager.add_server(name, str(script_path))
311
+ logger.info(f"✓ MCP server {name} initialized")
312
+ except Exception as e:
313
+ logger.error(f"Failed to initialize MCP server {name}: {e}")
314
+ raise
315
+
316
+ # Start all servers concurrently
317
+ tasks = [init_server(name, script_path) for name, script_path in servers.items()]
318
+ results = await asyncio.gather(*tasks, return_exceptions=True)
319
+
320
+ # Check for any failures
321
+ for i, result in enumerate(results):
322
+ if isinstance(result, Exception):
323
+ name = list(servers.keys())[i]
324
+ logger.error(f"Failed to initialize MCP server {name}: {result}")
325
+ raise result
326
+
327
+ logger.info("All MCP servers initialized successfully")
328
+ return mcp_manager
329
+
330
+ async def cleanup_mcp_servers() -> None:
331
+ """Cleanup MCP server sessions"""
332
+ logger.info("Cleaning up MCP server sessions...")
333
+ await mcp_manager.close_all()
334
+ logger.info("MCP cleanup complete")
335
+
336
+
337
+ def export_conversation(history: Optional[List[Any]]) -> Optional[Path]:
338
+ """Export conversation to markdown format"""
339
+ if not history:
340
+ return None
341
+
342
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
343
+ filename = f"als_conversation_{timestamp}.md"
344
+
345
+ content = f"""# ALS Research Conversation
346
+ **Exported:** {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
347
+
348
+ ---
349
+
350
+ """
351
+
352
+ for i, (user_msg, assistant_msg) in enumerate(history, 1):
353
+ content += f"## Query {i}\n\n**User:** {user_msg}\n\n**Assistant:**\n{assistant_msg}\n\n---\n\n"
354
+
355
+ content += f"""
356
+ *Generated by ALSARA - ALS Agentic Research Agent*
357
+ *Total interactions: {len(history)}*
358
+ """
359
+
360
+ filepath = Path(filename)
361
+ filepath.write_text(content, encoding='utf-8')
362
+ logger.info(f"Exported conversation to {filename}")
363
+
364
+ return filepath
365
+
366
+ async def get_all_tools() -> List[Dict[str, Any]]:
367
+ """Retrieve all available tools from MCP servers with caching"""
368
+ global _cached_tools, _tools_cache_time
369
+
370
+ # Check if cache is valid
371
+ if _cached_tools and _tools_cache_time:
372
+ if time.time() - _tools_cache_time < TOOLS_CACHE_TTL:
373
+ logger.debug("Using cached tool definitions")
374
+ return _cached_tools
375
+
376
+ # Fetch fresh tool definitions
377
+ logger.info("Fetching fresh tool definitions from MCP servers")
378
+ all_tools = []
379
+
380
+ # Get tools from all servers
381
+ server_tools = await mcp_manager.list_all_tools()
382
+
383
+ for server_name, tools in server_tools.items():
384
+ for tool in tools:
385
+ # Convert MCP tool to Anthropic function format
386
+ all_tools.append({
387
+ "name": f"{server_name}__{tool['name']}",
388
+ "description": tool.get('description', ''),
389
+ "input_schema": tool.get('inputSchema', {})
390
+ })
391
+
392
+ # Update cache
393
+ _cached_tools = all_tools
394
+ _tools_cache_time = time.time()
395
+ logger.info(f"Cached {len(all_tools)} tool definitions")
396
+
397
+ return all_tools
398
+
399
+ async def call_mcp_tool(tool_name: str, arguments: Dict[str, Any], max_retries: int = 3) -> str:
400
+ """Execute an MCP tool call with caching, rate limiting, retry logic, and error handling"""
401
+
402
+ # Check cache first (no retries needed for cached results)
403
+ cached_result = tool_cache.get(tool_name, arguments)
404
+ if cached_result:
405
+ return cached_result
406
+
407
+ last_error = None
408
+
409
+ for attempt in range(max_retries):
410
+ try:
411
+ # Apply rate limiting
412
+ await rate_limiter.wait_if_needed(tool_name.split("__")[0])
413
+
414
+ # Parse tool name
415
+ if "__" not in tool_name:
416
+ logger.error(f"Invalid tool name format: {tool_name}")
417
+ return f"Error: Invalid tool name format: {tool_name}"
418
+
419
+ server_name, tool_method = tool_name.split("__", 1)
420
+
421
+ if attempt > 0:
422
+ logger.info(f"Retry {attempt}/{max_retries} for tool: {tool_method} on server: {server_name}")
423
+ else:
424
+ logger.info(f"Calling tool: {tool_method} on server: {server_name}")
425
+
426
+ # Call tool with timeout using custom client
427
+ result = await asyncio.wait_for(
428
+ mcp_manager.call_tool(server_name, tool_method, arguments),
429
+ timeout=90.0 # 90 second timeout for complex tool calls (BioRxiv searches can be slow)
430
+ )
431
+
432
+ # Result is already a string from custom client
433
+ final_result = result if result else "No content returned from tool"
434
+
435
+ # Cache the result
436
+ tool_cache.set(tool_name, arguments, final_result)
437
+
438
+ # Record successful tool call
439
+ health_monitor.record_tool_call(tool_name)
440
+
441
+ return final_result
442
+
443
+ except asyncio.TimeoutError as e:
444
+ last_error = e
445
+ logger.warning(f"Tool call timed out (attempt {attempt + 1}/{max_retries}): {tool_name}")
446
+ if attempt < max_retries - 1:
447
+ await asyncio.sleep(2 ** attempt) # Exponential backoff: 1s, 2s, 4s
448
+ continue
449
+ # Last attempt failed
450
+ timeout_error = TimeoutError(f"Tool timeout after {max_retries} attempts - the {server_name} server may be overloaded")
451
+ return format_error_message(timeout_error, context=f"Calling {tool_name}")
452
+
453
+ except ValueError as e:
454
+ logger.error(f"Invalid tool/server: {tool_name} - {e}")
455
+ return format_error_message(e, context=f"Invalid tool: {tool_name}")
456
+
457
+ except Exception as e:
458
+ last_error = e
459
+ logger.warning(f"Error calling tool {tool_name} (attempt {attempt + 1}/{max_retries}): {e}")
460
+ if attempt < max_retries - 1:
461
+ await asyncio.sleep(2 ** attempt) # Exponential backoff
462
+ continue
463
+ # Last attempt failed
464
+ return format_error_message(e, context=f"Tool {tool_name} failed after {max_retries} attempts")
465
+
466
+ # Should not reach here, but handle just in case
467
+ if last_error:
468
+ return f"Tool failed after {max_retries} attempts: {str(last_error)[:200]}"
469
+ return "Unexpected error in tool execution"
470
+
471
+ def filter_internal_tags(text: str) -> str:
472
+ """Remove all internal processing tags from the output."""
473
+ import re
474
+
475
+ # Remove internal tags and their content with single regex
476
+ text = re.sub(r'<(thinking|search_quality_reflection|search_quality_score)>.*?</\1>|<(thinking|search_quality_reflection|search_quality_score)>.*$', '', text, flags=re.DOTALL)
477
+
478
+ # Remove wrapper tags but keep content
479
+ text = re.sub(r'</?(result|answer)>', '', text)
480
+
481
+ # Fix phase formatting - ensure consistent formatting
482
+ # Add proper line breaks around phase headers
483
+ # First normalize any existing phase markers to be on their own line
484
+ phase_patterns = [
485
+ # Fix incorrect formats (missing asterisks) first
486
+ (r'(?<!\*)🎯\s*PLANNING:(?!\*)', r'**🎯 PLANNING:**'),
487
+ (r'(?<!\*)🔧\s*EXECUTING:(?!\*)', r'**🔧 EXECUTING:**'),
488
+ (r'(?<!\*)🤔\s*REFLECTING:(?!\*)', r'**🤔 REFLECTING:**'),
489
+ (r'(?<!\*)✅\s*SYNTHESIS:(?!\*)', r'**✅ SYNTHESIS:**'),
490
+
491
+ # Then ensure the markers are on new lines (if not already)
492
+ (r'(?<!\n)(\*\*🎯\s*PLANNING:\*\*)', r'\n\n\1'),
493
+ (r'(?<!\n)(\*\*🔧\s*EXECUTING:\*\*)', r'\n\n\1'),
494
+ (r'(?<!\n)(\*\*🤔\s*REFLECTING:\*\*)', r'\n\n\1'),
495
+ (r'(?<!\n)(\*\*✅\s*SYNTHESIS:\*\*)', r'\n\n\1'),
496
+
497
+ # Then add spacing after them
498
+ (r'(\*\*🎯\s*PLANNING:\*\*)', r'\1\n'),
499
+ (r'(\*\*🔧\s*EXECUTING:\*\*)', r'\1\n'),
500
+ (r'(\*\*🤔\s*REFLECTING:\*\*)', r'\1\n'),
501
+ (r'(\*\*✅\s*SYNTHESIS:\*\*)', r'\1\n'),
502
+ ]
503
+
504
+ for pattern, replacement in phase_patterns:
505
+ text = re.sub(pattern, replacement, text)
506
+
507
+ # Clean up excessive whitespace while preserving intentional formatting
508
+ text = re.sub(r'[ \t]+', ' ', text) # Multiple spaces to single space
509
+ text = re.sub(r'\n{4,}', '\n\n\n', text) # Maximum 3 newlines
510
+ text = re.sub(r'^\n+', '', text) # Remove leading newlines
511
+ text = re.sub(r'\n+$', '\n', text) # Single trailing newline
512
+
513
+ return text.strip()
514
+
515
+ def is_complex_query(message: str) -> bool:
516
+ """Detect complex queries that might need more iterations"""
517
+ complex_indicators = [
518
+ "genotyping", "genetic testing", "multiple", "comprehensive",
519
+ "all", "compare", "versus", "difference between", "systematic",
520
+ "gene-targeted", "gene targeted", "list the main", "what are all",
521
+ "complete overview", "detailed analysis", "in-depth"
522
+ ]
523
+ return any(indicator in message.lower() for indicator in complex_indicators)
524
+
525
+
526
+ def validate_query(message: str) -> Tuple[bool, str]:
527
+ """Validate and sanitize user input to prevent injection and abuse"""
528
+ # Check length
529
+ if not message or not message.strip():
530
+ return False, "Please enter a query"
531
+
532
+ if len(message) > 2000:
533
+ return False, "Query too long (maximum 2000 characters). Please shorten your question."
534
+
535
+ # Check for potential injection patterns
536
+ suspicious_patterns = [
537
+ r'<script', r'javascript:', r'onclick', r'onerror',
538
+ r'\bignore\s+previous\s+instructions\b',
539
+ r'\bsystem\s+prompt\b',
540
+ r'\bforget\s+everything\b',
541
+ r'\bdisregard\s+all\b'
542
+ ]
543
+
544
+ for pattern in suspicious_patterns:
545
+ if re.search(pattern, message, re.IGNORECASE):
546
+ logger.warning(f"Suspicious pattern detected in query: {pattern}")
547
+ return False, "Invalid query format. Please rephrase your question."
548
+
549
+ # Check for excessive repetition (potential spam)
550
+ words = message.lower().split()
551
+ if len(words) > 10:
552
+ # Check if any word appears too frequently
553
+ word_freq = {}
554
+ for word in words:
555
+ word_freq[word] = word_freq.get(word, 0) + 1
556
+
557
+ max_freq = max(word_freq.values())
558
+ if max_freq > len(words) * 0.5: # If any word is more than 50% of the query
559
+ return False, "Query appears to contain excessive repetition. Please rephrase."
560
+
561
+ return True, ""
562
+
563
+
564
+ async def als_research_agent(message: str, history: Optional[List[Dict[str, Any]]]) -> AsyncGenerator[str, None]:
565
+ """Main agent logic with streaming response and error handling"""
566
+
567
+ global last_response_was_research
568
+
569
+ start_time = time.time()
570
+ health_monitor.record_request()
571
+
572
+ try:
573
+ # Validate input first
574
+ valid, error_msg = validate_query(message)
575
+ if not valid:
576
+ yield f"⚠️ **Input Validation Error:** {error_msg}"
577
+ return
578
+
579
+ logger.info(f"Received valid query: {message[:100]}...") # Log first 100 chars
580
+
581
+ # Truncate history to prevent memory bloat
582
+ if history and len(history) > MAX_CONVERSATION_LENGTH:
583
+ logger.info(f"Truncating conversation history from {len(history)} to {MAX_CONVERSATION_LENGTH} messages")
584
+ history = history[-MAX_CONVERSATION_LENGTH:]
585
+
586
+ # System prompt
587
+ base_prompt = """You are ALSARA, an expert ALS (Amyotrophic Lateral Sclerosis) research assistant with agentic capabilities for planning, execution, and reflection.
588
+
589
+ CRITICAL CONTEXT: ALL queries should be interpreted in the context of ALS research unless explicitly stated otherwise.
590
+
591
+ MANDATORY SEARCH QUERY RULES:
592
+ 1. ALWAYS include "ALS" or "amyotrophic lateral sclerosis" in EVERY search query
593
+ 2. If the user's query doesn't mention ALS, ADD IT to your search terms
594
+ 3. This prevents irrelevant results from other conditions
595
+
596
+ Examples:
597
+ - User: "genotyping for gene targeted treatments" → Search: "genotyping ALS gene targeted treatments"
598
+ - User: "psilocybin clinical trials" → Search: "psilocybin ALS clinical trials"
599
+ - User: "stem cell therapy" → Search: "stem cell therapy ALS"
600
+ - User: "gene therapy trials" → Search: "gene therapy ALS trials"
601
+
602
+ Your capabilities:
603
+ - Search PubMed for peer-reviewed research papers
604
+ - Find active clinical trials in the AACT database"""
605
+
606
+ # Add RAG capability only if enabled
607
+ enable_rag = os.getenv("ENABLE_RAG", "false").lower() == "true"
608
+ if enable_rag:
609
+ base_prompt += """
610
+ - **Semantic search using RAG**: Instantly search cached ALS research papers using AI-powered semantic matching"""
611
+
612
+ base_prompt += """
613
+ - Fetch and analyze web content
614
+ - Synthesize information from multiple sources
615
+ - Provide citations with PMIDs, DOIs, and NCT IDs
616
+
617
+ === AGENTIC WORKFLOW (REQUIRED) ===
618
+
619
+ You MUST follow ALL FOUR phases for EVERY query - no exceptions:
620
+
621
+ 1. **🎯 PLANNING PHASE** (MANDATORY - ALWAYS FIRST):
622
+ Before using any tools, you MUST explicitly outline your search strategy:"""
623
+
624
+ if enable_rag:
625
+ base_prompt += """
626
+ - FIRST check semantic cache using RAG for instant results from indexed papers"""
627
+
628
+ base_prompt += """
629
+ - State what databases you will search and in what order
630
+ - ALWAYS plan to search PubMed for peer-reviewed research
631
+ - For clinical questions, also include AACT trials database
632
+ - Identify key search terms and variations
633
+ - Explain your prioritization approach
634
+ - Format: MUST start on a NEW LINE with "**🎯 PLANNING:**" followed by your strategy
635
+
636
+ 2. **🔧 EXECUTION PHASE** (MANDATORY - AFTER PLANNING):
637
+ - MUST mark this phase on a NEW LINE with "**🔧 EXECUTING:**"
638
+ - Execute your planned searches systematically"""
639
+
640
+ if enable_rag:
641
+ base_prompt += """
642
+ - START with semantic search using RAG for instant cached results"""
643
+
644
+ base_prompt += """
645
+ - MINIMUM requirement: Search PubMed for peer-reviewed literature
646
+ - For clinical questions, search AACT trials database
647
+ - Gather initial results from each source
648
+ - Show tool calls and results
649
+ - This phase is for INITIAL searches only (as planned)
650
+
651
+ 3. **🤔 REFLECTION PHASE** (MANDATORY - AFTER EXECUTION):
652
+ After tool execution, you MUST ALWAYS reflect before synthesizing:
653
+
654
+ CRITICAL FORMAT REQUIREMENTS:
655
+ - MUST be EXACTLY: **🤔 REFLECTING:**
656
+ - MUST include the asterisks (**) for bold formatting
657
+ - MUST start on a NEW LINE (never inline with other text)
658
+ - WRONG: "🤔 REFLECTING:" (missing asterisks)
659
+ - WRONG: "search completed🤔 REFLECTING:" (inline, not on new line)
660
+ - CORRECT: New line, then **🤔 REFLECTING:**
661
+
662
+ Content requirements:
663
+ - Evaluate: "Do I have sufficient information to answer comprehensively?"
664
+ - Identify gaps: "What aspects of the query remain unaddressed?"
665
+ - Decide: "Should I refine my search or proceed to synthesis?"
666
+
667
+ CRITICAL: If you need more searches:
668
+ - DO NOT start a new PLANNING phase
669
+ - DO NOT write new phase markers
670
+ - Stay WITHIN the REFLECTION phase
671
+ - Simply continue searching and analyzing while in REFLECTING mode
672
+ - Additional searches are part of reflection, not a new workflow
673
+
674
+ - NEVER skip this phase - it ensures answer quality
675
+
676
+ 4. **✅ SYNTHESIS PHASE** (MANDATORY - FINAL PHASE):
677
+ - MUST start on a NEW LINE with "**✅ SYNTHESIS:**"
678
+ - Provide comprehensive synthesis of all findings
679
+ - Include all citations with URLs
680
+ - Summarize key insights
681
+ - **CONFIDENCE SCORING**: Include confidence level for key claims:
682
+ • High confidence (🟢): Multiple peer-reviewed studies or systematic reviews
683
+ • Moderate confidence (🟡): Limited studies or preprints with consistent findings
684
+ • Low confidence (🔴): Single study, conflicting evidence, or theoretical basis
685
+ - This phase MUST appear in EVERY response
686
+
687
+ FORMATTING RULES:
688
+ - Each phase marker MUST appear on its own line
689
+ - Never put phase markers inline with other text
690
+ - Always use the exact format: **[emoji] PHASE_NAME:**
691
+ - MUST include asterisks for bold: **🤔 REFLECTING:** not just 🤔 REFLECTING:
692
+ - Each phase should appear EXACTLY ONCE per response - never repeat the workflow
693
+
694
+ CRITICAL WORKFLOW RULES:
695
+ - You MUST include ALL FOUR phases in your response
696
+ - Each phase appears EXACTLY ONCE (never repeat Planning→Executing→Reflecting→Synthesis)
697
+ - Missing any phase is unacceptable
698
+ - Duplicating phases is unacceptable
699
+ - The workflow is a SINGLE CYCLE:
700
+ 1. PLANNING (once at start)
701
+ 2. EXECUTING (initial searches)
702
+ 3. REFLECTING (evaluate AND do additional searches if needed - all within this phase)
703
+ 4. SYNTHESIS (final answer)
704
+ - NEVER restart the workflow - additional searches happen WITHIN reflection
705
+
706
+ CRITICAL SYNTHESIS RULES:
707
+ - You MUST ALWAYS end with a ✅ SYNTHESIS phase
708
+ - If searches fail, state "Despite search limitations..." and provide knowledge-based answer
709
+ - If you reach iteration limits, immediately provide synthesis
710
+ - NEVER end without synthesis - this is a MANDATORY requirement
711
+ - If uncertain, start synthesis with: "Based on available information..."
712
+
713
+ SYNTHESIS MUST INCLUDE:
714
+ 1. Direct answer to the user's question
715
+ 2. Key findings from successful searches (if any)
716
+ 3. Citations with clickable URLs
717
+ 4. If searches failed: explanation + knowledge-based answer
718
+ 5. Suggested follow-up questions or alternative approaches
719
+
720
+ === SELF-CORRECTION BEHAVIOR ===
721
+
722
+ If your searches return zero or insufficient results:
723
+ - Try broader search terms (remove qualifiers)
724
+ - Try alternative terminology or synonyms
725
+ - Search for related concepts
726
+ - Explicitly state what you tried and what you found
727
+
728
+ When answering:
729
+ 1. Be concise in explanations while maintaining clarity
730
+ 2. Focus on presenting search results efficiently
731
+ 3. Always cite sources with specific identifiers AND URLs:
732
+ - PubMed: Include PMID and URL (https://pubmed.ncbi.nlm.nih.gov/PMID/)
733
+ - Preprints: Include DOI and URL (https://doi.org/DOI)
734
+ - Clinical Trials: Include NCT ID and URL (https://clinicaltrials.gov/study/NCTID)
735
+ 4. Use numbered citations [1], [2] with a references section at the end
736
+ 5. Prioritize recent research (2023-2025)
737
+ 6. When discussing preprints, note they are NOT peer-reviewed
738
+ 7. Explain complex concepts clearly
739
+ 8. Acknowledge uncertainty when appropriate
740
+ 9. Suggest related follow-up questions
741
+
742
+ CRITICAL CITATION RULES:
743
+ - ONLY cite papers, preprints, and trials that you have ACTUALLY found using the search tools
744
+ - NEVER make up or invent citations, PMIDs, DOIs, or NCT IDs
745
+ - NEVER cite papers from your training data unless you have verified them through search
746
+ - If you cannot find specific research on a topic, explicitly state "No studies found" rather than inventing citations
747
+ - Every citation must come from actual search results obtained through the available tools
748
+ - If asked about a topic you know from training but haven't searched, you MUST search first before citing
749
+
750
+ IMPORTANT: When referencing papers in your final answer, ALWAYS include clickable URLs alongside citations to make it easy for users to access the sources.
751
+
752
+ Available tools:
753
+ - pubmed__search_pubmed: Search peer-reviewed research literature
754
+ - pubmed__get_paper_details: Get full paper details from PubMed (USE SPARINGLY - only for most relevant papers)
755
+ # - biorxiv__search_preprints: (temporarily unavailable)
756
+ # - biorxiv__get_preprint_details: (temporarily unavailable)
757
+ - aact__search_aact_trials: Search clinical trials (PRIMARY - use this first)
758
+ - aact__get_aact_trial: Get specific trial details from AACT database
759
+ - trials_links__get_known_als_trials: Get curated list of important ALS trials (FALLBACK)
760
+ - trials_links__get_search_link: Generate direct ClinicalTrials.gov search URLs
761
+ - fetch__fetch_url: Retrieve web content
762
+
763
+ PERFORMANCE OPTIMIZATION:
764
+ - Search results already contain abstracts - use these for initial synthesis
765
+ - Only fetch full details for papers that are DIRECTLY relevant to the query
766
+ - Limit detail fetches to 5-7 most relevant items per database
767
+ - Prioritize based on: recency, relevance to query, impact/importance
768
+
769
+ Search strategy:
770
+ 1. Search all relevant databases (PubMed, AACT clinical trials)
771
+ 2. ALWAYS supplement with web fetching to:
772
+ - Find additional information not in databases
773
+ - Access sponsor/institution websites
774
+ - Get recent news and updates
775
+ - Retrieve full-text content when needed
776
+ - Verify and expand on database results
777
+ 3. Synthesize all sources for comprehensive answers
778
+
779
+ For clinical trials - NEW ARCHITECTURE:
780
+ PRIMARY SOURCE - AACT Database:
781
+ - Use search_aact_trials FIRST - provides comprehensive clinical trials data from AACT database
782
+ - 559,000+ trials available with no rate limits
783
+ - Use uppercase status values: RECRUITING, ACTIVE_NOT_RECRUITING, NOT_YET_RECRUITING, COMPLETED
784
+ - For ALS searches, the condition "ALS" will automatically match related terms
785
+
786
+ FALLBACK - Links Server (when AACT unavailable):
787
+ - Use get_known_als_trials for curated list of 8 important ALS trials
788
+ - Use get_search_link to generate search URLs for clinical trials
789
+ - Use get_trial_link to generate direct links to specific trials
790
+
791
+ ADDITIONAL SOURCES:
792
+ - If specific NCT IDs are mentioned, can also use fetch__fetch_url with:
793
+ https://clinicaltrials.gov/study/{NCT_ID}
794
+ - Search sponsor websites, medical news, and university pages for updates
795
+
796
+ ARCHITECTURE FLOW:
797
+ User Query → AACT Database (Primary)
798
+
799
+ If AACT unavailable
800
+
801
+ Links Server (Fallback)
802
+
803
+ Direct links to trial websites
804
+
805
+ Note: Direct API access is unavailable - using AACT database instead
806
+ """
807
+
808
+ # Add enhanced instructions for Llama models to improve thoroughness
809
+ if client.is_using_llama_primary():
810
+ llama_enhancement = """
811
+
812
+ ENHANCED SEARCH REQUIREMENTS FOR COMPREHENSIVE RESULTS:
813
+ You MUST follow this structured approach for EVERY research query:
814
+
815
+ === MANDATORY SEARCH PHASES ===
816
+ Phase 1 - Comprehensive Database Search (ALL databases REQUIRED):
817
+ □ Search PubMed with multiple keyword variations
818
+ □ Search AACT database for clinical trials
819
+ □ Use at least 3-5 different search queries per database
820
+
821
+ Phase 2 - Strategic Detail Fetching (BE SELECTIVE):
822
+ □ Get paper details for the TOP 5-7 most relevant PubMed results
823
+ □ Get trial details for the TOP 3-4 most relevant clinical trials
824
+ □ ONLY fetch details for papers that are DIRECTLY relevant to the query
825
+ □ Use search result abstracts to prioritize which papers need full details
826
+
827
+ Phase 3 - Synthesis Requirements:
828
+ □ Include ALL relevant papers found (not just top 3-5)
829
+ □ Organize by subtopic or treatment approach
830
+ □ Provide complete citations with URLs
831
+
832
+ MINIMUM SEARCH STANDARDS:
833
+ - For general queries: At least 10-15 total searches across all databases
834
+ - For specific treatments: At least 5-7 searches per database
835
+ - For comprehensive reviews: At least 15-20 total searches
836
+ - NEVER stop after finding just 2-3 results
837
+
838
+ EXAMPLE SEARCH PATTERN for "gene therapy ALS":
839
+ 1. pubmed__search_pubmed: "gene therapy ALS"
840
+ 2. pubmed__search_pubmed: "AAV ALS treatment"
841
+ 3. pubmed__search_pubmed: "SOD1 gene therapy"
842
+ 4. pubmed__search_pubmed: "C9orf72 gene therapy"
843
+ 5. pubmed__search_pubmed: "viral vector ALS"
844
+ # 6. biorxiv__search_preprints: (temporarily unavailable)
845
+ # 7. biorxiv__search_preprints: (temporarily unavailable)
846
+ 6. aact__search_aact_trials: condition="ALS", intervention="gene therapy"
847
+ 7. aact__search_aact_trials: condition="ALS", intervention="AAV"
848
+ 10. [Get details for ALL results found]
849
+ 11. [Web fetch for recent developments]
850
+
851
+ CRITICAL: Thoroughness is MORE important than speed. Users expect comprehensive results."""
852
+
853
+ system_prompt = base_prompt + llama_enhancement
854
+ logger.info("Using enhanced prompting for Llama model to improve search thoroughness")
855
+ else:
856
+ # Use base prompt directly for Claude
857
+ system_prompt = base_prompt
858
+
859
+ # Import query classifier
860
+ from query_classifier import QueryClassifier
861
+
862
+ # Classify the query to determine processing mode
863
+ classification = QueryClassifier.classify_query(message)
864
+ processing_hint = QueryClassifier.get_processing_hint(classification)
865
+ logger.info(f"Query classification: {classification}")
866
+
867
+ # Check smart cache for similar queries first
868
+ cached_result = smart_cache.find_similar_cached(message)
869
+ if cached_result:
870
+ logger.info(f"Smart cache hit for query: {message[:50]}...")
871
+ yield "🎯 **Using cached result** (similar query found)\n\n"
872
+ yield cached_result
873
+ return
874
+
875
+ # Check if this is a high-frequency query with special config
876
+ high_freq_config = smart_cache.get_high_frequency_config(message)
877
+ if high_freq_config:
878
+ logger.info(f"High-frequency query detected with config: {high_freq_config}")
879
+ # Note: We could use optimized search terms or Claude here
880
+ # For now, just log it and continue with normal processing
881
+
882
+ # Get available tools
883
+ tools = await get_all_tools()
884
+
885
+ # Check if this is a simple query that doesn't need research
886
+ if not classification['requires_research']:
887
+ # Simple query - skip the full research workflow
888
+ logger.info(f"Simple query detected - using direct response mode: {classification['reason']}")
889
+
890
+ # Mark that this response won't use research workflow (disable voice button)
891
+ global last_response_was_research
892
+ last_response_was_research = False
893
+
894
+ # Use a simplified prompt for non-research queries
895
+ simple_prompt = """You are an AI assistant for ALS research questions.
896
+ For this query, provide a helpful, conversational response without using research tools.
897
+ Keep your response friendly and informative."""
898
+
899
+ # For simple queries, just make one API call without tools
900
+ messages = [
901
+ {"role": "system", "content": simple_prompt},
902
+ {"role": "user", "content": message}
903
+ ]
904
+
905
+ # Display processing hint
906
+ yield f"{processing_hint}\n\n"
907
+
908
+ # Single API call for simple response (no tools)
909
+ async for response_text, tool_calls, provider_used in stream_with_retry(
910
+ client=client,
911
+ messages=messages,
912
+ tools=None, # No tools for simple queries
913
+ system_prompt=simple_prompt,
914
+ max_retries=2,
915
+ model=ANTHROPIC_MODEL,
916
+ max_tokens=2000, # Shorter responses for simple queries
917
+ stream_name="simple response"
918
+ ):
919
+ yield response_text
920
+
921
+ # Return early - skip all the research phases
922
+ return
923
+
924
+ # Research query - use full workflow with tools
925
+ logger.info(f"Research query detected - using full workflow: {classification['reason']}")
926
+
927
+ # Mark that this response will use research workflow (enable voice button)
928
+ last_response_was_research = True
929
+ yield f"{processing_hint}\n\n"
930
+
931
+ # Build messages for research workflow
932
+ messages = [
933
+ {"role": "system", "content": system_prompt}
934
+ ]
935
+
936
+ # Add history (remove Gradio metadata)
937
+ if history:
938
+ # Only keep 'role' and 'content' fields from messages
939
+ for msg in history:
940
+ if isinstance(msg, dict):
941
+ messages.append({
942
+ "role": msg.get("role"),
943
+ "content": msg.get("content")
944
+ })
945
+ else:
946
+ messages.append(msg)
947
+
948
+ # Add current message
949
+ messages.append({"role": "user", "content": message})
950
+
951
+ # Initial API call with streaming using helper function
952
+ full_response = ""
953
+ tool_calls = []
954
+
955
+ # Use the stream_with_retry helper to handle all retry logic
956
+ provider_used = "Anthropic Claude" # Track which provider
957
+ async for response_text, current_tool_calls, provider_used in stream_with_retry(
958
+ client=client,
959
+ messages=messages,
960
+ tools=tools,
961
+ system_prompt=system_prompt,
962
+ max_retries=2, # Increased from 0 to allow retries
963
+ model=ANTHROPIC_MODEL,
964
+ max_tokens=MAX_RESPONSE_TOKENS,
965
+ stream_name="initial API call"
966
+ ):
967
+ full_response = response_text
968
+ tool_calls = current_tool_calls
969
+ # Apply single-pass filtering when yielding
970
+ # Optionally show provider info when using fallback
971
+ if provider_used != "Anthropic Claude" and response_text:
972
+ yield f"[Using {provider_used}]\n{filter_internal_tags(full_response)}"
973
+ else:
974
+ yield filter_internal_tags(full_response)
975
+
976
+ # Handle recursive tool calls (agent may need multiple searches)
977
+ tool_iteration = 0
978
+
979
+ # Adjust iteration limit based on query complexity
980
+ if is_complex_query(message):
981
+ max_tool_iterations = 5
982
+ logger.info("Complex query detected - allowing up to 5 iterations")
983
+ else:
984
+ max_tool_iterations = 3
985
+ logger.info("Standard query - allowing up to 3 iterations")
986
+
987
+ while tool_calls and tool_iteration < max_tool_iterations:
988
+ tool_iteration += 1
989
+ logger.info(f"Tool iteration {tool_iteration}: processing {len(tool_calls)} tool calls")
990
+
991
+ # No need to re-yield the planning phase - it was already shown
992
+
993
+ # Build assistant message using helper
994
+ assistant_content = build_assistant_message(
995
+ text_content=full_response,
996
+ tool_calls=tool_calls
997
+ )
998
+
999
+ messages.append({
1000
+ "role": "assistant",
1001
+ "content": assistant_content
1002
+ })
1003
+
1004
+ # Show working indicator for long searches
1005
+ num_tools = len(tool_calls)
1006
+ if num_tools > 0:
1007
+ working_text = f"\n⏳ **Searching {num_tools} database{'s' if num_tools > 1 else ''} in parallel...** "
1008
+ if num_tools > 2:
1009
+ working_text += f"(this typically takes 30-45 seconds)\n"
1010
+ elif num_tools > 1:
1011
+ working_text += f"(this typically takes 15-30 seconds)\n"
1012
+ else:
1013
+ working_text += f"\n"
1014
+ full_response += working_text
1015
+ yield filter_internal_tags(full_response) # Show working indicator immediately
1016
+
1017
+ # Execute tool calls in parallel for better performance
1018
+ from parallel_tool_execution import execute_tool_calls_parallel
1019
+ progress_text, tool_results_content = await execute_tool_calls_parallel(
1020
+ tool_calls=tool_calls,
1021
+ call_mcp_tool_func=call_mcp_tool
1022
+ )
1023
+
1024
+ # Add progress text to full response and yield accumulated content
1025
+ full_response += progress_text
1026
+ if progress_text:
1027
+ yield filter_internal_tags(full_response) # Yield full accumulated response
1028
+
1029
+ # Add single user message with ALL tool results
1030
+ messages.append({
1031
+ "role": "user",
1032
+ "content": tool_results_content
1033
+ })
1034
+
1035
+ # Smart reflection: Only add reflection prompt if results seem incomplete
1036
+ if tool_iteration == 1:
1037
+ # First iteration - use normal workflow with reflection
1038
+ # Check confidence indicators in tool results
1039
+ results_text = str(tool_results_content).lower()
1040
+
1041
+ # Indicators of low confidence/incomplete results
1042
+ low_confidence_indicators = [
1043
+ 'no results found', '0 results', 'no papers',
1044
+ 'no trials', 'limited', 'insufficient', 'few results'
1045
+ ]
1046
+
1047
+ # Indicators of high confidence/complete results
1048
+ high_confidence_indicators = [
1049
+ 'recent study', 'multiple studies', 'clinical trial',
1050
+ 'systematic review', 'meta-analysis', 'significant results'
1051
+ ]
1052
+
1053
+ # Count confidence indicators
1054
+ low_conf_count = sum(1 for ind in low_confidence_indicators if ind in results_text)
1055
+ high_conf_count = sum(1 for ind in high_confidence_indicators if ind in results_text)
1056
+
1057
+ # Calculate total results found across all tools
1058
+ import re
1059
+ result_numbers = re.findall(r'(\d+)\s+(?:results?|papers?|studies|trials?)', results_text)
1060
+ total_results = sum(int(n) for n in result_numbers) if result_numbers else 0
1061
+
1062
+ # Decide if reflection is needed - more aggressive skipping for performance
1063
+ needs_reflection = (
1064
+ low_conf_count > 1 or # Only if multiple low-confidence indicators
1065
+ (high_conf_count == 0 and total_results < 10) or # No high confidence AND few results
1066
+ total_results < 3 # Almost no results at all
1067
+ )
1068
+
1069
+ if needs_reflection:
1070
+ reflection_prompt = [
1071
+ {"type": "text", "text": "\n\n**SMART REFLECTION:** Based on the results so far, please evaluate:\n\n1. Do you have sufficient high-quality information to answer comprehensively?\n2. Are there important aspects that need more investigation?\n3. Would refining search terms or trying different databases help?\n\nIf confident with current information (found relevant studies/trials), proceed to synthesis with (**✅ ANSWER:**). Otherwise, use reflection markers (**🤔 REFLECTING:**) and search for missing information."}
1072
+ ]
1073
+ messages.append({
1074
+ "role": "user",
1075
+ "content": reflection_prompt
1076
+ })
1077
+ logger.info(f"Smart reflection triggered (low_conf:{low_conf_count}, high_conf:{high_conf_count}, results:{total_results})")
1078
+ else:
1079
+ # High confidence - skip reflection and go straight to synthesis
1080
+ logger.info(f"Skipping reflection - high confidence (low_conf:{low_conf_count}, high_conf:{high_conf_count}, results:{total_results})")
1081
+ # Add a synthesis-only prompt
1082
+ synthesis_prompt = [
1083
+ {"type": "text", "text": "\n\n**HIGH CONFIDENCE RESULTS:** The search returned comprehensive information. Please proceed directly to synthesis with (**✅ SYNTHESIS:**) and provide a complete answer based on the findings."}
1084
+ ]
1085
+ messages.append({
1086
+ "role": "user",
1087
+ "content": synthesis_prompt
1088
+ })
1089
+ else:
1090
+ # Subsequent iterations (tool_iteration > 1) - UPDATE existing synthesis without repeating workflow phases
1091
+ logger.info(f"Iteration {tool_iteration}: Updating synthesis with additional results")
1092
+ update_prompt = [
1093
+ {"type": "text", "text": "\n\n**ADDITIONAL RESULTS:** You have gathered more information. Please UPDATE your previous synthesis by integrating these new findings. Do NOT repeat the planning/executing/reflecting phases - just provide an updated synthesis that incorporates both the previous and new information. Continue directly with the updated content, no phase markers needed."}
1094
+ ]
1095
+ messages.append({
1096
+ "role": "user",
1097
+ "content": update_prompt
1098
+ })
1099
+
1100
+ # Second API call with tool results (with retry logic)
1101
+ logger.info("Starting second streaming API call with tool results...")
1102
+ logger.info(f"Messages array has {len(messages)} messages")
1103
+ logger.info(f"Last 3 messages: {json.dumps([{'role': m.get('role'), 'content_type': type(m.get('content')).__name__, 'content_len': len(str(m.get('content')))} for m in messages[-3:]], indent=2)}")
1104
+ # Log the actual tool results content
1105
+ logger.info(f"Tool results content ({len(tool_results_content)} items): {json.dumps(tool_results_content[:1], indent=2) if tool_results_content else 'EMPTY'}") # Log first item only to avoid spam
1106
+
1107
+ # Second streaming call for synthesis
1108
+ synthesis_response = ""
1109
+ additional_tool_calls = []
1110
+
1111
+ # For subsequent iterations, use modified system prompt that doesn't require all phases
1112
+ iteration_system_prompt = system_prompt
1113
+ if tool_iteration > 1:
1114
+ iteration_system_prompt = """You are an AI assistant specializing in ALS (Amyotrophic Lateral Sclerosis) research.
1115
+
1116
+ You are continuing your research with additional results. Please integrate the new findings into an updated response.
1117
+
1118
+ IMPORTANT: Do NOT repeat the workflow phases (Planning/Executing/Reflecting/Synthesis) - you've already done those.
1119
+ Simply provide updated content that incorporates both previous and new information.
1120
+ Start your response directly with the updated information, no phase markers needed."""
1121
+
1122
+ # Limit tools on subsequent iterations to prevent endless loops
1123
+ available_tools = tools if tool_iteration == 1 else [] # No more tools after first iteration
1124
+
1125
+ async for response_text, current_tool_calls, provider_used in stream_with_retry(
1126
+ client=client,
1127
+ messages=messages,
1128
+ tools=available_tools,
1129
+ system_prompt=iteration_system_prompt,
1130
+ max_retries=2,
1131
+ model=ANTHROPIC_MODEL,
1132
+ max_tokens=MAX_RESPONSE_TOKENS,
1133
+ stream_name="synthesis API call"
1134
+ ):
1135
+ synthesis_response = response_text
1136
+ additional_tool_calls = current_tool_calls
1137
+
1138
+ full_response += synthesis_response
1139
+ # Yield the full accumulated response including planning, execution, and synthesis
1140
+ yield filter_internal_tags(full_response)
1141
+
1142
+ # Check for additional tool calls
1143
+ if additional_tool_calls:
1144
+ logger.info(f"Found {len(additional_tool_calls)} recursive tool calls")
1145
+
1146
+ # Check if we're about to hit the iteration limit
1147
+ if tool_iteration >= (max_tool_iterations - 1): # Last iteration before limit
1148
+ # We're on the last allowed iteration
1149
+ logger.info(f"Approaching iteration limit ({max_tool_iterations}), wrapping up with current results")
1150
+
1151
+ # Don't execute more tools, instead trigger final synthesis
1152
+ # Add a user message to force final synthesis without tools
1153
+ messages.append({
1154
+ "role": "user",
1155
+ "content": [{"type": "text", "text": "Please provide a complete synthesis of all the information you've found so far. No more searches are available - summarize what you've discovered."}]
1156
+ })
1157
+
1158
+ # Make one final API call to synthesize all the results
1159
+ final_synthesis = ""
1160
+ async for response_text, _, provider_used in stream_with_retry(
1161
+ client=client,
1162
+ messages=messages,
1163
+ tools=[], # No tools for final synthesis
1164
+ system_prompt=system_prompt,
1165
+ max_retries=1,
1166
+ model=ANTHROPIC_MODEL,
1167
+ max_tokens=MAX_RESPONSE_TOKENS,
1168
+ stream_name="final synthesis"
1169
+ ):
1170
+ final_synthesis = response_text
1171
+
1172
+ full_response += final_synthesis
1173
+ # Yield the full accumulated response
1174
+ yield filter_internal_tags(full_response)
1175
+
1176
+ # Clear tool_calls to exit the loop gracefully
1177
+ tool_calls = []
1178
+ else:
1179
+ # We have room for more iterations, proceed normally
1180
+ # Build assistant message for recursive calls
1181
+ assistant_content = build_assistant_message(
1182
+ text_content=synthesis_response,
1183
+ tool_calls=additional_tool_calls
1184
+ )
1185
+
1186
+ messages.append({
1187
+ "role": "assistant",
1188
+ "content": assistant_content
1189
+ })
1190
+
1191
+ # Execute recursive tool calls
1192
+ progress_text, tool_results_content = await execute_tool_calls(
1193
+ tool_calls=additional_tool_calls,
1194
+ call_mcp_tool_func=call_mcp_tool
1195
+ )
1196
+
1197
+ full_response += progress_text
1198
+ # Yield the full accumulated response
1199
+ if progress_text:
1200
+ yield filter_internal_tags(full_response)
1201
+
1202
+ # Add results and continue loop
1203
+ messages.append({
1204
+ "role": "user",
1205
+ "content": tool_results_content
1206
+ })
1207
+
1208
+ # Set tool_calls for next iteration
1209
+ tool_calls = additional_tool_calls
1210
+ else:
1211
+ # No more tool calls, exit loop
1212
+ tool_calls = []
1213
+
1214
+ if tool_iteration >= max_tool_iterations:
1215
+ logger.warning(f"Reached maximum tool iterations ({max_tool_iterations})")
1216
+
1217
+ # Force synthesis if we haven't provided one yet
1218
+ if tool_iteration > 0 and "✅ SYNTHESIS:" not in full_response:
1219
+ logger.warning(f"No synthesis found after {tool_iteration} iterations, forcing synthesis")
1220
+
1221
+ # Add a forced synthesis prompt
1222
+ synthesis_prompt_content = [{"type": "text", "text": "You MUST now provide a ✅ SYNTHESIS phase. Synthesize whatever information you've gathered, even if searches were limited. If you couldn't find specific research, provide knowledge-based answers with appropriate caveats."}]
1223
+ messages.append({
1224
+ "role": "user",
1225
+ "content": synthesis_prompt_content
1226
+ })
1227
+
1228
+ # Make final synthesis call without tools
1229
+ forced_synthesis = ""
1230
+ async for response_text, _, _ in stream_with_retry(
1231
+ client=client,
1232
+ messages=messages,
1233
+ tools=[], # No tools - just synthesize
1234
+ system_prompt=system_prompt,
1235
+ max_retries=1,
1236
+ model=ANTHROPIC_MODEL,
1237
+ max_tokens=MAX_RESPONSE_TOKENS,
1238
+ stream_name="forced synthesis"
1239
+ ):
1240
+ forced_synthesis = response_text
1241
+
1242
+ full_response += "\n\n" + forced_synthesis
1243
+ # Yield the full accumulated response with forced synthesis
1244
+ yield filter_internal_tags(full_response)
1245
+
1246
+ # No final yield needed - response has already been yielded incrementally
1247
+
1248
+ # Record successful response time
1249
+ response_time = time.time() - start_time
1250
+ health_monitor.record_response_time(response_time)
1251
+ logger.info(f"Request completed in {response_time:.2f} seconds")
1252
+
1253
+ except Exception as e:
1254
+ logger.error(f"Error in als_research_agent: {e}", exc_info=True)
1255
+ health_monitor.record_error(str(e))
1256
+ error_message = format_error_message(e, context=f"Processing query: {message[:100]}...")
1257
+ yield error_message
1258
+
1259
+ # Gradio Interface
1260
+ async def main() -> None:
1261
+ """Main function to setup and launch the Gradio interface"""
1262
+ global cleanup_task
1263
+
1264
+ try:
1265
+ # Setup MCP servers
1266
+ logger.info("Setting up MCP servers...")
1267
+ await setup_mcp_servers()
1268
+ logger.info("MCP servers initialized successfully")
1269
+
1270
+ # Start memory cleanup task
1271
+ cleanup_task = asyncio.create_task(cleanup_memory())
1272
+ logger.info("Memory cleanup task started")
1273
+
1274
+ except Exception as e:
1275
+ logger.error(f"Failed to initialize MCP servers: {e}", exc_info=True)
1276
+ raise
1277
+
1278
+ # Create Gradio interface with export button
1279
+ with gr.Blocks() as demo:
1280
+ gr.Markdown("# 🧬 ALSARA - ALS Agentic Research Assistant ")
1281
+ gr.Markdown("Ask questions about ALS research, treatments, and clinical trials. This agent searches PubMed, AACT clinical trials database, and other sources in real-time.")
1282
+
1283
+ # Show LLM configuration status using unified client
1284
+ llm_status = f"🤖 **LLM Provider:** {client.get_provider_display_name()}"
1285
+ gr.Markdown(llm_status)
1286
+
1287
+ with gr.Tabs():
1288
+ with gr.TabItem("Chat"):
1289
+ chatbot = gr.Chatbot(
1290
+ height=600,
1291
+ show_label=False,
1292
+ allow_tags=True, # Allow custom HTML tags from LLMs (Gradio 6 default)
1293
+ elem_classes="chatbot-container"
1294
+ )
1295
+
1296
+ with gr.TabItem("System Health"):
1297
+ gr.Markdown("## 📊 System Health Monitor")
1298
+
1299
+ def format_health_status():
1300
+ """Format health status for display"""
1301
+ status = health_monitor.get_health_status()
1302
+ return f"""
1303
+ **Status:** {status['status'].upper()} {'✅' if status['status'] == 'healthy' else '⚠️'}
1304
+
1305
+ **Uptime:** {status['uptime_seconds'] / 3600:.1f} hours
1306
+ **Total Requests:** {status['request_count']}
1307
+ **Error Rate:** {status['error_rate']:.1%}
1308
+ **Avg Response Time:** {status['avg_response_time']:.2f}s
1309
+
1310
+ **Cache Status:**
1311
+ - Cache Size: {status['cache_size']} items
1312
+ - Rate Limiter: {status['rate_limit_status']}
1313
+
1314
+ **Most Used Tools:**
1315
+ {chr(10).join([f"- {tool}: {count} calls" for tool, count in status['most_used_tools'].items()])}
1316
+
1317
+ **Last Error:** {status['last_error']['error'] if status['last_error'] else 'None'}
1318
+ """
1319
+
1320
+ health_display = gr.Markdown(format_health_status())
1321
+ refresh_btn = gr.Button("🔄 Refresh Health Status")
1322
+ refresh_btn.click(fn=format_health_status, outputs=health_display)
1323
+
1324
+ with gr.Row():
1325
+ with gr.Column(scale=6):
1326
+ msg = gr.Textbox(
1327
+ placeholder="Ask about ALS research, treatments, or clinical trials...",
1328
+ container=False,
1329
+ label="Type your question or use voice input"
1330
+ )
1331
+ with gr.Column(scale=1):
1332
+ audio_input = gr.Audio(
1333
+ sources=["microphone"],
1334
+ type="filepath",
1335
+ label="🎤 Voice Input"
1336
+ )
1337
+ export_btn = gr.DownloadButton("💾 Export", scale=1)
1338
+
1339
+ with gr.Row():
1340
+ submit_btn = gr.Button("Submit", variant="primary")
1341
+ retry_btn = gr.Button("🔄 Retry")
1342
+ undo_btn = gr.Button("↩️ Undo")
1343
+ clear_btn = gr.Button("🗑️ Clear")
1344
+ speak_btn = gr.Button("🔊 Read Last Response", variant="secondary", interactive=False)
1345
+
1346
+ # Audio output component (initially hidden)
1347
+ with gr.Row(visible=False) as audio_row:
1348
+ audio_output = gr.Audio(
1349
+ label="🔊 Voice Output",
1350
+ type="filepath",
1351
+ autoplay=True,
1352
+ visible=True
1353
+ )
1354
+
1355
+ gr.Examples(
1356
+ examples=[
1357
+ "Psilocybin trials and use in therapy",
1358
+ "Role of Omega-3 and omega-6 fatty acids in ALS treatment",
1359
+ "List the main genes that should be tested for ALS gene therapy eligibility",
1360
+ "What are the latest SOD1-targeted therapies in recent preprints?",
1361
+ "Find recruiting clinical trials for bulbar-onset ALS",
1362
+ "Explain the role of TDP-43 in ALS pathology",
1363
+ "What is the current status of tofersen clinical trials?",
1364
+ "Are there any new combination therapies being studied?",
1365
+ "What's the latest research on ALS biomarkers from the past 60 days?",
1366
+ "Search PubMed for recent ALS gene therapy research"
1367
+ ],
1368
+ inputs=msg
1369
+ )
1370
+
1371
+ # Chat interface logic with improved error handling
1372
+ async def respond(message: str, history: Optional[List[Dict[str, str]]]) -> AsyncGenerator[List[Dict[str, str]], None]:
1373
+ history = history or []
1374
+ # Append user message
1375
+ history.append({"role": "user", "content": message})
1376
+ # Append empty assistant message
1377
+ history.append({"role": "assistant", "content": ""})
1378
+
1379
+ try:
1380
+ # Pass history without the new messages to als_research_agent
1381
+ async for response in als_research_agent(message, history[:-2]):
1382
+ # Update the last assistant message in place
1383
+ history[-1]['content'] = response
1384
+ yield history
1385
+ except Exception as e:
1386
+ logger.error(f"Error in respond: {e}", exc_info=True)
1387
+ error_msg = f"❌ Error: {str(e)}"
1388
+ history[-1]['content'] = error_msg
1389
+ yield history
1390
+
1391
+ def update_speak_button():
1392
+ """Update the speak button state based on last_response_was_research"""
1393
+ global last_response_was_research
1394
+ return gr.update(interactive=last_response_was_research)
1395
+
1396
+ def undo_last(history: Optional[List[Dict[str, str]]]) -> Optional[List[Dict[str, str]]]:
1397
+ """Remove the last message pair from history"""
1398
+ if history and len(history) >= 2:
1399
+ # Remove last user message and assistant response
1400
+ return history[:-2]
1401
+ return history
1402
+
1403
+ async def retry_last(history: Optional[List[Dict[str, str]]]) -> AsyncGenerator[List[Dict[str, str]], None]:
1404
+ """Retry the last query with error handling"""
1405
+ if history and len(history) >= 2:
1406
+ # Get the last user message
1407
+ last_user_msg = history[-2]["content"] if history[-2]["role"] == "user" else None
1408
+ if last_user_msg:
1409
+ # Remove last assistant message, keep user message
1410
+ history = history[:-1]
1411
+ # Add new empty assistant message
1412
+ history.append({"role": "assistant", "content": ""})
1413
+ try:
1414
+ # Resubmit (pass history without the last user and assistant messages)
1415
+ async for response in als_research_agent(last_user_msg, history[:-2]):
1416
+ # Update the last assistant message in place
1417
+ history[-1]['content'] = response
1418
+ yield history
1419
+ except Exception as e:
1420
+ logger.error(f"Error in retry_last: {e}", exc_info=True)
1421
+ error_msg = f"❌ Error during retry: {str(e)}"
1422
+ history[-1]['content'] = error_msg
1423
+ yield history
1424
+ else:
1425
+ yield history
1426
+ else:
1427
+ yield history
1428
+
1429
+ async def process_voice_input(audio_file):
1430
+ """Process voice input and convert to text"""
1431
+ try:
1432
+ if audio_file is None:
1433
+ return ""
1434
+
1435
+ # Try to use speech recognition if available
1436
+ try:
1437
+ import speech_recognition as sr
1438
+ recognizer = sr.Recognizer()
1439
+
1440
+ # Load audio file
1441
+ with sr.AudioFile(audio_file) as source:
1442
+ audio_data = recognizer.record(source)
1443
+
1444
+ # Use Google's free speech recognition
1445
+ try:
1446
+ text = recognizer.recognize_google(audio_data)
1447
+ logger.info(f"Voice input transcribed: {text[:50]}...")
1448
+ return text
1449
+ except sr.UnknownValueError:
1450
+ logger.warning("Could not understand audio")
1451
+ return ""
1452
+ except sr.RequestError as e:
1453
+ logger.error(f"Speech recognition service error: {e}")
1454
+ return ""
1455
+
1456
+ except ImportError:
1457
+ logger.warning("speech_recognition not available")
1458
+ return ""
1459
+
1460
+ except Exception as e:
1461
+ logger.error(f"Error processing voice input: {e}")
1462
+ return ""
1463
+
1464
+ async def speak_last_response(history: Optional[List[Dict[str, str]]]) -> Tuple[gr.update, gr.update]:
1465
+ """Convert the last assistant response to speech using ElevenLabs"""
1466
+ try:
1467
+ # Check if the last response was from research workflow
1468
+ global last_response_was_research
1469
+ if not last_response_was_research:
1470
+ # This shouldn't happen since button is disabled, but handle it gracefully
1471
+ logger.info("Last response was not research-based, voice synthesis not available")
1472
+ return gr.update(visible=False), gr.update(value=None)
1473
+
1474
+ # Check ELEVENLABS_API_KEY
1475
+ api_key = os.getenv("ELEVENLABS_API_KEY")
1476
+ if not api_key:
1477
+ logger.warning("No ELEVENLABS_API_KEY configured")
1478
+ return gr.update(visible=True), gr.update(
1479
+ value=None,
1480
+ label="⚠️ Voice service unavailable - Please set ELEVENLABS_API_KEY"
1481
+ )
1482
+
1483
+ if not history or len(history) < 1:
1484
+ logger.warning("No history available for text-to-speech")
1485
+ return gr.update(visible=True), gr.update(
1486
+ value=None,
1487
+ label="⚠️ No conversation history to read"
1488
+ )
1489
+
1490
+ # Get the last assistant response
1491
+ last_response = None
1492
+
1493
+ # Detect and handle different history formats
1494
+ if isinstance(history, list) and len(history) > 0:
1495
+ # Check if history is a list of lists (Gradio chatbot format)
1496
+ if isinstance(history[0], list) and len(history[0]) == 2:
1497
+ # Format: [[user_msg, assistant_msg], ...]
1498
+ logger.info("Detected Gradio list-of-lists history format")
1499
+ for i, exchange in enumerate(reversed(history)):
1500
+ if len(exchange) == 2 and exchange[1]: # assistant message is second
1501
+ last_response = exchange[1]
1502
+ break
1503
+ elif isinstance(history[0], dict):
1504
+ # Format: [{"role": "user", "content": "..."}, ...]
1505
+ logger.info("Detected dict-based history format")
1506
+ for i, msg in enumerate(reversed(history)):
1507
+ if msg.get("role") == "assistant" and msg.get("content"):
1508
+ content = msg["content"]
1509
+ # CRITICAL FIX: Handle Claude API content blocks
1510
+ if isinstance(content, list):
1511
+ # Extract text from content blocks
1512
+ text_parts = []
1513
+ for block in content:
1514
+ if isinstance(block, dict):
1515
+ # Handle text block
1516
+ if block.get("type") == "text" and "text" in block:
1517
+ text_parts.append(block["text"])
1518
+ # Handle string content in dict
1519
+ elif "content" in block and isinstance(block["content"], str):
1520
+ text_parts.append(block["content"])
1521
+ elif isinstance(block, str):
1522
+ text_parts.append(block)
1523
+ last_response = "\n".join(text_parts)
1524
+ else:
1525
+ # Content is already a string
1526
+ last_response = content
1527
+ break
1528
+ elif isinstance(history[0], str):
1529
+ # Simple string list - take the last one
1530
+ logger.info("Detected simple string list history format")
1531
+ last_response = history[-1] if history else None
1532
+ else:
1533
+ # Unknown format - try to extract what we can
1534
+ logger.warning(f"Unknown history format: {type(history[0])}")
1535
+ # Try to convert to string as last resort
1536
+ try:
1537
+ last_response = str(history[-1]) if history else None
1538
+ except Exception as e:
1539
+ logger.error(f"Failed to extract last response: {e}")
1540
+
1541
+ if not last_response:
1542
+ logger.warning("No assistant response found in history")
1543
+ return gr.update(visible=True), gr.update(
1544
+ value=None,
1545
+ label="⚠️ No assistant response found to read"
1546
+ )
1547
+
1548
+ # Clean the response text (remove markdown, internal tags, etc.)
1549
+ # Convert to string if not already (safety check)
1550
+ last_response = str(last_response)
1551
+
1552
+ # IMPORTANT: Extract only the synthesis/main answer, skip references and "for more information"
1553
+ # Find where to cut off the response
1554
+ cutoff_patterns = [
1555
+ # Clear section headers with colons - most reliable indicators
1556
+ r'\n\s*(?:For (?:more|additional|further) (?:information|details|reading))\s*[::]',
1557
+ r'\n\s*(?:References?|Sources?|Citations?|Bibliography)\s*[::]',
1558
+ r'\n\s*(?:Additional (?:resources?|information|reading|materials?))\s*[::]',
1559
+
1560
+ # Markdown headers for reference sections (must be on their own line)
1561
+ r'\n\s*#{1,6}\s+(?:References?|Sources?|Citations?|Bibliography)\s*$',
1562
+ r'\n\s*#{1,6}\s+(?:For (?:more|additional|further) (?:information|details))\s*$',
1563
+ r'\n\s*#{1,6}\s+(?:Additional (?:Resources?|Information|Reading))\s*$',
1564
+ r'\n\s*#{1,6}\s+(?:Further Reading|Learn More)\s*$',
1565
+
1566
+ # Bold headers for reference sections (with newline after)
1567
+ r'\n\s*\*\*(?:References?|Sources?|Citations?)\*\*\s*[::]?\s*\n',
1568
+ r'\n\s*\*\*(?:For (?:more|additional) information)\*\*\s*[::]?\s*\n',
1569
+
1570
+ # Phrases that clearly introduce reference lists
1571
+ r'\n\s*(?:Here are|Below are|The following are)\s+(?:the |some |additional )?(?:references|sources|citations|papers cited|studies referenced)',
1572
+ r'\n\s*(?:References used|Sources consulted|Papers cited|Studies referenced)\s*[::]',
1573
+ r'\n\s*(?:Key|Recent|Selected|Relevant)\s+(?:references?|publications?|citations)\s*[::]',
1574
+
1575
+ # Clinical trials section headers with clear separators
1576
+ r'\n\s*(?:Clinical trials?|Studies|Research papers?)\s+(?:referenced|cited|mentioned|used)\s*[::]',
1577
+ r'\n\s*(?:AACT|ClinicalTrials\.gov)\s+(?:database entries?|trial IDs?|references?)\s*[::]',
1578
+
1579
+ # Web link sections
1580
+ r'\n\s*(?:Links?|URLs?|Websites?|Web resources?)\s*[::]',
1581
+ r'\n\s*(?:Visit|See|Check out)\s+(?:these|the following)\s+(?:links?|websites?|resources?)',
1582
+ r'\n\s*(?:Learn more|Read more|Find out more|Get more information)\s+(?:at|here|below)\s*[::]',
1583
+
1584
+ # Academic citation lists (only when preceded by double newline or clear separator)
1585
+ r'\n\n\s*\d+\.\s+[A-Z][a-z]+.*?et al\..*?(?:PMID|DOI|Journal)',
1586
+ r'\n\n\s*\[1\]\s+[A-Z][a-z]+.*?(?:et al\.|https?://)',
1587
+
1588
+ # Direct ID listings (clearly separate from main content)
1589
+ r'\n\s*(?:PMID|DOI|NCT)\s*[::]\s*\d+',
1590
+ r'\n\s*(?:Trial IDs?|Study IDs?)\s*[::]',
1591
+
1592
+ # Footer sections
1593
+ r'\n\s*(?:Note|Notes|Disclaimer|Important notice)\s*[::]',
1594
+ r'\n\s*(?:Data (?:source|from)|Database|Repository)\s*[::]',
1595
+ r'\n\s*(?:Retrieved from|Accessed via|Source database)\s*[::]',
1596
+ ]
1597
+
1598
+ # FIRST: Extract ONLY the synthesis section (after ✅ SYNTHESIS:)
1599
+ # More robust pattern that handles various formatting
1600
+ synthesis_patterns = [
1601
+ r'✅\s*\*{0,2}SYNTHESIS\*{0,2}\s*:?\s*\n+(.*)', # Standard format with newline
1602
+ r'\*\*✅\s*SYNTHESIS:\*\*\s*(.*)', # Bold format
1603
+ r'✅\s*SYNTHESIS:\s*(.*)', # Simple format
1604
+ r'SYNTHESIS:\s*(.*)', # Fallback without emoji
1605
+ ]
1606
+
1607
+ synthesis_text = None
1608
+ for pattern in synthesis_patterns:
1609
+ synthesis_match = re.search(pattern, last_response, re.IGNORECASE | re.DOTALL)
1610
+ if synthesis_match:
1611
+ synthesis_text = synthesis_match.group(1)
1612
+ logger.info(f"Extracted synthesis section using pattern: {pattern[:30]}...")
1613
+ break
1614
+
1615
+ if synthesis_text:
1616
+ logger.info("Extracted synthesis section for voice reading")
1617
+ else:
1618
+ # Fallback: if no synthesis marker found, use the whole response
1619
+ synthesis_text = last_response
1620
+ logger.info("No synthesis marker found, using full response")
1621
+
1622
+ # THEN: Remove references and footer sections
1623
+ for pattern in cutoff_patterns:
1624
+ match = re.search(pattern, synthesis_text, re.IGNORECASE | re.MULTILINE)
1625
+ if match:
1626
+ synthesis_text = synthesis_text[:match.start()]
1627
+ logger.info(f"Truncated response at pattern: {pattern[:50]}...")
1628
+ break
1629
+
1630
+ # Now clean the synthesis text
1631
+ clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', synthesis_text) # Remove bold
1632
+ clean_text = re.sub(r'\*(.*?)\*', r'\1', clean_text) # Remove italic
1633
+ clean_text = re.sub(r'#{1,6}\s*(.*?)\n', r'\1. ', clean_text) # Remove headers
1634
+ clean_text = re.sub(r'```.*?```', '', clean_text, flags=re.DOTALL) # Remove code blocks
1635
+ clean_text = re.sub(r'`(.*?)`', r'\1', clean_text) # Remove inline code
1636
+ clean_text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', clean_text) # Remove links
1637
+ clean_text = re.sub(r'<[^>]+>', '', clean_text) # Remove HTML tags
1638
+ clean_text = re.sub(r'\n{3,}', '\n\n', clean_text) # Reduce multiple newlines
1639
+
1640
+ # Strip leading/trailing whitespace
1641
+ clean_text = clean_text.strip()
1642
+
1643
+ # Ensure we have something to read
1644
+ if not clean_text or len(clean_text) < 10:
1645
+ logger.warning("Synthesis text too short after cleaning, using original")
1646
+ clean_text = last_response[:2500] # Fallback to first 2500 chars
1647
+ # Check if ElevenLabs server is available
1648
+ try:
1649
+ server_tools = await mcp_manager.list_all_tools()
1650
+ elevenlabs_available = any('elevenlabs' in tool for tool in server_tools.keys())
1651
+ if not elevenlabs_available:
1652
+ logger.error("ElevenLabs server not available in MCP tools")
1653
+ return gr.update(visible=True), gr.update(
1654
+ value=None,
1655
+ label="⚠️ Voice service not available - Please set ELEVENLABS_API_KEY"
1656
+ )
1657
+ except Exception as e:
1658
+ logger.error(f"Failed to check ElevenLabs availability: {e}", exc_info=True)
1659
+ return gr.update(visible=True), gr.update(
1660
+ value=None,
1661
+ label="⚠️ Voice service not available"
1662
+ )
1663
+
1664
+ # Remove phase markers from text
1665
+ clean_text = re.sub(r'\*\*[🎯🔧🤔✅].*?:\*\*', '', clean_text)
1666
+ # Call ElevenLabs text-to-speech through MCP
1667
+ logger.info(f"Calling ElevenLabs text-to-speech with {len(clean_text)} characters...")
1668
+ try:
1669
+ result = await call_mcp_tool(
1670
+ "elevenlabs__text_to_speech",
1671
+ {"text": clean_text, "speed": 0.95} # Slightly slower for clarity
1672
+ )
1673
+ except Exception as e:
1674
+ logger.error(f"MCP tool call failed: {e}", exc_info=True)
1675
+ raise
1676
+
1677
+ # Parse the result
1678
+ try:
1679
+ result_data = json.loads(result) if isinstance(result, str) else result
1680
+ # Check for API key error
1681
+ if "ELEVENLABS_API_KEY not configured" in str(result):
1682
+ logger.error("ElevenLabs API key not configured - found in result string")
1683
+ return gr.update(visible=True), gr.update(
1684
+ value=None,
1685
+ label="⚠️ Voice service unavailable - Please set ELEVENLABS_API_KEY environment variable"
1686
+ )
1687
+
1688
+ if result_data.get("status") == "success" and result_data.get("audio_base64"):
1689
+ # Save audio to temporary file
1690
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
1691
+ audio_data = base64.b64decode(result_data["audio_base64"])
1692
+ tmp_file.write(audio_data)
1693
+ audio_path = tmp_file.name
1694
+
1695
+ logger.info(f"Audio successfully generated and saved to: {audio_path}")
1696
+ return gr.update(visible=True), gr.update(
1697
+ value=audio_path,
1698
+ visible=True,
1699
+ label="🔊 Click to play voice output"
1700
+ )
1701
+ elif result_data.get("status") == "error":
1702
+ error_msg = result_data.get("message", "Unknown error")
1703
+ error_type = result_data.get("error", "Unknown")
1704
+ logger.error(f"ElevenLabs error - Type: {error_type}, Message: {error_msg}")
1705
+ return gr.update(visible=True), gr.update(
1706
+ value=None,
1707
+ label=f"⚠️ Voice service error: {error_msg}"
1708
+ )
1709
+ else:
1710
+ logger.error(f"Unexpected result structure")
1711
+ return gr.update(visible=True), gr.update(
1712
+ value=None,
1713
+ label="⚠️ Voice service returned no audio"
1714
+ )
1715
+ except json.JSONDecodeError as e:
1716
+ logger.error(f"JSON decode error: {e}")
1717
+ logger.error(f"Failed to parse ElevenLabs response, first 500 chars: {str(result)[:500]}")
1718
+ return gr.update(visible=True), gr.update(
1719
+ value=None,
1720
+ label="⚠️ Voice service response error"
1721
+ )
1722
+ except Exception as e:
1723
+ logger.error(f"Unexpected error in result parsing: {e}", exc_info=True)
1724
+ raise
1725
+
1726
+ except Exception as e:
1727
+ logger.error(f"Error in speak_last_response: {e}", exc_info=True)
1728
+ return gr.update(visible=True), gr.update(
1729
+ value=None,
1730
+ label=f"⚠️ Voice service error: {str(e)}"
1731
+ )
1732
+
1733
+ msg.submit(
1734
+ respond, [msg, chatbot], [chatbot],
1735
+ api_name="chat"
1736
+ ).then(
1737
+ update_speak_button, None, [speak_btn]
1738
+ ).then(
1739
+ lambda: "", None, [msg]
1740
+ )
1741
+
1742
+ # Add event handler for audio input
1743
+ audio_input.stop_recording(
1744
+ process_voice_input,
1745
+ inputs=[audio_input],
1746
+ outputs=[msg]
1747
+ ).then(
1748
+ lambda: None,
1749
+ outputs=[audio_input] # Clear audio after processing
1750
+ )
1751
+
1752
+ submit_btn.click(
1753
+ respond, [msg, chatbot], [chatbot],
1754
+ api_name="chat_button"
1755
+ ).then(
1756
+ update_speak_button, None, [speak_btn]
1757
+ ).then(
1758
+ lambda: "", None, [msg]
1759
+ )
1760
+
1761
+ retry_btn.click(
1762
+ retry_last, [chatbot], [chatbot],
1763
+ api_name="retry"
1764
+ ).then(
1765
+ update_speak_button, None, [speak_btn]
1766
+ )
1767
+
1768
+ undo_btn.click(
1769
+ undo_last, [chatbot], [chatbot],
1770
+ api_name="undo"
1771
+ )
1772
+
1773
+ clear_btn.click(
1774
+ lambda: None, None, chatbot,
1775
+ queue=False,
1776
+ api_name="clear"
1777
+ ).then(
1778
+ lambda: gr.update(interactive=False), None, [speak_btn]
1779
+ )
1780
+
1781
+ export_btn.click(
1782
+ export_conversation, chatbot, export_btn,
1783
+ api_name="export"
1784
+ )
1785
+
1786
+ speak_btn.click(
1787
+ speak_last_response, [chatbot], [audio_row, audio_output],
1788
+ api_name="speak"
1789
+ )
1790
+
1791
+ # Enable queue for streaming to work
1792
+ demo.queue()
1793
+
1794
+ try:
1795
+ # Use environment variable for port, default to 7860 for HuggingFace
1796
+ port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
1797
+ demo.launch(
1798
+ server_name="0.0.0.0",
1799
+ server_port=port,
1800
+ share=False
1801
+ )
1802
+ except KeyboardInterrupt:
1803
+ logger.info("Received keyboard interrupt, shutting down...")
1804
+ except Exception as e:
1805
+ logger.error(f"Error during launch: {e}", exc_info=True)
1806
+ finally:
1807
+ # Cleanup
1808
+ logger.info("Cleaning up resources...")
1809
+ await cleanup_mcp_servers()
1810
+
1811
+ if __name__ == "__main__":
1812
+ try:
1813
+ asyncio.run(main())
1814
+ except KeyboardInterrupt:
1815
+ logger.info("Application terminated by user")
1816
+ except Exception as e:
1817
+ logger.error(f"Application error: {e}", exc_info=True)
1818
+ raise
1819
+ finally:
1820
+ # Cancel cleanup task if running
1821
+ if cleanup_task and not cleanup_task.done():
1822
+ cleanup_task.cancel()
1823
+ logger.info("Cancelled memory cleanup task")
1824
+
1825
+ # Cleanup unified LLM client
1826
+ if client is not None:
1827
+ try:
1828
+ asyncio.run(client.cleanup())
1829
+ logger.info("LLM client cleanup completed")
1830
+ except Exception as e:
1831
+ logger.warning(f"LLM client cleanup error: {e}")
1832
+ pass
custom_mcp_client.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom MCP client using direct subprocess communication.
3
+ This bypasses the buggy stdio_client from mcp.client.stdio.
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import logging
9
+ import subprocess
10
+ import sys
11
+ from pathlib import Path
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class MCPClient:
18
+ """Custom MCP client using direct subprocess communication"""
19
+
20
+ def __init__(self, server_script: str, server_name: str):
21
+ self.server_script = server_script
22
+ self.server_name = server_name
23
+ self.process: Optional[subprocess.Popen] = None
24
+ self.message_id = 0
25
+ self._initialized = False
26
+ self.script_path = server_script # Store for potential restart
27
+
28
+ async def start(self):
29
+ """Start the MCP server subprocess"""
30
+ logger.info(f"Starting MCP server: {self.server_name}")
31
+
32
+ self.process = subprocess.Popen(
33
+ [sys.executable, self.server_script],
34
+ stdin=subprocess.PIPE,
35
+ stdout=subprocess.PIPE,
36
+ stderr=subprocess.PIPE,
37
+ text=True,
38
+ bufsize=1 # Line-buffered I/O to prevent 8KB truncation
39
+ )
40
+
41
+ # Initialize the session
42
+ await self._initialize()
43
+ logger.info(f"Successfully started MCP server: {self.server_name}")
44
+
45
+ async def _initialize(self):
46
+ """Initialize the MCP session"""
47
+ init_message = {
48
+ "jsonrpc": "2.0",
49
+ "id": self._next_id(),
50
+ "method": "initialize",
51
+ "params": {
52
+ "protocolVersion": "2024-11-05",
53
+ "capabilities": {},
54
+ "clientInfo": {
55
+ "name": "als-research-agent",
56
+ "version": "1.0.0"
57
+ }
58
+ }
59
+ }
60
+
61
+ response = await self._send_request(init_message)
62
+ if "result" in response:
63
+ self._initialized = True
64
+ logger.info(f"Initialized {self.server_name}: {response['result'].get('serverInfo', {})}")
65
+ else:
66
+ raise Exception(f"Initialization failed: {response}")
67
+
68
+ def _next_id(self) -> int:
69
+ """Get next message ID"""
70
+ self.message_id += 1
71
+ return self.message_id
72
+
73
+ async def _send_request(self, message: Dict[str, Any]) -> Dict[str, Any]:
74
+ """Send a JSON-RPC request and wait for response"""
75
+ if not self.process:
76
+ raise RuntimeError("Server not started")
77
+
78
+ # Check if process is still alive
79
+ if self.process.poll() is not None:
80
+ # Process has terminated
81
+ raise RuntimeError(f"Server {self.server_name} has terminated unexpectedly")
82
+
83
+ # Send request
84
+ request_json = json.dumps(message) + "\n"
85
+ self.process.stdin.write(request_json)
86
+ self.process.stdin.flush()
87
+
88
+ # Read response with timeout
89
+ try:
90
+ response_line = await asyncio.wait_for(
91
+ asyncio.to_thread(self.process.stdout.readline),
92
+ timeout=60.0 # Extended timeout for LlamaIndex/RAG server initialization
93
+ )
94
+
95
+ if not response_line:
96
+ raise Exception("Server closed stdout")
97
+
98
+ return json.loads(response_line)
99
+ except asyncio.TimeoutError:
100
+ raise Exception("Request timed out")
101
+
102
+ async def list_tools(self) -> List[Dict[str, Any]]:
103
+ """List available tools"""
104
+ if not self._initialized:
105
+ raise RuntimeError("Client not initialized")
106
+
107
+ message = {
108
+ "jsonrpc": "2.0",
109
+ "id": self._next_id(),
110
+ "method": "tools/list",
111
+ "params": {}
112
+ }
113
+
114
+ response = await self._send_request(message)
115
+ if "result" in response:
116
+ return response["result"].get("tools", [])
117
+ else:
118
+ raise Exception(f"List tools failed: {response}")
119
+
120
+ async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
121
+ """Call a tool"""
122
+ if not self._initialized:
123
+ raise RuntimeError("Client not initialized")
124
+
125
+ message = {
126
+ "jsonrpc": "2.0",
127
+ "id": self._next_id(),
128
+ "method": "tools/call",
129
+ "params": {
130
+ "name": tool_name,
131
+ "arguments": arguments
132
+ }
133
+ }
134
+
135
+ response = await self._send_request(message)
136
+ if "result" in response:
137
+ # Extract result from response
138
+ result = response["result"]
139
+
140
+ # Handle different response formats
141
+ if isinstance(result, dict):
142
+ # New format with 'result' field
143
+ if "result" in result:
144
+ return result["result"]
145
+ # Content array format
146
+ elif "content" in result:
147
+ content = result["content"]
148
+ if isinstance(content, list) and len(content) > 0:
149
+ return content[0].get("text", str(content))
150
+ return str(content)
151
+ else:
152
+ return str(result)
153
+ else:
154
+ return str(result)
155
+ else:
156
+ error = response.get("error", {})
157
+ raise Exception(f"Tool call failed: {error.get('message', response)}")
158
+
159
+ async def close(self):
160
+ """Close the MCP client and terminate server"""
161
+ if self.process:
162
+ logger.info(f"Closing MCP server: {self.server_name}")
163
+ self.process.terminate()
164
+ try:
165
+ self.process.wait(timeout=5)
166
+ except subprocess.TimeoutExpired:
167
+ self.process.kill()
168
+ self.process.wait()
169
+ self.process = None
170
+ self._initialized = False
171
+
172
+
173
+ class MCPClientManager:
174
+ """Manage multiple MCP clients"""
175
+
176
+ def __init__(self):
177
+ self.clients: Dict[str, MCPClient] = {}
178
+
179
+ async def add_server(self, name: str, script_path: str):
180
+ """Add and start an MCP server"""
181
+ client = MCPClient(script_path, name)
182
+ await client.start()
183
+ self.clients[name] = client
184
+ logger.info(f"Added MCP server: {name}")
185
+
186
+ async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> str:
187
+ """Call a tool on a specific server"""
188
+ if server_name not in self.clients:
189
+ raise ValueError(f"Server not found: {server_name}")
190
+
191
+ return await self.clients[server_name].call_tool(tool_name, arguments)
192
+
193
+ async def list_all_tools(self) -> Dict[str, List[Dict[str, Any]]]:
194
+ """List tools from all servers, handling failures gracefully"""
195
+ all_tools = {}
196
+ failed_servers = []
197
+
198
+ for name, client in self.clients.items():
199
+ try:
200
+ tools = await client.list_tools()
201
+ for tool in tools:
202
+ tool['server'] = name # Add server info to each tool
203
+ all_tools[name] = tools
204
+ except Exception as e:
205
+ logger.error(f"Failed to list tools from server {name}: {e}")
206
+ failed_servers.append(name)
207
+ # Continue with other servers instead of failing entirely
208
+ all_tools[name] = []
209
+
210
+ if failed_servers:
211
+ logger.warning(f"Some servers failed to respond: {', '.join(failed_servers)}")
212
+ # Try to restart failed servers
213
+ for server_name in failed_servers:
214
+ try:
215
+ client = self.clients[server_name]
216
+ script_path = client.script_path if hasattr(client, 'script_path') else None
217
+ if script_path:
218
+ logger.info(f"Attempting to restart {server_name} server...")
219
+ await client.close()
220
+ # Re-add the server (which will restart it)
221
+ await self.add_server(server_name, script_path)
222
+ # Try listing tools again after restart
223
+ tools = await self.clients[server_name].list_tools()
224
+ for tool in tools:
225
+ tool['server'] = server_name
226
+ all_tools[server_name] = tools
227
+ logger.info(f"Successfully restarted {server_name} server")
228
+ except Exception as restart_error:
229
+ logger.error(f"Failed to restart {server_name}: {restart_error}")
230
+ # Remove the failed server from clients to prevent further errors
231
+ if server_name in self.clients:
232
+ del self.clients[server_name]
233
+
234
+ return all_tools
235
+
236
+ async def close_all(self):
237
+ """Close all MCP clients"""
238
+ for client in self.clients.values():
239
+ await client.close()
240
+ self.clients.clear()
241
+ logger.info("All MCP servers closed")
llm_client.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Unified LLM Client - Single interface for all LLM providers
4
+ Handles Anthropic, SambaNova, and automatic fallback logic internally
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import asyncio
10
+ import httpx
11
+ from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple
12
+ from anthropic import AsyncAnthropic
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class UnifiedLLMClient:
20
+ """
21
+ Unified client that abstracts all LLM provider logic.
22
+ Provides a single, clean interface to the application.
23
+ """
24
+
25
+ def __init__(self):
26
+ """Initialize the unified client with automatic provider selection"""
27
+ self.primary_client = None
28
+ self.fallback_router = None
29
+ self.provider_name = None
30
+ self.config = self._load_configuration()
31
+ self._initialize_providers()
32
+
33
+ def _load_configuration(self) -> Dict[str, Any]:
34
+ """Load configuration from environment variables"""
35
+ return {
36
+ "anthropic_api_key": os.getenv("ANTHROPIC_API_KEY"),
37
+ "use_fallback": os.getenv("USE_FALLBACK_LLM", "false").lower() == "true",
38
+ "provider_preference": os.getenv("LLM_PROVIDER_PREFERENCE", "auto"),
39
+ "default_model": os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5-20250929"),
40
+ "max_retries": int(os.getenv("LLM_MAX_RETRIES", "2")),
41
+ "is_hf_space": os.getenv("SPACE_ID") is not None,
42
+ "enable_smart_routing": os.getenv("ENABLE_SMART_ROUTING", "false").lower() == "true"
43
+ }
44
+
45
+ def _initialize_providers(self):
46
+ """Initialize LLM providers based on configuration"""
47
+
48
+ # Try to initialize Anthropic first
49
+ if self.config["anthropic_api_key"]:
50
+ try:
51
+ self.primary_client = AsyncAnthropic(api_key=self.config["anthropic_api_key"])
52
+ self.provider_name = "Anthropic Claude"
53
+ logger.info("Anthropic client initialized successfully")
54
+ except Exception as e:
55
+ logger.warning(f"Failed to initialize Anthropic client: {e}")
56
+ self.primary_client = None
57
+
58
+ # Initialize fallback if needed
59
+ if self.config["use_fallback"] or not self.primary_client:
60
+ try:
61
+ from llm_providers import llm_router
62
+ self.fallback_router = llm_router
63
+
64
+ if not self.primary_client:
65
+ self.provider_name = "SambaNova Llama 3.3 70B"
66
+ logger.info("Using SambaNova as primary provider")
67
+ else:
68
+ logger.info("SambaNova fallback configured for automatic failover")
69
+
70
+ except ImportError:
71
+ logger.warning("Fallback LLM provider not available")
72
+
73
+ if not self.primary_client:
74
+ self._raise_configuration_error()
75
+
76
+ def _raise_configuration_error(self):
77
+ """Raise appropriate error for missing configuration"""
78
+ if self.config["is_hf_space"]:
79
+ raise ValueError(
80
+ "🚨 No LLM provider configured!\n\n"
81
+ "Option 1: Add your Anthropic API key as a Space secret:\n"
82
+ "1. Go to your Space Settings\n"
83
+ "2. Add secret: ANTHROPIC_API_KEY = your_key\n\n"
84
+ "Option 2: Enable free SambaNova fallback:\n"
85
+ "Add secret: USE_FALLBACK_LLM = true"
86
+ )
87
+ else:
88
+ raise ValueError(
89
+ "No LLM provider configured.\n\n"
90
+ "Option 1: Add to .env file:\n"
91
+ "ANTHROPIC_API_KEY=your_api_key_here\n\n"
92
+ "Option 2: Enable free SambaNova:\n"
93
+ "USE_FALLBACK_LLM=true"
94
+ )
95
+
96
+ async def stream(
97
+ self,
98
+ messages: List[Dict],
99
+ tools: List[Dict] = None,
100
+ system_prompt: str = None,
101
+ model: str = None,
102
+ max_tokens: int = 8192,
103
+ temperature: float = 0.7
104
+ ) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
105
+ """
106
+ Stream responses from the LLM with automatic fallback.
107
+
108
+ This is the main interface - it handles all provider selection,
109
+ retries, and fallback logic internally.
110
+
111
+ Yields: (response_text, tool_calls, provider_used)
112
+ """
113
+
114
+ # Use default model if not specified
115
+ if model is None:
116
+ model = self.config["default_model"]
117
+
118
+ # Track which provider we're using
119
+ provider_used = self.provider_name
120
+
121
+ # Determine provider order based on preference
122
+ use_anthropic_first = True
123
+ if self.config["provider_preference"] == "cost_optimize" and self.fallback_router:
124
+ # With cost_optimize, prefer SambaNova first
125
+ use_anthropic_first = False
126
+
127
+ # Apply smart routing if enabled
128
+ if self.config.get("enable_smart_routing", False) and self.primary_client and self.fallback_router:
129
+ # Extract the last user message for analysis
130
+ last_message = ""
131
+ for msg in reversed(messages):
132
+ if msg.get("role") == "user":
133
+ if isinstance(msg.get("content"), str):
134
+ last_message = msg["content"]
135
+ elif isinstance(msg.get("content"), list):
136
+ # Extract text from content blocks
137
+ for block in msg["content"]:
138
+ if isinstance(block, dict) and block.get("type") == "text":
139
+ last_message = block.get("text", "")
140
+ break
141
+ break
142
+
143
+ if last_message:
144
+ # Classify the query
145
+ query_type = self.classify_query_complexity(
146
+ last_message,
147
+ len(tools) if tools else 0
148
+ )
149
+
150
+ # Override provider preference based on classification
151
+ if query_type == "simple":
152
+ if use_anthropic_first:
153
+ logger.info(f"Smart routing: Directing simple query to Llama for cost savings: '{last_message[:80]}...'")
154
+ use_anthropic_first = False
155
+ elif query_type == "complex":
156
+ if not use_anthropic_first:
157
+ logger.info(f"Smart routing: Directing complex query to Claude for better quality: '{last_message[:80]}...'")
158
+ use_anthropic_first = True
159
+
160
+ # Try first provider based on preference
161
+ if use_anthropic_first and self.primary_client:
162
+ try:
163
+ async for result in self._stream_anthropic(
164
+ messages, tools, system_prompt, model, max_tokens, temperature
165
+ ):
166
+ yield result
167
+ return # Success, exit
168
+ except Exception as e:
169
+ logger.warning(f"Primary provider failed: {e}")
170
+
171
+ # Fall through to fallback if available
172
+ if not self.fallback_router:
173
+ raise
174
+
175
+ # Try fallback provider
176
+ if self.fallback_router:
177
+ if not use_anthropic_first or not self.primary_client:
178
+ logger.info("Using SambaNova as primary provider (cost_optimize mode)" if not use_anthropic_first else "Using fallback LLM provider")
179
+
180
+ try:
181
+ # Override provider preference to force SambaNova when smart routing decided to use it
182
+ effective_preference = "cost_optimize" if not use_anthropic_first else self.config["provider_preference"]
183
+
184
+ async for text, tool_calls, provider in self.fallback_router.stream_with_fallback(
185
+ messages=messages,
186
+ tools=tools or [],
187
+ system_prompt=system_prompt,
188
+ model=model,
189
+ max_tokens=max_tokens,
190
+ provider_preference=effective_preference
191
+ ):
192
+ yield (text, tool_calls, provider)
193
+
194
+ # If we used SambaNova first successfully with cost_optimize, we're done
195
+ if not use_anthropic_first:
196
+ return
197
+
198
+ except Exception as e:
199
+ if not use_anthropic_first and self.primary_client:
200
+ # SambaNova failed in cost_optimize mode, try Anthropic
201
+ logger.warning(f"SambaNova failed in cost_optimize mode: {e}, falling back to Anthropic")
202
+ try:
203
+ async for result in self._stream_anthropic(
204
+ messages, tools, system_prompt, model, max_tokens, temperature
205
+ ):
206
+ yield result
207
+ return # Success, exit
208
+ except Exception as anthropic_error:
209
+ logger.error(f"All LLM providers failed: SambaNova: {e}, Anthropic: {anthropic_error}")
210
+ raise RuntimeError("All LLM providers failed. Please check configuration.")
211
+ else:
212
+ logger.error(f"All LLM providers failed: {e}")
213
+ raise RuntimeError("All LLM providers failed. Please check configuration.")
214
+ else:
215
+ raise RuntimeError("No LLM providers available")
216
+
217
+ async def _stream_anthropic(
218
+ self,
219
+ messages: List[Dict],
220
+ tools: List[Dict],
221
+ system_prompt: str,
222
+ model: str,
223
+ max_tokens: int,
224
+ temperature: float
225
+ ) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
226
+ """Stream from Anthropic with retry logic"""
227
+
228
+ retry_delay = 1
229
+ last_error = None
230
+
231
+ # Skip system message if it's in messages array
232
+ api_messages = messages[1:] if messages and messages[0].get("role") == "system" else messages
233
+
234
+ # Use system prompt or extract from messages
235
+ if not system_prompt and messages and messages[0].get("role") == "system":
236
+ system_prompt = messages[0].get("content", "")
237
+
238
+ for attempt in range(self.config["max_retries"] + 1):
239
+ try:
240
+ logger.info(f"Streaming from Anthropic (attempt {attempt + 1})")
241
+
242
+ accumulated_text = ""
243
+ tool_calls = []
244
+
245
+ # Create the stream
246
+ stream_params = {
247
+ "model": model,
248
+ "max_tokens": max_tokens,
249
+ "messages": api_messages,
250
+ "temperature": temperature
251
+ }
252
+
253
+ if system_prompt:
254
+ stream_params["system"] = system_prompt
255
+
256
+ if tools:
257
+ stream_params["tools"] = tools
258
+
259
+ async with self.primary_client.messages.stream(**stream_params) as stream:
260
+ async for event in stream:
261
+ if event.type == "content_block_start":
262
+ if event.content_block.type == "tool_use":
263
+ tool_calls.append({
264
+ "id": event.content_block.id,
265
+ "name": event.content_block.name,
266
+ "input": {}
267
+ })
268
+
269
+ elif event.type == "content_block_delta":
270
+ if event.delta.type == "text_delta":
271
+ accumulated_text += event.delta.text
272
+ yield (accumulated_text, tool_calls, "Anthropic Claude")
273
+
274
+ # Get final message
275
+ final_message = await stream.get_final_message()
276
+
277
+ # Rebuild tool calls from final message
278
+ tool_calls.clear()
279
+ for block in final_message.content:
280
+ if block.type == "tool_use":
281
+ tool_calls.append({
282
+ "id": block.id,
283
+ "name": block.name,
284
+ "input": block.input
285
+ })
286
+ elif block.type == "text" and block.text:
287
+ if block.text not in accumulated_text:
288
+ accumulated_text += block.text
289
+
290
+ yield (accumulated_text, tool_calls, "Anthropic Claude")
291
+ return # Success
292
+
293
+ except (httpx.RemoteProtocolError, httpx.ReadError) as e:
294
+ last_error = e
295
+ logger.warning(f"Network error on attempt {attempt + 1}: {e}")
296
+
297
+ if attempt < self.config["max_retries"]:
298
+ await asyncio.sleep(retry_delay)
299
+ retry_delay *= 2
300
+ else:
301
+ raise
302
+
303
+ except Exception as e:
304
+ logger.error(f"Anthropic streaming error: {e}")
305
+ raise
306
+
307
+ def get_status(self) -> Dict[str, Any]:
308
+ """Get current client status and configuration"""
309
+ return {
310
+ "primary_provider": "Anthropic" if self.primary_client else None,
311
+ "fallback_enabled": bool(self.fallback_router),
312
+ "current_provider": self.provider_name,
313
+ "provider_preference": self.config["provider_preference"],
314
+ "max_retries": self.config["max_retries"]
315
+ }
316
+
317
+ def is_using_llama_primary(self) -> bool:
318
+ """Check if Llama/SambaNova is the primary provider"""
319
+ # Check if cost_optimize preference is set and fallback is available
320
+ if self.config.get("provider_preference") == "cost_optimize" and self.fallback_router:
321
+ return True
322
+ # Check if we have no Anthropic client and are using SambaNova
323
+ if not self.primary_client and self.fallback_router:
324
+ return True
325
+ return False
326
+
327
+ def classify_query_complexity(self, message: str, tools_count: int = 0) -> str:
328
+ """
329
+ Classify query as 'simple' or 'complex' based on content analysis.
330
+
331
+ Args:
332
+ message: The user's query text
333
+ tools_count: Number of tools available for this query
334
+
335
+ Returns:
336
+ 'simple' | 'complex' - The query classification
337
+ """
338
+ message_lower = message.lower()
339
+
340
+ # Simple query indicators (good for Llama)
341
+ simple_patterns = [
342
+ "what is", "define", "when was", "who is", "list of",
343
+ "how many", "name the", "what does", "explain what",
344
+ "is there", "are there", "can you list", "tell me about",
345
+ "what are the symptoms", "side effects of", "list the",
346
+ "symptoms of", "treatment for", "causes of"
347
+ ]
348
+
349
+ # Complex query indicators (better for Claude)
350
+ complex_patterns = [
351
+ "analyze", "compare", "evaluate", "synthesize", "comprehensive",
352
+ "all", "every", "detailed", "mechanism", "pathophysiology",
353
+ "genotyping", "gene therapy", "combination therapy",
354
+ "latest research", "recent studies", "cutting-edge",
355
+ "molecular", "genetic mutation", "therapeutic pipeline",
356
+ "clinical trial results", "meta-analysis", "systematic review",
357
+ # Enhanced trial-related patterns
358
+ "trials", "clinical trials", "studies", "clinical study",
359
+ "NCT", "recruiting", "enrollment", "study protocol",
360
+ "phase 1", "phase 2", "phase 3", "phase 4", "early phase",
361
+ "investigational", "experimental", "novel treatment",
362
+ "treatment pipeline", "research pipeline", "drug development"
363
+ ]
364
+
365
+ # Count pattern matches
366
+ simple_score = sum(1 for pattern in simple_patterns if pattern in message_lower)
367
+ complex_score = sum(1 for pattern in complex_patterns if pattern in message_lower)
368
+
369
+ # Decision logic
370
+ if complex_score > 0:
371
+ # Any complex indicator suggests complex query
372
+ return "complex"
373
+ elif simple_score > 0 and len(message) < 150:
374
+ # Simple pattern and short query
375
+ return "simple"
376
+ elif len(message) > 300:
377
+ # Long queries are likely complex
378
+ return "complex"
379
+ elif tools_count > 8:
380
+ # Many tools suggest complex analysis needed
381
+ return "complex"
382
+ else:
383
+ # Default to complex for safety (better quality)
384
+ return "complex" if self.primary_client else "simple"
385
+
386
+ def get_provider_display_name(self) -> str:
387
+ """Get a user-friendly provider status string"""
388
+ if self.primary_client and self.fallback_router:
389
+ # Both providers available
390
+ if self.config["provider_preference"] == "cost_optimize":
391
+ status = "SambaNova Llama 3.3 70B (primary, cost-optimized) with Anthropic Claude fallback"
392
+ elif self.config["provider_preference"] == "quality_first":
393
+ status = "Anthropic Claude (primary, quality-first) with SambaNova fallback"
394
+ else: # auto
395
+ status = "Anthropic Claude (with SambaNova fallback)"
396
+ elif self.primary_client:
397
+ status = "Anthropic Claude"
398
+ elif self.fallback_router:
399
+ status = f"SambaNova Llama 3.3 70B ({self.config['provider_preference']} mode)"
400
+ else:
401
+ status = "Not configured"
402
+
403
+ return status
404
+
405
+ async def cleanup(self):
406
+ """Clean up resources"""
407
+ if self.fallback_router:
408
+ try:
409
+ await self.fallback_router.cleanup()
410
+ except:
411
+ pass
412
+
413
+ async def __aenter__(self):
414
+ """Async context manager entry"""
415
+ return self
416
+
417
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
418
+ """Async context manager exit"""
419
+ await self.cleanup()
420
+
421
+
422
+ # Global instance (optional - can be created per request instead)
423
+ _global_client: Optional[UnifiedLLMClient] = None
424
+
425
+
426
+ def get_llm_client() -> UnifiedLLMClient:
427
+ """Get or create the global LLM client instance"""
428
+ global _global_client
429
+ if _global_client is None:
430
+ _global_client = UnifiedLLMClient()
431
+ return _global_client
432
+
433
+
434
+ async def cleanup_global_client():
435
+ """Clean up the global client instance"""
436
+ global _global_client
437
+ if _global_client:
438
+ await _global_client.cleanup()
439
+ _global_client = None
llm_providers.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multi-LLM provider support with fallback logic.
4
+ Includes SambaNova free tier as primary fallback option.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import logging
10
+ import httpx
11
+ import asyncio
12
+ from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple
13
+ from anthropic import AsyncAnthropic
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class SambaNovaProvider:
22
+ """
23
+ SambaNova Cloud provider - requires API key for access.
24
+ Get your API key at https://cloud.sambanova.ai/
25
+ Includes $5-30 free credits for new accounts.
26
+ """
27
+
28
+ BASE_URL = "https://api.sambanova.ai/v1"
29
+
30
+ # Available models
31
+ MODELS = {
32
+ "llama-3.3-70b": "Meta-Llama-3.3-70B-Instruct", # Latest and best!
33
+ "llama-3.1-405b": "Meta-Llama-3.1-405B-Instruct",
34
+ "llama-3.1-70b": "Meta-Llama-3.1-70B-Instruct",
35
+ "llama-3.1-8b": "Meta-Llama-3.1-8B-Instruct",
36
+ "llama-3.2-11b": "Llama-3.2-11B-Vision-Instruct",
37
+ "llama-3.2-3b": "Llama-3.2-3B-Instruct",
38
+ "llama-3.2-1b": "Llama-3.2-1B-Instruct"
39
+ }
40
+
41
+ def __init__(self, api_key: Optional[str] = None):
42
+ """
43
+ Initialize SambaNova provider.
44
+ API key is REQUIRED - get yours at https://cloud.sambanova.ai/
45
+ """
46
+ self.api_key = api_key or os.getenv("SAMBANOVA_API_KEY")
47
+ if not self.api_key:
48
+ raise ValueError(
49
+ "SAMBANOVA_API_KEY is required for SambaNova API access.\n"
50
+ "Get your API key at: https://cloud.sambanova.ai/\n"
51
+ "Then set it in your .env file: SAMBANOVA_API_KEY=your_key_here"
52
+ )
53
+ self.client = httpx.AsyncClient(timeout=60.0)
54
+
55
+ async def stream(
56
+ self,
57
+ messages: List[Dict],
58
+ system: str = None,
59
+ tools: List[Dict] = None,
60
+ model: str = "llama-3.1-70b",
61
+ max_tokens: int = 4096,
62
+ temperature: float = 0.7
63
+ ) -> AsyncGenerator[Tuple[str, List[Dict]], None]:
64
+ """
65
+ Stream responses from SambaNova API.
66
+ Compatible interface with Anthropic streaming.
67
+ """
68
+
69
+ # Select the full model name
70
+ full_model = self.MODELS.get(model, self.MODELS["llama-3.1-70b"])
71
+
72
+ # Convert messages to OpenAI format (SambaNova uses OpenAI-compatible API)
73
+ formatted_messages = []
74
+
75
+ # Add system message if provided
76
+ if system:
77
+ formatted_messages.append({
78
+ "role": "system",
79
+ "content": system
80
+ })
81
+
82
+ # Convert Anthropic message format to OpenAI format
83
+ for msg in messages:
84
+ if msg["role"] == "user":
85
+ formatted_messages.append({
86
+ "role": "user",
87
+ "content": msg.get("content", "")
88
+ })
89
+ elif msg["role"] == "assistant":
90
+ # Handle assistant messages with potential tool calls
91
+ content = msg.get("content", "")
92
+ if isinstance(content, list):
93
+ # Extract text from content blocks
94
+ text_parts = []
95
+ for block in content:
96
+ if block.get("type") == "text":
97
+ text_parts.append(block.get("text", ""))
98
+ content = "\n".join(text_parts)
99
+
100
+ formatted_messages.append({
101
+ "role": "assistant",
102
+ "content": content
103
+ })
104
+
105
+ # Prepare request payload
106
+ payload = {
107
+ "model": full_model,
108
+ "messages": formatted_messages,
109
+ "max_tokens": max_tokens,
110
+ "temperature": temperature,
111
+ "stream": True,
112
+ "stream_options": {"include_usage": True}
113
+ }
114
+
115
+ # Add tools if provided (for models that support it)
116
+ if tools and model in ["llama-3.3-70b", "llama-3.1-405b", "llama-3.1-70b"]:
117
+ # Convert Anthropic tool format to OpenAI format
118
+ openai_tools = []
119
+ for tool in tools:
120
+ openai_tools.append({
121
+ "type": "function",
122
+ "function": {
123
+ "name": tool["name"],
124
+ "description": tool.get("description", ""),
125
+ "parameters": tool.get("input_schema", {})
126
+ }
127
+ })
128
+ payload["tools"] = openai_tools
129
+ payload["tool_choice"] = "auto"
130
+
131
+ # Headers - API key is always required now
132
+ headers = {
133
+ "Content-Type": "application/json",
134
+ "Authorization": f"Bearer {self.api_key}"
135
+ }
136
+
137
+ try:
138
+ # Make streaming request
139
+ accumulated_text = ""
140
+ tool_calls = []
141
+
142
+ async with self.client.stream(
143
+ "POST",
144
+ f"{self.BASE_URL}/chat/completions",
145
+ json=payload,
146
+ headers=headers
147
+ ) as response:
148
+ response.raise_for_status()
149
+
150
+ async for line in response.aiter_lines():
151
+ if line.startswith("data: "):
152
+ data = line[6:] # Remove "data: " prefix
153
+
154
+ if data == "[DONE]":
155
+ break
156
+
157
+ try:
158
+ chunk = json.loads(data)
159
+
160
+ # Handle usage-only chunks (sent at end of stream)
161
+ if "usage" in chunk and ("choices" not in chunk or len(chunk.get("choices", [])) == 0):
162
+ # This is a usage statistics chunk, skip it
163
+ logger.debug(f"Received usage chunk: {chunk.get('usage', {})}")
164
+ continue
165
+
166
+ # Extract content from chunk
167
+ if "choices" in chunk and len(chunk["choices"]) > 0:
168
+ choice = chunk["choices"][0]
169
+ delta = choice.get("delta", {})
170
+
171
+ # Handle text content
172
+ if "content" in delta and delta["content"]:
173
+ accumulated_text += delta["content"]
174
+ yield (accumulated_text, tool_calls)
175
+
176
+ # Handle tool calls (if supported)
177
+ if "tool_calls" in delta:
178
+ for tc in delta["tool_calls"]:
179
+ # Convert OpenAI tool call format to Anthropic format
180
+ tool_calls.append({
181
+ "id": tc.get("id", f"tool_{len(tool_calls)}"),
182
+ "name": tc.get("function", {}).get("name", ""),
183
+ "input": json.loads(tc.get("function", {}).get("arguments", "{}"))
184
+ })
185
+
186
+ except json.JSONDecodeError:
187
+ logger.warning(f"Failed to parse SSE data: {data}")
188
+ continue
189
+
190
+ # Final yield with complete results
191
+ yield (accumulated_text, tool_calls)
192
+
193
+ except httpx.HTTPStatusError as e:
194
+ if e.response.status_code == 410:
195
+ logger.error("SambaNova API endpoint has been discontinued (410 GONE)")
196
+ raise RuntimeError(
197
+ "SambaNova API endpoint no longer exists. "
198
+ "Make sure you have a valid API key set in SAMBANOVA_API_KEY."
199
+ )
200
+ elif e.response.status_code == 401:
201
+ logger.error("SambaNova API authentication failed")
202
+ raise RuntimeError(
203
+ "SambaNova authentication failed. Please check your API key."
204
+ )
205
+ else:
206
+ logger.error(f"SambaNova API error: {e}")
207
+ raise
208
+ except httpx.HTTPError as e:
209
+ logger.error(f"SambaNova API error: {e}")
210
+ raise
211
+
212
+ async def close(self):
213
+ """Close the HTTP client"""
214
+ await self.client.aclose()
215
+
216
+
217
+ class LLMRouter:
218
+ """
219
+ Routes LLM requests to appropriate providers with fallback logic.
220
+ """
221
+
222
+ def __init__(self):
223
+ self.providers = {}
224
+ self._setup_providers()
225
+
226
+ def _setup_providers(self):
227
+ """Initialize available providers"""
228
+
229
+ # Primary: Anthropic (if API key available)
230
+ anthropic_key = os.getenv("ANTHROPIC_API_KEY")
231
+ if anthropic_key:
232
+ self.providers["anthropic"] = AsyncAnthropic(api_key=anthropic_key)
233
+ logger.info("Anthropic provider initialized")
234
+
235
+ # Fallback: SambaNova (always available, free!)
236
+ self.providers["sambanova"] = SambaNovaProvider()
237
+ logger.info("SambaNova provider initialized (free tier)")
238
+
239
+ async def stream_with_fallback(
240
+ self,
241
+ messages: List[Dict],
242
+ tools: List[Dict],
243
+ system_prompt: str,
244
+ model: str = None,
245
+ max_tokens: int = 4096,
246
+ provider_preference: str = "auto"
247
+ ) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
248
+ """
249
+ Stream from LLM with automatic fallback.
250
+ Returns (text, tool_calls, provider_used) tuples.
251
+ """
252
+
253
+ # Determine provider order based on preference
254
+ if provider_preference == "cost_optimize":
255
+ # Prefer free SambaNova first
256
+ provider_order = ["sambanova", "anthropic"]
257
+ elif provider_preference == "quality_first":
258
+ # Prefer Anthropic first
259
+ provider_order = ["anthropic", "sambanova"]
260
+ else: # auto
261
+ # Use Anthropic if available, fall back to SambaNova
262
+ provider_order = ["anthropic", "sambanova"] if "anthropic" in self.providers else ["sambanova"]
263
+
264
+ last_error = None
265
+
266
+ for provider_name in provider_order:
267
+ if provider_name not in self.providers:
268
+ continue
269
+
270
+ try:
271
+ logger.info(f"Attempting to use {provider_name} provider...")
272
+
273
+ if provider_name == "anthropic":
274
+ # Use existing Anthropic streaming
275
+ provider = self.providers["anthropic"]
276
+
277
+ # Stream from Anthropic
278
+ accumulated_text = ""
279
+ tool_calls = []
280
+
281
+ async with provider.messages.stream(
282
+ model=model or "claude-sonnet-4-5-20250929",
283
+ max_tokens=max_tokens,
284
+ messages=messages,
285
+ system=system_prompt,
286
+ tools=tools
287
+ ) as stream:
288
+ async for event in stream:
289
+ if event.type == "content_block_start":
290
+ if event.content_block.type == "tool_use":
291
+ tool_calls.append({
292
+ "id": event.content_block.id,
293
+ "name": event.content_block.name,
294
+ "input": {}
295
+ })
296
+
297
+ elif event.type == "content_block_delta":
298
+ if event.delta.type == "text_delta":
299
+ accumulated_text += event.delta.text
300
+ yield (accumulated_text, tool_calls, "Anthropic Claude")
301
+
302
+ # Get final message
303
+ final_message = await stream.get_final_message()
304
+
305
+ # Rebuild tool calls from final message
306
+ tool_calls.clear()
307
+ for block in final_message.content:
308
+ if block.type == "tool_use":
309
+ tool_calls.append({
310
+ "id": block.id,
311
+ "name": block.name,
312
+ "input": block.input
313
+ })
314
+
315
+ yield (accumulated_text, tool_calls, "Anthropic Claude")
316
+ return # Success!
317
+
318
+ elif provider_name == "sambanova":
319
+ # Use SambaNova streaming
320
+ provider = self.providers["sambanova"]
321
+
322
+ # Determine which Llama model to use
323
+ if max_tokens > 8192:
324
+ samba_model = "llama-3.1-405b" # Largest model for complex tasks
325
+ else:
326
+ # Default to Llama 3.3 70B - newest and best for most tasks
327
+ samba_model = "llama-3.3-70b"
328
+
329
+ async for text, tool_calls in provider.stream(
330
+ messages=messages,
331
+ system=system_prompt,
332
+ tools=tools,
333
+ model=samba_model,
334
+ max_tokens=max_tokens
335
+ ):
336
+ yield (text, tool_calls, f"SambaNova {samba_model}")
337
+
338
+ return # Success!
339
+
340
+ except Exception as e:
341
+ logger.warning(f"Provider {provider_name} failed: {e}")
342
+ last_error = e
343
+ continue
344
+
345
+ # All providers failed
346
+ error_msg = f"All LLM providers failed. Last error: {last_error}"
347
+ logger.error(error_msg)
348
+ raise Exception(error_msg)
349
+
350
+ async def cleanup(self):
351
+ """Clean up provider resources"""
352
+ if "sambanova" in self.providers:
353
+ await self.providers["sambanova"].close()
354
+
355
+
356
+ # Global router instance
357
+ llm_router = LLMRouter()
parallel_tool_execution.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Parallel tool execution optimization for ALS Research Agent
4
+ This module replaces sequential tool execution with parallel execution
5
+ to reduce response time by ~60-70% for multi-tool queries.
6
+ """
7
+
8
+ import asyncio
9
+ from typing import List, Dict, Tuple, Any
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ async def execute_single_tool(
16
+ tool_call: Dict,
17
+ call_mcp_tool_func,
18
+ index: int
19
+ ) -> Tuple[int, str, Dict]:
20
+ """
21
+ Execute a single tool call asynchronously.
22
+ Returns (index, progress_text, result_dict) to maintain order.
23
+ """
24
+ tool_name = tool_call["name"]
25
+ tool_args = tool_call["input"]
26
+
27
+ # Show search info in progress text
28
+ tool_display = tool_name.replace('__', ' → ')
29
+ search_info = ""
30
+ if "query" in tool_args:
31
+ search_info = f" `{tool_args['query'][:50]}{'...' if len(tool_args['query']) > 50 else ''}`"
32
+ elif "condition" in tool_args:
33
+ search_info = f" `{tool_args['condition'][:50]}{'...' if len(tool_args['condition']) > 50 else ''}`"
34
+
35
+ try:
36
+ # Call MCP tool
37
+ start_time = asyncio.get_event_loop().time()
38
+ tool_result = await call_mcp_tool_func(tool_name, tool_args)
39
+ elapsed = asyncio.get_event_loop().time() - start_time
40
+
41
+ logger.info(f"Tool {tool_name} completed in {elapsed:.2f}s")
42
+
43
+ # Check for zero results to provide clear indicators
44
+ has_results = True
45
+ results_count = 0
46
+
47
+ if isinstance(tool_result, str):
48
+ result_lower = tool_result.lower()
49
+
50
+ # Check for specific result counts
51
+ import re
52
+ count_matches = re.findall(r'found (\d+) (?:papers?|trials?|preprints?|results?)', result_lower)
53
+ if count_matches:
54
+ results_count = int(count_matches[0])
55
+
56
+ # Check for no results
57
+ if any(phrase in result_lower for phrase in [
58
+ "no results found", "0 results", "no papers found",
59
+ "no trials found", "no preprints found", "not found",
60
+ "zero results", "no matches"
61
+ ]) or results_count == 0:
62
+ has_results = False
63
+
64
+ # Create clear success/failure indicator
65
+ if has_results:
66
+ if results_count > 0:
67
+ progress_text = f"\n✅ **Found {results_count} results:** {tool_display}{search_info}"
68
+ else:
69
+ progress_text = f"\n✅ **Success:** {tool_display}{search_info}"
70
+ else:
71
+ progress_text = f"\n⚠️ **No results:** {tool_display}{search_info} - will try alternatives"
72
+
73
+ # Add timing for long operations
74
+ if elapsed > 5:
75
+ progress_text += f" (took {elapsed:.1f}s)"
76
+
77
+ # Check for zero results to enable self-correction
78
+ if not has_results:
79
+ # Add self-correction hint to the result
80
+ tool_result += "\n\n**SELF-CORRECTION HINT:** No results found with this query. Consider:\n"
81
+ tool_result += "1. Broadening search terms (remove qualifiers)\n"
82
+ tool_result += "2. Using alternative terminology or synonyms\n"
83
+ tool_result += "3. Searching related concepts\n"
84
+ tool_result += "4. Checking for typos in search terms"
85
+
86
+ result_dict = {
87
+ "type": "tool_result",
88
+ "tool_use_id": tool_call["id"],
89
+ "content": tool_result
90
+ }
91
+
92
+ return index, progress_text, result_dict
93
+
94
+ except Exception as e:
95
+ logger.error(f"Error executing tool {tool_name}: {e}")
96
+
97
+ # Clear failure indicator for errors
98
+ progress_text = f"\n❌ **Failed:** {tool_display}{search_info} - {str(e)[:50]}"
99
+
100
+ error_result = {
101
+ "type": "tool_result",
102
+ "tool_use_id": tool_call["id"],
103
+ "content": f"Error executing tool: {str(e)}"
104
+ }
105
+ return index, progress_text, error_result
106
+
107
+
108
+ async def execute_tool_calls_parallel(
109
+ tool_calls: List[Dict],
110
+ call_mcp_tool_func
111
+ ) -> Tuple[str, List[Dict]]:
112
+ """
113
+ Execute tool calls in parallel and collect results.
114
+ Maintains the original order of tool calls in results.
115
+
116
+ Returns: (progress_text, tool_results_content)
117
+ """
118
+ if not tool_calls:
119
+ return "", []
120
+
121
+ # Track execution time for progress reporting
122
+ start_time = asyncio.get_event_loop().time()
123
+
124
+ # Log parallel execution
125
+ logger.info(f"Executing {len(tool_calls)} tools in parallel")
126
+
127
+ # Create tasks for parallel execution
128
+ tasks = [
129
+ execute_single_tool(tool_call, call_mcp_tool_func, i)
130
+ for i, tool_call in enumerate(tool_calls)
131
+ ]
132
+
133
+ # Execute all tasks in parallel
134
+ results = await asyncio.gather(*tasks, return_exceptions=True)
135
+
136
+ # Sort results by index to maintain original order
137
+ sorted_results = sorted(
138
+ [r for r in results if not isinstance(r, Exception)],
139
+ key=lambda x: x[0]
140
+ )
141
+
142
+ # Combine results with progress summary
143
+ completed_count = len(sorted_results)
144
+ total_count = len(tool_calls)
145
+
146
+ # Create progress summary with timing info
147
+ elapsed_time = asyncio.get_event_loop().time() - start_time
148
+
149
+ if elapsed_time > 5:
150
+ timing_info = f" in {elapsed_time:.1f}s"
151
+ else:
152
+ timing_info = ""
153
+
154
+ progress_text = f"\n📊 **Search Progress:** Completed {completed_count}/{total_count} searches{timing_info}\n"
155
+
156
+ tool_results_content = []
157
+
158
+ for index, prog_text, result_dict in sorted_results:
159
+ progress_text += prog_text
160
+ tool_results_content.append(result_dict)
161
+
162
+ # Handle any exceptions
163
+ for i, result in enumerate(results):
164
+ if isinstance(result, Exception):
165
+ logger.error(f"Task {i} failed with exception: {result}")
166
+ # Add error result for failed tasks
167
+ if i < len(tool_calls):
168
+ tool_results_content.insert(i, {
169
+ "type": "tool_result",
170
+ "tool_use_id": tool_calls[i]["id"],
171
+ "content": f"Tool execution failed: {str(result)}"
172
+ })
173
+
174
+ return progress_text, tool_results_content
175
+
176
+
177
+ # Backward compatibility wrapper
178
+ async def execute_tool_calls_optimized(
179
+ tool_calls: List[Dict],
180
+ call_mcp_tool_func,
181
+ parallel: bool = True
182
+ ) -> Tuple[str, List[Dict]]:
183
+ """
184
+ Execute tool calls with optional parallel execution.
185
+
186
+ Args:
187
+ tool_calls: List of tool calls to execute
188
+ call_mcp_tool_func: Function to call MCP tools
189
+ parallel: If True, execute tools in parallel; if False, execute sequentially
190
+
191
+ Returns: (progress_text, tool_results_content)
192
+ """
193
+ if parallel and len(tool_calls) > 1:
194
+ # Use parallel execution for multiple tools
195
+ return await execute_tool_calls_parallel(tool_calls, call_mcp_tool_func)
196
+ else:
197
+ # Fall back to sequential execution (import from original)
198
+ from refactored_helpers import execute_tool_calls
199
+ return await execute_tool_calls(tool_calls, call_mcp_tool_func)
200
+
201
+
202
+ def estimate_time_savings(num_tools: int, avg_tool_time: float = 3.5) -> Dict[str, float]:
203
+ """
204
+ Estimate time savings from parallel execution.
205
+
206
+ Args:
207
+ num_tools: Number of tools to execute
208
+ avg_tool_time: Average time per tool in seconds
209
+
210
+ Returns: Dictionary with timing estimates
211
+ """
212
+ sequential_time = num_tools * avg_tool_time
213
+ # Parallel time is roughly the time of the slowest tool plus overhead
214
+ parallel_time = avg_tool_time + 0.5 # 0.5s overhead for coordination
215
+
216
+ savings = sequential_time - parallel_time
217
+ savings_percent = (savings / sequential_time) * 100 if sequential_time > 0 else 0
218
+
219
+ return {
220
+ "sequential_time": sequential_time,
221
+ "parallel_time": parallel_time,
222
+ "time_saved": savings,
223
+ "savings_percent": savings_percent
224
+ }
225
+
226
+
227
+ # Test the optimization
228
+ if __name__ == "__main__":
229
+ # Test time savings estimation
230
+ for n in [2, 3, 4, 5]:
231
+ estimates = estimate_time_savings(n)
232
+ print(f"\n{n} tools:")
233
+ print(f" Sequential: {estimates['sequential_time']:.1f}s")
234
+ print(f" Parallel: {estimates['parallel_time']:.1f}s")
235
+ print(f" Savings: {estimates['time_saved']:.1f}s ({estimates['savings_percent']:.0f}%)")
query_classifier.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Query Classification Module for ALS Research Agent
4
+ Determines whether a query requires full research workflow or simple response
5
+ """
6
+
7
+ import re
8
+ from typing import Dict, Tuple, List
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class QueryClassifier:
15
+ """Classify queries as research-required or simple questions"""
16
+
17
+ # Keywords that indicate ALS research is needed
18
+ RESEARCH_KEYWORDS = [
19
+ # Disease-specific terms
20
+ 'als', 'amyotrophic lateral sclerosis', 'motor neuron disease',
21
+ 'mnd', 'lou gehrig', 'ftd', 'frontotemporal dementia',
22
+
23
+ # Medical research terms
24
+ 'clinical trial', 'treatment', 'therapy', 'drug', 'medication',
25
+ 'gene therapy', 'stem cell', 'biomarker', 'diagnosis',
26
+ 'prognosis', 'survival', 'progression', 'symptom',
27
+ 'cure', 'breakthrough', 'research', 'study', 'paper',
28
+ 'latest', 'recent', 'new findings', 'advances',
29
+
30
+ # Specific ALS-related
31
+ 'riluzole', 'edaravone', 'radicava', 'relyvrio', 'qalsody',
32
+ 'tofersen', 'sod1', 'c9orf72', 'tdp-43', 'fus',
33
+
34
+ # Research actions
35
+ 'find studies', 'search papers', 'what research',
36
+ 'clinical evidence', 'scientific literature'
37
+ ]
38
+
39
+ # Keywords that indicate simple/general questions
40
+ SIMPLE_KEYWORDS = [
41
+ 'hello', 'hi', 'hey', 'thanks', 'thank you',
42
+ 'how are you', "what's your name", 'who are you',
43
+ 'what can you do', 'help', 'test', 'testing',
44
+ 'explain', 'define', 'what is', 'what are',
45
+ 'how does', 'why', 'when', 'where', 'who'
46
+ ]
47
+
48
+ # Exclusion patterns for non-research queries
49
+ NON_RESEARCH_PATTERNS = [
50
+ r'^(hi|hello|hey|thanks|thank you)',
51
+ r'^test\s',
52
+ r'^how (are you|do you)',
53
+ r'^what (is|are) (the|a|an)\s+\w+$', # Simple definitions
54
+ r'^(explain|define)\s+\w+$', # Simple explanations
55
+ r'^\w{1,3}$', # Very short queries
56
+ ]
57
+
58
+ @classmethod
59
+ def classify_query(cls, query: str) -> Dict[str, any]:
60
+ """
61
+ Classify a query and determine processing strategy.
62
+
63
+ Returns:
64
+ Dict with:
65
+ - requires_research: bool - Whether to use full research workflow
66
+ - confidence: float - Confidence in classification (0-1)
67
+ - reason: str - Explanation of classification
68
+ - suggested_mode: str - 'research' or 'simple'
69
+ """
70
+ query_lower = query.lower().strip()
71
+
72
+ # Check for very short or empty queries
73
+ if len(query_lower) < 5:
74
+ return {
75
+ 'requires_research': False,
76
+ 'confidence': 0.9,
77
+ 'reason': 'Query too short for research',
78
+ 'suggested_mode': 'simple'
79
+ }
80
+
81
+ # Check exclusion patterns first
82
+ for pattern in cls.NON_RESEARCH_PATTERNS:
83
+ if re.match(pattern, query_lower):
84
+ return {
85
+ 'requires_research': False,
86
+ 'confidence': 0.85,
87
+ 'reason': 'Matches non-research pattern',
88
+ 'suggested_mode': 'simple'
89
+ }
90
+
91
+ # Count research keywords
92
+ research_score = sum(
93
+ 1 for keyword in cls.RESEARCH_KEYWORDS
94
+ if keyword in query_lower
95
+ )
96
+
97
+ # Count simple keywords
98
+ simple_score = sum(
99
+ 1 for keyword in cls.SIMPLE_KEYWORDS
100
+ if keyword in query_lower
101
+ )
102
+
103
+ # Check for question complexity
104
+ has_multiple_questions = query.count('?') > 1
105
+ has_complex_structure = len(query.split()) > 15
106
+ mentions_comparison = any(word in query_lower for word in
107
+ ['compare', 'versus', 'vs', 'difference between'])
108
+
109
+ # Decision logic - Conservative approach for ALS research agent
110
+
111
+ # FIRST: Check if this is truly just a greeting/thanks (only these skip research)
112
+ greeting_only = query_lower in ['hi', 'hello', 'hey', 'thanks', 'thank you', 'bye', 'goodbye', 'test']
113
+ if greeting_only and research_score == 0:
114
+ return {
115
+ 'requires_research': False,
116
+ 'confidence': 0.95,
117
+ 'reason': 'Pure greeting or acknowledgment',
118
+ 'suggested_mode': 'simple'
119
+ }
120
+
121
+ # SECOND: If ANY research keyword is present, use research mode
122
+ # This includes "ALS", "treatment", "therapy", etc.
123
+ if research_score >= 1:
124
+ return {
125
+ 'requires_research': True,
126
+ 'confidence': min(0.95, 0.7 + research_score * 0.1),
127
+ 'reason': f'Contains research-related terms ({research_score} keywords)',
128
+ 'suggested_mode': 'research'
129
+ }
130
+
131
+ # THIRD: Check for questions about the agent itself
132
+ about_agent = any(phrase in query_lower for phrase in [
133
+ 'who are you', 'what can you do', 'how do you work',
134
+ 'what are you', 'your capabilities'
135
+ ])
136
+ if about_agent:
137
+ return {
138
+ 'requires_research': False,
139
+ 'confidence': 0.85,
140
+ 'reason': 'Question about the agent itself',
141
+ 'suggested_mode': 'simple'
142
+ }
143
+
144
+ # DEFAULT: For an ALS research agent, when in doubt, use research mode
145
+ # This is safer than potentially missing important medical queries
146
+ return {
147
+ 'requires_research': True,
148
+ 'confidence': 0.6,
149
+ 'reason': 'Default to research mode for potential medical queries',
150
+ 'suggested_mode': 'research'
151
+ }
152
+
153
+ @classmethod
154
+ def should_use_tools(cls, query: str) -> bool:
155
+ """Quick check if query needs research tools"""
156
+ classification = cls.classify_query(query)
157
+ return classification['requires_research'] and classification['confidence'] > 0.65
158
+
159
+ @classmethod
160
+ def get_processing_hint(cls, classification: Dict) -> str:
161
+ """Get a hint for how to process the query"""
162
+ if classification['requires_research']:
163
+ return "🔬 Using full research workflow ..."
164
+ else:
165
+ return "💬 Providing direct response without research tools"
166
+
167
+
168
+ def test_classifier():
169
+ """Test the classifier with example queries"""
170
+ test_queries = [
171
+ # Should require research
172
+ "What are the latest gene therapy trials for ALS?",
173
+ "Compare riluzole and edaravone effectiveness",
174
+ "Find recent studies on SOD1 mutations",
175
+ "What breakthroughs in ALS treatment happened in 2024?",
176
+ "Are there any promising stem cell therapies for motor neuron disease?",
177
+
178
+ # Should NOT require research
179
+ "Hello, how are you?",
180
+ "What is your name?",
181
+ "Test",
182
+ "Thanks for your help",
183
+ "Explain what a database is",
184
+ "What time is it?",
185
+ "How do I use this app?",
186
+ ]
187
+
188
+ print("Query Classification Test Results")
189
+ print("=" * 60)
190
+
191
+ for query in test_queries:
192
+ result = QueryClassifier.classify_query(query)
193
+ print(f"\nQuery: {query[:50]}...")
194
+ print(f"Requires Research: {result['requires_research']}")
195
+ print(f"Confidence: {result['confidence']:.2f}")
196
+ print(f"Reason: {result['reason']}")
197
+ print(f"Mode: {result['suggested_mode']}")
198
+ print("-" * 40)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ test_classifier()
refactored_helpers.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Helper functions to consolidate duplicate code in als_agent_app.py
4
+ Refactored to improve efficiency and reduce redundancy
5
+ """
6
+
7
+ import asyncio
8
+ import httpx
9
+ import logging
10
+ import os
11
+ from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple
12
+ from llm_client import UnifiedLLMClient
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ async def stream_with_retry(
18
+ client,
19
+ messages: List[Dict],
20
+ tools: List[Dict],
21
+ system_prompt: str,
22
+ max_retries: int = 2,
23
+ model: str = None,
24
+ max_tokens: int = 8192,
25
+ stream_name: str = "API call",
26
+ temperature: float = 0.7
27
+ ) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
28
+ """
29
+ Simplified wrapper that delegates to UnifiedLLMClient.
30
+
31
+ The client parameter can be:
32
+ - An Anthropic client (for backward compatibility)
33
+ - A UnifiedLLMClient instance
34
+ - None (will create a UnifiedLLMClient)
35
+
36
+ Yields: (response_text, tool_calls, provider_used) tuples
37
+ """
38
+
39
+ # If client is None or is an Anthropic client, use UnifiedLLMClient
40
+ if client is None or not hasattr(client, 'stream'):
41
+ # Create or get a UnifiedLLMClient instance
42
+ llm_client = UnifiedLLMClient()
43
+
44
+ logger.info(f"Using {llm_client.get_provider_display_name()} for {stream_name}")
45
+
46
+ try:
47
+ # Use the unified client's stream method
48
+ async for text, tool_calls, provider in llm_client.stream(
49
+ messages=messages,
50
+ tools=tools,
51
+ system_prompt=system_prompt,
52
+ model=model,
53
+ max_tokens=max_tokens,
54
+ temperature=temperature
55
+ ):
56
+ yield (text, tool_calls, provider)
57
+ finally:
58
+ # Clean up if we created the client
59
+ await llm_client.cleanup()
60
+
61
+ else:
62
+ # Client is already a UnifiedLLMClient
63
+ logger.info(f"Using provided {client.get_provider_display_name()} for {stream_name}")
64
+
65
+ async for text, tool_calls, provider in client.stream(
66
+ messages=messages,
67
+ tools=tools,
68
+ system_prompt=system_prompt,
69
+ model=model,
70
+ max_tokens=max_tokens,
71
+ temperature=temperature
72
+ ):
73
+ yield (text, tool_calls, provider)
74
+
75
+
76
+ async def execute_tool_calls(
77
+ tool_calls: List[Dict],
78
+ call_mcp_tool_func
79
+ ) -> Tuple[str, List[Dict]]:
80
+ """
81
+ Execute tool calls and collect results.
82
+ Consolidates duplicate tool execution logic.
83
+ Now includes self-correction hints for zero results.
84
+
85
+ Returns: (progress_text, tool_results_content)
86
+ """
87
+ progress_text = ""
88
+ tool_results_content = []
89
+ zero_result_tools = []
90
+
91
+ for tool_call in tool_calls:
92
+ tool_name = tool_call["name"]
93
+ tool_args = tool_call["input"]
94
+
95
+ # Single, clean execution marker with search info
96
+ tool_display = tool_name.replace('__', ' → ')
97
+
98
+ # Show key search parameters
99
+ search_info = ""
100
+ if "query" in tool_args:
101
+ search_info = f" `{tool_args['query'][:50]}{'...' if len(tool_args['query']) > 50 else ''}`"
102
+ elif "condition" in tool_args:
103
+ search_info = f" `{tool_args['condition'][:50]}{'...' if len(tool_args['condition']) > 50 else ''}`"
104
+
105
+ progress_text += f"\n🔧 **Searching:** {tool_display}{search_info}\n"
106
+
107
+ # Call MCP tool
108
+ tool_result = await call_mcp_tool_func(tool_name, tool_args)
109
+
110
+ # Check for zero results to enable self-correction
111
+ if isinstance(tool_result, str):
112
+ result_lower = tool_result.lower()
113
+ if any(phrase in result_lower for phrase in [
114
+ "no results found", "0 results", "no papers found",
115
+ "no trials found", "no preprints found", "not found",
116
+ "zero results", "no matches"
117
+ ]):
118
+ zero_result_tools.append((tool_name, tool_args))
119
+
120
+ # Add self-correction hint to the result
121
+ tool_result += "\n\n**SELF-CORRECTION HINT:** No results found with this query. Consider:\n"
122
+ tool_result += "1. Broadening search terms (remove qualifiers)\n"
123
+ tool_result += "2. Using alternative terminology or synonyms\n"
124
+ tool_result += "3. Searching related concepts\n"
125
+ tool_result += "4. Checking for typos in search terms"
126
+
127
+ # Add to results array
128
+ tool_results_content.append({
129
+ "type": "tool_result",
130
+ "tool_use_id": tool_call["id"],
131
+ "content": tool_result
132
+ })
133
+
134
+ return progress_text, tool_results_content
135
+
136
+
137
+ def build_assistant_message(
138
+ text_content: str,
139
+ tool_calls: List[Dict],
140
+ strip_markers: List[str] = None
141
+ ) -> List[Dict]:
142
+ """
143
+ Build assistant message content with text and tool uses.
144
+ Consolidates duplicate message building logic.
145
+
146
+ Args:
147
+ text_content: Text content to include
148
+ tool_calls: List of tool calls to include
149
+ strip_markers: Optional list of text markers to strip from content
150
+
151
+ Returns: List of content blocks for assistant message
152
+ """
153
+ assistant_content = []
154
+
155
+ # Process text content
156
+ if text_content and text_content.strip():
157
+ processed_text = text_content
158
+
159
+ # Strip any specified markers
160
+ if strip_markers:
161
+ for marker in strip_markers:
162
+ processed_text = processed_text.replace(marker, "")
163
+
164
+ processed_text = processed_text.strip()
165
+
166
+ if processed_text:
167
+ assistant_content.append({
168
+ "type": "text",
169
+ "text": processed_text
170
+ })
171
+
172
+ # Add tool uses
173
+ for tc in tool_calls:
174
+ assistant_content.append({
175
+ "type": "tool_use",
176
+ "id": tc["id"],
177
+ "name": tc["name"],
178
+ "input": tc["input"]
179
+ })
180
+
181
+ return assistant_content
182
+
183
+
184
+ def should_continue_iterations(
185
+ iteration_count: int,
186
+ max_iterations: int,
187
+ tool_calls: List[Dict]
188
+ ) -> bool:
189
+ """
190
+ Check if tool iterations should continue.
191
+ Centralizes iteration control logic.
192
+ """
193
+ if not tool_calls:
194
+ return False
195
+
196
+ if iteration_count >= max_iterations:
197
+ logger.warning(f"Reached maximum tool iterations ({max_iterations})")
198
+ return False
199
+
200
+ return True
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Dependencies
2
+ # REQUIRES PYTHON 3.10+ (Recommended: Python 3.12)
3
+ # Compatible with Gradio 5.x and 6.x
4
+ gradio>=6.0.0
5
+ anthropic>=0.34.0
6
+ mcp>=1.21.2 # FastMCP API required
7
+ httpx>=0.25.0
8
+
9
+ # Web Scraping
10
+ beautifulsoup4>=4.12.0
11
+ lxml>=4.9.0
12
+
13
+ # Database Access (for AACT clinical trials database)
14
+ psycopg2-binary>=2.9.0
15
+ asyncpg>=0.29.0 # For async PostgreSQL with connection pooling
16
+
17
+ # Configuration
18
+ python-dotenv>=1.0.0
19
+
20
+ # Testing
21
+ pytest>=7.4.0
22
+ pytest-asyncio>=0.21.0
23
+ pytest-cov>=4.1.0
24
+ pytest-mock>=3.12.0
25
+
26
+ # RAG and Research Memory (LlamaIndex)
27
+ llama-index-core>=0.11.0
28
+ llama-index-vector-stores-chroma>=0.2.0
29
+ llama-index-embeddings-huggingface>=0.3.0
30
+ chromadb>=0.5.0
31
+ sentence-transformers>=3.0.0
32
+ transformers>=4.30.0
33
+
34
+ # Development
35
+ black>=23.0.0
36
+ flake8>=6.1.0
37
+ mypy>=1.7.0
servers/aact_server.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AACT Database MCP Server with Connection Pooling
4
+ Provides access to ClinicalTrials.gov data through the AACT PostgreSQL database.
5
+
6
+ AACT (Aggregate Analysis of ClinicalTrials.gov) is maintained by Duke University
7
+ and FDA, providing complete ClinicalTrials.gov data updated daily.
8
+
9
+ Database access: aact-db.ctti-clinicaltrials.org
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import json
15
+ import logging
16
+ import asyncio
17
+ from typing import Optional, Dict, Any, List
18
+ from datetime import datetime, timedelta
19
+ from pathlib import Path
20
+
21
+ # Add parent directory to path for shared imports
22
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
+ from shared.config import config
24
+
25
+ # Load .env file if it exists
26
+ try:
27
+ from dotenv import load_dotenv
28
+ env_path = Path(__file__).parent.parent / '.env'
29
+ if env_path.exists():
30
+ load_dotenv(env_path)
31
+ logging.info(f"Loaded .env file from {env_path}")
32
+ except ImportError:
33
+ pass
34
+
35
+ # Try both asyncpg (preferred) and psycopg2 (fallback)
36
+ try:
37
+ import asyncpg
38
+ ASYNCPG_AVAILABLE = True
39
+ except ImportError:
40
+ ASYNCPG_AVAILABLE = False
41
+
42
+ try:
43
+ import psycopg2
44
+ from psycopg2.extras import RealDictCursor
45
+ POSTGRES_AVAILABLE = True
46
+ except ImportError:
47
+ POSTGRES_AVAILABLE = False
48
+
49
+ from mcp.server.fastmcp import FastMCP
50
+
51
+ # Setup logging
52
+ logging.basicConfig(level=logging.INFO)
53
+ logger = logging.getLogger(__name__)
54
+
55
+ # Initialize MCP server
56
+ mcp = FastMCP("aact-database")
57
+
58
+ # Database configuration
59
+ AACT_HOST = os.getenv("AACT_HOST", "aact-db.ctti-clinicaltrials.org")
60
+ AACT_PORT = os.getenv("AACT_PORT", "5432")
61
+ AACT_DB = os.getenv("AACT_DB", "aact")
62
+ AACT_USER = os.getenv("AACT_USER", "aact")
63
+ AACT_PASSWORD = os.getenv("AACT_PASSWORD", "")
64
+
65
+ # Global connection pool (initialized once)
66
+ _connection_pool: Optional[asyncpg.Pool] = None
67
+
68
+ async def get_connection_pool() -> asyncpg.Pool:
69
+ """Get or create the global connection pool"""
70
+ global _connection_pool
71
+
72
+ if _connection_pool is None or _connection_pool._closed:
73
+ logger.info("Creating new database connection pool...")
74
+
75
+ # Build connection URL
76
+ if AACT_PASSWORD:
77
+ dsn = f"postgresql://{AACT_USER}:{AACT_PASSWORD}@{AACT_HOST}:{AACT_PORT}/{AACT_DB}"
78
+ else:
79
+ dsn = f"postgresql://{AACT_USER}@{AACT_HOST}:{AACT_PORT}/{AACT_DB}"
80
+
81
+ try:
82
+ _connection_pool = await asyncpg.create_pool(
83
+ dsn=dsn,
84
+ min_size=2, # Minimum connections in pool
85
+ max_size=10, # Maximum connections in pool
86
+ max_queries=50000, # Max queries per connection before recycling
87
+ max_inactive_connection_lifetime=300, # Close idle connections after 5 min
88
+ command_timeout=60.0, # Query timeout
89
+ statement_cache_size=20, # Cache prepared statements
90
+ )
91
+ logger.info("Connection pool created successfully")
92
+ except Exception as e:
93
+ logger.error(f"Failed to create connection pool: {e}")
94
+ raise
95
+
96
+ return _connection_pool
97
+
98
+ async def execute_query_pooled(query: str, params: tuple = ()) -> List[Dict]:
99
+ """Execute query using connection pool (asyncpg)"""
100
+ pool = await get_connection_pool()
101
+
102
+ async with pool.acquire() as conn:
103
+ # Convert rows to dicts
104
+ rows = await conn.fetch(query, *params)
105
+ return [dict(row) for row in rows]
106
+
107
+ def execute_query_sync(query: str, params: tuple = ()) -> List[Dict]:
108
+ """Fallback: Execute query synchronously (psycopg2)"""
109
+ conn = None
110
+ cursor = None
111
+ try:
112
+ # Build connection string
113
+ conn_params = {
114
+ "host": AACT_HOST,
115
+ "port": AACT_PORT,
116
+ "dbname": AACT_DB,
117
+ "user": AACT_USER,
118
+ }
119
+
120
+ if AACT_PASSWORD:
121
+ conn_params["password"] = AACT_PASSWORD
122
+
123
+ conn = psycopg2.connect(**conn_params)
124
+ cursor = conn.cursor(cursor_factory=RealDictCursor)
125
+ cursor.execute(query, params)
126
+ results = cursor.fetchall()
127
+
128
+ return [dict(row) for row in results]
129
+
130
+ except Exception as e:
131
+ logger.error(f"Database query failed: {e}")
132
+ raise
133
+ finally:
134
+ if cursor:
135
+ cursor.close()
136
+ if conn:
137
+ conn.close()
138
+
139
+ async def execute_query(query: str, params: tuple = ()) -> List[Dict]:
140
+ """Execute query using best available method"""
141
+ if ASYNCPG_AVAILABLE:
142
+ # Prefer asyncpg with connection pooling
143
+ return await execute_query_pooled(query, params)
144
+ elif POSTGRES_AVAILABLE:
145
+ # Fallback to synchronous psycopg2
146
+ return await asyncio.to_thread(execute_query_sync, query, params)
147
+ else:
148
+ raise RuntimeError("No PostgreSQL driver available (install asyncpg or psycopg2)")
149
+
150
+ @mcp.tool()
151
+ async def search_als_trials(
152
+ status: Optional[str] = "RECRUITING",
153
+ phase: Optional[str] = None,
154
+ intervention: Optional[str] = None,
155
+ location: Optional[str] = None,
156
+ max_results: int = 20
157
+ ) -> str:
158
+ """Search for ALS clinical trials in the AACT database.
159
+
160
+ Args:
161
+ status: Trial status (RECRUITING, ENROLLING_BY_INVITATION, ACTIVE_NOT_RECRUITING, COMPLETED)
162
+ phase: Trial phase (PHASE_1, PHASE_2, PHASE_3, PHASE_4, EARLY_PHASE_1)
163
+ intervention: Type of intervention to search for
164
+ location: Country or region
165
+ max_results: Maximum number of results to return
166
+ """
167
+
168
+ if not (ASYNCPG_AVAILABLE or POSTGRES_AVAILABLE):
169
+ return json.dumps({
170
+ "error": "Database not available",
171
+ "message": "PostgreSQL driver not installed. Install asyncpg or psycopg2-binary."
172
+ })
173
+
174
+ logger.info(f"🔎 AACT Search: status={status}, phase={phase}, intervention={intervention}, location={location}")
175
+
176
+ try:
177
+ # Build the query with proper filters
178
+ base_query = """
179
+ SELECT DISTINCT
180
+ s.nct_id,
181
+ s.brief_title,
182
+ s.overall_status,
183
+ s.phase,
184
+ s.enrollment,
185
+ s.start_date,
186
+ s.completion_date,
187
+ s.study_type,
188
+ s.official_title,
189
+ d.name as sponsor,
190
+ STRING_AGG(DISTINCT i.name, ', ') as interventions,
191
+ STRING_AGG(DISTINCT c.name, ', ') as conditions,
192
+ COUNT(DISTINCT f.id) as num_locations
193
+ FROM studies s
194
+ LEFT JOIN sponsors sp ON s.nct_id = sp.nct_id AND sp.lead_or_collaborator = 'lead'
195
+ LEFT JOIN responsible_parties d ON sp.nct_id = d.nct_id
196
+ LEFT JOIN interventions i ON s.nct_id = i.nct_id
197
+ LEFT JOIN conditions c ON s.nct_id = c.nct_id
198
+ LEFT JOIN facilities f ON s.nct_id = f.nct_id
199
+ WHERE (
200
+ LOWER(c.name) LIKE '%amyotrophic lateral sclerosis%' OR
201
+ LOWER(c.name) LIKE '%als %' OR
202
+ LOWER(c.name) LIKE '% als' OR
203
+ LOWER(c.name) LIKE '%motor neuron disease%' OR
204
+ LOWER(c.name) LIKE '%lou gehrig%' OR
205
+ LOWER(s.brief_title) LIKE '%amyotrophic lateral sclerosis%' OR
206
+ LOWER(s.brief_title) LIKE '%als %' OR
207
+ LOWER(s.brief_title) LIKE '% als' OR
208
+ LOWER(s.official_title) LIKE '%amyotrophic lateral sclerosis%' OR
209
+ LOWER(s.official_title) LIKE '%als %' OR
210
+ LOWER(s.official_title) LIKE '% als'
211
+ )
212
+ """
213
+
214
+ # Apply filters
215
+ conditions = []
216
+ params = []
217
+ param_count = 1
218
+
219
+ if status:
220
+ conditions.append(f"UPPER(s.overall_status) = ${param_count}")
221
+ params.append(status.upper())
222
+ param_count += 1
223
+
224
+ if phase:
225
+ phase_map = {
226
+ 'PHASE_1': 'Phase 1',
227
+ 'PHASE_2': 'Phase 2',
228
+ 'PHASE_3': 'Phase 3',
229
+ 'PHASE_4': 'Phase 4',
230
+ 'EARLY_PHASE_1': 'Early Phase 1'
231
+ }
232
+ mapped_phase = phase_map.get(phase.upper(), phase)
233
+ conditions.append(f"s.phase = ${param_count}")
234
+ params.append(mapped_phase)
235
+ param_count += 1
236
+
237
+ if intervention:
238
+ conditions.append(f"LOWER(i.name) LIKE ${param_count}")
239
+ params.append(f"%{intervention.lower()}%")
240
+ param_count += 1
241
+
242
+ if location:
243
+ base_query = base_query.replace("LEFT JOIN facilities f", "INNER JOIN facilities f")
244
+ conditions.append(f"(LOWER(f.country) LIKE ${param_count} OR LOWER(f.state) LIKE ${param_count})")
245
+ params.append(f"%{location.lower()}%")
246
+ param_count += 1
247
+
248
+ # Add conditions to query
249
+ if conditions:
250
+ base_query += " AND " + " AND ".join(conditions)
251
+
252
+ # Add GROUP BY and ORDER BY
253
+ base_query += """
254
+ GROUP BY s.nct_id, s.brief_title, s.overall_status, s.phase,
255
+ s.enrollment, s.start_date, s.completion_date,
256
+ s.study_type, s.official_title, d.name
257
+ ORDER BY
258
+ CASE s.overall_status
259
+ WHEN 'Recruiting' THEN 1
260
+ WHEN 'Enrolling by invitation' THEN 2
261
+ WHEN 'Active, not recruiting' THEN 3
262
+ WHEN 'Not yet recruiting' THEN 4
263
+ ELSE 5
264
+ END,
265
+ s.start_date DESC NULLS LAST
266
+ LIMIT ${param_count}
267
+ """
268
+ params.append(max_results)
269
+
270
+ # Execute query
271
+ logger.debug(f"📊 Executing query with {len(params)} parameters")
272
+ results = await execute_query(base_query, tuple(params))
273
+
274
+ logger.info(f"✅ AACT Results: Found {len(results) if results else 0} trials")
275
+
276
+ if not results:
277
+ return json.dumps({
278
+ "message": "No ALS trials found matching your criteria",
279
+ "total": 0,
280
+ "trials": []
281
+ })
282
+
283
+ # Format results
284
+ trials = []
285
+ for row in results:
286
+ trial = {
287
+ "nct_id": row['nct_id'],
288
+ "title": row['brief_title'],
289
+ "status": row['overall_status'],
290
+ "phase": row['phase'],
291
+ "enrollment": row['enrollment'],
292
+ "sponsor": row['sponsor'],
293
+ "interventions": row['interventions'],
294
+ "conditions": row['conditions'],
295
+ "locations_count": row['num_locations'],
296
+ "start_date": str(row['start_date']) if row['start_date'] else None,
297
+ "completion_date": str(row['completion_date']) if row['completion_date'] else None,
298
+ "url": f"https://clinicaltrials.gov/study/{row['nct_id']}"
299
+ }
300
+ trials.append(trial)
301
+
302
+ return json.dumps({
303
+ "message": f"Found {len(trials)} ALS clinical trials",
304
+ "total": len(trials),
305
+ "trials": trials
306
+ }, indent=2)
307
+
308
+ except Exception as e:
309
+ logger.error(f"❌ AACT Database query failed: {e}")
310
+ logger.error(f" Query type: search_als_trials")
311
+ logger.error(f" Parameters: status={status}, phase={phase}, intervention={intervention}")
312
+ return json.dumps({
313
+ "error": "Database query failed",
314
+ "message": str(e)
315
+ })
316
+
317
+ @mcp.tool()
318
+ async def get_trial_details(nct_id: str) -> str:
319
+ """Get detailed information about a specific clinical trial.
320
+
321
+ Args:
322
+ nct_id: The NCT ID of the trial (e.g., 'NCT04856982')
323
+ """
324
+
325
+ if not (ASYNCPG_AVAILABLE or POSTGRES_AVAILABLE):
326
+ return json.dumps({
327
+ "error": "Database not available",
328
+ "message": "PostgreSQL driver not installed."
329
+ })
330
+
331
+ try:
332
+ # Main trial information
333
+ main_query = """
334
+ SELECT
335
+ s.nct_id,
336
+ s.brief_title,
337
+ s.official_title,
338
+ s.overall_status,
339
+ s.phase,
340
+ s.study_type,
341
+ s.enrollment,
342
+ s.start_date,
343
+ s.primary_completion_date,
344
+ s.completion_date,
345
+ s.first_posted_date,
346
+ s.last_update_posted_date,
347
+ s.why_stopped,
348
+ b.description as brief_summary,
349
+ dd.description as detailed_description,
350
+ e.criteria as eligibility_criteria,
351
+ e.gender,
352
+ e.minimum_age,
353
+ e.maximum_age,
354
+ e.healthy_volunteers,
355
+ rp.name as sponsor,
356
+ rp.responsible_party_type
357
+ FROM studies s
358
+ LEFT JOIN brief_summaries b ON s.nct_id = b.nct_id
359
+ LEFT JOIN detailed_descriptions dd ON s.nct_id = dd.nct_id
360
+ LEFT JOIN eligibilities e ON s.nct_id = e.nct_id
361
+ LEFT JOIN responsible_parties rp ON s.nct_id = rp.nct_id
362
+ WHERE s.nct_id = $1
363
+ """
364
+
365
+ results = await execute_query(main_query, (nct_id,))
366
+
367
+ if not results:
368
+ return json.dumps({
369
+ "error": "Trial not found",
370
+ "message": f"No trial found with NCT ID: {nct_id}"
371
+ })
372
+
373
+ trial_info = results[0]
374
+
375
+ # Get outcomes
376
+ outcomes_query = """
377
+ SELECT outcome_type, measure, time_frame, description
378
+ FROM outcomes
379
+ WHERE nct_id = $1
380
+ ORDER BY outcome_type, id
381
+ LIMIT 20
382
+ """
383
+ outcomes = await execute_query(outcomes_query, (nct_id,))
384
+
385
+ # Get interventions
386
+ interventions_query = """
387
+ SELECT intervention_type, name, description
388
+ FROM interventions
389
+ WHERE nct_id = $1
390
+ """
391
+ interventions = await execute_query(interventions_query, (nct_id,))
392
+
393
+ # Get locations
394
+ locations_query = """
395
+ SELECT name, city, state, country, status
396
+ FROM facilities
397
+ WHERE nct_id = $1
398
+ LIMIT 50
399
+ """
400
+ locations = await execute_query(locations_query, (nct_id,))
401
+
402
+ # Format the response
403
+ return json.dumps({
404
+ "nct_id": trial_info['nct_id'],
405
+ "title": trial_info['brief_title'],
406
+ "official_title": trial_info['official_title'],
407
+ "status": trial_info['overall_status'],
408
+ "phase": trial_info['phase'],
409
+ "study_type": trial_info['study_type'],
410
+ "enrollment": trial_info['enrollment'],
411
+ "sponsor": trial_info['sponsor'],
412
+ "dates": {
413
+ "start": str(trial_info['start_date']) if trial_info['start_date'] else None,
414
+ "primary_completion": str(trial_info['primary_completion_date']) if trial_info['primary_completion_date'] else None,
415
+ "completion": str(trial_info['completion_date']) if trial_info['completion_date'] else None,
416
+ "first_posted": str(trial_info['first_posted_date']) if trial_info['first_posted_date'] else None,
417
+ "last_updated": str(trial_info['last_update_posted_date']) if trial_info['last_update_posted_date'] else None
418
+ },
419
+ "summary": trial_info['brief_summary'],
420
+ "detailed_description": trial_info['detailed_description'],
421
+ "eligibility": {
422
+ "criteria": trial_info['eligibility_criteria'],
423
+ "gender": trial_info['gender'],
424
+ "age_range": f"{trial_info['minimum_age'] or 'N/A'} - {trial_info['maximum_age'] or 'N/A'}",
425
+ "healthy_volunteers": trial_info['healthy_volunteers']
426
+ },
427
+ "outcomes": [
428
+ {
429
+ "type": o['outcome_type'],
430
+ "measure": o['measure'],
431
+ "time_frame": o['time_frame'],
432
+ "description": o['description']
433
+ } for o in outcomes
434
+ ],
435
+ "interventions": [
436
+ {
437
+ "type": i['intervention_type'],
438
+ "name": i['name'],
439
+ "description": i['description']
440
+ } for i in interventions
441
+ ],
442
+ "locations": [
443
+ {
444
+ "name": l['name'],
445
+ "city": l['city'],
446
+ "state": l['state'],
447
+ "country": l['country'],
448
+ "status": l['status']
449
+ } for l in locations
450
+ ],
451
+ "url": f"https://clinicaltrials.gov/study/{nct_id}"
452
+ }, indent=2)
453
+
454
+ except Exception as e:
455
+ logger.error(f"Failed to get trial details: {e}")
456
+ return json.dumps({
457
+ "error": "Database query failed",
458
+ "message": str(e)
459
+ })
460
+
461
+ # Cleanup on shutdown
462
+ # Note: FastMCP doesn't have a built-in shutdown handler
463
+ # The connection pool will be closed when the process ends
464
+ # async def cleanup():
465
+ # """Close the connection pool on shutdown"""
466
+ # global _connection_pool
467
+ # if _connection_pool:
468
+ # await _connection_pool.close()
469
+ # logger.info("Connection pool closed")
470
+
471
+ if __name__ == "__main__":
472
+ mcp.run()
servers/biorxiv_server.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # biorxiv_server_fixed.py
2
+ from mcp.server.fastmcp import FastMCP
3
+ import httpx
4
+ import logging
5
+ from datetime import datetime, timedelta
6
+ import sys
7
+ from pathlib import Path
8
+ import re
9
+
10
+ # Add parent directory to path for shared imports
11
+ sys.path.insert(0, str(Path(__file__).parent.parent))
12
+
13
+ from shared import (
14
+ config,
15
+ RateLimiter,
16
+ format_authors,
17
+ ErrorFormatter,
18
+ truncate_text
19
+ )
20
+ from shared.http_client import get_http_client, CustomHTTPClient
21
+
22
+ # Configure logging with DEBUG for detailed troubleshooting
23
+ logging.basicConfig(
24
+ level=logging.DEBUG,
25
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+ mcp = FastMCP("biorxiv-server")
30
+
31
+ # Rate limiting using shared utility
32
+ rate_limiter = RateLimiter(config.rate_limits.biorxiv_delay)
33
+
34
+
35
+ def preprocess_query(query: str) -> tuple[list[str], list[str]]:
36
+ """Preprocess query into search terms and handle synonyms.
37
+
38
+ Returns:
39
+ tuple of (primary_terms, all_search_terms)
40
+ """
41
+ # Convert to lowercase for matching
42
+ query_lower = query.lower()
43
+
44
+ # Common ALS-related synonyms and variations
45
+ synonyms = {
46
+ 'als': ['amyotrophic lateral sclerosis', 'motor neuron disease', 'motor neurone disease', 'lou gehrig'],
47
+ 'amyotrophic lateral sclerosis': ['als', 'motor neuron disease'],
48
+ 'mnd': ['motor neuron disease', 'motor neurone disease', 'als'],
49
+ 'sod1': ['superoxide dismutase 1', 'cu/zn superoxide dismutase'],
50
+ 'tdp-43': ['tdp43', 'tardbp', 'tar dna binding protein'],
51
+ 'c9orf72': ['c9', 'chromosome 9 open reading frame 72'],
52
+ 'fus': ['fused in sarcoma', 'tls'],
53
+ }
54
+
55
+ # Split query into individual terms (handle multiple spaces and special chars)
56
+ # Keep hyphenated words together (like TDP-43)
57
+ terms = re.split(r'\s+', query_lower.strip())
58
+
59
+ # Build comprehensive search term list
60
+ all_terms = []
61
+ primary_terms = []
62
+
63
+ for term in terms:
64
+ # Skip very short terms unless they're known abbreviations
65
+ if len(term) < 3 and term not in ['als', 'mnd', 'fus', 'c9']:
66
+ continue
67
+
68
+ primary_terms.append(term)
69
+ all_terms.append(term)
70
+
71
+ # Add synonyms if they exist
72
+ if term in synonyms:
73
+ all_terms.extend(synonyms[term])
74
+
75
+ # Remove duplicates while preserving order
76
+ seen = set()
77
+ all_terms = [t for t in all_terms if not (t in seen or seen.add(t))]
78
+ primary_terms = [t for t in primary_terms if not (t in seen or seen.add(t))]
79
+
80
+ return primary_terms, all_terms
81
+
82
+
83
+ def matches_query(paper: dict, primary_terms: list[str], all_terms: list[str], require_all: bool = False) -> bool:
84
+ """Check if a paper matches the search query.
85
+
86
+ Args:
87
+ paper: Paper dictionary from bioRxiv API
88
+ primary_terms: Main search terms from user query
89
+ all_terms: All search terms including synonyms
90
+ require_all: If True, require ALL primary terms. If False, require ANY term.
91
+
92
+ Returns:
93
+ True if paper matches search criteria
94
+ """
95
+ # Get searchable text
96
+ title = paper.get("title", "").lower()
97
+ abstract = paper.get("abstract", "").lower()
98
+ searchable_text = f" {title} {abstract} " # Add spaces for boundary matching
99
+
100
+ # DEBUG: Log paper being checked
101
+ paper_doi = paper.get("doi", "unknown")
102
+ logger.debug(f"🔍 Checking paper: {title[:60]}... (DOI: {paper_doi})")
103
+
104
+ if not searchable_text.strip():
105
+ logger.debug(f" ❌ Rejected: No title/abstract")
106
+ return False
107
+
108
+ # For ALS specifically, need to be careful about word boundaries
109
+ has_any_match = False
110
+ matched_term = None
111
+ for term in all_terms:
112
+ # For short terms like "ALS", require word boundaries
113
+ if len(term) <= 3:
114
+ # Check for word boundary match
115
+ pattern = r'\b' + re.escape(term) + r'\b'
116
+ if re.search(pattern, searchable_text, re.IGNORECASE):
117
+ has_any_match = True
118
+ matched_term = term
119
+ break
120
+ else:
121
+ # For longer terms, can be more lenient
122
+ if term.lower() in searchable_text:
123
+ has_any_match = True
124
+ matched_term = term
125
+ break
126
+
127
+ if not has_any_match:
128
+ logger.debug(f" ❌ Rejected: No term match. Terms searched: {all_terms[:3]}...")
129
+ return False
130
+
131
+ logger.debug(f" ✅ Matched on term: '{matched_term}'")
132
+
133
+ # If we only need any match, we're done
134
+ if not require_all:
135
+ return True
136
+
137
+ # For require_all, check that all primary terms are present
138
+ # Allow for word boundaries to avoid partial matches
139
+ for term in primary_terms:
140
+ # Create pattern that matches the term as a whole word or part of hyphenated word
141
+ # This handles cases like "TDP-43" or "SOD1"
142
+ pattern = r'\b' + re.escape(term) + r'(?:\b|[-])'
143
+ if not re.search(pattern, searchable_text, re.IGNORECASE):
144
+ return False
145
+
146
+ return True
147
+
148
+
149
+ @mcp.tool()
150
+ async def search_preprints(
151
+ query: str,
152
+ server: str = "both",
153
+ max_results: int = 10,
154
+ days_back: int = 365
155
+ ) -> str:
156
+ """Search bioRxiv and medRxiv for ALS preprints. Returns recent preprints before peer review.
157
+
158
+ Args:
159
+ query: Search query (e.g., 'ALS TDP-43')
160
+ server: Which server to search - one of: biorxiv, medrxiv, both (default: both)
161
+ max_results: Maximum number of results (default: 10)
162
+ days_back: Number of days to look back (default: 365 - about 1 year)
163
+ """
164
+ try:
165
+ logger.info(f"🔎 Searching bioRxiv/medRxiv for: '{query}'")
166
+ logger.info(f" Parameters: server={server}, max_results={max_results}, days_back={days_back}")
167
+
168
+ # Preprocess query for better matching
169
+ primary_terms, all_terms = preprocess_query(query)
170
+ logger.info(f"📝 Search terms: primary={primary_terms}, all={all_terms}")
171
+
172
+ # Calculate date range
173
+ end_date = datetime.now()
174
+ start_date = end_date - timedelta(days=days_back)
175
+
176
+ # Format dates for API (YYYY-MM-DD)
177
+ start_date_str = start_date.strftime("%Y-%m-%d")
178
+ end_date_str = end_date.strftime("%Y-%m-%d")
179
+ logger.info(f"📅 Date range: {start_date_str} to {end_date_str}")
180
+
181
+ # bioRxiv/medRxiv API endpoint
182
+ base_url = "https://api.biorxiv.org/details"
183
+
184
+ all_results = []
185
+ servers_to_search = []
186
+
187
+ if server in ["biorxiv", "both"]:
188
+ servers_to_search.append("biorxiv")
189
+ if server in ["medrxiv", "both"]:
190
+ servers_to_search.append("medrxiv")
191
+
192
+ # Use a custom HTTP client with proper timeout for bioRxiv
193
+ # Don't use shared client as it may have conflicting timeout settings
194
+ async with CustomHTTPClient(timeout=15.0) as client:
195
+ for srv in servers_to_search:
196
+ try:
197
+ cursor = 0
198
+ found_in_server = []
199
+ max_iterations = 1 # Only check first page (100 papers) for much faster response
200
+ iteration = 0
201
+
202
+ while iteration < max_iterations:
203
+ # Rate limiting
204
+ await rate_limiter.wait()
205
+
206
+ # Search by date range with cursor for pagination
207
+ url = f"{base_url}/{srv}/{start_date_str}/{end_date_str}/{cursor}"
208
+
209
+ logger.info(f"🌐 Querying {srv} API (page {iteration+1}, cursor={cursor})")
210
+ logger.info(f" URL: {url}")
211
+ response = await client.get(url)
212
+ response.raise_for_status()
213
+ data = response.json()
214
+
215
+ # Extract collection
216
+ collection = data.get("collection", [])
217
+
218
+ if not collection:
219
+ logger.info(f"📭 No more results from {srv}")
220
+ break
221
+
222
+ logger.info(f"📦 Fetched {len(collection)} papers from API")
223
+
224
+ # Show first few papers for debugging
225
+ if iteration == 0 and collection:
226
+ logger.info(" Sample papers from API:")
227
+ for i, paper in enumerate(collection[:3]):
228
+ logger.info(f" {i+1}. {paper.get('title', 'No title')[:60]}...")
229
+
230
+ # Filter papers using improved matching
231
+ # Start with lenient matching (ANY term)
232
+ logger.debug(f"🔍 Starting to filter {len(collection)} papers...")
233
+ filtered = [
234
+ paper for paper in collection
235
+ if matches_query(paper, primary_terms, all_terms, require_all=False)
236
+ ]
237
+
238
+ logger.info(f"✅ Filtered results: {len(filtered)}/{len(collection)} papers matched")
239
+
240
+ if len(filtered) > 0:
241
+ logger.info(" Matched papers:")
242
+ for i, paper in enumerate(filtered[:3]):
243
+ logger.info(f" {i+1}. {paper.get('title', 'No title')[:60]}...")
244
+
245
+ found_in_server.extend(filtered)
246
+ logger.info(f"📊 Running total for {srv}: {len(found_in_server)} papers")
247
+
248
+ # Check if we have enough results
249
+ if len(found_in_server) >= max_results:
250
+ logger.info(f"Reached max_results limit ({max_results})")
251
+ break
252
+
253
+ # Continue searching if we haven't found enough
254
+ if len(found_in_server) < 5 and iteration < max_iterations - 1:
255
+ # Keep searching for more results
256
+ pass
257
+ elif len(found_in_server) > 0 and iteration >= 3:
258
+ # Found some results after reasonable search
259
+ logger.info(f"Found {len(found_in_server)} results after {iteration+1} pages")
260
+ break
261
+
262
+ # Check for more pages
263
+ messages = data.get("messages", [])
264
+
265
+ # The API returns "cursor" in messages for next page
266
+ has_more = False
267
+ for msg in messages:
268
+ if "cursor=" in str(msg):
269
+ try:
270
+ cursor_str = str(msg).split("cursor=")[1].split()[0]
271
+ next_cursor = int(cursor_str)
272
+ if next_cursor > cursor:
273
+ cursor = next_cursor
274
+ has_more = True
275
+ break
276
+ except:
277
+ pass
278
+
279
+ # Alternative: increment by collection size
280
+ if not has_more:
281
+ if len(collection) >= 100:
282
+ cursor += len(collection)
283
+ else:
284
+ # Less than full page means we've reached the end
285
+ break
286
+
287
+ iteration += 1
288
+
289
+ all_results.extend(found_in_server[:max_results])
290
+ logger.info(f"🏁 Total results from {srv}: {len(found_in_server)} papers found")
291
+
292
+ except httpx.HTTPStatusError as e:
293
+ logger.warning(f"Error searching {srv}: {e}")
294
+ continue
295
+ except Exception as e:
296
+ logger.warning(f"Unexpected error searching {srv}: {e}")
297
+ continue
298
+
299
+ # If no results with lenient matching, provide helpful message
300
+ if not all_results:
301
+ logger.warning(f"⚠️ No preprints found for query: {query}")
302
+
303
+ # Provide suggestions for improving search
304
+ suggestions = []
305
+ if len(primary_terms) > 3:
306
+ suggestions.append("Try using fewer search terms")
307
+ if not any(term in ['als', 'amyotrophic lateral sclerosis', 'motor neuron'] for term in all_terms):
308
+ suggestions.append("Add 'ALS' or 'motor neuron disease' to your search")
309
+ if days_back < 365:
310
+ suggestions.append(f"Expand the time range beyond {days_back} days")
311
+
312
+ suggestion_text = ""
313
+ if suggestions:
314
+ suggestion_text = "\n\nSuggestions:\n" + "\n".join(f"- {s}" for s in suggestions)
315
+
316
+ return f"No preprints found for query: '{query}' in the last {days_back} days{suggestion_text}"
317
+
318
+ # Sort by date (most recent first)
319
+ all_results.sort(key=lambda x: x.get("date", ""), reverse=True)
320
+
321
+ # Limit results
322
+ all_results = all_results[:max_results]
323
+
324
+ logger.info(f"🎯 FINAL RESULTS: Returning {len(all_results)} preprints for '{query}'")
325
+ if all_results:
326
+ logger.info(" Top results:")
327
+ for i, paper in enumerate(all_results[:3], 1):
328
+ logger.info(f" {i}. {paper.get('title', 'No title')[:60]}...")
329
+ logger.info(f" DOI: {paper.get('doi', 'unknown')}, Date: {paper.get('date', 'unknown')}")
330
+
331
+ # Format results
332
+ result = f"Found {len(all_results)} preprints for query: '{query}'\n\n"
333
+
334
+ for i, paper in enumerate(all_results, 1):
335
+ title = paper.get("title", "No title")
336
+ doi = paper.get("doi", "Unknown")
337
+ date = paper.get("date", "Unknown")
338
+ authors = paper.get("authors", "Unknown authors")
339
+ authors_str = format_authors(authors, max_authors=3)
340
+
341
+ abstract = paper.get("abstract", "No abstract available")
342
+ category = paper.get("category", "")
343
+ server_name = "bioRxiv" if "biorxiv" in doi else "medRxiv"
344
+
345
+ result += f"{i}. **{title}**\n"
346
+ result += f" DOI: {doi} | {server_name} | Posted: {date}\n"
347
+ result += f" Authors: {authors_str}\n"
348
+ if category:
349
+ result += f" Category: {category}\n"
350
+ result += f" Abstract: {truncate_text(abstract, max_chars=300, suffix='')}\n"
351
+ result += f" URL: https://doi.org/{doi}\n\n"
352
+
353
+ logger.info(f"Successfully retrieved {len(all_results)} preprints")
354
+ return result
355
+
356
+ except httpx.TimeoutException:
357
+ logger.error("bioRxiv/medRxiv API request timed out")
358
+ return "Error: bioRxiv/medRxiv API request timed out. Please try again."
359
+ except httpx.HTTPStatusError as e:
360
+ logger.error(f"bioRxiv/medRxiv API error: {e}")
361
+ return f"Error: bioRxiv/medRxiv API returned status code {e.response.status_code}"
362
+ except Exception as e:
363
+ logger.error(f"Unexpected error in search_preprints: {e}")
364
+ return f"Error searching preprints: {str(e)}"
365
+
366
+
367
+ @mcp.tool()
368
+ async def get_preprint_details(doi: str) -> str:
369
+ """Get full details for a specific bioRxiv/medRxiv preprint by DOI.
370
+
371
+ Args:
372
+ doi: The DOI of the preprint (e.g., '10.1101/2024.01.01.123456')
373
+ """
374
+ try:
375
+ logger.info(f"Getting details for DOI: {doi}")
376
+
377
+ # Ensure DOI is properly formatted
378
+ if not doi.startswith("10.1101/"):
379
+ doi = f"10.1101/{doi}"
380
+
381
+ # Determine server from DOI
382
+ # bioRxiv DOIs typically have format: 10.1101/YYYY.MM.DD.NNNNNN
383
+ # medRxiv DOIs are similar but the content determines the server
384
+
385
+ # Use shared HTTP client for connection pooling
386
+ client = get_http_client(timeout=30.0)
387
+ # Try the DOI endpoint
388
+ url = f"https://api.biorxiv.org/details/{doi}"
389
+
390
+ response = await client.get(url)
391
+
392
+ if response.status_code == 404:
393
+ # Try with both servers
394
+ for server in ["biorxiv", "medrxiv"]:
395
+ url = f"https://api.biorxiv.org/details/{server}/{doi}"
396
+ response = await client.get(url)
397
+ if response.status_code == 200:
398
+ break
399
+ else:
400
+ return f"Preprint with DOI {doi} not found"
401
+
402
+ response.raise_for_status()
403
+ data = response.json()
404
+
405
+ collection = data.get("collection", [])
406
+ if not collection:
407
+ return f"No details found for DOI: {doi}"
408
+
409
+ # Get the first (and should be only) paper
410
+ paper = collection[0]
411
+
412
+ title = paper.get("title", "No title")
413
+ date = paper.get("date", "Unknown")
414
+ authors = paper.get("authors", "Unknown authors")
415
+ abstract = paper.get("abstract", "No abstract available")
416
+ category = paper.get("category", "")
417
+ server_name = paper.get("server", "Unknown")
418
+
419
+ result = f"**{title}**\n\n"
420
+ result += f"**DOI:** {doi}\n"
421
+ result += f"**Server:** {server_name}\n"
422
+ result += f"**Posted:** {date}\n"
423
+ if category:
424
+ result += f"**Category:** {category}\n"
425
+ result += f"**Authors:** {authors}\n\n"
426
+ result += f"**Abstract:**\n{abstract}\n\n"
427
+ result += f"**Full Text URL:** https://doi.org/{doi}\n"
428
+
429
+ return result
430
+
431
+ except httpx.HTTPStatusError as e:
432
+ logger.error(f"Error fetching preprint details: {e}")
433
+ return f"Error fetching preprint details: HTTP {e.response.status_code}"
434
+ except Exception as e:
435
+ logger.error(f"Unexpected error getting preprint details: {e}")
436
+ return f"Error getting preprint details: {str(e)}"
437
+
438
+
439
+ if __name__ == "__main__":
440
+ mcp.run(transport="stdio")
servers/clinicaltrials_links.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simplified ClinicalTrials.gov Link Generator
4
+ Provides direct links and known trials as fallback when AACT is unavailable
5
+ """
6
+
7
+ from mcp.server.fastmcp import FastMCP
8
+ import logging
9
+ from typing import Optional
10
+ from urllib.parse import quote_plus
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ mcp = FastMCP("clinicaltrials-links")
17
+
18
+ # Known important ALS trials (updated periodically)
19
+ KNOWN_ALS_TRIALS = {
20
+ "NCT05112094": {
21
+ "title": "Tofersen (ATLAS)",
22
+ "description": "SOD1-targeted antisense therapy for SOD1-ALS",
23
+ "status": "Active",
24
+ "sponsor": "Biogen"
25
+ },
26
+ "NCT04856982": {
27
+ "title": "HEALEY ALS Platform Trial",
28
+ "description": "Multiple drugs tested simultaneously",
29
+ "status": "Recruiting",
30
+ "sponsor": "Massachusetts General Hospital"
31
+ },
32
+ "NCT04768972": {
33
+ "title": "Ravulizumab",
34
+ "description": "Complement C5 inhibition",
35
+ "status": "Active",
36
+ "sponsor": "Alexion"
37
+ },
38
+ "NCT05370950": {
39
+ "title": "Pridopidine",
40
+ "description": "Sigma-1 receptor agonist",
41
+ "status": "Recruiting",
42
+ "sponsor": "Prilenia"
43
+ },
44
+ "NCT04632225": {
45
+ "title": "NurOwn",
46
+ "description": "MSC-NTF cells (mesenchymal stem cells)",
47
+ "status": "Active",
48
+ "sponsor": "BrainStorm Cell"
49
+ },
50
+ "NCT07204977": {
51
+ "title": "Acamprosate",
52
+ "description": "C9orf72 hexanucleotide repeat expansion treatment",
53
+ "status": "Recruiting",
54
+ "sponsor": "Mayo Clinic"
55
+ },
56
+ "NCT07161999": {
57
+ "title": "COYA 302",
58
+ "description": "Regulatory T-cell therapy",
59
+ "status": "Recruiting",
60
+ "sponsor": "Coya Therapeutics"
61
+ },
62
+ "NCT07023835": {
63
+ "title": "Usnoflast",
64
+ "description": "Anti-inflammatory for ALS",
65
+ "status": "Recruiting",
66
+ "sponsor": "Seelos Therapeutics"
67
+ }
68
+ }
69
+
70
+
71
+ @mcp.tool()
72
+ async def get_trial_link(nct_id: str) -> str:
73
+ """Generate direct link to a ClinicalTrials.gov trial page.
74
+
75
+ Args:
76
+ nct_id: NCT identifier (e.g., 'NCT05112094')
77
+ """
78
+ nct_id = nct_id.upper()
79
+ url = f"https://clinicaltrials.gov/study/{nct_id}"
80
+
81
+ result = f"**Direct link to trial {nct_id}:**\n{url}\n\n"
82
+
83
+ # Add info if it's a known trial
84
+ if nct_id in KNOWN_ALS_TRIALS:
85
+ trial = KNOWN_ALS_TRIALS[nct_id]
86
+ result += f"**{trial['title']}**\n"
87
+ result += f"Description: {trial['description']}\n"
88
+ result += f"Status: {trial['status']}\n"
89
+ result += f"Sponsor: {trial['sponsor']}\n"
90
+
91
+ return result
92
+
93
+
94
+ @mcp.tool()
95
+ async def get_search_link(
96
+ condition: str = "ALS",
97
+ status: Optional[str] = None,
98
+ intervention: Optional[str] = None,
99
+ location: Optional[str] = None
100
+ ) -> str:
101
+ """Generate direct search link for ClinicalTrials.gov.
102
+
103
+ Args:
104
+ condition: Medical condition (default: ALS)
105
+ status: Trial status (recruiting, active, completed)
106
+ intervention: Treatment/drug name
107
+ location: Country or city
108
+ """
109
+ base_url = "https://clinicaltrials.gov/search"
110
+ params = []
111
+
112
+ # Add condition
113
+ params.append(f"cond={quote_plus(condition)}")
114
+
115
+ # Map status to ClinicalTrials.gov format
116
+ if status:
117
+ status_lower = status.lower()
118
+ if "recruit" in status_lower:
119
+ params.append("recrs=a") # Recruiting
120
+ elif "active" in status_lower:
121
+ params.append("recrs=d") # Active, not recruiting
122
+ elif "complet" in status_lower:
123
+ params.append("recrs=e") # Completed
124
+
125
+ # Add intervention
126
+ if intervention:
127
+ params.append(f"intr={quote_plus(intervention)}")
128
+
129
+ # Add location
130
+ if location:
131
+ params.append(f"locn={quote_plus(location)}")
132
+
133
+ # Build URL
134
+ search_url = f"{base_url}?{'&'.join(params)}"
135
+
136
+ result = f"**Direct search on ClinicalTrials.gov:**\n\n"
137
+ result += f"Search parameters:\n"
138
+ result += f"- Condition: {condition}\n"
139
+ if status:
140
+ result += f"- Status: {status}\n"
141
+ if intervention:
142
+ result += f"- Intervention: {intervention}\n"
143
+ if location:
144
+ result += f"- Location: {location}\n"
145
+ result += f"\n🔗 **Search URL:** {search_url}\n"
146
+ result += f"\nClick the link above to see results on ClinicalTrials.gov"
147
+
148
+ return result
149
+
150
+
151
+ @mcp.tool()
152
+ async def get_known_als_trials(
153
+ status_filter: Optional[str] = None
154
+ ) -> str:
155
+ """Get list of known important ALS trials.
156
+
157
+ Args:
158
+ status_filter: Filter by status (recruiting, active, all)
159
+ """
160
+ result = "**Important ALS Clinical Trials:**\n\n"
161
+
162
+ if not KNOWN_ALS_TRIALS:
163
+ return "No known trials available in offline database."
164
+
165
+ count = 0
166
+ for nct_id, trial in KNOWN_ALS_TRIALS.items():
167
+ # Apply status filter if provided
168
+ if status_filter:
169
+ filter_lower = status_filter.lower()
170
+ trial_status = trial['status'].lower()
171
+
172
+ if filter_lower == "recruiting" and "recruit" not in trial_status:
173
+ continue
174
+ elif filter_lower == "active" and "active" not in trial_status:
175
+ continue
176
+ elif filter_lower == "completed" and "complet" not in trial_status:
177
+ continue
178
+
179
+ count += 1
180
+ result += f"{count}. **{trial['title']}** ({nct_id})\n"
181
+ result += f" {trial['description']}\n"
182
+ result += f" Status: {trial['status']} | Sponsor: {trial['sponsor']}\n"
183
+ result += f" 🔗 https://clinicaltrials.gov/study/{nct_id}\n\n"
184
+
185
+ if count == 0:
186
+ result += f"No trials found with status filter: {status_filter}\n"
187
+ else:
188
+ result += f"\n📌 *This is a curated list. For comprehensive search, use AACT database server.*"
189
+
190
+ return result
191
+
192
+
193
+ @mcp.tool()
194
+ async def get_trial_resources() -> str:
195
+ """Get helpful resources for finding clinical trials."""
196
+
197
+ resources = """**Clinical Trials Resources for ALS:**
198
+
199
+ **Official Databases:**
200
+ 1. **ClinicalTrials.gov**: https://clinicaltrials.gov/search?cond=ALS
201
+ - Official US trials registry
202
+ - Most comprehensive for US trials
203
+
204
+ 2. **WHO ICTRP**: https://trialsearch.who.int/
205
+ - International trials from all countries
206
+ - Includes non-US trials
207
+
208
+ 3. **EU Clinical Trials Register**: https://www.clinicaltrialsregister.eu/
209
+ - European trials database
210
+
211
+ **ALS-Specific Resources:**
212
+ 1. **Northeast ALS Consortium (NEALS)**: https://www.neals.org/
213
+ - Network of ALS clinical trial sites
214
+ - Trial matching service
215
+
216
+ 2. **ALS Therapy Development Institute**: https://www.als.net/clinical-trials/
217
+ - Independent ALS research organization
218
+ - Trial tracker and updates
219
+
220
+ 3. **I AM ALS Registry**: https://iamals.org/get-help/clinical-trials/
221
+ - Patient-focused trial information
222
+ - Trial matching assistance
223
+
224
+ **Major ALS Clinical Centers:**
225
+ - Massachusetts General Hospital (Healey Center)
226
+ - Johns Hopkins ALS Clinic
227
+ - Mayo Clinic ALS Center
228
+ - Cleveland Clinic Lou Ruvo Center
229
+ - UCSF ALS Center
230
+
231
+ **Tips for Finding Trials:**
232
+ 1. Use condition terms: "ALS", "Amyotrophic Lateral Sclerosis", "Motor Neuron Disease"
233
+ 2. Check recruiting AND not-yet-recruiting trials
234
+ 3. Consider trials at different phases (1, 2, 3)
235
+ 4. Look for platform trials testing multiple drugs
236
+ 5. Contact trial coordinators directly for eligibility
237
+
238
+ **Note:** For programmatic access to trial data, use the AACT database server which provides complete ClinicalTrials.gov data without API restrictions.
239
+ """
240
+
241
+ return resources
242
+
243
+
244
+ if __name__ == "__main__":
245
+ mcp.run(transport="stdio")
servers/elevenlabs_server.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ElevenLabs MCP Server for Voice Capabilities
4
+ Provides text-to-speech and speech-to-text for ALS Research Agent
5
+
6
+ This server enables voice accessibility features crucial for ALS patients
7
+ who may have limited mobility but retain cognitive function.
8
+ """
9
+
10
+ from mcp.server.fastmcp import FastMCP
11
+ import httpx
12
+ import logging
13
+ import os
14
+ import base64
15
+ import json
16
+ from typing import Optional, Dict, Any
17
+ from pathlib import Path
18
+ import sys
19
+
20
+ # Add parent directory to path for shared imports
21
+ sys.path.insert(0, str(Path(__file__).parent.parent))
22
+
23
+ from shared import config
24
+ from shared.http_client import get_http_client
25
+
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Initialize MCP server
31
+ mcp = FastMCP("elevenlabs-voice")
32
+
33
+ # ElevenLabs API configuration
34
+ ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY")
35
+ ELEVENLABS_API_BASE = "https://api.elevenlabs.io/v1"
36
+
37
+ # Default voice settings optimized for clarity (important for ALS patients)
38
+ DEFAULT_VOICE_ID = os.getenv("ELEVENLABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM") # Rachel voice (clear and calm)
39
+ DEFAULT_MODEL = "eleven_turbo_v2_5" # Turbo v2.5 - Fastest model available (40% faster than v2)
40
+
41
+ # Voice settings for accessibility
42
+ VOICE_SETTINGS = {
43
+ "stability": 0.5, # Balanced for speed and clarity (turbo model)
44
+ "similarity_boost": 0.5, # Balanced setting for faster processing
45
+ "style": 0.0, # Neutral style for clarity
46
+ "use_speaker_boost": True # Enhanced clarity
47
+ }
48
+
49
+
50
+ @mcp.tool()
51
+ async def text_to_speech(
52
+ text: str,
53
+ voice_id: Optional[str] = None,
54
+ output_format: str = "mp3_44100_128",
55
+ speed: float = 1.0
56
+ ) -> str:
57
+ """Convert text to speech optimized for ALS patients.
58
+
59
+ Args:
60
+ text: Text to convert to speech (research findings, paper summaries, etc.)
61
+ voice_id: ElevenLabs voice ID (defaults to clear, calm voice)
62
+ output_format: Audio format (mp3_44100_128, mp3_44100_192, pcm_16000, etc.)
63
+ speed: Speech rate (0.5-2.0, default 1.0 - can be slower for clarity)
64
+
65
+ Returns:
66
+ Base64 encoded audio data and metadata
67
+ """
68
+ try:
69
+ if not ELEVENLABS_API_KEY:
70
+ return json.dumps({
71
+ "status": "error",
72
+ "error": "ELEVENLABS_API_KEY not configured",
73
+ "message": "Please set your ElevenLabs API key in .env file"
74
+ }, indent=2)
75
+
76
+ # Limit text length to avoid ElevenLabs API timeouts
77
+ # Testing shows 2500 chars is safe, 5000 chars times out
78
+ max_length = 2500
79
+ if len(text) > max_length:
80
+ logger.warning(f"Text truncated from {len(text)} to {max_length} characters to avoid timeout")
81
+ # Try to truncate at a sentence boundary
82
+ truncated = text[:max_length]
83
+ last_period = truncated.rfind('.')
84
+ last_newline = truncated.rfind('\n')
85
+ # Use the latest sentence/paragraph boundary
86
+ boundary = max(last_period, last_newline)
87
+ if boundary > max_length - 500: # If there's a boundary in the last 500 chars
88
+ text = truncated[:boundary + 1]
89
+ else:
90
+ text = truncated + "..."
91
+
92
+ voice_id = voice_id or DEFAULT_VOICE_ID
93
+
94
+ # Prepare the request
95
+ url = f"{ELEVENLABS_API_BASE}/text-to-speech/{voice_id}"
96
+
97
+ headers = {
98
+ "xi-api-key": ELEVENLABS_API_KEY,
99
+ "Content-Type": "application/json"
100
+ }
101
+
102
+ # Adjust voice settings for speed
103
+ adjusted_settings = VOICE_SETTINGS.copy()
104
+ if speed < 1.0:
105
+ # Slower speech - increase stability for clarity
106
+ adjusted_settings["stability"] = min(1.0, adjusted_settings["stability"] + 0.1)
107
+
108
+ payload = {
109
+ "text": text,
110
+ "model_id": DEFAULT_MODEL,
111
+ "voice_settings": adjusted_settings
112
+ }
113
+
114
+ logger.info(f"Converting text to speech: {len(text)} characters")
115
+
116
+ # Set timeout based on text length (with 2500 char limit, 45s should be enough)
117
+ timeout = 45.0
118
+ logger.info(f"Using timeout of {timeout} seconds")
119
+
120
+ # Use shared HTTP client for connection pooling
121
+ client = get_http_client(timeout=timeout)
122
+ response = await client.post(url, json=payload, headers=headers)
123
+ response.raise_for_status()
124
+
125
+ # Get the audio data
126
+ audio_data = response.content
127
+
128
+ # Encode to base64 for transmission
129
+ audio_base64 = base64.b64encode(audio_data).decode('utf-8')
130
+
131
+ # Return structured response
132
+ result = {
133
+ "status": "success",
134
+ "audio_base64": audio_base64,
135
+ "format": output_format,
136
+ "duration_estimate": len(text) / 150 * 60, # Rough estimate: 150 words/min
137
+ "text_length": len(text),
138
+ "voice_id": voice_id,
139
+ "message": "Audio generated successfully. Use the audio_base64 field to play the audio."
140
+ }
141
+
142
+ logger.info(f"Successfully generated {len(audio_data)} bytes of audio")
143
+ return json.dumps(result, indent=2)
144
+
145
+ except httpx.HTTPStatusError as e:
146
+ logger.error(f"ElevenLabs API error: {e}")
147
+ if e.response.status_code == 401:
148
+ return json.dumps({
149
+ "status": "error",
150
+ "error": "Authentication failed",
151
+ "message": "Check your ELEVENLABS_API_KEY"
152
+ }, indent=2)
153
+ elif e.response.status_code == 429:
154
+ return json.dumps({
155
+ "status": "error",
156
+ "error": "Rate limit exceeded",
157
+ "message": "Please wait before trying again"
158
+ }, indent=2)
159
+ else:
160
+ return json.dumps({
161
+ "status": "error",
162
+ "error": f"API error: {e.response.status_code}",
163
+ "message": str(e)
164
+ }, indent=2)
165
+
166
+ except Exception as e:
167
+ logger.error(f"Unexpected error in text_to_speech: {e}")
168
+ return json.dumps({
169
+ "status": "error",
170
+ "error": "Text-to-speech error",
171
+ "message": str(e)
172
+ }, indent=2)
173
+
174
+
175
+ @mcp.tool()
176
+ async def create_audio_summary(
177
+ content: str,
178
+ summary_type: str = "research",
179
+ max_duration: int = 60
180
+ ) -> str:
181
+ """Create an audio summary of research content optimized for listening.
182
+
183
+ This tool reformats technical content into a more listenable format
184
+ before converting to speech - important for complex medical research.
185
+
186
+ Args:
187
+ content: Research content to summarize (paper abstract, findings, etc.)
188
+ summary_type: Type of summary - "research", "clinical", "patient-friendly"
189
+ max_duration: Target duration in seconds (affects summary length)
190
+
191
+ Returns:
192
+ Audio summary with both text and audio versions
193
+ """
194
+ try:
195
+ # Calculate target word count (assuming 150 words per minute)
196
+ target_words = int((max_duration / 60) * 150)
197
+
198
+ # Process content based on summary type
199
+ if summary_type == "patient-friendly":
200
+ # Simplify medical jargon for patients/families
201
+ processed_text = _simplify_medical_content(content, target_words)
202
+ elif summary_type == "clinical":
203
+ # Focus on clinical relevance
204
+ processed_text = _extract_clinical_relevance(content, target_words)
205
+ else: # research
206
+ # Standard research summary
207
+ processed_text = _create_research_summary(content, target_words)
208
+
209
+ # Add intro for context
210
+ intro = "Here's your audio research summary: "
211
+ final_text = intro + processed_text
212
+
213
+ # Convert to speech
214
+ tts_result = await text_to_speech(
215
+ text=final_text,
216
+ speed=0.95 # Slightly slower for complex content
217
+ )
218
+
219
+ # Parse the TTS result
220
+ tts_data = json.loads(tts_result)
221
+
222
+ if tts_data.get("status") != "success":
223
+ return tts_result # Return error from TTS
224
+
225
+ # Return enhanced result
226
+ result = {
227
+ "status": "success",
228
+ "audio_base64": tts_data["audio_base64"],
229
+ "text_summary": processed_text,
230
+ "summary_type": summary_type,
231
+ "word_count": len(processed_text.split()),
232
+ "estimated_duration": tts_data["duration_estimate"],
233
+ "format": tts_data["format"],
234
+ "message": f"Audio summary created: {summary_type} format, ~{int(tts_data['duration_estimate'])} seconds"
235
+ }
236
+
237
+ return json.dumps(result, indent=2)
238
+
239
+ except Exception as e:
240
+ logger.error(f"Error creating audio summary: {e}")
241
+ return json.dumps({
242
+ "status": "error",
243
+ "error": "Summary creation error",
244
+ "message": str(e)
245
+ }, indent=2)
246
+
247
+
248
+ @mcp.tool()
249
+ async def list_voices() -> str:
250
+ """List available voices optimized for medical/research content.
251
+
252
+ Returns voices suitable for clear pronunciation of medical terminology.
253
+ """
254
+ try:
255
+ if not ELEVENLABS_API_KEY:
256
+ return json.dumps({
257
+ "status": "error",
258
+ "error": "ELEVENLABS_API_KEY not configured",
259
+ "message": "Please set your ElevenLabs API key in .env file"
260
+ }, indent=2)
261
+
262
+ url = f"{ELEVENLABS_API_BASE}/voices"
263
+ headers = {"xi-api-key": ELEVENLABS_API_KEY}
264
+
265
+ # Use shared HTTP client for connection pooling
266
+ client = get_http_client(timeout=10.0)
267
+ response = await client.get(url, headers=headers)
268
+ response.raise_for_status()
269
+
270
+ data = response.json()
271
+ voices = data.get("voices", [])
272
+
273
+ # Filter and rank voices for medical content
274
+ recommended_voices = []
275
+ for voice in voices:
276
+ # Prefer clear, professional voices
277
+ labels = voice.get("labels", {})
278
+ if any(label in ["clear", "professional", "narration"] for label in labels.values()):
279
+ recommended_voices.append({
280
+ "voice_id": voice["voice_id"],
281
+ "name": voice["name"],
282
+ "preview_url": voice.get("preview_url"),
283
+ "description": voice.get("description", ""),
284
+ "recommended_for": "medical_content"
285
+ })
286
+
287
+ # Add all other voices
288
+ other_voices = []
289
+ for voice in voices:
290
+ if voice["voice_id"] not in [v["voice_id"] for v in recommended_voices]:
291
+ other_voices.append({
292
+ "voice_id": voice["voice_id"],
293
+ "name": voice["name"],
294
+ "preview_url": voice.get("preview_url"),
295
+ "description": voice.get("description", "")
296
+ })
297
+
298
+ result = {
299
+ "status": "success",
300
+ "recommended_voices": recommended_voices[:5], # Top 5 recommended
301
+ "other_voices": other_voices[:10], # Limit for clarity
302
+ "total_voices": len(voices),
303
+ "message": "Recommended voices are optimized for clear medical terminology pronunciation"
304
+ }
305
+
306
+ return json.dumps(result, indent=2)
307
+
308
+ except Exception as e:
309
+ logger.error(f"Error listing voices: {e}")
310
+ return json.dumps({
311
+ "status": "error",
312
+ "error": "Failed to list voices",
313
+ "message": str(e)
314
+ }, indent=2)
315
+
316
+
317
+ @mcp.tool()
318
+ async def pronunciation_guide(
319
+ medical_terms: list[str],
320
+ include_audio: bool = True
321
+ ) -> str:
322
+ """Generate pronunciation guide for medical terms.
323
+
324
+ Critical for ALS patients/caregivers learning about complex terminology.
325
+
326
+ Args:
327
+ medical_terms: List of medical terms to pronounce
328
+ include_audio: Whether to include audio pronunciation
329
+
330
+ Returns:
331
+ Pronunciation guide with optional audio
332
+ """
333
+ try:
334
+ results = []
335
+
336
+ for term in medical_terms[:10]: # Limit to prevent long processing
337
+ # Create phonetic breakdown
338
+ phonetic = _get_phonetic_spelling(term)
339
+
340
+ # Create pronunciation text
341
+ pronunciation_text = f"{term}. {phonetic}. {term}."
342
+
343
+ result_entry = {
344
+ "term": term,
345
+ "phonetic": phonetic
346
+ }
347
+
348
+ if include_audio:
349
+ # Generate audio
350
+ tts_result = await text_to_speech(
351
+ text=pronunciation_text,
352
+ speed=0.8 # Slower for clarity
353
+ )
354
+
355
+ tts_data = json.loads(tts_result)
356
+ if tts_data.get("status") == "success":
357
+ result_entry["audio_base64"] = tts_data["audio_base64"]
358
+
359
+ results.append(result_entry)
360
+
361
+ return json.dumps({
362
+ "status": "success",
363
+ "pronunciations": results,
364
+ "message": f"Generated pronunciation guide for {len(results)} terms"
365
+ }, indent=2)
366
+
367
+ except Exception as e:
368
+ logger.error(f"Error creating pronunciation guide: {e}")
369
+ return json.dumps({
370
+ "status": "error",
371
+ "error": "Pronunciation guide error",
372
+ "message": str(e)
373
+ }, indent=2)
374
+
375
+
376
+ # Helper functions for content processing
377
+
378
+ def _simplify_medical_content(content: str, target_words: int) -> str:
379
+ """Simplify medical content for patient understanding."""
380
+ # This would ideally use NLP, but for now, basic simplification
381
+
382
+ # First, strip references for cleaner audio
383
+ content = _strip_references(content)
384
+
385
+ # Common medical term replacements
386
+ replacements = {
387
+ "amyotrophic lateral sclerosis": "ALS or Lou Gehrig's disease",
388
+ "motor neurons": "nerve cells that control muscles",
389
+ "neurodegeneration": "nerve cell damage",
390
+ "pathogenesis": "disease development",
391
+ "etiology": "cause",
392
+ "prognosis": "expected outcome",
393
+ "therapeutic": "treatment",
394
+ "pharmacological": "drug-based",
395
+ "intervention": "treatment",
396
+ "mortality": "death rate",
397
+ "morbidity": "illness rate"
398
+ }
399
+
400
+ simplified = content.lower()
401
+ for term, replacement in replacements.items():
402
+ simplified = simplified.replace(term, replacement)
403
+
404
+ # Truncate to target length
405
+ words = simplified.split()
406
+ if len(words) > target_words:
407
+ words = words[:target_words]
408
+ simplified = " ".join(words) + "..."
409
+
410
+ return simplified.capitalize()
411
+
412
+
413
+ def _extract_clinical_relevance(content: str, target_words: int) -> str:
414
+ """Extract clinically relevant information."""
415
+ # Focus on treatment, outcomes, and practical implications
416
+
417
+ # First, strip references for cleaner audio
418
+ content = _strip_references(content)
419
+
420
+ # Look for key clinical phrases
421
+ clinical_markers = [
422
+ "treatment", "therapy", "outcome", "survival", "progression",
423
+ "clinical trial", "efficacy", "safety", "adverse", "benefit",
424
+ "patient", "dose", "administration"
425
+ ]
426
+
427
+ sentences = content.split(". ")
428
+ relevant_sentences = []
429
+
430
+ for sentence in sentences:
431
+ if any(marker in sentence.lower() for marker in clinical_markers):
432
+ relevant_sentences.append(sentence)
433
+
434
+ result = ". ".join(relevant_sentences)
435
+
436
+ # Truncate to target length
437
+ words = result.split()
438
+ if len(words) > target_words:
439
+ words = words[:target_words]
440
+ result = " ".join(words) + "..."
441
+
442
+ return result
443
+
444
+
445
+ def _create_research_summary(content: str, target_words: int) -> str:
446
+ """Create a research-focused summary."""
447
+ # Extract key findings and implications
448
+
449
+ # First, strip references section if present
450
+ content = _strip_references(content)
451
+
452
+ # Simply truncate for now (could be enhanced with NLP)
453
+ words = content.split()
454
+ if len(words) > target_words:
455
+ words = words[:target_words]
456
+ content = " ".join(words) + "..."
457
+
458
+ return content
459
+
460
+
461
+ def _strip_references(content: str) -> str:
462
+ """Remove references section and citations from content for audio reading."""
463
+ import re
464
+
465
+ # Extract only synthesis content if it's marked
466
+ synthesis_match = re.search(r'✅\s*SYNTHESIS:?\s*(.*?)(?=##?\s*References|##?\s*Bibliography|$)',
467
+ content, flags=re.DOTALL | re.IGNORECASE)
468
+ if synthesis_match:
469
+ content = synthesis_match.group(1)
470
+
471
+ # Remove References section (multiple possible formats)
472
+ patterns_to_remove = [
473
+ r'##?\s*References.*$', # ## References or # References to end
474
+ r'##?\s*Bibliography.*$', # Bibliography section
475
+ r'##?\s*Citations.*$', # Citations section
476
+ r'##?\s*Works Cited.*$', # Works Cited section
477
+ r'##?\s*Key References.*$', # Key References section
478
+ ]
479
+
480
+ for pattern in patterns_to_remove:
481
+ content = re.sub(pattern, '', content, flags=re.DOTALL | re.IGNORECASE)
482
+
483
+ # Remove inline citations like [1], [2,3], [PMID: 12345678]
484
+ content = re.sub(r'\[[\d,\s]+\]', '', content) # [1], [2,3], etc.
485
+ content = re.sub(r'\[PMID:\s*\d+\]', '', content) # [PMID: 12345678]
486
+ content = re.sub(r'\[NCT\d+\]', '', content) # [NCT12345678]
487
+
488
+ # Remove URLs for cleaner audio
489
+ content = re.sub(r'https?://[^\s\)]+', '', content)
490
+ content = re.sub(r'www\.[^\s\)]+', '', content)
491
+
492
+ # Remove PMID/DOI/NCT references
493
+ content = re.sub(r'PMID:\s*\d+', '', content)
494
+ content = re.sub(r'DOI:\s*[^\s]+', '', content)
495
+ content = re.sub(r'NCT\d{8}', '', content)
496
+
497
+ # Remove markdown formatting that sounds awkward in audio
498
+ content = re.sub(r'\*\*(.*?)\*\*', r'\1', content) # Remove bold
499
+ content = re.sub(r'\*(.*?)\*', r'\1', content) # Remove italic
500
+ content = re.sub(r'`(.*?)`', r'\1', content) # Remove inline code
501
+ content = re.sub(r'#{1,6}\s*', '', content) # Remove headers
502
+ content = re.sub(r'^[-*+]\s+', '', content, flags=re.MULTILINE) # Remove bullet points
503
+ content = re.sub(r'^\d+\.\s+', '', content, flags=re.MULTILINE) # Remove numbered lists
504
+
505
+ # Replace markdown links with just the text
506
+ content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content)
507
+
508
+ # Clean up extra whitespace
509
+ content = re.sub(r'\s+', ' ', content)
510
+ content = re.sub(r'\n{3,}', '\n\n', content)
511
+
512
+ return content.strip()
513
+
514
+
515
+ def _get_phonetic_spelling(term: str) -> str:
516
+ """Generate phonetic spelling for medical terms."""
517
+ # Basic phonetic rules for medical terms
518
+ # This could be enhanced with a medical pronunciation dictionary
519
+
520
+ phonetic_map = {
521
+ "amyotrophic": "AM-ee-oh-TROH-fik",
522
+ "lateral": "LAT-er-al",
523
+ "sclerosis": "skleh-ROH-sis",
524
+ "tdp-43": "T-D-P forty-three",
525
+ "riluzole": "RIL-you-zole",
526
+ "edaravone": "ed-AR-a-vone",
527
+ "tofersen": "TOE-fer-sen",
528
+ "neurofilament": "NUR-oh-FIL-a-ment",
529
+ "astrocyte": "AS-tro-site",
530
+ "oligodendrocyte": "oh-li-go-DEN-dro-site"
531
+ }
532
+
533
+ term_lower = term.lower()
534
+ if term_lower in phonetic_map:
535
+ return phonetic_map[term_lower]
536
+
537
+ # Basic syllable breakdown for unknown terms
538
+ # This is very simplified and could be improved
539
+ syllables = []
540
+ current = ""
541
+ for char in term:
542
+ if char in "aeiouAEIOU" and current:
543
+ syllables.append(current + char)
544
+ current = ""
545
+ else:
546
+ current += char
547
+ if current:
548
+ syllables.append(current)
549
+
550
+ return "-".join(syllables).upper()
551
+
552
+
553
+ if __name__ == "__main__":
554
+ # Check for API key
555
+ if not ELEVENLABS_API_KEY:
556
+ logger.warning("ELEVENLABS_API_KEY not set in environment")
557
+ logger.warning("Voice features will be limited without API key")
558
+ logger.info("Get your API key at: https://elevenlabs.io")
559
+
560
+ # Run the MCP server
561
+ mcp.run(transport="stdio")
servers/fetch_server.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fetch_server.py
2
+ from mcp.server.fastmcp import FastMCP
3
+ import httpx
4
+ from bs4 import BeautifulSoup
5
+ from urllib.parse import urlparse
6
+ import logging
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add parent directory to path for shared imports
11
+ sys.path.insert(0, str(Path(__file__).parent.parent))
12
+
13
+ from shared import (
14
+ config,
15
+ clean_whitespace,
16
+ truncate_text
17
+ )
18
+ from shared.http_client import get_http_client
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ mcp = FastMCP("fetch-server")
25
+
26
+ def validate_url(url: str) -> tuple[bool, str]:
27
+ """Validate URL for security concerns. Returns (is_valid, error_message)"""
28
+ try:
29
+ parsed = urlparse(url)
30
+
31
+ # Check scheme using shared config
32
+ if parsed.scheme not in config.security.allowed_schemes:
33
+ return False, f"Invalid URL scheme. Only {', '.join(config.security.allowed_schemes)} are allowed."
34
+
35
+ # Check for blocked hosts (SSRF protection)
36
+ hostname = parsed.hostname
37
+ if not hostname:
38
+ return False, "Invalid URL: no hostname found."
39
+
40
+ # Use shared security config for SSRF checks
41
+ if config.security.is_private_ip(hostname):
42
+ return False, "Access to localhost/private IPs is not allowed."
43
+
44
+ return True, ""
45
+
46
+ except Exception as e:
47
+ return False, f"Invalid URL: {str(e)}"
48
+
49
+ def parse_clinical_trial_page(soup: BeautifulSoup, url: str) -> str:
50
+ """Parse ClinicalTrials.gov trial detail page for structured data."""
51
+ # Check if this is a ClinicalTrials.gov page
52
+ if "clinicaltrials.gov" not in url.lower():
53
+ return None
54
+
55
+ # Extract NCT ID from URL
56
+ import re
57
+ nct_match = re.search(r'NCT\d{8}', url)
58
+ nct_id = nct_match.group() if nct_match else "Unknown"
59
+
60
+ # Try to extract key trial information
61
+ trial_info = []
62
+ trial_info.append(f"**NCT ID:** {nct_id}")
63
+ trial_info.append(f"**URL:** {url}")
64
+
65
+ # Look for title
66
+ title = soup.find('h1')
67
+ if title:
68
+ trial_info.append(f"**Title:** {title.get_text(strip=True)}")
69
+
70
+ # Look for status (various patterns)
71
+ status_patterns = [
72
+ soup.find('span', string=re.compile(r'Recruiting|Active|Completed|Enrolling', re.I)),
73
+ soup.find('div', string=re.compile(r'Recruitment Status', re.I))
74
+ ]
75
+ for pattern in status_patterns:
76
+ if pattern:
77
+ status_text = pattern.get_text(strip=True) if hasattr(pattern, 'get_text') else str(pattern)
78
+ trial_info.append(f"**Status:** {status_text}")
79
+ break
80
+
81
+ # Look for study description
82
+ desc_section = soup.find('div', {'class': re.compile('description', re.I)})
83
+ if desc_section:
84
+ desc_text = desc_section.get_text(strip=True)[:500]
85
+ trial_info.append(f"**Description:** {desc_text}...")
86
+
87
+ # Look for conditions
88
+ conditions = soup.find_all(string=re.compile(r'Condition', re.I))
89
+ if conditions:
90
+ for cond in conditions[:1]: # Just first mention
91
+ parent = cond.parent
92
+ if parent:
93
+ trial_info.append(f"**Condition:** {parent.get_text(strip=True)[:200]}")
94
+ break
95
+
96
+ # Look for interventions
97
+ interventions = soup.find_all(string=re.compile(r'Intervention', re.I))
98
+ if interventions:
99
+ for inter in interventions[:1]: # Just first mention
100
+ parent = inter.parent
101
+ if parent:
102
+ trial_info.append(f"**Intervention:** {parent.get_text(strip=True)[:200]}")
103
+ break
104
+
105
+ # Look for sponsor
106
+ sponsor = soup.find(string=re.compile(r'Sponsor', re.I))
107
+ if sponsor and sponsor.parent:
108
+ trial_info.append(f"**Sponsor:** {sponsor.parent.get_text(strip=True)[:100]}")
109
+
110
+ # Locations/Sites
111
+ locations = soup.find_all(string=re.compile(r'Location|Site', re.I))
112
+ if locations:
113
+ location_texts = []
114
+ for loc in locations[:3]: # First 3 locations
115
+ if loc.parent:
116
+ location_texts.append(loc.parent.get_text(strip=True)[:50])
117
+ if location_texts:
118
+ trial_info.append(f"**Locations:** {', '.join(location_texts)}")
119
+
120
+ if len(trial_info) > 2: # If we found meaningful data
121
+ return "\n\n".join(trial_info) + "\n\n**Note:** This is extracted from the trial webpage. Some details may be incomplete due to page structure variations."
122
+
123
+ return None
124
+
125
+ @mcp.tool()
126
+ async def fetch_url(url: str, extract_text_only: bool = True) -> str:
127
+ """Fetch content from a URL (paper abstract page, news article, etc.).
128
+
129
+ Args:
130
+ url: URL to fetch
131
+ extract_text_only: Extract only main text content (default: True)
132
+ """
133
+ try:
134
+ logger.info(f"Fetching URL: {url}")
135
+
136
+ # Validate URL
137
+ is_valid, error_msg = validate_url(url)
138
+ if not is_valid:
139
+ logger.warning(f"URL validation failed: {error_msg}")
140
+ return f"Error: {error_msg}"
141
+
142
+ # Use shared HTTP client for connection pooling
143
+ client = get_http_client(timeout=config.api.timeout)
144
+ response = await client.get(url, headers={
145
+ "User-Agent": config.api.user_agent
146
+ })
147
+ response.raise_for_status()
148
+
149
+ # Check content size using shared config
150
+ content_length = response.headers.get('content-length')
151
+ if content_length and int(content_length) > config.content_limits.max_content_size:
152
+ logger.warning(f"Content too large: {content_length} bytes")
153
+ return f"Error: Content size ({content_length} bytes) exceeds maximum allowed size of {config.content_limits.max_content_size} bytes"
154
+
155
+ # Check actual content size
156
+ if len(response.content) > config.content_limits.max_content_size:
157
+ logger.warning(f"Content too large: {len(response.content)} bytes")
158
+ return f"Error: Content size exceeds maximum allowed size of {config.content_limits.max_content_size} bytes"
159
+
160
+ if extract_text_only:
161
+ soup = BeautifulSoup(response.text, 'html.parser')
162
+
163
+ # Check if this is a clinical trial page and try enhanced parsing
164
+ trial_data = parse_clinical_trial_page(soup, url)
165
+ if trial_data:
166
+ logger.info(f"Successfully parsed clinical trial page: {url}")
167
+ return trial_data
168
+
169
+ # Otherwise, do standard text extraction
170
+ # Remove script and style elements
171
+ for script in soup(["script", "style", "meta", "link"]):
172
+ script.decompose()
173
+
174
+ # Get text
175
+ text = soup.get_text()
176
+
177
+ # Clean up whitespace using shared utility
178
+ text = clean_whitespace(text)
179
+
180
+ # Limit to reasonable size for LLM context using shared utility
181
+ text = truncate_text(text, max_chars=config.content_limits.max_text_chars)
182
+
183
+ logger.info(f"Successfully fetched and extracted text from {url}")
184
+ return text
185
+ else:
186
+ # Return raw HTML, but still limit size using shared utility
187
+ html = truncate_text(response.text, max_chars=config.content_limits.max_text_chars)
188
+
189
+ logger.info(f"Successfully fetched raw HTML from {url}")
190
+ return html
191
+
192
+ except httpx.TimeoutException:
193
+ logger.error(f"Request to {url} timed out")
194
+ return f"Error: Request timed out after {config.api.timeout} seconds"
195
+ except httpx.HTTPStatusError as e:
196
+ logger.error(f"HTTP error fetching {url}: {e}")
197
+ return f"Error: HTTP {e.response.status_code} - {e.response.reason_phrase}"
198
+ except httpx.RequestError as e:
199
+ logger.error(f"Request error fetching {url}: {e}")
200
+ return f"Error: Failed to fetch URL - {str(e)}"
201
+ except Exception as e:
202
+ logger.error(f"Unexpected error fetching {url}: {e}")
203
+ return f"Error: {str(e)}"
204
+
205
+ if __name__ == "__main__":
206
+ mcp.run(transport="stdio")
servers/llamaindex_server.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LlamaIndex MCP Server for Research Memory and RAG
4
+ Provides persistent memory and semantic search capabilities for ALS Research Agent
5
+
6
+ This server enables the agent to remember all research it encounters, build
7
+ knowledge over time, and discover connections between papers.
8
+ """
9
+
10
+ from mcp.server.fastmcp import FastMCP
11
+ import logging
12
+ import os
13
+ import json
14
+ import hashlib
15
+ from typing import Optional, List, Dict, Any
16
+ from pathlib import Path
17
+ import sys
18
+ from datetime import datetime
19
+ import asyncio
20
+
21
+ # Add parent directory to path for shared imports
22
+ sys.path.insert(0, str(Path(__file__).parent.parent))
23
+
24
+ from shared import config
25
+
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Initialize MCP server
31
+ mcp = FastMCP("llamaindex-rag")
32
+
33
+ # Import LlamaIndex components (will be installed)
34
+ try:
35
+ from llama_index.core import (
36
+ VectorStoreIndex,
37
+ Document,
38
+ StorageContext,
39
+ Settings,
40
+ load_index_from_storage
41
+ )
42
+ from llama_index.core.node_parser import SentenceSplitter
43
+ from llama_index.vector_stores.chroma import ChromaVectorStore
44
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
45
+ import chromadb
46
+ LLAMAINDEX_AVAILABLE = True
47
+ except ImportError:
48
+ LLAMAINDEX_AVAILABLE = False
49
+ logger.warning("LlamaIndex not installed. Install with: pip install llama-index chromadb sentence-transformers")
50
+
51
+ # Configuration
52
+ CHROMA_DB_PATH = os.getenv("CHROMA_DB_PATH", "./chroma_db")
53
+ EMBED_MODEL = os.getenv("LLAMAINDEX_EMBED_MODEL", "dmis-lab/biobert-base-cased-v1.2")
54
+ CHUNK_SIZE = int(os.getenv("LLAMAINDEX_CHUNK_SIZE", "1024"))
55
+ CHUNK_OVERLAP = int(os.getenv("LLAMAINDEX_CHUNK_OVERLAP", "200"))
56
+
57
+ # Global index storage
58
+ research_index = None
59
+ chroma_client = None
60
+ collection = None
61
+ papers_metadata = {} # Store paper metadata separately
62
+
63
+
64
+ class ResearchMemoryManager:
65
+ """Manages persistent research memory using LlamaIndex and ChromaDB"""
66
+
67
+ def __init__(self):
68
+ self.index = None
69
+ self.chroma_client = None
70
+ self.collection = None
71
+ self.metadata_path = Path(CHROMA_DB_PATH) / "metadata.json"
72
+
73
+ if LLAMAINDEX_AVAILABLE:
74
+ self._initialize_index()
75
+
76
+ def _initialize_index(self):
77
+ """Initialize or load existing index from ChromaDB"""
78
+ try:
79
+ # Create directory if it doesn't exist
80
+ Path(CHROMA_DB_PATH).mkdir(parents=True, exist_ok=True)
81
+
82
+ # Initialize ChromaDB client
83
+ self.chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
84
+
85
+ # Get or create collection
86
+ try:
87
+ self.collection = self.chroma_client.get_collection("als_research")
88
+ logger.info(f"Loaded existing ChromaDB collection with {self.collection.count()} papers")
89
+ except:
90
+ self.collection = self.chroma_client.create_collection("als_research")
91
+ logger.info("Created new ChromaDB collection")
92
+
93
+ # Initialize embedding model - prefer biomedical models
94
+ try:
95
+ embed_model = HuggingFaceEmbedding(
96
+ model_name=EMBED_MODEL,
97
+ cache_folder="./embed_cache"
98
+ )
99
+ logger.info(f"Using embedding model: {EMBED_MODEL}")
100
+ except Exception as e:
101
+ logger.warning(f"Failed to load {EMBED_MODEL}, falling back to default")
102
+ embed_model = HuggingFaceEmbedding(
103
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
104
+ cache_folder="./embed_cache"
105
+ )
106
+
107
+ # Configure settings
108
+ Settings.embed_model = embed_model
109
+ Settings.chunk_size = CHUNK_SIZE
110
+ Settings.chunk_overlap = CHUNK_OVERLAP
111
+
112
+ # Initialize vector store
113
+ vector_store = ChromaVectorStore(chroma_collection=self.collection)
114
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
115
+
116
+ # Create or load index
117
+ if self.collection.count() > 0:
118
+ # Load existing index
119
+ self.index = VectorStoreIndex.from_vector_store(
120
+ vector_store,
121
+ storage_context=storage_context
122
+ )
123
+ logger.info("Loaded existing vector index")
124
+ else:
125
+ # Create new index
126
+ self.index = VectorStoreIndex(
127
+ [],
128
+ storage_context=storage_context
129
+ )
130
+ logger.info("Created new vector index")
131
+
132
+ # Load metadata
133
+ self._load_metadata()
134
+
135
+ except Exception as e:
136
+ logger.error(f"Failed to initialize index: {e}")
137
+ self.index = None
138
+
139
+ def _load_metadata(self):
140
+ """Load paper metadata from disk"""
141
+ global papers_metadata
142
+ if self.metadata_path.exists():
143
+ try:
144
+ with open(self.metadata_path, 'r') as f:
145
+ papers_metadata = json.load(f)
146
+ logger.info(f"Loaded metadata for {len(papers_metadata)} papers")
147
+ except Exception as e:
148
+ logger.error(f"Failed to load metadata: {e}")
149
+ papers_metadata = {}
150
+ else:
151
+ papers_metadata = {}
152
+
153
+ def _save_metadata(self):
154
+ """Save paper metadata to disk"""
155
+ try:
156
+ with open(self.metadata_path, 'w') as f:
157
+ json.dump(papers_metadata, f, indent=2, default=str)
158
+ except Exception as e:
159
+ logger.error(f"Failed to save metadata: {e}")
160
+
161
+ def generate_paper_id(self, title: str, doi: Optional[str] = None) -> str:
162
+ """Generate unique ID for a paper"""
163
+ if doi:
164
+ return hashlib.md5(doi.encode()).hexdigest()
165
+ return hashlib.md5(title.lower().encode()).hexdigest()
166
+
167
+ async def index_paper(
168
+ self,
169
+ title: str,
170
+ abstract: str,
171
+ authors: List[str],
172
+ doi: Optional[str] = None,
173
+ journal: Optional[str] = None,
174
+ year: Optional[int] = None,
175
+ findings: Optional[str] = None,
176
+ url: Optional[str] = None,
177
+ paper_type: str = "research"
178
+ ) -> Dict[str, Any]:
179
+ """Index a research paper with metadata"""
180
+
181
+ if not self.index:
182
+ return {"status": "error", "message": "Index not initialized"}
183
+
184
+ # Generate unique ID
185
+ paper_id = self.generate_paper_id(title, doi)
186
+
187
+ # Check if already indexed
188
+ if paper_id in papers_metadata:
189
+ return {
190
+ "status": "already_indexed",
191
+ "paper_id": paper_id,
192
+ "title": title,
193
+ "message": "Paper already in research memory"
194
+ }
195
+
196
+ # Prepare document text
197
+ doc_text = f"Title: {title}\n\n"
198
+ doc_text += f"Authors: {', '.join(authors)}\n\n"
199
+
200
+ if journal:
201
+ doc_text += f"Journal: {journal}\n"
202
+ if year:
203
+ doc_text += f"Year: {year}\n\n"
204
+
205
+ doc_text += f"Abstract: {abstract}\n\n"
206
+
207
+ if findings:
208
+ doc_text += f"Key Findings: {findings}\n\n"
209
+
210
+ # Create document with metadata (ChromaDB only accepts strings, not lists)
211
+ metadata = {
212
+ "paper_id": paper_id,
213
+ "title": title,
214
+ "authors": ", ".join(authors) if authors else "", # Convert list to string
215
+ "doi": doi,
216
+ "journal": journal,
217
+ "year": year,
218
+ "url": url,
219
+ "paper_type": paper_type,
220
+ "indexed_at": datetime.now().isoformat()
221
+ }
222
+
223
+ document = Document(
224
+ text=doc_text,
225
+ metadata=metadata
226
+ )
227
+
228
+ try:
229
+ # Add to index
230
+ self.index.insert(document)
231
+
232
+ # Store metadata
233
+ papers_metadata[paper_id] = metadata
234
+ self._save_metadata()
235
+
236
+ logger.info(f"Indexed paper: {title}")
237
+
238
+ return {
239
+ "status": "success",
240
+ "paper_id": paper_id,
241
+ "title": title,
242
+ "message": f"Successfully indexed paper into research memory"
243
+ }
244
+
245
+ except Exception as e:
246
+ logger.error(f"Failed to index paper: {e}")
247
+ return {
248
+ "status": "error",
249
+ "message": f"Failed to index paper: {str(e)}"
250
+ }
251
+
252
+ async def search_similar(
253
+ self,
254
+ query: str,
255
+ top_k: int = 5,
256
+ include_scores: bool = True
257
+ ) -> List[Dict[str, Any]]:
258
+ """Search for similar research in memory"""
259
+
260
+ if not self.index:
261
+ return []
262
+
263
+ try:
264
+ # Use retriever for direct vector search (no LLM needed)
265
+ retriever = self.index.as_retriever(
266
+ similarity_top_k=top_k
267
+ )
268
+
269
+ # Search using retriever
270
+ nodes = retriever.retrieve(query)
271
+
272
+ results = []
273
+ for node in nodes:
274
+ result = {
275
+ "text": node.text[:500] + "..." if len(node.text) > 500 else node.text,
276
+ "metadata": node.metadata,
277
+ "score": node.score if include_scores else None
278
+ }
279
+ results.append(result)
280
+
281
+ return results
282
+
283
+ except Exception as e:
284
+ logger.error(f"Search failed: {e}")
285
+ return []
286
+
287
+
288
+ # Global manager - will be initialized on first use
289
+ memory_manager = None
290
+ _initialization_lock = asyncio.Lock() # Prevent race conditions during initialization
291
+ _initialization_started = False
292
+
293
+
294
+ async def ensure_initialized():
295
+ """Ensure the memory manager is initialized (lazy initialization)."""
296
+ global memory_manager, _initialization_started
297
+
298
+ # Quick check without lock
299
+ if memory_manager is not None:
300
+ return True
301
+
302
+ # Thread-safe initialization
303
+ async with _initialization_lock:
304
+ # Double-check after acquiring lock
305
+ if memory_manager is not None:
306
+ return True
307
+
308
+ if not LLAMAINDEX_AVAILABLE:
309
+ return False
310
+
311
+ if _initialization_started:
312
+ # Another thread is initializing, wait for it
313
+ while memory_manager is None and _initialization_started:
314
+ await asyncio.sleep(0.1)
315
+ return memory_manager is not None
316
+
317
+ try:
318
+ _initialization_started = True
319
+ logger.info("🔄 Initializing LlamaIndex RAG system (this may take 20-30 seconds)...")
320
+ logger.info(" Loading BioBERT embedding model...")
321
+
322
+ # Initialize the memory manager
323
+ memory_manager = ResearchMemoryManager()
324
+
325
+ logger.info("✅ LlamaIndex RAG system initialized successfully")
326
+ return True
327
+
328
+ except Exception as e:
329
+ logger.error(f"❌ Failed to initialize LlamaIndex: {e}")
330
+ _initialization_started = False
331
+ return False
332
+
333
+
334
+ @mcp.tool()
335
+ async def index_paper(
336
+ title: str,
337
+ abstract: str,
338
+ authors: str,
339
+ doi: Optional[str] = None,
340
+ journal: Optional[str] = None,
341
+ year: Optional[int] = None,
342
+ findings: Optional[str] = None,
343
+ url: Optional[str] = None
344
+ ) -> str:
345
+ """Index a research paper into persistent memory for future retrieval.
346
+
347
+ The agent's research memory persists across sessions, building knowledge over time.
348
+
349
+ Args:
350
+ title: Paper title
351
+ abstract: Paper abstract or summary
352
+ authors: Comma-separated list of authors
353
+ doi: Digital Object Identifier (optional)
354
+ journal: Journal or preprint server name (optional)
355
+ year: Publication year (optional)
356
+ findings: Key findings or implications (optional)
357
+ url: URL to paper (optional)
358
+
359
+ Returns:
360
+ Status of indexing operation
361
+ """
362
+
363
+ if not LLAMAINDEX_AVAILABLE:
364
+ return json.dumps({
365
+ "status": "error",
366
+ "error": "LlamaIndex not installed",
367
+ "message": "Install with: pip install llama-index chromadb sentence-transformers"
368
+ }, indent=2)
369
+
370
+ # Lazy initialization on first use
371
+ if not await ensure_initialized():
372
+ return json.dumps({
373
+ "status": "error",
374
+ "error": "Memory manager initialization failed",
375
+ "message": "Check LlamaIndex configuration and dependencies"
376
+ }, indent=2)
377
+
378
+ try:
379
+ # Parse authors
380
+ authors_list = [a.strip() for a in authors.split(",")]
381
+
382
+ result = await memory_manager.index_paper(
383
+ title=title,
384
+ abstract=abstract,
385
+ authors=authors_list,
386
+ doi=doi,
387
+ journal=journal,
388
+ year=year,
389
+ findings=findings,
390
+ url=url
391
+ )
392
+
393
+ if result["status"] == "success":
394
+ return json.dumps({
395
+ "status": "success",
396
+ "paper_id": result["paper_id"],
397
+ "title": result["title"],
398
+ "message": f"✅ Indexed into research memory. Total papers: {len(papers_metadata)}",
399
+ "total_papers_indexed": len(papers_metadata)
400
+ }, indent=2)
401
+
402
+ elif result["status"] == "already_indexed":
403
+ return json.dumps({
404
+ "status": "already_indexed",
405
+ "paper_id": result["paper_id"],
406
+ "title": result["title"],
407
+ "message": "ℹ️ Paper already in research memory",
408
+ "total_papers_indexed": len(papers_metadata)
409
+ }, indent=2)
410
+
411
+ else:
412
+ return json.dumps({"status": "error", "error": "Indexing failed", "message": result.get("message", "Unknown error")}, indent=2)
413
+
414
+ except Exception as e:
415
+ logger.error(f"Error indexing paper: {e}")
416
+ return json.dumps({"status": "error", "error": "Indexing error", "message": str(e)}, indent=2)
417
+
418
+
419
+ @mcp.tool()
420
+ async def semantic_search(
421
+ query: str,
422
+ max_results: int = 5
423
+ ) -> str:
424
+ """Search research memory using semantic similarity.
425
+
426
+ Finds papers similar to your query across all indexed research,
427
+ even if they don't contain exact keywords.
428
+
429
+ Args:
430
+ query: Search query (can be a question, topic, or paper abstract)
431
+ max_results: Maximum number of results to return (default: 5)
432
+
433
+ Returns:
434
+ Similar papers from research memory
435
+ """
436
+
437
+ if not LLAMAINDEX_AVAILABLE:
438
+ return json.dumps({
439
+ "status": "error",
440
+ "error": "LlamaIndex not installed",
441
+ "message": "Install with: pip install llama-index chromadb sentence-transformers"
442
+ }, indent=2)
443
+
444
+ # Lazy initialization on first use
445
+ if not await ensure_initialized():
446
+ return json.dumps({
447
+ "status": "error",
448
+ "error": "Memory manager initialization failed",
449
+ "message": "Check LlamaIndex configuration and dependencies"
450
+ }, indent=2)
451
+
452
+ if not memory_manager.index:
453
+ return json.dumps({
454
+ "status": "error",
455
+ "error": "No research memory available",
456
+ "message": "No papers have been indexed yet"
457
+ }, indent=2)
458
+
459
+ try:
460
+ results = await memory_manager.search_similar(
461
+ query=query,
462
+ top_k=max_results
463
+ )
464
+
465
+ if not results:
466
+ return json.dumps({
467
+ "status": "no_results",
468
+ "query": query,
469
+ "message": "No similar research found in memory"
470
+ }, indent=2)
471
+
472
+ # Format results
473
+ formatted_results = []
474
+ for i, result in enumerate(results, 1):
475
+ metadata = result["metadata"]
476
+ formatted_results.append({
477
+ "rank": i,
478
+ "title": metadata.get("title", "Unknown"),
479
+ "authors": metadata.get("authors", []),
480
+ "year": metadata.get("year"),
481
+ "journal": metadata.get("journal"),
482
+ "doi": metadata.get("doi"),
483
+ "url": metadata.get("url"),
484
+ "similarity_score": round(result["score"], 3) if result["score"] else None,
485
+ "excerpt": result["text"][:300] + "..."
486
+ })
487
+
488
+ return json.dumps({
489
+ "status": "success",
490
+ "query": query,
491
+ "num_results": len(formatted_results),
492
+ "results": formatted_results,
493
+ "message": f"Found {len(formatted_results)} similar papers in research memory"
494
+ }, indent=2)
495
+
496
+ except Exception as e:
497
+ logger.error(f"Search error: {e}")
498
+ return json.dumps({"status": "error", "error": "Search failed", "message": str(e)}, indent=2)
499
+
500
+
501
+ @mcp.tool()
502
+ async def get_research_connections(
503
+ paper_title: str,
504
+ connection_type: str = "similar",
505
+ max_connections: int = 5
506
+ ) -> str:
507
+ """Discover connections between research papers in memory.
508
+
509
+ Finds related papers that might share themes, methods, or findings.
510
+
511
+ Args:
512
+ paper_title: Title of paper to find connections for
513
+ connection_type: Type of connections - "similar", "citations", "authors"
514
+ max_connections: Maximum connections to return
515
+
516
+ Returns:
517
+ Connected papers with relationship descriptions
518
+ """
519
+
520
+ if not LLAMAINDEX_AVAILABLE:
521
+ return json.dumps({
522
+ "status": "error",
523
+ "error": "LlamaIndex not installed",
524
+ "message": "Install with: pip install llama-index chromadb sentence-transformers"
525
+ }, indent=2)
526
+
527
+ # Lazy initialization on first use
528
+ if not await ensure_initialized():
529
+ return json.dumps({
530
+ "status": "error",
531
+ "error": "Memory manager initialization failed",
532
+ "message": "Check LlamaIndex configuration and dependencies"
533
+ }, indent=2)
534
+
535
+ try:
536
+ # For now, we'll use similarity search
537
+ # Future: implement citation networks, co-authorship graphs
538
+
539
+ if connection_type == "similar":
540
+ # Search for papers similar to this title
541
+ results = await memory_manager.search_similar(
542
+ query=paper_title,
543
+ top_k=max_connections + 1 # +1 because it might include itself
544
+ )
545
+
546
+ # Filter out the paper itself
547
+ filtered_results = []
548
+ for result in results:
549
+ if result["metadata"].get("title", "").lower() != paper_title.lower():
550
+ filtered_results.append(result)
551
+
552
+ if not filtered_results:
553
+ return json.dumps({
554
+ "status": "no_connections",
555
+ "paper": paper_title,
556
+ "message": "No connections found in research memory"
557
+ }, indent=2)
558
+
559
+ connections = []
560
+ for result in filtered_results[:max_connections]:
561
+ metadata = result["metadata"]
562
+ connections.append({
563
+ "title": metadata.get("title", "Unknown"),
564
+ "authors": metadata.get("authors", []),
565
+ "year": metadata.get("year"),
566
+ "connection_strength": round(result["score"], 3) if result["score"] else None,
567
+ "connection_type": "semantic_similarity",
568
+ "url": metadata.get("url")
569
+ })
570
+
571
+ return json.dumps({
572
+ "status": "success",
573
+ "paper": paper_title,
574
+ "connection_type": connection_type,
575
+ "num_connections": len(connections),
576
+ "connections": connections,
577
+ "message": f"Found {len(connections)} connected papers"
578
+ }, indent=2)
579
+
580
+ else:
581
+ return json.dumps({
582
+ "status": "not_implemented",
583
+ "message": f"Connection type '{connection_type}' not yet implemented. Use 'similar' for now."
584
+ }, indent=2)
585
+
586
+ except Exception as e:
587
+ logger.error(f"Error finding connections: {e}")
588
+ return json.dumps({"status": "error", "error": "Connection search failed", "message": str(e)}, indent=2)
589
+
590
+
591
+ @mcp.tool()
592
+ async def list_indexed_papers(
593
+ limit: int = 20,
594
+ sort_by: str = "date"
595
+ ) -> str:
596
+ """List papers currently in research memory.
597
+
598
+ Shows what research the agent has learned from previously.
599
+
600
+ Args:
601
+ limit: Maximum papers to list (default: 20)
602
+ sort_by: Sort order - "date" (indexed date) or "year" (publication year)
603
+
604
+ Returns:
605
+ List of indexed papers with metadata
606
+ """
607
+
608
+ if not papers_metadata:
609
+ return json.dumps({
610
+ "status": "empty",
611
+ "message": "No papers indexed yet. Research memory is empty.",
612
+ "total_papers": 0
613
+ }, indent=2)
614
+
615
+ try:
616
+ # Get papers list
617
+ papers_list = []
618
+ for paper_id, metadata in papers_metadata.items():
619
+ # Convert authors string back to list
620
+ authors_str = metadata.get("authors", "")
621
+ authors_list = authors_str.split(", ") if authors_str else []
622
+
623
+ papers_list.append({
624
+ "paper_id": paper_id,
625
+ "title": metadata.get("title", "Unknown"),
626
+ "authors": authors_list,
627
+ "year": metadata.get("year"),
628
+ "journal": metadata.get("journal"),
629
+ "doi": metadata.get("doi"),
630
+ "indexed_at": metadata.get("indexed_at"),
631
+ "url": metadata.get("url")
632
+ })
633
+
634
+ # Sort
635
+ if sort_by == "date":
636
+ papers_list.sort(key=lambda x: x.get("indexed_at", ""), reverse=True)
637
+ elif sort_by == "year":
638
+ papers_list.sort(key=lambda x: x.get("year", 0), reverse=True)
639
+
640
+ # Limit
641
+ papers_list = papers_list[:limit]
642
+
643
+ return json.dumps({
644
+ "status": "success",
645
+ "total_papers": len(papers_metadata),
646
+ "showing": len(papers_list),
647
+ "sort_by": sort_by,
648
+ "papers": papers_list,
649
+ "message": f"Research memory contains {len(papers_metadata)} papers"
650
+ }, indent=2)
651
+
652
+ except Exception as e:
653
+ logger.error(f"Error listing papers: {e}")
654
+ return json.dumps({"status": "error", "error": "Failed to list papers", "message": str(e)}, indent=2)
655
+
656
+
657
+ @mcp.tool()
658
+ async def clear_research_memory(
659
+ confirm: bool = False
660
+ ) -> str:
661
+ """Clear all papers from research memory.
662
+
663
+ ⚠️ This will permanently delete all indexed research!
664
+
665
+ Args:
666
+ confirm: Must be True to actually clear memory
667
+
668
+ Returns:
669
+ Confirmation of memory clearing
670
+ """
671
+ global papers_metadata
672
+
673
+ if not confirm:
674
+ return json.dumps({
675
+ "status": "confirmation_required",
676
+ "message": "⚠️ This will delete all research memory. Set confirm=True to proceed.",
677
+ "current_papers": len(papers_metadata)
678
+ }, indent=2)
679
+
680
+ try:
681
+ # Check if memory manager needs initialization
682
+ # Only initialize if we have papers to clear
683
+ if papers_metadata and not memory_manager:
684
+ await ensure_initialized()
685
+
686
+ # Clear ChromaDB collection
687
+ if memory_manager and memory_manager.collection:
688
+ # Delete and recreate collection
689
+ memory_manager.chroma_client.delete_collection("als_research")
690
+ memory_manager.collection = memory_manager.chroma_client.create_collection("als_research")
691
+
692
+ # Reinitialize index
693
+ memory_manager._initialize_index()
694
+
695
+ # Clear metadata
696
+ num_papers = len(papers_metadata)
697
+ papers_metadata = {}
698
+
699
+ # Save empty metadata
700
+ if memory_manager:
701
+ memory_manager._save_metadata()
702
+
703
+ logger.info(f"Cleared research memory: {num_papers} papers removed")
704
+
705
+ return json.dumps({
706
+ "status": "success",
707
+ "message": f"✅ Research memory cleared. Removed {num_papers} papers.",
708
+ "papers_removed": num_papers
709
+ }, indent=2)
710
+
711
+ except Exception as e:
712
+ logger.error(f"Error clearing memory: {e}")
713
+ return json.dumps({"status": "error", "error": "Failed to clear memory", "message": str(e)}, indent=2)
714
+
715
+
716
+ if __name__ == "__main__":
717
+ # Check for required packages
718
+ if not LLAMAINDEX_AVAILABLE:
719
+ logger.error("LlamaIndex dependencies not installed!")
720
+ logger.info("Install with: pip install llama-index-core llama-index-vector-stores-chroma")
721
+ logger.info(" pip install chromadb sentence-transformers transformers")
722
+ else:
723
+ logger.info(f"LlamaIndex RAG server starting...")
724
+ logger.info(f"ChromaDB path: {CHROMA_DB_PATH}")
725
+ logger.info(f"Embedding model: {EMBED_MODEL}")
726
+ logger.info(f"Papers in memory: {len(papers_metadata)}")
727
+
728
+ # Run the MCP server
729
+ mcp.run(transport="stdio")
servers/pubmed_server.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pubmed_server.py
2
+ from mcp.server.fastmcp import FastMCP
3
+ import httpx
4
+ import logging
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ # Add parent directory to path for shared imports
9
+ sys.path.insert(0, str(Path(__file__).parent.parent))
10
+
11
+ from shared import (
12
+ config,
13
+ RateLimiter,
14
+ format_authors,
15
+ ErrorFormatter,
16
+ truncate_text
17
+ )
18
+ from shared.http_client import get_http_client
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Create FastMCP server
25
+ mcp = FastMCP("pubmed-server")
26
+
27
+ # Rate limiting using shared utility
28
+ rate_limiter = RateLimiter(config.rate_limits.pubmed_delay)
29
+
30
+
31
+ @mcp.tool()
32
+ async def search_pubmed(
33
+ query: str,
34
+ max_results: int = 10,
35
+ sort: str = "relevance"
36
+ ) -> str:
37
+ """Search PubMed for ALS research papers. Returns titles, abstracts, PMIDs, and publication dates.
38
+
39
+ Args:
40
+ query: Search query (e.g., 'ALS SOD1 therapy')
41
+ max_results: Maximum number of results (default: 10)
42
+ sort: Sort order - 'relevance' or 'date' (default: 'relevance')
43
+ """
44
+ try:
45
+ logger.info(f"Searching PubMed for: {query}")
46
+
47
+ # PubMed E-utilities API (no auth required)
48
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
49
+
50
+ # Rate limiting
51
+ await rate_limiter.wait()
52
+
53
+ # Step 1: Search for PMIDs
54
+ search_params = {
55
+ "db": "pubmed",
56
+ "term": query,
57
+ "retmax": max_results,
58
+ "retmode": "json",
59
+ "sort": sort
60
+ }
61
+
62
+ # Use shared HTTP client for connection pooling
63
+ client = get_http_client(timeout=config.api.timeout)
64
+
65
+ # Get PMIDs
66
+ search_resp = await client.get(f"{base_url}/esearch.fcgi", params=search_params)
67
+ search_resp.raise_for_status()
68
+ search_data = search_resp.json()
69
+ pmids = search_data.get("esearchresult", {}).get("idlist", [])
70
+
71
+ if not pmids:
72
+ logger.info(f"No results found for query: {query}")
73
+ return ErrorFormatter.no_results(query)
74
+
75
+ # Rate limiting
76
+ await rate_limiter.wait()
77
+
78
+ # Step 2: Fetch details for PMIDs
79
+ fetch_params = {
80
+ "db": "pubmed",
81
+ "id": ",".join(pmids),
82
+ "retmode": "xml"
83
+ }
84
+
85
+ fetch_resp = await client.get(f"{base_url}/efetch.fcgi", params=fetch_params)
86
+ fetch_resp.raise_for_status()
87
+
88
+ # Parse XML and extract key info
89
+ papers = parse_pubmed_xml(fetch_resp.text)
90
+
91
+ result = f"Found {len(papers)} papers for query: '{query}'\n\n"
92
+ for i, paper in enumerate(papers, 1):
93
+ result += f"{i}. **{paper['title']}**\n"
94
+ result += f" PMID: {paper['pmid']} | Published: {paper['date']}\n"
95
+ result += f" Authors: {paper['authors']}\n"
96
+ result += f" URL: https://pubmed.ncbi.nlm.nih.gov/{paper['pmid']}/\n"
97
+ result += f" Abstract: {truncate_text(paper['abstract'], max_chars=300, suffix='')}...\n\n"
98
+
99
+ logger.info(f"Successfully retrieved {len(papers)} papers")
100
+ return result
101
+
102
+ except httpx.TimeoutException:
103
+ logger.error("PubMed API request timed out")
104
+ return "Error: PubMed API request timed out. Please try again."
105
+ except httpx.HTTPStatusError as e:
106
+ logger.error(f"PubMed API error: {e}")
107
+ return f"Error: PubMed API returned status {e.response.status_code}"
108
+ except Exception as e:
109
+ logger.error(f"Unexpected error in search_pubmed: {e}")
110
+ return f"Error: {str(e)}"
111
+
112
+
113
+ @mcp.tool()
114
+ async def get_paper_details(pmid: str) -> str:
115
+ """Get full details for a specific PubMed paper by PMID.
116
+
117
+ Args:
118
+ pmid: PubMed ID
119
+ """
120
+ try:
121
+ logger.info(f"Fetching details for PMID: {pmid}")
122
+
123
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
124
+
125
+ # Rate limiting
126
+ await rate_limiter.wait()
127
+
128
+ fetch_params = {
129
+ "db": "pubmed",
130
+ "id": pmid,
131
+ "retmode": "xml"
132
+ }
133
+
134
+ # Use shared HTTP client for connection pooling
135
+ client = get_http_client(timeout=config.api.timeout)
136
+ fetch_resp = await client.get(f"{base_url}/efetch.fcgi", params=fetch_params)
137
+ fetch_resp.raise_for_status()
138
+
139
+ papers = parse_pubmed_xml(fetch_resp.text)
140
+
141
+ if not papers:
142
+ return ErrorFormatter.not_found("paper", pmid)
143
+
144
+ paper = papers[0]
145
+
146
+ # Format detailed response
147
+ result = f"**{paper['title']}**\n\n"
148
+ result += f"**PMID:** {paper['pmid']}\n"
149
+ result += f"**Published:** {paper['date']}\n"
150
+ result += f"**Authors:** {paper['authors']}\n\n"
151
+ result += f"**Abstract:**\n{paper['abstract']}\n\n"
152
+ result += f"**Journal:** {paper.get('journal', 'N/A')}\n"
153
+ result += f"**DOI:** {paper.get('doi', 'N/A')}\n"
154
+ result += f"**PubMed URL:** https://pubmed.ncbi.nlm.nih.gov/{pmid}/\n"
155
+
156
+ logger.info(f"Successfully retrieved details for PMID: {pmid}")
157
+ return result
158
+
159
+ except httpx.TimeoutException:
160
+ logger.error("PubMed API request timed out")
161
+ return "Error: PubMed API request timed out. Please try again."
162
+ except httpx.HTTPStatusError as e:
163
+ logger.error(f"PubMed API error: {e}")
164
+ return f"Error: PubMed API returned status {e.response.status_code}"
165
+ except Exception as e:
166
+ logger.error(f"Unexpected error in get_paper_details: {e}")
167
+ return f"Error: {str(e)}"
168
+
169
+
170
+ def parse_pubmed_xml(xml_text: str) -> list[dict]:
171
+ """Parse PubMed XML response into structured data with error handling"""
172
+ import xml.etree.ElementTree as ET
173
+
174
+ papers = []
175
+
176
+ try:
177
+ root = ET.fromstring(xml_text)
178
+ except ET.ParseError as e:
179
+ logger.error(f"XML parsing error: {e}")
180
+ return papers
181
+
182
+ for article in root.findall(".//PubmedArticle"):
183
+ try:
184
+ # Extract title
185
+ title_elem = article.find(".//ArticleTitle")
186
+ title = "".join(title_elem.itertext()) if title_elem is not None else "No title"
187
+
188
+ # Extract abstract (may have multiple AbstractText elements)
189
+ abstract_parts = []
190
+ for abstract_elem in article.findall(".//AbstractText"):
191
+ if abstract_elem is not None and abstract_elem.text:
192
+ label = abstract_elem.get("Label", "")
193
+ text = "".join(abstract_elem.itertext())
194
+ if label:
195
+ abstract_parts.append(f"{label}: {text}")
196
+ else:
197
+ abstract_parts.append(text)
198
+ abstract = " ".join(abstract_parts) if abstract_parts else "No abstract available"
199
+
200
+ # Extract PMID
201
+ pmid_elem = article.find(".//PMID")
202
+ pmid = pmid_elem.text if pmid_elem is not None else "Unknown"
203
+
204
+ # Extract date - correct path in MedlineCitation
205
+ pub_date = article.find(".//MedlineCitation/Article/Journal/JournalIssue/PubDate")
206
+ if pub_date is not None:
207
+ year_elem = pub_date.find("Year")
208
+ month_elem = pub_date.find("Month")
209
+ year = year_elem.text if year_elem is not None else "Unknown"
210
+ month = month_elem.text if month_elem is not None else ""
211
+ date_str = f"{month} {year}" if month else year
212
+ else:
213
+ # Try alternative date location
214
+ date_completed = article.find(".//DateCompleted")
215
+ if date_completed is not None:
216
+ year_elem = date_completed.find("Year")
217
+ year = year_elem.text if year_elem is not None else "Unknown"
218
+ date_str = year
219
+ else:
220
+ date_str = "Unknown"
221
+
222
+ # Extract authors
223
+ authors = []
224
+ for author in article.findall(".//Author"):
225
+ last = author.find("LastName")
226
+ first = author.find("ForeName")
227
+ collective = author.find("CollectiveName")
228
+
229
+ if collective is not None and collective.text:
230
+ authors.append(collective.text)
231
+ elif last is not None and first is not None:
232
+ authors.append(f"{first.text} {last.text}")
233
+ elif last is not None:
234
+ authors.append(last.text)
235
+
236
+ # Format authors using shared utility
237
+ authors_str = format_authors("; ".join(authors), max_authors=3) if authors else "Unknown authors"
238
+
239
+ # Extract journal name
240
+ journal_elem = article.find(".//Journal/Title")
241
+ journal = journal_elem.text if journal_elem is not None else "Unknown"
242
+
243
+ # Extract DOI
244
+ doi = None
245
+ for article_id in article.findall(".//ArticleId"):
246
+ if article_id.get("IdType") == "doi":
247
+ doi = article_id.text
248
+ break
249
+
250
+ papers.append({
251
+ "title": title,
252
+ "abstract": abstract,
253
+ "pmid": pmid,
254
+ "date": date_str,
255
+ "authors": authors_str,
256
+ "journal": journal,
257
+ "doi": doi or "N/A"
258
+ })
259
+
260
+ except Exception as e:
261
+ logger.warning(f"Error parsing article: {e}")
262
+ continue
263
+
264
+ return papers
265
+
266
+
267
+ if __name__ == "__main__":
268
+ # Run with stdio transport
269
+ mcp.run(transport="stdio")
shared/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # shared/__init__.py
2
+ """Shared utilities and configuration for ALS Research Agent"""
3
+
4
+ from .config import config, AppConfig, APIConfig, RateLimitConfig, ContentLimits, SecurityConfig
5
+ from .utils import (
6
+ RateLimiter,
7
+ safe_api_call,
8
+ truncate_text,
9
+ format_authors,
10
+ clean_whitespace,
11
+ ErrorFormatter,
12
+ create_citation
13
+ )
14
+ from .cache import SimpleCache
15
+
16
+ __all__ = [
17
+ # Configuration
18
+ 'config',
19
+ 'AppConfig',
20
+ 'APIConfig',
21
+ 'RateLimitConfig',
22
+ 'ContentLimits',
23
+ 'SecurityConfig',
24
+ # Utilities
25
+ 'RateLimiter',
26
+ 'safe_api_call',
27
+ 'truncate_text',
28
+ 'format_authors',
29
+ 'clean_whitespace',
30
+ 'ErrorFormatter',
31
+ 'create_citation',
32
+ # Cache
33
+ 'SimpleCache',
34
+ ]
shared/cache.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # shared/cache.py
2
+ """Simple in-memory cache for API responses"""
3
+
4
+ import time
5
+ import hashlib
6
+ import json
7
+ from typing import Optional, Any
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class SimpleCache:
14
+ """Simple TTL-based in-memory cache with size limits"""
15
+
16
+ def __init__(self, ttl: int = 3600, max_size: int = 100):
17
+ """
18
+ Initialize cache with TTL and size limits
19
+
20
+ Args:
21
+ ttl: Time to live in seconds (default: 1 hour)
22
+ max_size: Maximum number of cached entries (default: 100)
23
+ """
24
+ self.cache = {}
25
+ self.ttl = ttl
26
+ self.max_size = max_size
27
+
28
+ def _make_key(self, tool_name: str, arguments: dict) -> str:
29
+ """Create cache key from tool name and arguments"""
30
+ # Sort dict for consistent hashing
31
+ args_str = json.dumps(arguments, sort_keys=True)
32
+ key_str = f"{tool_name}:{args_str}"
33
+ return hashlib.md5(key_str.encode()).hexdigest()
34
+
35
+ def get(self, tool_name: str, arguments: dict) -> Optional[str]:
36
+ """Get cached result if available and not expired"""
37
+ key = self._make_key(tool_name, arguments)
38
+
39
+ if key in self.cache:
40
+ result, timestamp = self.cache[key]
41
+
42
+ # Check if expired
43
+ if time.time() - timestamp < self.ttl:
44
+ logger.info(f"Cache HIT for {tool_name}")
45
+ return result
46
+ else:
47
+ # Remove expired entry
48
+ del self.cache[key]
49
+ logger.info(f"Cache EXPIRED for {tool_name}")
50
+
51
+ logger.info(f"Cache MISS for {tool_name}")
52
+ return None
53
+
54
+ def set(self, tool_name: str, arguments: dict, result: str) -> None:
55
+ """Store result in cache with LRU eviction if at capacity"""
56
+ key = self._make_key(tool_name, arguments)
57
+
58
+ # Check if we need to evict an entry
59
+ if len(self.cache) >= self.max_size and key not in self.cache:
60
+ # Find and remove oldest entry (LRU based on timestamp)
61
+ if self.cache: # Safety check
62
+ oldest_key = min(self.cache.keys(),
63
+ key=lambda k: self.cache[k][1])
64
+ del self.cache[oldest_key]
65
+ logger.debug(f"Evicted oldest cache entry to maintain size limit")
66
+
67
+ self.cache[key] = (result, time.time())
68
+ logger.debug(f"Cached result for {tool_name} (cache size: {len(self.cache)}/{self.max_size})")
69
+
70
+ def clear(self) -> None:
71
+ """Clear all cache entries"""
72
+ self.cache.clear()
73
+ logger.info("Cache cleared")
74
+
75
+ def size(self) -> int:
76
+ """Get number of cached items"""
77
+ return len(self.cache)
78
+
79
+ def cleanup_expired(self) -> int:
80
+ """Remove all expired entries and return count of removed items"""
81
+ expired_keys = []
82
+ current_time = time.time()
83
+
84
+ for key, (result, timestamp) in self.cache.items():
85
+ if current_time - timestamp >= self.ttl:
86
+ expired_keys.append(key)
87
+
88
+ for key in expired_keys:
89
+ del self.cache[key]
90
+
91
+ if expired_keys:
92
+ logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")
93
+
94
+ return len(expired_keys)
shared/config.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # shared/config.py
2
+ """Shared configuration for MCP servers"""
3
+
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+
9
+ @dataclass
10
+ class APIConfig:
11
+ """Configuration for API calls"""
12
+ timeout: float = 15.0 # Reduced from 30s - PubMed typically responds in <1s
13
+ max_retries: int = 3
14
+ user_agent: str = "Mozilla/5.0 (compatible; ALS-Research-Bot/1.0)"
15
+
16
+
17
+ @dataclass
18
+ class RateLimitConfig:
19
+ """Rate limiting configuration for different APIs"""
20
+ # PubMed: 3 req/sec without key, 10 req/sec with key
21
+ pubmed_delay: float = 0.34 # ~3 requests per second
22
+
23
+ # ClinicalTrials.gov: conservative limit (API limit is ~50 req/min)
24
+ clinicaltrials_delay: float = 1.5 # ~40 requests per minute (safe margin)
25
+
26
+ # bioRxiv/medRxiv: be respectful
27
+ biorxiv_delay: float = 1.0 # 1 request per second
28
+
29
+ # General web fetching
30
+ fetch_delay: float = 0.5
31
+
32
+
33
+ @dataclass
34
+ class ContentLimits:
35
+ """Content size and length limits"""
36
+ # Maximum content size for downloads (10MB)
37
+ max_content_size: int = 10 * 1024 * 1024
38
+
39
+ # Maximum characters for LLM context
40
+ max_text_chars: int = 8000
41
+
42
+ # Maximum abstract preview length
43
+ max_abstract_preview: int = 300
44
+
45
+ # Maximum description preview length
46
+ max_description_preview: int = 500
47
+
48
+
49
+ @dataclass
50
+ class SecurityConfig:
51
+ """Security-related configuration"""
52
+ allowed_schemes: list[str] = None
53
+ blocked_hosts: list[str] = None
54
+
55
+ def __post_init__(self):
56
+ if self.allowed_schemes is None:
57
+ self.allowed_schemes = ['http', 'https']
58
+
59
+ if self.blocked_hosts is None:
60
+ self.blocked_hosts = [
61
+ 'localhost',
62
+ '127.0.0.1',
63
+ '0.0.0.0',
64
+ '[::1]'
65
+ ]
66
+
67
+ def is_private_ip(self, hostname: str) -> bool:
68
+ """Check if hostname is a private IP"""
69
+ hostname_lower = hostname.lower()
70
+
71
+ # Check exact matches
72
+ if hostname_lower in self.blocked_hosts:
73
+ return True
74
+
75
+ # Check private IP ranges
76
+ if hostname_lower.startswith(('192.168.', '10.')):
77
+ return True
78
+
79
+ # Check 172.16-31 range
80
+ if hostname_lower.startswith('172.'):
81
+ try:
82
+ second_octet = int(hostname.split('.')[1])
83
+ if 16 <= second_octet <= 31:
84
+ return True
85
+ except (ValueError, IndexError):
86
+ pass
87
+
88
+ return False
89
+
90
+
91
+ @dataclass
92
+ class AppConfig:
93
+ """Application-wide configuration"""
94
+ # API configurations
95
+ api: APIConfig = None
96
+ rate_limits: RateLimitConfig = None
97
+ content_limits: ContentLimits = None
98
+ security: SecurityConfig = None
99
+
100
+ # Environment variables
101
+ anthropic_api_key: Optional[str] = None
102
+ anthropic_model: str = "claude-sonnet-4-5-20250929"
103
+ gradio_port: int = 7860
104
+ log_level: str = "INFO"
105
+
106
+ # PubMed email (optional, increases rate limit)
107
+ pubmed_email: Optional[str] = None
108
+
109
+ def __post_init__(self):
110
+ # Initialize sub-configs
111
+ if self.api is None:
112
+ self.api = APIConfig()
113
+ if self.rate_limits is None:
114
+ self.rate_limits = RateLimitConfig()
115
+ if self.content_limits is None:
116
+ self.content_limits = ContentLimits()
117
+ if self.security is None:
118
+ self.security = SecurityConfig()
119
+
120
+ # Load from environment
121
+ self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", self.anthropic_api_key)
122
+ self.anthropic_model = os.getenv("ANTHROPIC_MODEL", self.anthropic_model)
123
+ self.gradio_port = int(os.getenv("GRADIO_SERVER_PORT", self.gradio_port))
124
+ self.log_level = os.getenv("LOG_LEVEL", self.log_level)
125
+ self.pubmed_email = os.getenv("PUBMED_EMAIL", self.pubmed_email)
126
+
127
+ @classmethod
128
+ def from_env(cls) -> 'AppConfig':
129
+ """Create configuration from environment variables"""
130
+ return cls()
131
+
132
+
133
+ # Global configuration instance
134
+ config = AppConfig.from_env()
shared/http_client.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Shared HTTP client with connection pooling for better performance.
4
+ All MCP servers should use this instead of creating new clients for each request.
5
+ """
6
+
7
+ import httpx
8
+ from typing import Optional
9
+
10
+ # Global HTTP client with connection pooling
11
+ # This maintains persistent connections to servers for faster subsequent requests
12
+ _http_client: Optional[httpx.AsyncClient] = None
13
+
14
+ def get_http_client(timeout: float = 30.0) -> httpx.AsyncClient:
15
+ """
16
+ Get the shared HTTP client with connection pooling.
17
+
18
+ NOTE: For different timeout values, use CustomHTTPClient context manager
19
+ instead to avoid conflicts between servers.
20
+
21
+ Args:
22
+ timeout: Request timeout in seconds (default 30)
23
+
24
+ Returns:
25
+ Shared httpx.AsyncClient instance
26
+ """
27
+ global _http_client
28
+
29
+ if _http_client is None or _http_client.is_closed:
30
+ _http_client = httpx.AsyncClient(
31
+ timeout=httpx.Timeout(timeout),
32
+ limits=httpx.Limits(
33
+ max_connections=100, # Maximum number of connections
34
+ max_keepalive_connections=20, # Keep 20 connections alive for reuse
35
+ keepalive_expiry=300 # Keep connections alive for 5 minutes
36
+ ),
37
+ # Follow redirects by default
38
+ follow_redirects=True
39
+ )
40
+
41
+ return _http_client
42
+
43
+ async def close_http_client():
44
+ """Close the shared HTTP client (call on shutdown)."""
45
+ global _http_client
46
+ if _http_client and not _http_client.is_closed:
47
+ await _http_client.aclose()
48
+ _http_client = None
49
+
50
+ # Context manager for temporary clients with custom settings
51
+ class CustomHTTPClient:
52
+ """Context manager for creating temporary HTTP clients with custom settings."""
53
+
54
+ def __init__(self, timeout: float = 30.0, **kwargs):
55
+ self.timeout = timeout
56
+ self.kwargs = kwargs
57
+ self.client = None
58
+
59
+ async def __aenter__(self):
60
+ self.client = httpx.AsyncClient(
61
+ timeout=httpx.Timeout(self.timeout),
62
+ **self.kwargs
63
+ )
64
+ return self.client
65
+
66
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
67
+ if self.client:
68
+ await self.client.aclose()
shared/utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # shared/utils.py
2
+ """Shared utilities for MCP servers"""
3
+
4
+ import asyncio
5
+ import time
6
+ import logging
7
+ from typing import Optional, Callable, Any
8
+ from mcp.types import TextContent
9
+ import httpx
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class RateLimiter:
15
+ """Rate limiter for API calls"""
16
+
17
+ def __init__(self, delay: float):
18
+ """
19
+ Initialize rate limiter
20
+
21
+ Args:
22
+ delay: Minimum delay between requests in seconds
23
+ """
24
+ self.delay = delay
25
+ self.last_request_time: Optional[float] = None
26
+
27
+ async def wait(self) -> None:
28
+ """Wait if necessary to respect rate limit"""
29
+ if self.last_request_time is not None:
30
+ elapsed = time.time() - self.last_request_time
31
+ if elapsed < self.delay:
32
+ await asyncio.sleep(self.delay - elapsed)
33
+ self.last_request_time = time.time()
34
+
35
+
36
+ async def safe_api_call(
37
+ func: Callable,
38
+ *args: Any,
39
+ timeout: float = 30.0,
40
+ error_prefix: str = "API",
41
+ **kwargs: Any
42
+ ) -> list[TextContent]:
43
+ """
44
+ Safely execute an API call with comprehensive error handling
45
+
46
+ Args:
47
+ func: Async function to call
48
+ *args: Positional arguments for func
49
+ timeout: Timeout in seconds
50
+ error_prefix: Prefix for error messages
51
+ **kwargs: Keyword arguments for func
52
+
53
+ Returns:
54
+ list[TextContent]: Result or error message
55
+ """
56
+ try:
57
+ return await asyncio.wait_for(func(*args, **kwargs), timeout=timeout)
58
+
59
+ except asyncio.TimeoutError:
60
+ logger.error(f"{error_prefix} request timed out after {timeout}s")
61
+ return [TextContent(
62
+ type="text",
63
+ text=f"Error: {error_prefix} request timed out after {timeout} seconds. Please try again."
64
+ )]
65
+
66
+ except httpx.TimeoutException:
67
+ logger.error(f"{error_prefix} request timed out")
68
+ return [TextContent(
69
+ type="text",
70
+ text=f"Error: {error_prefix} request timed out. Please try again."
71
+ )]
72
+
73
+ except httpx.HTTPStatusError as e:
74
+ logger.error(f"{error_prefix} error: HTTP {e.response.status_code}")
75
+ return [TextContent(
76
+ type="text",
77
+ text=f"Error: {error_prefix} returned status {e.response.status_code}"
78
+ )]
79
+
80
+ except httpx.RequestError as e:
81
+ logger.error(f"{error_prefix} request error: {e}")
82
+ return [TextContent(
83
+ type="text",
84
+ text=f"Error: Failed to connect to {error_prefix}. Please check your connection."
85
+ )]
86
+
87
+ except Exception as e:
88
+ logger.error(f"Unexpected error in {error_prefix}: {e}", exc_info=True)
89
+ return [TextContent(
90
+ type="text",
91
+ text=f"Error: {str(e)}"
92
+ )]
93
+
94
+
95
+ def truncate_text(text: str, max_chars: int = 8000, suffix: str = "...") -> str:
96
+ """
97
+ Truncate text to maximum length with suffix
98
+
99
+ Args:
100
+ text: Text to truncate
101
+ max_chars: Maximum character count
102
+ suffix: Suffix to add when truncated
103
+
104
+ Returns:
105
+ Truncated text
106
+ """
107
+ if len(text) <= max_chars:
108
+ return text
109
+
110
+ return text[:max_chars] + f"\n\n[Content truncated at {max_chars} characters]{suffix}"
111
+
112
+
113
+ def format_authors(authors: str, max_authors: int = 3) -> str:
114
+ """
115
+ Format author list with et al. if needed
116
+
117
+ Args:
118
+ authors: Semicolon-separated author list
119
+ max_authors: Maximum authors to show
120
+
121
+ Returns:
122
+ Formatted author string
123
+ """
124
+ if not authors or authors == "Unknown":
125
+ return "Unknown authors"
126
+
127
+ author_list = [a.strip() for a in authors.split(";")]
128
+
129
+ if len(author_list) <= max_authors:
130
+ return ", ".join(author_list)
131
+
132
+ return ", ".join(author_list[:max_authors]) + " et al."
133
+
134
+
135
+ def clean_whitespace(text: str) -> str:
136
+ """
137
+ Clean up excessive whitespace in text
138
+
139
+ Args:
140
+ text: Text to clean
141
+
142
+ Returns:
143
+ Cleaned text
144
+ """
145
+ lines = (line.strip() for line in text.splitlines())
146
+ chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
147
+ return '\n'.join(chunk for chunk in chunks if chunk)
148
+
149
+
150
+ class ErrorFormatter:
151
+ """Consistent error message formatting"""
152
+
153
+ @staticmethod
154
+ def not_found(resource_type: str, identifier: str) -> str:
155
+ """Format not found error"""
156
+ return f"No {resource_type} found with identifier: {identifier}"
157
+
158
+ @staticmethod
159
+ def no_results(query: str, time_period: str = "") -> str:
160
+ """Format no results error"""
161
+ time_str = f" {time_period}" if time_period else ""
162
+ return f"No results found for query: {query}{time_str}"
163
+
164
+ @staticmethod
165
+ def validation_error(field: str, issue: str) -> str:
166
+ """Format validation error"""
167
+ return f"Validation error: {field} - {issue}"
168
+
169
+ @staticmethod
170
+ def api_error(service: str, status_code: int) -> str:
171
+ """Format API error"""
172
+ return f"Error: {service} API returned status {status_code}"
173
+
174
+
175
+ def create_citation(
176
+ identifier: str,
177
+ identifier_type: str,
178
+ url: Optional[str] = None
179
+ ) -> str:
180
+ """
181
+ Create a formatted citation string
182
+
183
+ Args:
184
+ identifier: Citation identifier (PMID, DOI, NCT ID)
185
+ identifier_type: Type of identifier
186
+ url: Optional URL
187
+
188
+ Returns:
189
+ Formatted citation
190
+ """
191
+ citation = f"{identifier_type}: {identifier}"
192
+ if url:
193
+ citation += f" | URL: {url}"
194
+ return citation
smart_cache.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Smart Cache System for ALS Research Agent
4
+ Features:
5
+ - Query normalization to match similar queries
6
+ - Cache pre-warming with common queries
7
+ - High-frequency question optimization
8
+ """
9
+
10
+ import json
11
+ import hashlib
12
+ import re
13
+ from typing import Dict, List, Optional, Tuple, Any
14
+ from datetime import datetime, timedelta
15
+ import asyncio
16
+ import logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class SmartCache:
22
+ """Advanced caching system with query normalization and pre-warming"""
23
+
24
+ def __init__(self, cache_dir: str = ".cache", ttl_hours: int = 24):
25
+ """
26
+ Initialize smart cache system.
27
+
28
+ Args:
29
+ cache_dir: Directory for cache storage
30
+ ttl_hours: Time-to-live for cached entries in hours
31
+ """
32
+ self.cache_dir = cache_dir
33
+ self.ttl = timedelta(hours=ttl_hours)
34
+ self.cache = {} # In-memory cache
35
+ self.normalized_cache = {} # Maps normalized queries to original cache keys
36
+ self.high_frequency_queries = {} # User-specified common queries
37
+ self.query_stats = {} # Track query frequency
38
+
39
+ # Ensure cache directory exists
40
+ import os
41
+ os.makedirs(cache_dir, exist_ok=True)
42
+
43
+ # Load persistent cache on init
44
+ self.load_cache()
45
+
46
+ def normalize_query(self, query: str) -> str:
47
+ """
48
+ Normalize query for better cache matching.
49
+
50
+ Handles variations like:
51
+ - "ALS gene therapy" vs "gene therapy ALS"
52
+ - "What are the latest trials" vs "what are latest trials"
53
+ - Different word orders, case, punctuation
54
+ """
55
+ # Convert to lowercase
56
+ normalized = query.lower().strip()
57
+
58
+ # Remove common question words that don't affect meaning
59
+ question_words = [
60
+ 'what', 'how', 'when', 'where', 'why', 'who', 'which',
61
+ 'are', 'is', 'the', 'a', 'an', 'there', 'can', 'could',
62
+ 'would', 'should', 'do', 'does', 'did', 'have', 'has', 'had'
63
+ ]
64
+
65
+ # Remove punctuation
66
+ normalized = re.sub(r'[^\w\s]', ' ', normalized)
67
+
68
+ # Split into words and remove question words
69
+ words = normalized.split()
70
+ content_words = [w for w in words if w not in question_words]
71
+
72
+ # Sort words alphabetically for consistent ordering
73
+ # This makes "ALS gene therapy" match "gene therapy ALS"
74
+ content_words.sort()
75
+
76
+ # Join back together
77
+ normalized = ' '.join(content_words)
78
+
79
+ # Remove extra whitespace
80
+ normalized = ' '.join(normalized.split())
81
+
82
+ return normalized
83
+
84
+ def generate_cache_key(self, query: str, include_normalization: bool = True) -> str:
85
+ """
86
+ Generate a cache key for a query.
87
+
88
+ Args:
89
+ query: The original query
90
+ include_normalization: Whether to also store normalized version
91
+
92
+ Returns:
93
+ Hash-based cache key
94
+ """
95
+ # Generate hash of original query
96
+ original_hash = hashlib.sha256(query.encode()).hexdigest()[:16]
97
+
98
+ if include_normalization:
99
+ # Also store mapping from normalized query to this cache key
100
+ normalized = self.normalize_query(query)
101
+ normalized_hash = hashlib.sha256(normalized.encode()).hexdigest()[:16]
102
+
103
+ # Store mapping for future lookups
104
+ if normalized_hash not in self.normalized_cache:
105
+ self.normalized_cache[normalized_hash] = []
106
+ if original_hash not in self.normalized_cache[normalized_hash]:
107
+ self.normalized_cache[normalized_hash].append(original_hash)
108
+
109
+ return original_hash
110
+
111
+ def find_similar_cached(self, query: str) -> Optional[Dict[str, Any]]:
112
+ """
113
+ Find cached results for similar queries.
114
+
115
+ Args:
116
+ query: The query to search for
117
+
118
+ Returns:
119
+ Cached result if found, None otherwise
120
+ """
121
+ # First try exact match
122
+ exact_key = self.generate_cache_key(query, include_normalization=False)
123
+ if exact_key in self.cache:
124
+ entry = self.cache[exact_key]
125
+ if self._is_valid(entry):
126
+ logger.info(f"Cache hit (exact): {query[:50]}...")
127
+ self._update_stats(query)
128
+ return entry['result']
129
+
130
+ # Try normalized match
131
+ normalized = self.normalize_query(query)
132
+ normalized_key = hashlib.sha256(normalized.encode()).hexdigest()[:16]
133
+
134
+ if normalized_key in self.normalized_cache:
135
+ # Check all original queries that normalize to this
136
+ for original_key in self.normalized_cache[normalized_key]:
137
+ if original_key in self.cache:
138
+ entry = self.cache[original_key]
139
+ if self._is_valid(entry):
140
+ logger.info(f"Cache hit (normalized): {query[:50]}...")
141
+ self._update_stats(query)
142
+ return entry['result']
143
+
144
+ logger.info(f"Cache miss: {query[:50]}...")
145
+ return None
146
+
147
+ def store(self, query: str, result: Any, metadata: Optional[Dict] = None):
148
+ """
149
+ Store a query result in cache.
150
+
151
+ Args:
152
+ query: The original query
153
+ result: The result to cache
154
+ metadata: Optional metadata about the result
155
+ """
156
+ cache_key = self.generate_cache_key(query, include_normalization=True)
157
+
158
+ entry = {
159
+ 'query': query,
160
+ 'result': result,
161
+ 'timestamp': datetime.now().isoformat(),
162
+ 'metadata': metadata or {},
163
+ 'access_count': 0
164
+ }
165
+
166
+ self.cache[cache_key] = entry
167
+ self._update_stats(query)
168
+
169
+ # Persist to disk asynchronously (non-blocking)
170
+ asyncio.create_task(self._save_cache_async())
171
+
172
+ logger.info(f"Cached result for: {query[:50]}...")
173
+
174
+ def _is_valid(self, entry: Dict) -> bool:
175
+ """Check if a cache entry is still valid (not expired)"""
176
+ try:
177
+ timestamp = datetime.fromisoformat(entry['timestamp'])
178
+ age = datetime.now() - timestamp
179
+ return age < self.ttl
180
+ except:
181
+ return False
182
+
183
+ def _update_stats(self, query: str):
184
+ """Update query frequency statistics"""
185
+ normalized = self.normalize_query(query)
186
+ if normalized not in self.query_stats:
187
+ self.query_stats[normalized] = {'count': 0, 'last_access': None}
188
+
189
+ self.query_stats[normalized]['count'] += 1
190
+ self.query_stats[normalized]['last_access'] = datetime.now().isoformat()
191
+
192
+ async def pre_warm_cache(self, queries: List[Dict[str, Any]],
193
+ search_func=None, llm_func=None):
194
+ """
195
+ Pre-warm cache with common queries.
196
+
197
+ Args:
198
+ queries: List of dicts with 'query', 'search_terms', 'use_claude' keys
199
+ search_func: Async function to perform searches
200
+ llm_func: Async function to call Claude for high-priority queries
201
+ """
202
+ logger.info(f"Pre-warming cache with {len(queries)} queries...")
203
+
204
+ for query_config in queries:
205
+ query = query_config['query']
206
+
207
+ # Check if already cached
208
+ if self.find_similar_cached(query):
209
+ logger.info(f"Already cached: {query}")
210
+ continue
211
+
212
+ try:
213
+ # Use optimized search terms if provided
214
+ search_terms = query_config.get('search_terms', query)
215
+ use_claude = query_config.get('use_claude', False)
216
+
217
+ if search_func:
218
+ # Perform search with optimized terms
219
+ logger.info(f"Pre-warming: {query}")
220
+
221
+ if use_claude and llm_func:
222
+ # Use Claude for high-priority queries
223
+ result = await llm_func(search_terms)
224
+ else:
225
+ # Use standard search
226
+ result = await search_func(search_terms)
227
+
228
+ # Cache the result
229
+ self.store(query, result, {
230
+ 'pre_warmed': True,
231
+ 'optimized_terms': search_terms,
232
+ 'used_claude': use_claude
233
+ })
234
+
235
+ # Small delay to avoid overwhelming APIs
236
+ await asyncio.sleep(1)
237
+
238
+ except Exception as e:
239
+ logger.error(f"Failed to pre-warm cache for '{query}': {e}")
240
+
241
+ def add_high_frequency_query(self, query: str, config: Dict[str, Any]):
242
+ """
243
+ Add a high-frequency query configuration.
244
+
245
+ Args:
246
+ query: The query pattern
247
+ config: Configuration dict with search_terms, use_claude, etc.
248
+ """
249
+ normalized = self.normalize_query(query)
250
+ self.high_frequency_queries[normalized] = {
251
+ 'original': query,
252
+ 'config': config,
253
+ 'added': datetime.now().isoformat()
254
+ }
255
+ logger.info(f"Added high-frequency query: {query}")
256
+
257
+ def get_high_frequency_config(self, query: str) -> Optional[Dict[str, Any]]:
258
+ """
259
+ Get configuration for a high-frequency query if it matches.
260
+
261
+ Args:
262
+ query: The query to check
263
+
264
+ Returns:
265
+ Configuration dict if this is a high-frequency query
266
+ """
267
+ normalized = self.normalize_query(query)
268
+ if normalized in self.high_frequency_queries:
269
+ return self.high_frequency_queries[normalized]['config']
270
+ return None
271
+
272
+ def get_cache_stats(self) -> Dict[str, Any]:
273
+ """Get cache statistics"""
274
+ valid_entries = sum(1 for entry in self.cache.values() if self._is_valid(entry))
275
+ total_entries = len(self.cache)
276
+
277
+ # Get top queries
278
+ top_queries = sorted(
279
+ self.query_stats.items(),
280
+ key=lambda x: x[1]['count'],
281
+ reverse=True
282
+ )[:10]
283
+
284
+ return {
285
+ 'total_entries': total_entries,
286
+ 'valid_entries': valid_entries,
287
+ 'expired_entries': total_entries - valid_entries,
288
+ 'normalized_groups': len(self.normalized_cache),
289
+ 'high_frequency_queries': len(self.high_frequency_queries),
290
+ 'top_queries': [
291
+ {'query': q, 'count': stats['count']}
292
+ for q, stats in top_queries
293
+ ]
294
+ }
295
+
296
+ def clear_expired(self):
297
+ """Remove expired entries from cache"""
298
+ expired_keys = [
299
+ key for key, entry in self.cache.items()
300
+ if not self._is_valid(entry)
301
+ ]
302
+
303
+ for key in expired_keys:
304
+ del self.cache[key]
305
+
306
+ if expired_keys:
307
+ logger.info(f"Cleared {len(expired_keys)} expired cache entries")
308
+ self.save_cache()
309
+
310
+ def save_cache(self):
311
+ """Persist cache to disk"""
312
+ cache_file = f"{self.cache_dir}/smart_cache.json"
313
+ try:
314
+ with open(cache_file, 'w') as f:
315
+ json.dump({
316
+ 'cache': self.cache,
317
+ 'normalized_cache': self.normalized_cache,
318
+ 'high_frequency_queries': self.high_frequency_queries,
319
+ 'query_stats': self.query_stats
320
+ }, f, indent=2)
321
+ logger.debug(f"Cache saved to {cache_file}")
322
+ except Exception as e:
323
+ logger.error(f"Failed to save cache: {e}")
324
+
325
+ async def _save_cache_async(self):
326
+ """Async version of save_cache that doesn't block"""
327
+ try:
328
+ await asyncio.to_thread(self.save_cache)
329
+ except Exception as e:
330
+ logger.error(f"Failed to save cache asynchronously: {e}")
331
+
332
+ def load_cache(self):
333
+ """Load cache from disk"""
334
+ cache_file = f"{self.cache_dir}/smart_cache.json"
335
+ try:
336
+ with open(cache_file, 'r') as f:
337
+ data = json.load(f)
338
+ self.cache = data.get('cache', {})
339
+ self.normalized_cache = data.get('normalized_cache', {})
340
+ self.high_frequency_queries = data.get('high_frequency_queries', {})
341
+ self.query_stats = data.get('query_stats', {})
342
+
343
+ # Clear expired entries on load
344
+ self.clear_expired()
345
+
346
+ logger.info(f"Loaded cache with {len(self.cache)} entries")
347
+ except FileNotFoundError:
348
+ logger.info("No existing cache file found")
349
+ except Exception as e:
350
+ logger.error(f"Failed to load cache: {e}")
351
+
352
+
353
+ # Configuration for common ALS queries to pre-warm
354
+ DEFAULT_PREWARM_QUERIES = [
355
+ {
356
+ 'query': 'What are the latest ALS treatments?',
357
+ 'search_terms': 'ALS treatment therapy 2024 riluzole edaravone',
358
+ 'use_claude': True # High-frequency, use Claude for best results
359
+ },
360
+ {
361
+ 'query': 'Gene therapy for ALS',
362
+ 'search_terms': 'ALS gene therapy SOD1 C9orf72 clinical trial',
363
+ 'use_claude': True
364
+ },
365
+ {
366
+ 'query': 'ALS clinical trials',
367
+ 'search_terms': 'ALS clinical trials recruiting phase 2 phase 3',
368
+ 'use_claude': False
369
+ },
370
+ {
371
+ 'query': 'What causes ALS?',
372
+ 'search_terms': 'ALS etiology pathogenesis genetic environmental factors',
373
+ 'use_claude': True
374
+ },
375
+ {
376
+ 'query': 'ALS symptoms and diagnosis',
377
+ 'search_terms': 'ALS symptoms diagnosis EMG criteria El Escorial',
378
+ 'use_claude': False
379
+ },
380
+ {
381
+ 'query': 'Stem cell therapy for ALS',
382
+ 'search_terms': 'ALS stem cell therapy mesenchymal clinical trial',
383
+ 'use_claude': False
384
+ },
385
+ {
386
+ 'query': 'ALS prognosis and life expectancy',
387
+ 'search_terms': 'ALS prognosis survival life expectancy factors',
388
+ 'use_claude': True
389
+ },
390
+ {
391
+ 'query': 'New ALS drugs',
392
+ 'search_terms': 'ALS new drugs FDA approved pipeline 2024',
393
+ 'use_claude': False
394
+ },
395
+ {
396
+ 'query': 'ALS biomarkers',
397
+ 'search_terms': 'ALS biomarkers neurofilament TDP-43 diagnostic prognostic',
398
+ 'use_claude': False
399
+ },
400
+ {
401
+ 'query': 'Is there a cure for ALS?',
402
+ 'search_terms': 'ALS cure breakthrough research treatment advances',
403
+ 'use_claude': True
404
+ }
405
+ ]
406
+
407
+
408
+ def test_smart_cache():
409
+ """Test the smart cache functionality"""
410
+ print("Testing Smart Cache System")
411
+ print("=" * 60)
412
+
413
+ cache = SmartCache()
414
+
415
+ # Test query normalization
416
+ test_queries = [
417
+ ("What are the latest ALS gene therapy trials?", "ALS gene therapy trials"),
418
+ ("gene therapy ALS", "ALS gene therapy"),
419
+ ("What is ALS?", "ALS"),
420
+ ("HOW does riluzole work for ALS?", "ALS riluzole work"),
421
+ ]
422
+
423
+ print("\n1. Query Normalization Tests:")
424
+ for original, expected_words in test_queries:
425
+ normalized = cache.normalize_query(original)
426
+ print(f" Original: {original}")
427
+ print(f" Normalized: {normalized}")
428
+ print(f" Expected words present: {all(w in normalized for w in expected_words.lower().split())}")
429
+ print()
430
+
431
+ # Test similar query matching
432
+ print("\n2. Similar Query Matching:")
433
+ cache.store("What are the latest ALS treatments?", {"result": "Treatment data"})
434
+
435
+ similar_queries = [
436
+ "latest ALS treatments",
437
+ "ALS latest treatments",
438
+ "What are latest treatments for ALS?",
439
+ "treatments ALS latest"
440
+ ]
441
+
442
+ for query in similar_queries:
443
+ result = cache.find_similar_cached(query)
444
+ print(f" Query: {query}")
445
+ print(f" Found: {result is not None}")
446
+
447
+ # Test cache statistics
448
+ print("\n3. Cache Statistics:")
449
+ stats = cache.get_cache_stats()
450
+ print(f" Total entries: {stats['total_entries']}")
451
+ print(f" Valid entries: {stats['valid_entries']}")
452
+ print(f" Normalized groups: {stats['normalized_groups']}")
453
+
454
+ print("\n✅ Smart cache tests completed!")
455
+
456
+
457
+ if __name__ == "__main__":
458
+ test_smart_cache()