diff --git a/.cursorrules b/.cursorrules
deleted file mode 100644
index 8fbe6def025d95d15c47f657eafbbbf0643a5ca5..0000000000000000000000000000000000000000
--- a/.cursorrules
+++ /dev/null
@@ -1,240 +0,0 @@
-# DeepCritical Project - Cursor Rules
-
-## Project-Wide Rules
-
-**Architecture**: Multi-agent research system using Pydantic AI for agent orchestration, supporting iterative and deep research patterns. Uses middleware for state management, budget tracking, and workflow coordination.
-
-**Type Safety**: ALWAYS use complete type hints. All functions must have parameter and return type annotations. Use `mypy --strict` compliance. Use `TYPE_CHECKING` imports for circular dependencies: `from typing import TYPE_CHECKING; if TYPE_CHECKING: from src.services.embeddings import EmbeddingService`
-
-**Async Patterns**: ALL I/O operations must be async (`async def`, `await`). Use `asyncio.gather()` for parallel operations. CPU-bound work must use `run_in_executor()`: `loop = asyncio.get_running_loop(); result = await loop.run_in_executor(None, cpu_bound_function, args)`. Never block the event loop.
-
-**Error Handling**: Use custom exceptions from `src/utils/exceptions.py`: `DeepCriticalError`, `SearchError`, `RateLimitError`, `JudgeError`, `ConfigurationError`. Always chain exceptions: `raise SearchError(...) from e`. Log with structlog: `logger.error("Operation failed", error=str(e), context=value)`.
-
-**Logging**: Use `structlog` for ALL logging (NOT `print` or `logging`). Import: `import structlog; logger = structlog.get_logger()`. Log with structured data: `logger.info("event", key=value)`. Use appropriate levels: DEBUG, INFO, WARNING, ERROR.
-
-**Pydantic Models**: All data exchange uses Pydantic models from `src/utils/models.py`. Models are frozen (`model_config = {"frozen": True}`) for immutability. Use `Field()` with descriptions. Validate with `ge=`, `le=`, `min_length=`, `max_length=` constraints.
-
-**Code Style**: Ruff with 100-char line length. Ignore rules: `PLR0913` (too many arguments), `PLR0912` (too many branches), `PLR0911` (too many returns), `PLR2004` (magic values), `PLW0603` (global statement), `PLC0415` (lazy imports).
-
-**Docstrings**: Google-style docstrings for all public functions. Include Args, Returns, Raises sections. Use type hints in docstrings only if needed for clarity.
-
-**Testing**: Unit tests in `tests/unit/` (mocked, fast). Integration tests in `tests/integration/` (real APIs, marked `@pytest.mark.integration`). Use `respx` for httpx mocking, `pytest-mock` for general mocking.
-
-**State Management**: Use `ContextVar` in middleware for thread-safe isolation. Never use global mutable state (except singletons via `@lru_cache`). Use `WorkflowState` from `src/middleware/state_machine.py` for workflow state.
-
-**Citation Validation**: ALWAYS validate references before returning reports. Use `validate_references()` from `src/utils/citation_validator.py`. Remove hallucinated citations. Log warnings for removed citations.
-
----
-
-## src/agents/ - Agent Implementation Rules
-
-**Pattern**: All agents use Pydantic AI `Agent` class. Agents have structured output types (Pydantic models) or return strings. Use factory functions in `src/agent_factory/agents.py` for creation.
-
-**Agent Structure**:
-- System prompt as module-level constant (with date injection: `datetime.now().strftime("%Y-%m-%d")`)
-- Agent class with `__init__(model: Any | None = None)`
-- Main method (e.g., `async def evaluate()`, `async def write_report()`)
-- Factory function: `def create_agent_name(model: Any | None = None) -> AgentName`
-
-**Model Initialization**: Use `get_model()` from `src/agent_factory/judges.py` if no model provided. Support OpenAI/Anthropic/HF Inference via settings.
-
-**Error Handling**: Return fallback values (e.g., `KnowledgeGapOutput(research_complete=False, outstanding_gaps=[...])`) on failure. Log errors with context. Use retry logic (3 retries) in Pydantic AI Agent initialization.
-
-**Input Validation**: Validate query/inputs are not empty. Truncate very long inputs with warnings. Handle None values gracefully.
-
-**Output Types**: Use structured output types from `src/utils/models.py` (e.g., `KnowledgeGapOutput`, `AgentSelectionPlan`, `ReportDraft`). For text output (writer agents), return `str` directly.
-
-**Agent-Specific Rules**:
-- `knowledge_gap.py`: Outputs `KnowledgeGapOutput`. Evaluates research completeness.
-- `tool_selector.py`: Outputs `AgentSelectionPlan`. Selects tools (RAG/web/database).
-- `writer.py`: Returns markdown string. Includes citations in numbered format.
-- `long_writer.py`: Uses `ReportDraft` input/output. Handles section-by-section writing.
-- `proofreader.py`: Takes `ReportDraft`, returns polished markdown.
-- `thinking.py`: Returns observation string from conversation history.
-- `input_parser.py`: Outputs `ParsedQuery` with research mode detection.
-
----
-
-## src/tools/ - Search Tool Rules
-
-**Protocol**: All tools implement `SearchTool` protocol from `src/tools/base.py`: `name` property and `async def search(query, max_results) -> list[Evidence]`.
-
-**Rate Limiting**: Use `@retry` decorator from tenacity: `@retry(stop=stop_after_attempt(3), wait=wait_exponential(...))`. Implement `_rate_limit()` method for APIs with limits. Use shared rate limiters from `src/tools/rate_limiter.py`.
-
-**Error Handling**: Raise `SearchError` or `RateLimitError` on failures. Handle HTTP errors (429, 500, timeout). Return empty list on non-critical errors (log warning).
-
-**Query Preprocessing**: Use `preprocess_query()` from `src/tools/query_utils.py` to remove noise and expand synonyms.
-
-**Evidence Conversion**: Convert API responses to `Evidence` objects with `Citation`. Extract metadata (title, url, date, authors). Set relevance scores (0.0-1.0). Handle missing fields gracefully.
-
-**Tool-Specific Rules**:
-- `pubmed.py`: Use NCBI E-utilities (ESearch → EFetch). Rate limit: 0.34s between requests. Parse XML with `xmltodict`. Handle single vs. multiple articles.
-- `clinicaltrials.py`: Use `requests` library (NOT httpx - WAF blocks httpx). Run in thread pool: `await asyncio.to_thread(requests.get, ...)`. Filter: Only interventional studies, active/completed.
-- `europepmc.py`: Handle preprint markers: `[PREPRINT - Not peer-reviewed]`. Build URLs from DOI or PMID.
-- `rag_tool.py`: Wraps `LlamaIndexRAGService`. Returns Evidence from RAG results. Handles ingestion.
-- `search_handler.py`: Orchestrates parallel searches across multiple tools. Uses `asyncio.gather()` with `return_exceptions=True`. Aggregates results into `SearchResult`.
-
----
-
-## src/middleware/ - Middleware Rules
-
-**State Management**: Use `ContextVar` for thread-safe isolation. `WorkflowState` uses `ContextVar[WorkflowState | None]`. Initialize with `init_workflow_state(embedding_service)`. Access with `get_workflow_state()` (auto-initializes if missing).
-
-**WorkflowState**: Tracks `evidence: list[Evidence]`, `conversation: Conversation`, `embedding_service: Any`. Methods: `add_evidence()` (deduplicates by URL), `async search_related()` (semantic search).
-
-**WorkflowManager**: Manages parallel research loops. Methods: `add_loop()`, `run_loops_parallel()`, `update_loop_status()`, `sync_loop_evidence_to_state()`. Uses `asyncio.gather()` for parallel execution. Handles errors per loop (don't fail all if one fails).
-
-**BudgetTracker**: Tracks tokens, time, iterations per loop and globally. Methods: `create_budget()`, `add_tokens()`, `start_timer()`, `update_timer()`, `increment_iteration()`, `check_budget()`, `can_continue()`. Token estimation: `estimate_tokens(text)` (~4 chars per token), `estimate_llm_call_tokens(prompt, response)`.
-
-**Models**: All middleware models in `src/utils/models.py`. `IterationData`, `Conversation`, `ResearchLoop`, `BudgetStatus` are used by middleware.
-
----
-
-## src/orchestrator/ - Orchestration Rules
-
-**Research Flows**: Two patterns: `IterativeResearchFlow` (single loop) and `DeepResearchFlow` (plan → parallel loops → synthesis). Both support agent chains (`use_graph=False`) and graph execution (`use_graph=True`).
-
-**IterativeResearchFlow**: Pattern: Generate observations → Evaluate gaps → Select tools → Execute → Judge → Continue/Complete. Uses `KnowledgeGapAgent`, `ToolSelectorAgent`, `ThinkingAgent`, `WriterAgent`, `JudgeHandler`. Tracks iterations, time, budget.
-
-**DeepResearchFlow**: Pattern: Planner → Parallel iterative loops per section → Synthesizer. Uses `PlannerAgent`, `IterativeResearchFlow` (per section), `LongWriterAgent` or `ProofreaderAgent`. Uses `WorkflowManager` for parallel execution.
-
-**Graph Orchestrator**: Uses Pydantic AI Graphs (when available) or agent chains (fallback). Routes based on research mode (iterative/deep/auto). Streams `AgentEvent` objects for UI.
-
-**State Initialization**: Always call `init_workflow_state()` before running flows. Initialize `BudgetTracker` per loop. Use `WorkflowManager` for parallel coordination.
-
-**Event Streaming**: Yield `AgentEvent` objects during execution. Event types: "started", "search_complete", "judge_complete", "hypothesizing", "synthesizing", "complete", "error". Include iteration numbers and data payloads.
-
----
-
-## src/services/ - Service Rules
-
-**EmbeddingService**: Local sentence-transformers (NO API key required). All operations async-safe via `run_in_executor()`. ChromaDB for vector storage. Deduplication threshold: 0.85 (85% similarity = duplicate).
-
-**LlamaIndexRAGService**: Uses OpenAI embeddings (requires `OPENAI_API_KEY`). Methods: `ingest_evidence()`, `retrieve()`, `query()`. Returns documents with metadata (source, title, url, date, authors). Lazy initialization with graceful fallback.
-
-**StatisticalAnalyzer**: Generates Python code via LLM. Executes in Modal sandbox (secure, isolated). Library versions pinned in `SANDBOX_LIBRARIES` dict. Returns `AnalysisResult` with verdict (SUPPORTED/REFUTED/INCONCLUSIVE).
-
-**Singleton Pattern**: Use `@lru_cache(maxsize=1)` for singletons: `@lru_cache(maxsize=1); def get_service() -> Service: return Service()`. Lazy initialization to avoid requiring dependencies at import time.
-
----
-
-## src/utils/ - Utility Rules
-
-**Models**: All Pydantic models in `src/utils/models.py`. Use frozen models (`model_config = {"frozen": True}`) except where mutation needed. Use `Field()` with descriptions. Validate with constraints.
-
-**Config**: Settings via Pydantic Settings (`src/utils/config.py`). Load from `.env` automatically. Use `settings` singleton: `from src.utils.config import settings`. Validate API keys with properties: `has_openai_key`, `has_anthropic_key`.
-
-**Exceptions**: Custom exception hierarchy in `src/utils/exceptions.py`. Base: `DeepCriticalError`. Specific: `SearchError`, `RateLimitError`, `JudgeError`, `ConfigurationError`. Always chain exceptions.
-
-**LLM Factory**: Centralized LLM model creation in `src/utils/llm_factory.py`. Supports OpenAI, Anthropic, HF Inference. Use `get_model()` or factory functions. Check requirements before initialization.
-
-**Citation Validator**: Use `validate_references()` from `src/utils/citation_validator.py`. Removes hallucinated citations (URLs not in evidence). Logs warnings. Returns validated report string.
-
----
-
-## src/orchestrator_factory.py Rules
-
-**Purpose**: Factory for creating orchestrators. Supports "simple" (legacy) and "advanced" (magentic) modes. Auto-detects mode based on API key availability.
-
-**Pattern**: Lazy import for optional dependencies (`_get_magentic_orchestrator_class()`). Handles `ImportError` gracefully with clear error messages.
-
-**Mode Detection**: `_determine_mode()` checks explicit mode or auto-detects: "advanced" if `settings.has_openai_key`, else "simple". Maps "magentic" → "advanced".
-
-**Function Signature**: `create_orchestrator(search_handler, judge_handler, config, mode) -> Any`. Simple mode requires handlers. Advanced mode uses MagenticOrchestrator.
-
-**Error Handling**: Raise `ValueError` with clear messages if requirements not met. Log mode selection with structlog.
-
----
-
-## src/orchestrator_hierarchical.py Rules
-
-**Purpose**: Hierarchical orchestrator using middleware and sub-teams. Adapts Magentic ChatAgent to SubIterationTeam protocol.
-
-**Pattern**: Uses `SubIterationMiddleware` with `ResearchTeam` and `LLMSubIterationJudge`. Event-driven via callback queue.
-
-**State Initialization**: Initialize embedding service with graceful fallback. Use `init_magentic_state()` (deprecated, but kept for compatibility).
-
-**Event Streaming**: Uses `asyncio.Queue` for event coordination. Yields `AgentEvent` objects. Handles event callback pattern with `asyncio.wait()`.
-
-**Error Handling**: Log errors with context. Yield error events. Process remaining events after task completion.
-
----
-
-## src/orchestrator_magentic.py Rules
-
-**Purpose**: Magentic-based orchestrator using ChatAgent pattern. Each agent has internal LLM. Manager orchestrates agents.
-
-**Pattern**: Uses `MagenticBuilder` with participants (searcher, hypothesizer, judge, reporter). Manager uses `OpenAIChatClient`. Workflow built in `_build_workflow()`.
-
-**Event Processing**: `_process_event()` converts Magentic events to `AgentEvent`. Handles: `MagenticOrchestratorMessageEvent`, `MagenticAgentMessageEvent`, `MagenticFinalResultEvent`, `MagenticAgentDeltaEvent`, `WorkflowOutputEvent`.
-
-**Text Extraction**: `_extract_text()` defensively extracts text from messages. Priority: `.content` → `.text` → `str(message)`. Handles buggy message objects.
-
-**State Initialization**: Initialize embedding service with graceful fallback. Use `init_magentic_state()` (deprecated).
-
-**Requirements**: Must call `check_magentic_requirements()` in `__init__`. Requires `agent-framework-core` and OpenAI API key.
-
-**Event Types**: Maps agent names to event types: "search" → "search_complete", "judge" → "judge_complete", "hypothes" → "hypothesizing", "report" → "synthesizing".
-
----
-
-## src/agent_factory/ - Factory Rules
-
-**Pattern**: Factory functions for creating agents and handlers. Lazy initialization for optional dependencies. Support OpenAI/Anthropic/HF Inference.
-
-**Judges**: `create_judge_handler()` creates `JudgeHandler` with structured output (`JudgeAssessment`). Supports `MockJudgeHandler`, `HFInferenceJudgeHandler` as fallbacks.
-
-**Agents**: Factory functions in `agents.py` for all Pydantic AI agents. Pattern: `create_agent_name(model: Any | None = None) -> AgentName`. Use `get_model()` if model not provided.
-
-**Graph Builder**: `graph_builder.py` contains utilities for building research graphs. Supports iterative and deep research graph construction.
-
-**Error Handling**: Raise `ConfigurationError` if required API keys missing. Log agent creation. Handle import errors gracefully.
-
----
-
-## src/prompts/ - Prompt Rules
-
-**Pattern**: System prompts stored as module-level constants. Include date injection: `datetime.now().strftime("%Y-%m-%d")`. Format evidence with truncation (1500 chars per item).
-
-**Judge Prompts**: In `judge.py`. Handle empty evidence case separately. Always request structured JSON output.
-
-**Hypothesis Prompts**: In `hypothesis.py`. Use diverse evidence selection (MMR algorithm). Sentence-aware truncation.
-
-**Report Prompts**: In `report.py`. Include full citation details. Use diverse evidence selection (n=20). Emphasize citation validation rules.
-
----
-
-## Testing Rules
-
-**Structure**: Unit tests in `tests/unit/` (mocked, fast). Integration tests in `tests/integration/` (real APIs, marked `@pytest.mark.integration`).
-
-**Mocking**: Use `respx` for httpx mocking. Use `pytest-mock` for general mocking. Mock LLM calls in unit tests (use `MockJudgeHandler`).
-
-**Fixtures**: Common fixtures in `tests/conftest.py`: `mock_httpx_client`, `mock_llm_response`.
-
-**Coverage**: Aim for >80% coverage. Test error handling, edge cases, and integration paths.
-
----
-
-## File-Specific Agent Rules
-
-**knowledge_gap.py**: Outputs `KnowledgeGapOutput`. System prompt evaluates research completeness. Handles conversation history. Returns fallback on error.
-
-**writer.py**: Returns markdown string. System prompt includes citation format examples. Validates inputs. Truncates long findings. Retry logic for transient failures.
-
-**long_writer.py**: Uses `ReportDraft` input/output. Writes sections iteratively. Reformats references (deduplicates, renumbers). Reformats section headings.
-
-**proofreader.py**: Takes `ReportDraft`, returns polished markdown. Removes duplicates. Adds summary. Preserves references.
-
-**tool_selector.py**: Outputs `AgentSelectionPlan`. System prompt lists available agents (WebSearchAgent, SiteCrawlerAgent, RAGAgent). Guidelines for when to use each.
-
-**thinking.py**: Returns observation string. Generates observations from conversation history. Uses query and background context.
-
-**input_parser.py**: Outputs `ParsedQuery`. Detects research mode (iterative/deep). Extracts entities and research questions. Improves/refines query.
-
-
-
-
-
-
-
diff --git a/.env.example b/.env.example
deleted file mode 100644
index 442ff75d33f92422e78850b3c9d6d49af6f1d6e3..0000000000000000000000000000000000000000
--- a/.env.example
+++ /dev/null
@@ -1,107 +0,0 @@
-# HuggingFace
-HF_TOKEN=your_huggingface_token_here
-
-# OpenAI (optional)
-OPENAI_API_KEY=your_openai_key_here
-
-# Anthropic (optional)
-ANTHROPIC_API_KEY=your_anthropic_key_here
-
-# Model names (optional - sensible defaults set in config.py)
-# ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
-# OPENAI_MODEL=gpt-5.1
-
-
-# ============================================
-# Audio Processing Configuration (TTS)
-# ============================================
-# Kokoro TTS Model Configuration
-TTS_MODEL=hexgrad/Kokoro-82M
-TTS_VOICE=af_heart
-TTS_SPEED=1.0
-TTS_GPU=T4
-TTS_TIMEOUT=60
-
-# Available TTS Voices:
-# American English Female: af_heart, af_bella, af_nicole, af_aoede, af_kore, af_sarah, af_nova, af_sky, af_alloy, af_jessica, af_river
-# American English Male: am_michael, am_fenrir, am_puck, am_echo, am_eric, am_liam, am_onyx, am_santa, am_adam
-
-# Available GPU Types (Modal):
-# T4 - Cheapest, good for testing (default)
-# A10 - Good balance of cost/performance
-# A100 - Fastest, most expensive
-# L4 - NVIDIA L4 GPU
-# L40S - NVIDIA L40S GPU
-# Note: GPU type is set at function definition time. Changes require app restart.
-
-# ============================================
-# Audio Processing Configuration (STT)
-# ============================================
-# Speech-to-Text API Configuration
-STT_API_URL=nvidia/canary-1b-v2
-STT_SOURCE_LANG=English
-STT_TARGET_LANG=English
-
-# Available STT Languages:
-# English, Bulgarian, Croatian, Czech, Danish, Dutch, Estonian, Finnish, French, German, Greek, Hungarian, Italian, Latvian, Lithuanian, Maltese, Polish, Portuguese, Romanian, Slovak, Slovenian, Spanish, Swedish, Russian, Ukrainian
-
-# ============================================
-# Audio Feature Flags
-# ============================================
-ENABLE_AUDIO_INPUT=true
-ENABLE_AUDIO_OUTPUT=true
-
-# ============================================
-# Image OCR Configuration
-# ============================================
-OCR_API_URL=prithivMLmods/Multimodal-OCR3
-ENABLE_IMAGE_INPUT=true
-
-# ============== EMBEDDINGS ==============
-
-# OpenAI Embedding Model (used if LLM_PROVIDER is openai and performing RAG/Embeddings)
-OPENAI_EMBEDDING_MODEL=text-embedding-3-small
-
-# Local Embedding Model (used for local/offline embeddings)
-LOCAL_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
-
-# ============== HUGGINGFACE (FREE TIER) ==============
-
-# HuggingFace Token - enables Llama 3.1 (best quality free model)
-# Get yours at: https://huggingface.co/settings/tokens
-#
-# WITHOUT HF_TOKEN: Falls back to ungated models (zephyr-7b-beta)
-# WITH HF_TOKEN: Uses Llama 3.1 8B Instruct (requires accepting license)
-#
-# For HuggingFace Spaces deployment:
-# Set this as a "Secret" in Space Settings -> Variables and secrets
-# Users/judges don't need their own token - the Space secret is used
-#
-HF_TOKEN=hf_your-token-here
-
-# ============== AGENT CONFIGURATION ==============
-
-MAX_ITERATIONS=10
-SEARCH_TIMEOUT=30
-LOG_LEVEL=INFO
-
-# ============================================
-# Modal Configuration (Required for TTS)
-# ============================================
-# Modal credentials are required for TTS (Text-to-Speech) functionality
-# Get your credentials from: https://modal.com/
-MODAL_TOKEN_ID=your_modal_token_id_here
-MODAL_TOKEN_SECRET=your_modal_token_secret_here
-
-# ============== EXTERNAL SERVICES ==============
-
-# PubMed (optional - higher rate limits)
-NCBI_API_KEY=your-ncbi-key-here
-
-# Vector Database (optional - for LlamaIndex RAG)
-CHROMA_DB_PATH=./chroma_db
-# Neo4j Knowledge Graph
-NEO4J_URI=bolt://localhost:7687
-NEO4J_USER=neo4j
-NEO4J_PASSWORD=your_neo4j_password_here
-NEO4J_DATABASE=your_database_name
diff --git a/.gitignore b/.gitignore
deleted file mode 100644
index 9a982ecba57a76ced6b098ad431436049f052514..0000000000000000000000000000000000000000
--- a/.gitignore
+++ /dev/null
@@ -1,84 +0,0 @@
-folder/
-site/
-.cursor/
-.ruff_cache/
-# Python
-__pycache__/
-*.py[cod]
-*$py.class
-*.so
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-
-# Virtual environments
-.venv/
-venv/
-ENV/
-env/
-
-# IDE
-.vscode/
-.idea/
-*.swp
-*.swo
-
-# Environment
-.env
-.env.local
-*.local
-
-# Claude
-.claude/
-
-# Burner docs (working drafts, not for commit)
-burner_docs/
-
-# Reference repos (clone locally, don't commit)
-reference_repos/autogen-microsoft/
-reference_repos/claude-agent-sdk/
-reference_repos/pydanticai-research-agent/
-reference_repos/pubmed-mcp-server/
-reference_repos/DeepCritical/
-
-# Keep the README in reference_repos
-!reference_repos/README.md
-
-# Development directory
-dev/
-
-# OS
-.DS_Store
-Thumbs.db
-
-# Logs
-*.log
-logs/
-
-# Testing
-.pytest_cache/
-.mypy_cache/
-.coverage
-htmlcov/
-test_output*.txt
-
-# Database files
-chroma_db/
-*.sqlite3
-
-
-# Trigger rebuild Wed Nov 26 17:51:41 EST 2025
-.env
diff --git a/.pre-commit-hooks/run_pytest.ps1 b/.pre-commit-hooks/run_pytest.ps1
deleted file mode 100644
index 3df4f371b845a48ce3a1ea32e307218abbd5a033..0000000000000000000000000000000000000000
--- a/.pre-commit-hooks/run_pytest.ps1
+++ /dev/null
@@ -1,19 +0,0 @@
-# PowerShell pytest runner for pre-commit (Windows)
-# Uses uv if available, otherwise falls back to pytest
-
-if (Get-Command uv -ErrorAction SilentlyContinue) {
- # Sync dependencies before running tests
- uv sync
- uv run pytest $args
-} else {
- Write-Warning "uv not found, using system pytest (may have missing dependencies)"
- pytest $args
-}
-
-
-
-
-
-
-
-
diff --git a/.pre-commit-hooks/run_pytest.sh b/.pre-commit-hooks/run_pytest.sh
deleted file mode 100644
index b2a4be920113fd340631f64602c24042e8c81086..0000000000000000000000000000000000000000
--- a/.pre-commit-hooks/run_pytest.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-# Cross-platform pytest runner for pre-commit
-# Uses uv if available, otherwise falls back to pytest
-
-if command -v uv >/dev/null 2>&1; then
- # Sync dependencies before running tests
- uv sync
- uv run pytest "$@"
-else
- echo "Warning: uv not found, using system pytest (may have missing dependencies)"
- pytest "$@"
-fi
-
-
-
-
-
-
-
-
diff --git a/.pre-commit-hooks/run_pytest_embeddings.ps1 b/.pre-commit-hooks/run_pytest_embeddings.ps1
deleted file mode 100644
index 47a3e32a202240c42e5a205d2afd778a23292db7..0000000000000000000000000000000000000000
--- a/.pre-commit-hooks/run_pytest_embeddings.ps1
+++ /dev/null
@@ -1,14 +0,0 @@
-# PowerShell wrapper to sync embeddings dependencies and run embeddings tests
-
-$ErrorActionPreference = "Stop"
-
-if (Get-Command uv -ErrorAction SilentlyContinue) {
- Write-Host "Syncing embeddings dependencies..."
- uv sync --extra embeddings
- Write-Host "Running embeddings tests..."
- uv run pytest tests/ -v -m local_embeddings --tb=short -p no:logfire
-} else {
- Write-Error "uv not found"
- exit 1
-}
-
diff --git a/.pre-commit-hooks/run_pytest_embeddings.sh b/.pre-commit-hooks/run_pytest_embeddings.sh
deleted file mode 100644
index 6f1b80746217244367ee86fcd7d69837df648b40..0000000000000000000000000000000000000000
--- a/.pre-commit-hooks/run_pytest_embeddings.sh
+++ /dev/null
@@ -1,15 +0,0 @@
-#!/bin/bash
-# Wrapper script to sync embeddings dependencies and run embeddings tests
-
-set -e
-
-if command -v uv >/dev/null 2>&1; then
- echo "Syncing embeddings dependencies..."
- uv sync --extra embeddings
- echo "Running embeddings tests..."
- uv run pytest tests/ -v -m local_embeddings --tb=short -p no:logfire
-else
- echo "Error: uv not found"
- exit 1
-fi
-
diff --git a/.pre-commit-hooks/run_pytest_unit.ps1 b/.pre-commit-hooks/run_pytest_unit.ps1
deleted file mode 100644
index c1196d22e86fe66a56d12f673c003ac88aa6b09f..0000000000000000000000000000000000000000
--- a/.pre-commit-hooks/run_pytest_unit.ps1
+++ /dev/null
@@ -1,14 +0,0 @@
-# PowerShell wrapper to sync dependencies and run unit tests
-
-$ErrorActionPreference = "Stop"
-
-if (Get-Command uv -ErrorAction SilentlyContinue) {
- Write-Host "Syncing dependencies..."
- uv sync
- Write-Host "Running unit tests..."
- uv run pytest tests/unit/ -v -m "not openai and not embedding_provider" --tb=short -p no:logfire
-} else {
- Write-Error "uv not found"
- exit 1
-}
-
diff --git a/.pre-commit-hooks/run_pytest_unit.sh b/.pre-commit-hooks/run_pytest_unit.sh
deleted file mode 100644
index 173ab1b607647ecf4b4a1de6b75abd47fc0130ec..0000000000000000000000000000000000000000
--- a/.pre-commit-hooks/run_pytest_unit.sh
+++ /dev/null
@@ -1,15 +0,0 @@
-#!/bin/bash
-# Wrapper script to sync dependencies and run unit tests
-
-set -e
-
-if command -v uv >/dev/null 2>&1; then
- echo "Syncing dependencies..."
- uv sync
- echo "Running unit tests..."
- uv run pytest tests/unit/ -v -m "not openai and not embedding_provider" --tb=short -p no:logfire
-else
- echo "Error: uv not found"
- exit 1
-fi
-
diff --git a/.pre-commit-hooks/run_pytest_with_sync.ps1 b/.pre-commit-hooks/run_pytest_with_sync.ps1
deleted file mode 100644
index 546a5096bc6e4b9a46d039f5761234022b8658dd..0000000000000000000000000000000000000000
--- a/.pre-commit-hooks/run_pytest_with_sync.ps1
+++ /dev/null
@@ -1,25 +0,0 @@
-# PowerShell wrapper for pytest runner
-# Ensures uv is available and runs the Python script
-
-param(
- [Parameter(Position=0)]
- [string]$TestType = "unit"
-)
-
-$ErrorActionPreference = "Stop"
-
-# Check if uv is available
-if (-not (Get-Command uv -ErrorAction SilentlyContinue)) {
- Write-Error "uv not found. Please install uv: https://github.com/astral-sh/uv"
- exit 1
-}
-
-# Get the script directory
-$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
-$PythonScript = Join-Path $ScriptDir "run_pytest_with_sync.py"
-
-# Run the Python script using uv
-uv run python $PythonScript $TestType
-
-exit $LASTEXITCODE
-
diff --git a/.pre-commit-hooks/run_pytest_with_sync.py b/.pre-commit-hooks/run_pytest_with_sync.py
deleted file mode 100644
index 70c4eb52f8239d167868ba47b2d4bb80d9ac3173..0000000000000000000000000000000000000000
--- a/.pre-commit-hooks/run_pytest_with_sync.py
+++ /dev/null
@@ -1,235 +0,0 @@
-#!/usr/bin/env python3
-"""Cross-platform pytest runner that syncs dependencies before running tests."""
-
-import shutil
-import subprocess
-import sys
-from pathlib import Path
-
-
-def clean_caches(project_root: Path) -> None:
- """Remove pytest and Python cache directories and files.
-
- Comprehensively removes all cache files and directories to ensure
- clean test runs. Only scans specific directories to avoid resource
- exhaustion from scanning large directories like .venv on Windows.
- """
- # Directories to scan for caches (only project code, not dependencies)
- scan_dirs = ["src", "tests", ".pre-commit-hooks"]
-
- # Directories to exclude (to avoid resource issues)
- exclude_dirs = {
- ".venv",
- "venv",
- "ENV",
- "env",
- ".git",
- "node_modules",
- "dist",
- "build",
- ".eggs",
- "reference_repos",
- "folder",
- }
-
- # Comprehensive list of cache patterns to remove
- cache_patterns = [
- ".pytest_cache",
- "__pycache__",
- "*.pyc",
- "*.pyo",
- "*.pyd",
- ".mypy_cache",
- ".ruff_cache",
- ".coverage",
- "coverage.xml",
- "htmlcov",
- ".hypothesis", # Hypothesis testing framework cache
- ".tox", # Tox cache (if used)
- ".cache", # General Python cache
- ]
-
- def should_exclude(path: Path) -> bool:
- """Check if a path should be excluded from cache cleanup."""
- # Check if any parent directory is in exclude list
- for parent in path.parents:
- if parent.name in exclude_dirs:
- return True
- # Check if the path itself is excluded
- if path.name in exclude_dirs:
- return True
- return False
-
- cleaned = []
-
- # Only scan specific directories to avoid resource exhaustion
- for scan_dir in scan_dirs:
- scan_path = project_root / scan_dir
- if not scan_path.exists():
- continue
-
- for pattern in cache_patterns:
- if "*" in pattern:
- # Handle glob patterns for files
- try:
- for cache_file in scan_path.rglob(pattern):
- if should_exclude(cache_file):
- continue
- try:
- if cache_file.is_file():
- cache_file.unlink()
- cleaned.append(str(cache_file.relative_to(project_root)))
- except OSError:
- pass # Ignore errors (file might be locked or already deleted)
- except OSError:
- pass # Ignore errors during directory traversal
- else:
- # Handle directory patterns
- try:
- for cache_dir in scan_path.rglob(pattern):
- if should_exclude(cache_dir):
- continue
- try:
- if cache_dir.is_dir():
- shutil.rmtree(cache_dir, ignore_errors=True)
- cleaned.append(str(cache_dir.relative_to(project_root)))
- except OSError:
- pass # Ignore errors (directory might be locked)
- except OSError:
- pass # Ignore errors during directory traversal
-
- # Also clean root-level caches (like .pytest_cache in project root)
- root_cache_patterns = [
- ".pytest_cache",
- ".mypy_cache",
- ".ruff_cache",
- ".coverage",
- "coverage.xml",
- "htmlcov",
- ".hypothesis",
- ".tox",
- ".cache",
- ".pytest",
- ]
- for pattern in root_cache_patterns:
- cache_path = project_root / pattern
- if cache_path.exists():
- try:
- if cache_path.is_dir():
- shutil.rmtree(cache_path, ignore_errors=True)
- elif cache_path.is_file():
- cache_path.unlink()
- cleaned.append(pattern)
- except OSError:
- pass
-
- # Also remove any .pyc files in root directory
- try:
- for pyc_file in project_root.glob("*.pyc"):
- try:
- pyc_file.unlink()
- cleaned.append(pyc_file.name)
- except OSError:
- pass
- except OSError:
- pass
-
- if cleaned:
- print(
- f"Cleaned {len(cleaned)} cache items: {', '.join(cleaned[:10])}{'...' if len(cleaned) > 10 else ''}"
- )
- else:
- print("No cache files found to clean")
-
-
-def run_command(
- cmd: list[str], check: bool = True, shell: bool = False, cwd: str | None = None
-) -> int:
- """Run a command and return exit code."""
- try:
- result = subprocess.run(
- cmd,
- check=check,
- shell=shell,
- cwd=cwd,
- env=None, # Use current environment, uv will handle venv
- )
- return result.returncode
- except subprocess.CalledProcessError as e:
- return e.returncode
- except FileNotFoundError:
- print(f"Error: Command not found: {cmd[0]}")
- return 1
-
-
-def main() -> int:
- """Main entry point."""
- import os
-
- # Get the project root (where pyproject.toml is)
- script_dir = Path(__file__).parent
- project_root = script_dir.parent
-
- # Change to project root to ensure uv works correctly
- os.chdir(project_root)
-
- # Clean caches before running tests
- print("Cleaning pytest and Python caches...")
- clean_caches(project_root)
-
- # Check if uv is available
- if run_command(["uv", "--version"], check=False) != 0:
- print("Error: uv not found. Please install uv: https://github.com/astral-sh/uv")
- return 1
-
- # Parse arguments
- test_type = sys.argv[1] if len(sys.argv) > 1 else "unit"
- extra_args = sys.argv[2:] if len(sys.argv) > 2 else []
-
- # Sync dependencies - always include dev
- # Note: embeddings dependencies are now in main dependencies, not optional
- # Use --extra dev for [project.optional-dependencies].dev (not --dev which is for [dependency-groups])
- sync_cmd = ["uv", "sync", "--extra", "dev"]
-
- print(f"Syncing dependencies for {test_type} tests...")
- if run_command(sync_cmd, cwd=project_root) != 0:
- return 1
-
- # Build pytest command - use uv run to ensure correct environment
- if test_type == "unit":
- pytest_args = [
- "tests/unit/",
- "-v",
- "-m",
- "not openai and not embedding_provider",
- "--tb=short",
- "-p",
- "no:logfire",
- "--cache-clear", # Clear pytest cache before running
- ]
- elif test_type == "embeddings":
- pytest_args = [
- "tests/",
- "-v",
- "-m",
- "local_embeddings",
- "--tb=short",
- "-p",
- "no:logfire",
- "--cache-clear", # Clear pytest cache before running
- ]
- else:
- pytest_args = []
-
- pytest_args.extend(extra_args)
-
- # Use uv run python -m pytest to ensure we use the venv's pytest
- # This is more reliable than uv run pytest which might find system pytest
- pytest_cmd = ["uv", "run", "python", "-m", "pytest", *pytest_args]
-
- print(f"Running {test_type} tests...")
- return run_command(pytest_cmd, cwd=project_root)
-
-
-if __name__ == "__main__":
- sys.exit(main())
diff --git a/.python-version b/.python-version
deleted file mode 100644
index 2c0733315e415bfb5e5b353f9996ecd964d395b2..0000000000000000000000000000000000000000
--- a/.python-version
+++ /dev/null
@@ -1 +0,0 @@
-3.11
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
deleted file mode 100644
index 9a93fd1f812752141d49e4e27efee17405ed9563..0000000000000000000000000000000000000000
--- a/CONTRIBUTING.md
+++ /dev/null
@@ -1,494 +0,0 @@
-# Contributing to The DETERMINATOR
-
-Thank you for your interest in contributing to The DETERMINATOR! This guide will help you get started.
-
-## Table of Contents
-
-- [Git Workflow](#git-workflow)
-- [Getting Started](#getting-started)
-- [Development Commands](#development-commands)
-- [MCP Integration](#mcp-integration)
-- [Common Pitfalls](#common-pitfalls)
-- [Key Principles](#key-principles)
-- [Pull Request Process](#pull-request-process)
-
-> **Note**: Additional sections (Code Style, Error Handling, Testing, Implementation Patterns, Code Quality, and Prompt Engineering) are available as separate pages in the [documentation](https://deepcritical.github.io/GradioDemo/contributing/).
-> **Note on Project Names**: "The DETERMINATOR" is the product name, "DeepCritical" is the organization/project name, and "determinator" is the Python package name.
-
-## Repository Information
-
-- **GitHub Repository**: [`DeepCritical/GradioDemo`](https://github.com/DeepCritical/GradioDemo) (source of truth, PRs, code review)
-- **HuggingFace Space**: [`DataQuests/DeepCritical`](https://huggingface.co/spaces/DataQuests/DeepCritical) (deployment/demo)
-- **Package Name**: `determinator` (Python package name in `pyproject.toml`)
-
-## Git Workflow
-
-- `main`: Production-ready (GitHub)
-- `dev`: Development integration (GitHub)
-- Use feature branches: `yourname-dev`
-- **NEVER** push directly to `main` or `dev` on HuggingFace
-- GitHub is source of truth; HuggingFace is for deployment
-
-### Dual Repository Setup
-
-This project uses a dual repository setup:
-
-- **GitHub (`DeepCritical/GradioDemo`)**: Source of truth for code, PRs, and code review
-- **HuggingFace (`DataQuests/DeepCritical`)**: Deployment target for the Gradio demo
-
-#### Remote Configuration
-
-When cloning, set up remotes as follows:
-
-```bash
-# Clone from GitHub
-git clone https://github.com/DeepCritical/GradioDemo.git
-cd GradioDemo
-
-# Add HuggingFace remote (optional, for deployment)
-git remote add huggingface-upstream https://huggingface.co/spaces/DataQuests/DeepCritical
-```
-
-**Important**: Never push directly to `main` or `dev` on HuggingFace. Always work through GitHub PRs. GitHub is the source of truth; HuggingFace is for deployment/demo only.
-
-## Getting Started
-
-1. **Fork the repository** on GitHub: [`DeepCritical/GradioDemo`](https://github.com/DeepCritical/GradioDemo)
-2. **Clone your fork**:
-
- ```bash
- git clone https://github.com/yourusername/GradioDemo.git
- cd GradioDemo
- ```
-
-3. **Install dependencies**:
-
- ```bash
- uv sync --all-extras
- uv run pre-commit install
- ```
-
-4. **Create a feature branch**:
-
- ```bash
- git checkout -b yourname-feature-name
- ```
-
-5. **Make your changes** following the guidelines below
-6. **Run checks**:
-
- ```bash
- uv run ruff check src tests
- uv run mypy src
- uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire
- ```
-
-7. **Commit and push**:
-
- ```bash
- git commit -m "Description of changes"
- git push origin yourname-feature-name
- ```
-
-8. **Create a pull request** on GitHub
-
-## Package Manager
-
-This project uses [`uv`](https://github.com/astral-sh/uv) as the package manager. All commands should be prefixed with `uv run` to ensure they run in the correct environment.
-
-### Installation
-
-```bash
-# Install uv if you haven't already (recommended: standalone installer)
-# Unix/macOS/Linux:
-curl -LsSf https://astral.sh/uv/install.sh | sh
-
-# Windows (PowerShell):
-powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
-
-# Alternative: pipx install uv
-# Or: pip install uv
-
-# Sync all dependencies including dev extras
-uv sync --all-extras
-
-# Install pre-commit hooks
-uv run pre-commit install
-```
-
-## Development Commands
-
-```bash
-# Installation
-uv sync --all-extras # Install all dependencies including dev
-uv run pre-commit install # Install pre-commit hooks
-
-# Code Quality Checks (run all before committing)
-uv run ruff check src tests # Lint with ruff
-uv run ruff format src tests # Format with ruff
-uv run mypy src # Type checking
-uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire # Tests with coverage
-
-# Testing Commands
-uv run pytest tests/unit/ -v -m "not openai" -p no:logfire # Run unit tests (excludes OpenAI tests)
-uv run pytest tests/ -v -m "huggingface" -p no:logfire # Run HuggingFace tests
-uv run pytest tests/ -v -p no:logfire # Run all tests
-uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire # Tests with terminal coverage
-uv run pytest --cov=src --cov-report=html -p no:logfire # Generate HTML coverage report (opens htmlcov/index.html)
-
-# Documentation Commands
-uv run mkdocs build # Build documentation
-uv run mkdocs serve # Serve documentation locally (http://127.0.0.1:8000)
-```
-
-### Test Markers
-
-The project uses pytest markers to categorize tests. See [Testing Guidelines](docs/contributing/testing.md) for details:
-
-- `unit`: Unit tests (mocked, fast)
-- `integration`: Integration tests (real APIs)
-- `slow`: Slow tests
-- `openai`: Tests requiring OpenAI API key
-- `huggingface`: Tests requiring HuggingFace API key
-- `embedding_provider`: Tests requiring API-based embedding providers
-- `local_embeddings`: Tests using local embeddings
-
-**Note**: The `-p no:logfire` flag disables the logfire plugin to avoid conflicts during testing.
-
-## Code Style & Conventions
-
-### Type Safety
-
-- **ALWAYS** use type hints for all function parameters and return types
-- Use `mypy --strict` compliance (no `Any` unless absolutely necessary)
-- Use `TYPE_CHECKING` imports for circular dependencies:
-
-
-[TYPE_CHECKING Import Pattern](../src/utils/citation_validator.py) start_line:8 end_line:11
-
-
-### Pydantic Models
-
-- All data exchange uses Pydantic models (`src/utils/models.py`)
-- Models are frozen (`model_config = {"frozen": True}`) for immutability
-- Use `Field()` with descriptions for all model fields
-- Validate with `ge=`, `le=`, `min_length=`, `max_length=` constraints
-
-### Async Patterns
-
-- **ALL** I/O operations must be async (`async def`, `await`)
-- Use `asyncio.gather()` for parallel operations
-- CPU-bound work (embeddings, parsing) must use `run_in_executor()`:
-
-```python
-loop = asyncio.get_running_loop()
-result = await loop.run_in_executor(None, cpu_bound_function, args)
-```
-
-- Never block the event loop with synchronous I/O
-
-### Linting
-
-- Ruff with 100-char line length
-- Ignore rules documented in `pyproject.toml`:
- - `PLR0913`: Too many arguments (agents need many params)
- - `PLR0912`: Too many branches (complex orchestrator logic)
- - `PLR0911`: Too many return statements (complex agent logic)
- - `PLR2004`: Magic values (statistical constants)
- - `PLW0603`: Global statement (singleton pattern)
- - `PLC0415`: Lazy imports for optional dependencies
-
-### Pre-commit
-
-- Pre-commit hooks run automatically on commit
-- Must pass: lint + typecheck + test-cov
-- Install hooks with: `uv run pre-commit install`
-- Note: `uv sync --all-extras` installs the pre-commit package, but you must run `uv run pre-commit install` separately to set up the git hooks
-
-## Error Handling & Logging
-
-### Exception Hierarchy
-
-Use custom exception hierarchy (`src/utils/exceptions.py`):
-
-
-[Exception Hierarchy](../src/utils/exceptions.py) start_line:4 end_line:31
-
-
-### Error Handling Rules
-
-- Always chain exceptions: `raise SearchError(...) from e`
-- Log errors with context using `structlog`:
-
-```python
-logger.error("Operation failed", error=str(e), context=value)
-```
-
-- Never silently swallow exceptions
-- Provide actionable error messages
-
-### Logging
-
-- Use `structlog` for all logging (NOT `print` or `logging`)
-- Import: `import structlog; logger = structlog.get_logger()`
-- Log with structured data: `logger.info("event", key=value)`
-- Use appropriate levels: DEBUG, INFO, WARNING, ERROR
-
-### Logging Examples
-
-```python
-logger.info("Starting search", query=query, tools=[t.name for t in tools])
-logger.warning("Search tool failed", tool=tool.name, error=str(result))
-logger.error("Assessment failed", error=str(e))
-```
-
-### Error Chaining
-
-Always preserve exception context:
-
-```python
-try:
- result = await api_call()
-except httpx.HTTPError as e:
- raise SearchError(f"API call failed: {e}") from e
-```
-
-## Testing Requirements
-
-### Test Structure
-
-- Unit tests in `tests/unit/` (mocked, fast)
-- Integration tests in `tests/integration/` (real APIs, marked `@pytest.mark.integration`)
-- Use markers: `unit`, `integration`, `slow`
-
-### Mocking
-
-- Use `respx` for httpx mocking
-- Use `pytest-mock` for general mocking
-- Mock LLM calls in unit tests (use `MockJudgeHandler`)
-- Fixtures in `tests/conftest.py`: `mock_httpx_client`, `mock_llm_response`
-
-### TDD Workflow
-
-1. Write failing test in `tests/unit/`
-2. Implement in `src/`
-3. Ensure test passes
-4. Run checks: `uv run ruff check src tests && uv run mypy src && uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire`
-
-### Test Examples
-
-```python
-@pytest.mark.unit
-async def test_pubmed_search(mock_httpx_client):
- tool = PubMedTool()
- results = await tool.search("metformin", max_results=5)
- assert len(results) > 0
- assert all(isinstance(r, Evidence) for r in results)
-
-@pytest.mark.integration
-async def test_real_pubmed_search():
- tool = PubMedTool()
- results = await tool.search("metformin", max_results=3)
- assert len(results) <= 3
-```
-
-### Test Coverage
-
-- Run `uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire` for coverage report
-- Run `uv run pytest --cov=src --cov-report=html -p no:logfire` for HTML coverage report (opens `htmlcov/index.html`)
-- Aim for >80% coverage on critical paths
-- Exclude: `__init__.py`, `TYPE_CHECKING` blocks
-
-## Implementation Patterns
-
-### Search Tools
-
-All tools implement `SearchTool` protocol (`src/tools/base.py`):
-
-- Must have `name` property
-- Must implement `async def search(query, max_results) -> list[Evidence]`
-- Use `@retry` decorator from tenacity for resilience
-- Rate limiting: Implement `_rate_limit()` for APIs with limits (e.g., PubMed)
-- Error handling: Raise `SearchError` or `RateLimitError` on failures
-
-Example pattern:
-
-```python
-class MySearchTool:
- @property
- def name(self) -> str:
- return "mytool"
-
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(...))
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- # Implementation
- return evidence_list
-```
-
-### Judge Handlers
-
-- Implement `JudgeHandlerProtocol` (`async def assess(question, evidence) -> JudgeAssessment`)
-- Use pydantic-ai `Agent` with `output_type=JudgeAssessment`
-- System prompts in `src/prompts/judge.py`
-- Support fallback handlers: `MockJudgeHandler`, `HFInferenceJudgeHandler`
-- Always return valid `JudgeAssessment` (never raise exceptions)
-
-### Agent Factory Pattern
-
-- Use factory functions for creating agents (`src/agent_factory/`)
-- Lazy initialization for optional dependencies (e.g., embeddings, Modal)
-- Check requirements before initialization:
-
-
-[Check Magentic Requirements](../src/utils/llm_factory.py) start_line:152 end_line:170
-
-
-### State Management
-
-- **Magentic Mode**: Use `ContextVar` for thread-safe state (`src/agents/state.py`)
-- **Simple Mode**: Pass state via function parameters
-- Never use global mutable state (except singletons via `@lru_cache`)
-
-### Singleton Pattern
-
-Use `@lru_cache(maxsize=1)` for singletons:
-
-
-[Singleton Pattern Example](../src/services/statistical_analyzer.py) start_line:252 end_line:255
-
-
-- Lazy initialization to avoid requiring dependencies at import time
-
-## Code Quality & Documentation
-
-### Docstrings
-
-- Google-style docstrings for all public functions
-- Include Args, Returns, Raises sections
-- Use type hints in docstrings only if needed for clarity
-
-Example:
-
-
-[Search Method Docstring Example](../src/tools/pubmed.py) start_line:51 end_line:58
-
-
-### Code Comments
-
-- Explain WHY, not WHAT
-- Document non-obvious patterns (e.g., why `requests` not `httpx` for ClinicalTrials)
-- Mark critical sections: `# CRITICAL: ...`
-- Document rate limiting rationale
-- Explain async patterns when non-obvious
-
-## Prompt Engineering & Citation Validation
-
-### Judge Prompts
-
-- System prompt in `src/prompts/judge.py`
-- Format evidence with truncation (1500 chars per item)
-- Handle empty evidence case separately
-- Always request structured JSON output
-- Use `format_user_prompt()` and `format_empty_evidence_prompt()` helpers
-
-### Hypothesis Prompts
-
-- Use diverse evidence selection (MMR algorithm)
-- Sentence-aware truncation (`truncate_at_sentence()`)
-- Format: Drug → Target → Pathway → Effect
-- System prompt emphasizes mechanistic reasoning
-- Use `format_hypothesis_prompt()` with embeddings for diversity
-
-### Report Prompts
-
-- Include full citation details for validation
-- Use diverse evidence selection (n=20)
-- **CRITICAL**: Emphasize citation validation rules
-- Format hypotheses with support/contradiction counts
-- System prompt includes explicit JSON structure requirements
-
-### Citation Validation
-
-- **ALWAYS** validate references before returning reports
-- Use `validate_references()` from `src/utils/citation_validator.py`
-- Remove hallucinated citations (URLs not in evidence)
-- Log warnings for removed citations
-- Never trust LLM-generated citations without validation
-
-### Citation Validation Rules
-
-1. Every reference URL must EXACTLY match a provided evidence URL
-2. Do NOT invent, fabricate, or hallucinate any references
-3. Do NOT modify paper titles, authors, dates, or URLs
-4. If unsure about a citation, OMIT it rather than guess
-5. Copy URLs exactly as provided - do not create similar-looking URLs
-
-### Evidence Selection
-
-- Use `select_diverse_evidence()` for MMR-based selection
-- Balance relevance vs diversity (lambda=0.7 default)
-- Sentence-aware truncation preserves meaning
-- Limit evidence per prompt to avoid context overflow
-
-## MCP Integration
-
-### MCP Tools
-
-- Functions in `src/mcp_tools.py` for Claude Desktop
-- Full type hints required
-- Google-style docstrings with Args/Returns sections
-- Formatted string returns (markdown)
-
-### Gradio MCP Server
-
-- Enable with `mcp_server=True` in `demo.launch()`
-- Endpoint: `/gradio_api/mcp/`
-- Use `ssr_mode=False` to fix hydration issues in HF Spaces
-
-## Common Pitfalls
-
-1. **Blocking the event loop**: Never use sync I/O in async functions
-2. **Missing type hints**: All functions must have complete type annotations
-3. **Hallucinated citations**: Always validate references
-4. **Global mutable state**: Use ContextVar or pass via parameters
-5. **Import errors**: Lazy-load optional dependencies (magentic, modal, embeddings)
-6. **Rate limiting**: Always implement for external APIs
-7. **Error chaining**: Always use `from e` when raising exceptions
-
-## Key Principles
-
-1. **Type Safety First**: All code must pass `mypy --strict`
-2. **Async Everything**: All I/O must be async
-3. **Test-Driven**: Write tests before implementation
-4. **No Hallucinations**: Validate all citations
-5. **Graceful Degradation**: Support free tier (HF Inference) when no API keys
-6. **Lazy Loading**: Don't require optional dependencies at import time
-7. **Structured Logging**: Use structlog, never print()
-8. **Error Chaining**: Always preserve exception context
-
-## Pull Request Process
-
-1. Ensure all checks pass: `uv run ruff check src tests && uv run mypy src && uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire`
-2. Update documentation if needed
-3. Add tests for new features
-4. Update CHANGELOG if applicable
-5. Request review from maintainers
-6. Address review feedback
-7. Wait for approval before merging
-
-## Project Structure
-
-- `src/`: Main source code
-- `tests/`: Test files (`unit/` and `integration/`)
-- `docs/`: Documentation source files (MkDocs)
-- `examples/`: Example usage scripts
-- `pyproject.toml`: Project configuration and dependencies
-- `.pre-commit-config.yaml`: Pre-commit hook configuration
-
-## Questions?
-
-- Open an issue on [GitHub](https://github.com/DeepCritical/GradioDemo)
-- Check existing [documentation](https://deepcritical.github.io/GradioDemo/)
-- Review code examples in the codebase
-
-Thank you for contributing to The DETERMINATOR!
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index 9d6fc14dce9d1bdbc102a1479304490324313167..0000000000000000000000000000000000000000
--- a/Dockerfile
+++ /dev/null
@@ -1,52 +0,0 @@
-# Dockerfile for DeepCritical
-FROM python:3.11-slim
-
-# Set working directory
-WORKDIR /app
-
-# Install system dependencies (curl needed for HEALTHCHECK)
-RUN apt-get update && apt-get install -y \
- git \
- curl \
- && rm -rf /var/lib/apt/lists/*
-
-# Install uv
-RUN pip install uv==0.5.4
-
-# Copy project files
-COPY pyproject.toml .
-COPY uv.lock .
-COPY src/ src/
-COPY README.md .
-
-# Install runtime dependencies only (no dev/test tools)
-RUN uv sync --frozen --no-dev --extra embeddings --extra magentic
-
-# Create non-root user BEFORE downloading models
-RUN useradd --create-home --shell /bin/bash appuser
-
-# Set cache directory for HuggingFace models (must be writable by appuser)
-ENV HF_HOME=/app/.cache
-ENV TRANSFORMERS_CACHE=/app/.cache
-
-# Create cache dir with correct ownership
-RUN mkdir -p /app/.cache && chown -R appuser:appuser /app/.cache
-
-# Pre-download the embedding model during build (as appuser to set correct ownership)
-USER appuser
-RUN uv run python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
-
-# Expose port
-EXPOSE 7860
-
-# Health check
-HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
- CMD curl -f http://localhost:7860/ || exit 1
-
-# Set environment variables
-ENV GRADIO_SERVER_NAME=0.0.0.0
-ENV GRADIO_SERVER_PORT=7860
-ENV PYTHONPATH=/app
-
-# Run the app
-CMD ["uv", "run", "python", "-m", "src.app"]
diff --git a/LICENSE.md b/LICENSE.md
deleted file mode 100644
index a1f9be9c2733fb22fc43dfc4e5f23c62dbfb02ad..0000000000000000000000000000000000000000
--- a/LICENSE.md
+++ /dev/null
@@ -1,25 +0,0 @@
-# License
-
-DeepCritical is licensed under the MIT License.
-
-## MIT License
-
-Copyright (c) 2024 DeepCritical Team
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/README.md b/README.md
index a4c0e90eae994942af1b1ea0f5cdfe967a29b8f5..e4c9a9698223d89be48e808d1f511cc4a2182141 100644
--- a/README.md
+++ b/README.md
@@ -1,63 +1,15 @@
---
-title: The DETERMINATOR
-emoji: 🐉
-colorFrom: red
-colorTo: yellow
+title: DeepCritical
+emoji: 📈
+colorFrom: blue
+colorTo: purple
sdk: gradio
-sdk_version: "6.0.1"
-python_version: "3.11"
+sdk_version: 6.0.0
app_file: src/app.py
-hf_oauth: true
-hf_oauth_expiration_minutes: 480
-hf_oauth_scopes:
- # Required for HuggingFace Inference API (includes all third-party providers)
- # This scope grants access to:
- # - HuggingFace's own Inference API
- # - Third-party inference providers (nebius, together, scaleway, hyperbolic, novita, nscale, sambanova, ovh, fireworks, etc.)
- # - All models available through the Inference Providers API
- - inference-api
- # Optional: Uncomment if you need to access user's billing information
- # - read-billing
-pinned: true
+pinned: false
license: mit
-tags:
- - mcp-in-action-track-enterprise
- - mcp-hackathon
- - deep-research
- - biomedical-ai
- - pydantic-ai
- - llamaindex
- - modal
- - building-mcp-track-enterprise
- - building-mcp-track-consumer
- - mcp-in-action-track-enterprise
- - mcp-in-action-track-consumer
- - building-mcp-track-modal
- - building-mcp-track-blaxel
- - building-mcp-track-llama-index
- - building-mcp-track-HUGGINGFACE
+short_description: Deep Search for Critical Research [BigData] -> [Actionable]
---
-> [!IMPORTANT]
-> **You are reading the Gradio Demo README!**
->
-> - 📚 **Documentation**: See our [technical documentation](https://deepcritical.github.io/GradioDemo/) for detailed information
-> - 📖 **Complete README**: Check out the [Github README](.github/README.md) for setup, configuration, and contribution guidelines
-> - ⚠️**This README is for our Gradio Demo Only !**
+### DeepCritical
-
-
-[](https://github.com/DeepCritical/GradioDemo)
-[](deepcritical.github.io/GradioDemo/)
-[](https://huggingface.co/spaces/DataQuests/DeepCritical)
-[](https://codecov.io/gh/DeepCritical/GradioDemo)
-[](https://discord.gg/qdfnvSPcqP)
-
-
-
-
-# The DETERMINATOR
-
-## About
-
-The DETERMINATOR is a powerful generalist deep research agent system that stops at nothing until finding precise answers to complex questions. It uses iterative search-and-judge loops to comprehensively investigate any research question from any domain.
diff --git a/deployments/README.md b/deployments/README.md
deleted file mode 100644
index 3a4f4a4d8d7a3cccf6beacd56ce135117f3ad07a..0000000000000000000000000000000000000000
--- a/deployments/README.md
+++ /dev/null
@@ -1,46 +0,0 @@
-# Deployments
-
-This directory contains infrastructure deployment scripts for DeepCritical services.
-
-## Modal Deployments
-
-### TTS Service (`modal_tts.py`)
-
-Deploys the Kokoro TTS (Text-to-Speech) function to Modal's GPU infrastructure.
-
-**Deploy:**
-```bash
-modal deploy deployments/modal_tts.py
-```
-
-**Features:**
-- Kokoro 82M TTS model
-- GPU-accelerated (T4)
-- Voice options: af_heart, af_bella, am_michael, etc.
-- Configurable speech speed
-
-**Requirements:**
-- Modal account and credentials (`MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET` in `.env`)
-- GPU quota on Modal
-
-**After Deployment:**
-The function will be available at:
-- App: `deepcritical-tts`
-- Function: `kokoro_tts_function`
-
-The main application (`src/services/tts_modal.py`) will call this deployed function.
-
----
-
-## Adding New Deployments
-
-When adding new deployment scripts:
-
-1. Create a new file: `deployments/.py`
-2. Use Modal's app pattern:
- ```python
- import modal
- app = modal.App("deepcritical-")
- ```
-3. Document in this README
-4. Test deployment: `modal deploy deployments/.py`
diff --git a/deployments/modal_tts.py b/deployments/modal_tts.py
deleted file mode 100644
index 9987a339f6b89eb63cd512eb594dd6a6d488f42a..0000000000000000000000000000000000000000
--- a/deployments/modal_tts.py
+++ /dev/null
@@ -1,97 +0,0 @@
-"""Deploy Kokoro TTS function to Modal.
-
-This script deploys the TTS function to Modal so it can be called
-from the main DeepCritical application.
-
-Usage:
- modal deploy deploy_modal_tts.py
-
-After deployment, the function will be available at:
- App: deepcritical-tts
- Function: kokoro_tts_function
-"""
-
-import modal
-import numpy as np
-
-# Create Modal app
-app = modal.App("deepcritical-tts")
-
-# Define Kokoro TTS dependencies
-KOKORO_DEPENDENCIES = [
- "torch>=2.0.0",
- "transformers>=4.30.0",
- "numpy<2.0",
-]
-
-# Create Modal image with Kokoro
-tts_image = (
- modal.Image.debian_slim(python_version="3.11")
- .apt_install("git") # Install git first for pip install from github
- .pip_install(*KOKORO_DEPENDENCIES)
- .pip_install("git+https://github.com/hexgrad/kokoro.git")
-)
-
-
-@app.function(
- image=tts_image,
- gpu="T4",
- timeout=60,
-)
-def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, np.ndarray]:
- """Modal GPU function for Kokoro TTS.
-
- This function runs on Modal's GPU infrastructure.
- Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS
-
- Args:
- text: Text to synthesize
- voice: Voice ID (e.g., af_heart, af_bella, am_michael)
- speed: Speech speed multiplier (0.5-2.0)
-
- Returns:
- Tuple of (sample_rate, audio_array)
- """
- import numpy as np
-
- try:
- import torch
- from kokoro import KModel, KPipeline
-
- # Initialize model (cached on GPU)
- model = KModel().to("cuda").eval()
- pipeline = KPipeline(lang_code=voice[0])
- pack = pipeline.load_voice(voice)
-
- # Generate audio - accumulate all chunks
- audio_chunks = []
- for _, ps, _ in pipeline(text, voice, speed):
- ref_s = pack[len(ps) - 1]
- audio = model(ps, ref_s, speed)
- audio_chunks.append(audio.numpy())
-
- # Concatenate all audio chunks
- if audio_chunks:
- full_audio = np.concatenate(audio_chunks)
- return (24000, full_audio)
-
- # If no audio generated, return empty
- return (24000, np.zeros(1, dtype=np.float32))
-
- except ImportError as e:
- raise RuntimeError(
- f"Kokoro not installed: {e}. "
- "Install with: pip install git+https://github.com/hexgrad/kokoro.git"
- ) from e
- except Exception as e:
- raise RuntimeError(f"TTS synthesis failed: {e}") from e
-
-
-# Optional: Add a test entrypoint
-@app.local_entrypoint()
-def test():
- """Test the TTS function."""
- print("Testing Modal TTS function...")
- sample_rate, audio = kokoro_tts_function.remote("Hello, this is a test.", "af_heart", 1.0)
- print(f"Generated audio: {sample_rate}Hz, shape={audio.shape}")
- print("✓ TTS function works!")
diff --git a/pyproject.toml b/pyproject.toml
deleted file mode 100644
index 6f59cac5edf5e957ee88a6fec36ad45db3040bd0..0000000000000000000000000000000000000000
--- a/pyproject.toml
+++ /dev/null
@@ -1,205 +0,0 @@
-[project]
-name = "determinator"
-version = "0.1.0"
-description = "The DETERMINATOR - the Deep Research Agent that Stops at Nothing"
-readme = "README.md"
-requires-python = ">=3.11"
-dependencies = [
- "pydantic>=2.7",
- "pydantic-settings>=2.2", # For BaseSettings (config)
- "pydantic-ai>=0.0.16", # Agent framework
- "openai>=1.0.0",
- "anthropic>=0.18.0",
- "httpx>=0.27", # Async HTTP client (PubMed)
- "beautifulsoup4>=4.12", # HTML parsing
- "xmltodict>=0.13", # PubMed XML -> dict
- "huggingface-hub>=0.20.0", # Hugging Face Inference API
- "gradio[mcp,oauth]>=6.0.0", # Chat interface with MCP server support (6.0 required for css in launch())
- "python-dotenv>=1.0", # .env loading
- "tenacity>=8.2", # Retry logic
- "structlog>=24.1", # Structured logging
- "requests>=2.32.5", # ClinicalTrials.gov (httpx blocked by WAF)
- "pydantic-graph>=1.22.0",
- "limits>=3.0", # Web search
- "llama-index-llms-huggingface>=0.6.1",
- "llama-index-llms-huggingface-api>=0.6.1",
- "llama-index-vector-stores-chroma>=0.5.3",
- "llama-index>=0.14.8",
- "gradio-client>=1.0.0", # For STT/OCR API calls
- "soundfile>=0.12.0", # For audio file I/O
- "pillow>=10.0.0", # For image processing
- "torch>=2.0.0", # Required by Kokoro TTS
- "transformers>=4.57.2", # Required by Kokoro TTS
- "modal>=0.63.0", # Required for TTS GPU execution
- "tokenizers>=0.22.0,<=0.23.0",
- "rpds-py>=0.29.0",
- "pydantic-ai-slim[huggingface]>=0.0.18",
- "agent-framework-core>=1.0.0b251120,<2.0.0",
- "chromadb>=0.4.0",
- "sentence-transformers>=2.2.0",
- "numpy<2.0",
- "llama-index-llms-openai>=0.6.9",
- "llama-index-embeddings-openai>=0.5.1",
- "ddgs>=9.9.2",
- "aiohttp>=3.13.2",
- "lxml>=6.0.2",
- "fake-useragent==2.2.0",
- "socksio==1.0.0",
- "neo4j>=6.0.3",
- "md2pdf>=1.0.1",
-]
-
-[project.optional-dependencies]
-dev = [
- # Testing
- "pytest>=8.0",
- "pytest-asyncio>=0.23",
- "pytest-sugar>=1.0",
- "pytest-cov>=5.0",
- "pytest-mock>=3.12",
- "respx>=0.21", # Mock httpx requests
- "typer>=0.9.0", # Gradio CLI dependency for smoke tests
-
- # Quality
- "ruff>=0.4.0",
- "mypy>=1.10",
- "pre-commit>=3.7",
-
- # Documentation
- "mkdocs>=1.6.0",
- "mkdocs-material>=9.0.0",
- "mkdocs-mermaid2-plugin>=1.1.0",
- "mkdocs-codeinclude-plugin>=0.2.0",
- "mkdocs-git-revision-date-localized-plugin>=1.2.0",
- "mkdocs-minify-plugin>=0.8.0",
- "pymdown-extensions>=10.17.2",
-]
-magentic = [
- "agent-framework-core>=1.0.0b251120,<2.0.0", # Microsoft Agent Framework (PyPI)
-]
-embeddings = [
- "chromadb>=0.4.0",
- "sentence-transformers>=2.2.0",
- "numpy<2.0", # chromadb compatibility: uses np.float_ removed in NumPy 2.0
-]
-modal = [
- # Mario's Modal code execution + LlamaIndex RAG
- # Note: modal>=0.63.0 is now in main dependencies for TTS support
- "llama-index>=0.11.0",
- "llama-index-llms-openai>=0.6.9",
- "llama-index-embeddings-openai>=0.5.1",
- "llama-index-vector-stores-chroma",
- "chromadb>=0.4.0",
- "numpy<2.0", # chromadb compatibility: uses np.float_ removed in NumPy 2.0
-]
-
-[build-system]
-requires = ["hatchling"]
-build-backend = "hatchling.build"
-
-[tool.hatch.build.targets.wheel]
-packages = ["src"]
-
-# ============== RUFF CONFIG ==============
-[tool.ruff]
-line-length = 100
-target-version = "py311"
-src = ["src"]
-exclude = [
- "tests/",
- "examples/",
- "reference_repos/",
- "folder/",
-]
-
-[tool.ruff.lint]
-select = [
- "E", # pycodestyle errors
- "F", # pyflakes
- "B", # flake8-bugbear
- "I", # isort
- "N", # pep8-naming
- "UP", # pyupgrade
- "PL", # pylint
- "RUF", # ruff-specific
-]
-ignore = [
- "PLR0913", # Too many arguments (agents need many params)
- "PLR0912", # Too many branches (complex orchestrator logic)
- "PLR0911", # Too many return statements (complex agent logic)
- "PLR0915", # Too many statements (Gradio UI setup functions)
- "PLR2004", # Magic values (statistical constants like p-values)
- "PLW0603", # Global statement (singleton pattern for Modal)
- "PLC0415", # Lazy imports for optional dependencies
- "E402", # Module level import not at top (needed for pytest.importorskip)
- "E501", # Line too long (ignore line length violations)
- "RUF100", # Unused noqa (version differences between local/CI)
-]
-
-[tool.ruff.lint.isort]
-known-first-party = ["src"]
-
-# ============== MYPY CONFIG ==============
-[tool.mypy]
-python_version = "3.11"
-strict = true
-ignore_missing_imports = true
-disallow_untyped_defs = true
-warn_return_any = true
-warn_unused_ignores = false
-explicit_package_bases = true
-mypy_path = "."
-exclude = [
- "^reference_repos/",
- "^examples/",
- "^folder/",
- "^src/app.py",
-]
-
-# ============== PYTEST CONFIG ==============
-[tool.pytest.ini_options]
-testpaths = ["tests"]
-asyncio_mode = "auto"
-addopts = [
- "-v",
- "--tb=short",
- "--strict-markers",
- "-p",
- "no:logfire",
-]
-markers = [
- "unit: Unit tests (mocked)",
- "integration: Integration tests (real APIs)",
- "slow: Slow tests",
- "openai: Tests that require OpenAI API key",
- "huggingface: Tests that require HuggingFace API key or use HuggingFace models",
- "embedding_provider: Tests that require API-based embedding providers (OpenAI, etc.)",
- "local_embeddings: Tests that use local embeddings (sentence-transformers, ChromaDB)",
-]
-
-# ============== COVERAGE CONFIG ==============
-[tool.coverage.run]
-source = ["src"]
-omit = ["*/__init__.py"]
-
-[tool.coverage.report]
-exclude_lines = [
- "pragma: no cover",
- "if TYPE_CHECKING:",
- "raise NotImplementedError",
-]
-
-[dependency-groups]
-dev = [
- "mkdocs>=1.6.1",
- "mkdocs-codeinclude-plugin>=0.2.1",
- "mkdocs-material>=9.7.0",
- "mkdocs-mermaid2-plugin>=1.2.3",
- "mkdocs-git-revision-date-localized-plugin>=1.2.0",
- "mkdocs-minify-plugin>=0.8.0",
- "structlog>=25.5.0",
- "ty>=0.0.1a28",
-]
-
-# Note: agent-framework-core is optional for magentic mode (multi-agent orchestration)
-# Version pinned to 1.0.0b* to avoid breaking changes. CI skips tests via pytest.importorskip
diff --git a/requirements.txt b/requirements.txt
index cf1b66bb960423122615ba3340799ef6e9d92435..3dfd393b3eb853eff9cee0d4e8ad62f73ba27ff0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,92 +1 @@
-##########################
-# DO NOT USE THIS FILE
-# FOR GRADIO DEMO ONLY
-##########################
-
-
-#Core dependencies for HuggingFace Spaces
-pydantic>=2.7
-pydantic-settings>=2.2
-pydantic-ai>=0.0.16
-
-# OPTIONAL AI Providers
-openai>=1.0.0
-anthropic>=0.18.0
-
-# HTTP & Parsing
-httpx>=0.27
-aiohttp>=3.13.2 # Required for website crawling
-beautifulsoup4>=4.12
-lxml>=6.0.2 # Required for BeautifulSoup lxml parser (faster than html.parser)
-xmltodict>=0.13
-
-# HuggingFace Hub
-huggingface-hub>=0.20.0
-
-# UI (Gradio with MCP server support)
-gradio[mcp,oauth]>=6.0.0
-
-# Utils
-python-dotenv>=1.0
-tenacity>=8.2
-structlog>=24.1
-requests>=2.32.5
-limits>=3.0 # Rate limiting
-pydantic-graph>=1.22.0
-
-# Web search
-ddgs>=9.9.2 # duckduckgo-search has been renamed to ddgs
-fake-useragent==2.2.0
-socksio==1.0.0
-# LlamaIndex RAG
-llama-index-llms-huggingface>=0.6.1
-llama-index-llms-huggingface-api>=0.6.1
-llama-index-vector-stores-chroma>=0.5.3
-llama-index>=0.14.8
-
-# Audio/Image processing
-gradio-client>=1.0.0 # For STT/OCR API calls
-soundfile>=0.12.0 # For audio file I/O
-pillow>=10.0.0 # For image processing
-
-# TTS dependencies (for Modal GPU TTS)
-torch>=2.0.0 # Required by Kokoro TTS
-transformers>=4.57.2 # Required by Kokoro TTS
-modal>=0.63.0 # Required for TTS GPU execution
-# Note: Kokoro is installed in Modal image from: git+https://github.com/hexgrad/kokoro.git
-
-# Embeddings & Vector Store
-tokenizers>=0.22.0,<=0.23.0
-rpds-py>=0.29.0 # Python implementation of rpds (required by chromadb on Windows)
-chromadb>=0.4.0
-sentence-transformers>=2.2.0
-numpy<2.0 # chromadb compatibility: uses np.float_ removed in NumPy 2.0
-neo4j>=6.0.3
-
-### DOCUMENT STUFF
-
-cssselect2==0.8.0
-docopt==0.6.2
-fonttools==4.61.0
-markdown2==2.5.4
-md2pdf==1.0.1
-pydyf==0.11.0
-pyphen==0.17.2
-tinycss2==1.5.1
-tinyhtml5==2.0.0
-weasyprint==66.0
-webencodings==0.5.1
-zopfli==0.4.0
-
-# Optional: Modal for code execution
-modal>=0.63.0
-
-# Pydantic AI with HuggingFace support
-pydantic-ai-slim[huggingface]>=0.0.18
-
-# Multi-agent orchestration (Advanced mode)
-agent-framework-core>=1.0.0b251120,<2.0.0
-
-# LlamaIndex RAG - OpenAI
-llama-index-llms-openai>=0.6.9
-llama-index-embeddings-openai>=0.5.1
+deepcritical
\ No newline at end of file
diff --git a/src/agent_factory/agents.py b/src/agent_factory/agents.py
index c676140e509d833431b9b280b5f2695de82ec181..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/src/agent_factory/agents.py
+++ b/src/agent_factory/agents.py
@@ -1,361 +0,0 @@
-"""Agent factory functions for creating research agents.
-
-Provides factory functions for creating all Pydantic AI agents used in
-the research workflows, following the pattern from judges.py.
-"""
-
-from typing import TYPE_CHECKING, Any
-
-import structlog
-
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-
-if TYPE_CHECKING:
- from src.agent_factory.graph_builder import GraphBuilder
- from src.agents.input_parser import InputParserAgent
- from src.agents.knowledge_gap import KnowledgeGapAgent
- from src.agents.long_writer import LongWriterAgent
- from src.agents.proofreader import ProofreaderAgent
- from src.agents.thinking import ThinkingAgent
- from src.agents.tool_selector import ToolSelectorAgent
- from src.agents.writer import WriterAgent
- from src.orchestrator.graph_orchestrator import GraphOrchestrator
- from src.orchestrator.planner_agent import PlannerAgent
- from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
-
-logger = structlog.get_logger()
-
-
-def create_input_parser_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> "InputParserAgent":
- """
- Create input parser agent for query analysis and research mode detection.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured InputParserAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- from src.agents.input_parser import create_input_parser_agent as _create_agent
-
- try:
- logger.debug("Creating input parser agent")
- return _create_agent(model=model, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create input parser agent", error=str(e))
- raise ConfigurationError(f"Failed to create input parser agent: {e}") from e
-
-
-def create_planner_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> "PlannerAgent":
- """
- Create planner agent with web search and crawl tools.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured PlannerAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- # Lazy import to avoid circular dependencies
- from src.orchestrator.planner_agent import create_planner_agent as _create_planner_agent
-
- try:
- logger.debug("Creating planner agent")
- return _create_planner_agent(model=model, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create planner agent", error=str(e))
- raise ConfigurationError(f"Failed to create planner agent: {e}") from e
-
-
-def create_knowledge_gap_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> "KnowledgeGapAgent":
- """
- Create knowledge gap agent for evaluating research completeness.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured KnowledgeGapAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- from src.agents.knowledge_gap import create_knowledge_gap_agent as _create_agent
-
- try:
- logger.debug("Creating knowledge gap agent")
- return _create_agent(model=model, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create knowledge gap agent", error=str(e))
- raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e
-
-
-def create_tool_selector_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> "ToolSelectorAgent":
- """
- Create tool selector agent for choosing tools to address gaps.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured ToolSelectorAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- from src.agents.tool_selector import create_tool_selector_agent as _create_agent
-
- try:
- logger.debug("Creating tool selector agent")
- return _create_agent(model=model, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create tool selector agent", error=str(e))
- raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e
-
-
-def create_thinking_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> "ThinkingAgent":
- """
- Create thinking agent for generating observations.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured ThinkingAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- from src.agents.thinking import create_thinking_agent as _create_agent
-
- try:
- logger.debug("Creating thinking agent")
- return _create_agent(model=model, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create thinking agent", error=str(e))
- raise ConfigurationError(f"Failed to create thinking agent: {e}") from e
-
-
-def create_writer_agent(model: Any | None = None, oauth_token: str | None = None) -> "WriterAgent":
- """
- Create writer agent for generating final reports.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured WriterAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- from src.agents.writer import create_writer_agent as _create_agent
-
- try:
- logger.debug("Creating writer agent")
- return _create_agent(model=model, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create writer agent", error=str(e))
- raise ConfigurationError(f"Failed to create writer agent: {e}") from e
-
-
-def create_long_writer_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> "LongWriterAgent":
- """
- Create long writer agent for iteratively writing report sections.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured LongWriterAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- from src.agents.long_writer import create_long_writer_agent as _create_agent
-
- try:
- logger.debug("Creating long writer agent")
- return _create_agent(model=model, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create long writer agent", error=str(e))
- raise ConfigurationError(f"Failed to create long writer agent: {e}") from e
-
-
-def create_proofreader_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> "ProofreaderAgent":
- """
- Create proofreader agent for finalizing report drafts.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured ProofreaderAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- from src.agents.proofreader import create_proofreader_agent as _create_agent
-
- try:
- logger.debug("Creating proofreader agent")
- return _create_agent(model=model, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create proofreader agent", error=str(e))
- raise ConfigurationError(f"Failed to create proofreader agent: {e}") from e
-
-
-def create_iterative_flow(
- max_iterations: int = 5,
- max_time_minutes: int = 10,
- verbose: bool = True,
- use_graph: bool | None = None,
-) -> "IterativeResearchFlow":
- """
- Create iterative research flow.
-
- Args:
- max_iterations: Maximum number of iterations
- max_time_minutes: Maximum time in minutes
- verbose: Whether to log progress
- use_graph: Whether to use graph execution. If None, reads from settings.use_graph_execution
-
- Returns:
- Configured IterativeResearchFlow instance
- """
- from src.orchestrator.research_flow import IterativeResearchFlow
-
- try:
- # Use settings default if not explicitly provided
- if use_graph is None:
- use_graph = settings.use_graph_execution
-
- logger.debug("Creating iterative research flow", use_graph=use_graph)
- return IterativeResearchFlow(
- max_iterations=max_iterations,
- max_time_minutes=max_time_minutes,
- verbose=verbose,
- use_graph=use_graph,
- )
- except Exception as e:
- logger.error("Failed to create iterative flow", error=str(e))
- raise ConfigurationError(f"Failed to create iterative flow: {e}") from e
-
-
-def create_deep_flow(
- max_iterations: int = 5,
- max_time_minutes: int = 10,
- verbose: bool = True,
- use_long_writer: bool = True,
- use_graph: bool | None = None,
-) -> "DeepResearchFlow":
- """
- Create deep research flow.
-
- Args:
- max_iterations: Maximum iterations per section
- max_time_minutes: Maximum time per section
- verbose: Whether to log progress
- use_long_writer: Whether to use long writer (True) or proofreader (False)
- use_graph: Whether to use graph execution. If None, reads from settings.use_graph_execution
-
- Returns:
- Configured DeepResearchFlow instance
- """
- from src.orchestrator.research_flow import DeepResearchFlow
-
- try:
- # Use settings default if not explicitly provided
- if use_graph is None:
- use_graph = settings.use_graph_execution
-
- logger.debug("Creating deep research flow", use_graph=use_graph)
- return DeepResearchFlow(
- max_iterations=max_iterations,
- max_time_minutes=max_time_minutes,
- verbose=verbose,
- use_long_writer=use_long_writer,
- use_graph=use_graph,
- )
- except Exception as e:
- logger.error("Failed to create deep flow", error=str(e))
- raise ConfigurationError(f"Failed to create deep flow: {e}") from e
-
-
-def create_graph_orchestrator(
- mode: str = "auto",
- max_iterations: int = 5,
- max_time_minutes: int = 10,
- use_graph: bool = True,
-) -> "GraphOrchestrator":
- """
- Create graph orchestrator.
-
- Args:
- mode: Research mode ("iterative", "deep", or "auto")
- max_iterations: Maximum iterations per loop
- max_time_minutes: Maximum time per loop
- use_graph: Whether to use graph execution (True) or agent chains (False)
-
- Returns:
- Configured GraphOrchestrator instance
- """
- from src.orchestrator.graph_orchestrator import create_graph_orchestrator as _create
-
- try:
- logger.debug("Creating graph orchestrator", mode=mode, use_graph=use_graph)
- return _create(
- mode=mode, # type: ignore[arg-type]
- max_iterations=max_iterations,
- max_time_minutes=max_time_minutes,
- use_graph=use_graph,
- )
- except Exception as e:
- logger.error("Failed to create graph orchestrator", error=str(e))
- raise ConfigurationError(f"Failed to create graph orchestrator: {e}") from e
-
-
-def create_graph_builder() -> "GraphBuilder":
- """
- Create a graph builder instance.
-
- Returns:
- GraphBuilder instance
- """
- from src.agent_factory.graph_builder import GraphBuilder
-
- try:
- logger.debug("Creating graph builder")
- return GraphBuilder()
- except Exception as e:
- logger.error("Failed to create graph builder", error=str(e))
- raise ConfigurationError(f"Failed to create graph builder: {e}") from e
diff --git a/src/agent_factory/graph_builder.py b/src/agent_factory/graph_builder.py
deleted file mode 100644
index c06c1b0ac0d38a943f957bc100bef6cdfc09f390..0000000000000000000000000000000000000000
--- a/src/agent_factory/graph_builder.py
+++ /dev/null
@@ -1,635 +0,0 @@
-"""Graph builder utilities for constructing research workflow graphs.
-
-Provides classes and utilities for building graph-based orchestration systems
-using Pydantic AI agents as nodes.
-"""
-
-from collections.abc import Callable
-from typing import TYPE_CHECKING, Any, Literal
-
-import structlog
-from pydantic import BaseModel, Field
-
-if TYPE_CHECKING:
- from pydantic_ai import Agent
-
- from src.middleware.state_machine import WorkflowState
-
-logger = structlog.get_logger()
-
-
-# ============================================================================
-# Graph Node Models
-# ============================================================================
-
-
-class GraphNode(BaseModel):
- """Base class for graph nodes."""
-
- node_id: str = Field(description="Unique identifier for the node")
- node_type: Literal["agent", "state", "decision", "parallel"] = Field(description="Type of node")
- description: str = Field(default="", description="Human-readable description of the node")
-
- model_config = {"frozen": True}
-
-
-class AgentNode(GraphNode):
- """Node that executes a Pydantic AI agent."""
-
- node_type: Literal["agent"] = "agent"
- agent: Any = Field(description="Pydantic AI agent to execute")
- input_transformer: Callable[[Any], Any] | None = Field(
- default=None, description="Transform input before passing to agent"
- )
- output_transformer: Callable[[Any], Any] | None = Field(
- default=None, description="Transform output after agent execution"
- )
-
- model_config = {"arbitrary_types_allowed": True}
-
-
-class StateNode(GraphNode):
- """Node that updates or reads workflow state."""
-
- node_type: Literal["state"] = "state"
- state_updater: Callable[[Any, Any], Any] = Field(
- description="Function to update workflow state"
- )
- state_reader: Callable[[Any], Any] | None = Field(
- default=None, description="Function to read state (optional)"
- )
-
- model_config = {"arbitrary_types_allowed": True}
-
-
-class DecisionNode(GraphNode):
- """Node that makes routing decisions based on conditions."""
-
- node_type: Literal["decision"] = "decision"
- decision_function: Callable[[Any], str] = Field(
- description="Function that returns next node ID based on input"
- )
- options: list[str] = Field(description="List of possible next node IDs", min_length=1)
-
- model_config = {"arbitrary_types_allowed": True}
-
-
-class ParallelNode(GraphNode):
- """Node that executes multiple nodes in parallel."""
-
- node_type: Literal["parallel"] = "parallel"
- parallel_nodes: list[str] = Field(
- description="List of node IDs to run in parallel", min_length=0
- )
- aggregator: Callable[[list[Any]], Any] | None = Field(
- default=None, description="Function to aggregate parallel results"
- )
-
- model_config = {"arbitrary_types_allowed": True}
-
-
-# ============================================================================
-# Graph Edge Models
-# ============================================================================
-
-
-class GraphEdge(BaseModel):
- """Base class for graph edges."""
-
- from_node: str = Field(description="Source node ID")
- to_node: str = Field(description="Target node ID")
- condition: Callable[[Any], bool] | None = Field(
- default=None, description="Optional condition function"
- )
- weight: float = Field(default=1.0, description="Edge weight for routing decisions")
-
- model_config = {"arbitrary_types_allowed": True}
-
-
-class SequentialEdge(GraphEdge):
- """Edge that is always traversed (no condition)."""
-
- condition: None = None
-
-
-class ConditionalEdge(GraphEdge):
- """Edge that is traversed based on a condition."""
-
- condition: Callable[[Any], bool] = Field(description="Required condition function")
- condition_description: str = Field(
- default="", description="Human-readable description of condition"
- )
-
-
-class ParallelEdge(GraphEdge):
- """Edge used for parallel execution branches."""
-
- condition: None = None
-
-
-# ============================================================================
-# Research Graph Class
-# ============================================================================
-
-
-class ResearchGraph(BaseModel):
- """Represents a research workflow graph with nodes and edges."""
-
- nodes: dict[str, GraphNode] = Field(default_factory=dict, description="All nodes in the graph")
- edges: dict[str, list[GraphEdge]] = Field(
- default_factory=dict, description="Edges by source node ID"
- )
- entry_node: str = Field(description="Starting node ID")
- exit_nodes: list[str] = Field(default_factory=list, description="Terminal node IDs")
-
- model_config = {"arbitrary_types_allowed": True}
-
- def add_node(self, node: GraphNode) -> None:
- """Add a node to the graph.
-
- Args:
- node: The node to add
-
- Raises:
- ValueError: If node ID already exists
- """
- if node.node_id in self.nodes:
- raise ValueError(f"Node {node.node_id} already exists in graph")
- self.nodes[node.node_id] = node
- logger.debug("Node added to graph", node_id=node.node_id, type=node.node_type)
-
- def add_edge(self, edge: GraphEdge) -> None:
- """Add an edge to the graph.
-
- Args:
- edge: The edge to add
-
- Raises:
- ValueError: If source or target node doesn't exist
- """
- if edge.from_node not in self.nodes:
- raise ValueError(f"Source node {edge.from_node} not found in graph")
- if edge.to_node not in self.nodes:
- raise ValueError(f"Target node {edge.to_node} not found in graph")
-
- if edge.from_node not in self.edges:
- self.edges[edge.from_node] = []
- self.edges[edge.from_node].append(edge)
- logger.debug(
- "Edge added to graph",
- from_node=edge.from_node,
- to_node=edge.to_node,
- )
-
- def get_node(self, node_id: str) -> GraphNode | None:
- """Get a node by ID.
-
- Args:
- node_id: The node ID
-
- Returns:
- The node, or None if not found
- """
- return self.nodes.get(node_id)
-
- def get_next_nodes(self, node_id: str, context: Any = None) -> list[tuple[str, GraphEdge]]:
- """Get all possible next nodes from a given node.
-
- Args:
- node_id: The current node ID
- context: Optional context for evaluating conditions
-
- Returns:
- List of (node_id, edge) tuples for valid next nodes
- """
- if node_id not in self.edges:
- return []
-
- next_nodes = []
- for edge in self.edges[node_id]:
- # Evaluate condition if present
- if edge.condition is None or edge.condition(context):
- next_nodes.append((edge.to_node, edge))
-
- return next_nodes
-
- def validate_structure(self) -> list[str]:
- """Validate the graph structure.
-
- Returns:
- List of validation error messages (empty if valid)
- """
- errors = []
-
- # Check entry node exists
- if self.entry_node not in self.nodes:
- errors.append(f"Entry node {self.entry_node} not found in graph")
-
- # Check exit nodes exist and at least one is defined
- if not self.exit_nodes:
- errors.append("At least one exit node must be defined")
- for exit_node in self.exit_nodes:
- if exit_node not in self.nodes:
- errors.append(f"Exit node {exit_node} not found in graph")
-
- # Check all edges reference valid nodes
- for from_node, edge_list in self.edges.items():
- if from_node not in self.nodes:
- errors.append(f"Edge source node {from_node} not found")
- for edge in edge_list:
- if edge.to_node not in self.nodes:
- errors.append(f"Edge target node {edge.to_node} not found")
-
- # Check all nodes are reachable from entry node (basic check)
- if self.entry_node in self.nodes:
- reachable = {self.entry_node}
- queue = [self.entry_node]
- while queue:
- current = queue.pop(0)
- for next_node, _ in self.get_next_nodes(current):
- if next_node not in reachable:
- reachable.add(next_node)
- queue.append(next_node)
-
- unreachable = set(self.nodes.keys()) - reachable
- if unreachable:
- errors.append(f"Unreachable nodes from entry node: {', '.join(unreachable)}")
-
- return errors
-
-
-# ============================================================================
-# Graph Builder Class
-# ============================================================================
-
-
-class GraphBuilder:
- """Builder for constructing research workflow graphs."""
-
- def __init__(self) -> None:
- """Initialize the graph builder."""
- self.graph = ResearchGraph(entry_node="", exit_nodes=[])
-
- def add_agent_node(
- self,
- node_id: str,
- agent: "Agent[Any, Any]",
- description: str = "",
- input_transformer: Callable[[Any], Any] | None = None,
- output_transformer: Callable[[Any], Any] | None = None,
- ) -> "GraphBuilder":
- """Add an agent node to the graph.
-
- Args:
- node_id: Unique identifier for the node
- agent: Pydantic AI agent to execute
- description: Human-readable description
- input_transformer: Optional input transformation function
- output_transformer: Optional output transformation function
-
- Returns:
- Self for method chaining
- """
- node = AgentNode(
- node_id=node_id,
- agent=agent,
- description=description,
- input_transformer=input_transformer,
- output_transformer=output_transformer,
- )
- self.graph.add_node(node)
- return self
-
- def add_state_node(
- self,
- node_id: str,
- state_updater: Callable[["WorkflowState", Any], "WorkflowState"],
- description: str = "",
- state_reader: Callable[["WorkflowState"], Any] | None = None,
- ) -> "GraphBuilder":
- """Add a state node to the graph.
-
- Args:
- node_id: Unique identifier for the node
- state_updater: Function to update workflow state
- description: Human-readable description
- state_reader: Optional function to read state
-
- Returns:
- Self for method chaining
- """
- node = StateNode(
- node_id=node_id,
- state_updater=state_updater,
- description=description,
- state_reader=state_reader,
- )
- self.graph.add_node(node)
- return self
-
- def add_decision_node(
- self,
- node_id: str,
- decision_function: Callable[[Any], str],
- options: list[str],
- description: str = "",
- ) -> "GraphBuilder":
- """Add a decision node to the graph.
-
- Args:
- node_id: Unique identifier for the node
- decision_function: Function that returns next node ID
- options: List of possible next node IDs
- description: Human-readable description
-
- Returns:
- Self for method chaining
- """
- node = DecisionNode(
- node_id=node_id,
- decision_function=decision_function,
- options=options,
- description=description,
- )
- self.graph.add_node(node)
- return self
-
- def add_parallel_node(
- self,
- node_id: str,
- parallel_nodes: list[str],
- description: str = "",
- aggregator: Callable[[list[Any]], Any] | None = None,
- ) -> "GraphBuilder":
- """Add a parallel node to the graph.
-
- Args:
- node_id: Unique identifier for the node
- parallel_nodes: List of node IDs to run in parallel
- description: Human-readable description
- aggregator: Optional function to aggregate results
-
- Returns:
- Self for method chaining
- """
- node = ParallelNode(
- node_id=node_id,
- parallel_nodes=parallel_nodes,
- description=description,
- aggregator=aggregator,
- )
- self.graph.add_node(node)
- return self
-
- def connect_nodes(
- self,
- from_node: str,
- to_node: str,
- condition: Callable[[Any], bool] | None = None,
- condition_description: str = "",
- ) -> "GraphBuilder":
- """Connect two nodes with an edge.
-
- Args:
- from_node: Source node ID
- to_node: Target node ID
- condition: Optional condition function
- condition_description: Description of condition (if conditional)
-
- Returns:
- Self for method chaining
- """
- if condition is None:
- edge: GraphEdge = SequentialEdge(from_node=from_node, to_node=to_node)
- else:
- edge = ConditionalEdge(
- from_node=from_node,
- to_node=to_node,
- condition=condition,
- condition_description=condition_description,
- )
- self.graph.add_edge(edge)
- return self
-
- def set_entry_node(self, node_id: str) -> "GraphBuilder":
- """Set the entry node for the graph.
-
- Args:
- node_id: The entry node ID
-
- Returns:
- Self for method chaining
- """
- self.graph.entry_node = node_id
- return self
-
- def set_exit_nodes(self, node_ids: list[str]) -> "GraphBuilder":
- """Set the exit nodes for the graph.
-
- Args:
- node_ids: List of exit node IDs
-
- Returns:
- Self for method chaining
- """
- self.graph.exit_nodes = node_ids
- return self
-
- def build(self) -> ResearchGraph:
- """Finalize graph construction and validate.
-
- Returns:
- The constructed ResearchGraph
-
- Raises:
- ValueError: If graph validation fails
- """
- errors = self.graph.validate_structure()
- if errors:
- error_msg = "Graph validation failed:\n" + "\n".join(f" - {e}" for e in errors)
- logger.error("Graph validation failed", errors=errors)
- raise ValueError(error_msg)
-
- logger.info(
- "Graph built successfully",
- nodes=len(self.graph.nodes),
- edges=sum(len(edges) for edges in self.graph.edges.values()),
- entry_node=self.graph.entry_node,
- exit_nodes=self.graph.exit_nodes,
- )
- return self.graph
-
-
-# ============================================================================
-# Factory Functions
-# ============================================================================
-
-
-def create_iterative_graph(
- knowledge_gap_agent: "Agent[Any, Any]",
- tool_selector_agent: "Agent[Any, Any]",
- thinking_agent: "Agent[Any, Any]",
- writer_agent: "Agent[Any, Any]",
-) -> ResearchGraph:
- """Create a graph for iterative research flow.
-
- Args:
- knowledge_gap_agent: Agent for evaluating knowledge gaps
- tool_selector_agent: Agent for selecting tools
- thinking_agent: Agent for generating observations
- writer_agent: Agent for writing final report
-
- Returns:
- Constructed ResearchGraph for iterative research
- """
- builder = GraphBuilder()
-
- # Add nodes
- builder.add_agent_node("thinking", thinking_agent, "Generate observations")
- builder.add_agent_node("knowledge_gap", knowledge_gap_agent, "Evaluate knowledge gaps")
-
- def _decision_function(result: Any) -> str:
- """Decision function for continue_decision node.
-
- Args:
- result: Result from knowledge_gap node (KnowledgeGapOutput or tuple)
-
- Returns:
- Next node ID: "writer" if research complete, "tool_selector" otherwise
- """
- # Handle case where result might be a tuple (validation error)
- if isinstance(result, tuple):
- # Try to extract research_complete from tuple
- if len(result) == 2 and isinstance(result[0], str) and result[0] == "research_complete":
- # Format: ('research_complete', False)
- return "writer" if result[1] else "tool_selector"
- # Try to find boolean value in tuple
- for item in result:
- if isinstance(item, bool):
- return "writer" if item else "tool_selector"
- elif isinstance(item, dict) and "research_complete" in item:
- return "writer" if item["research_complete"] else "tool_selector"
- # Default to continuing research if we can't determine
- return "tool_selector"
-
- # Normal case: result is KnowledgeGapOutput object
- research_complete = getattr(result, "research_complete", False)
- return "writer" if research_complete else "tool_selector"
-
- builder.add_decision_node(
- "continue_decision",
- decision_function=_decision_function,
- options=["tool_selector", "writer"],
- description="Decide whether to continue research or write report",
- )
- builder.add_agent_node("tool_selector", tool_selector_agent, "Select tools to address gap")
- builder.add_state_node(
- "execute_tools",
- state_updater=lambda state,
- tasks: state, # Placeholder - actual execution handled separately
- description="Execute selected tools",
- )
- builder.add_agent_node("writer", writer_agent, "Write final report")
-
- # Add edges
- builder.connect_nodes("thinking", "knowledge_gap")
- builder.connect_nodes("knowledge_gap", "continue_decision")
- builder.connect_nodes("continue_decision", "tool_selector")
- builder.connect_nodes("continue_decision", "writer")
- builder.connect_nodes("tool_selector", "execute_tools")
- builder.connect_nodes("execute_tools", "thinking") # Loop back
-
- # Set entry and exit
- builder.set_entry_node("thinking")
- builder.set_exit_nodes(["writer"])
-
- return builder.build()
-
-
-def create_deep_graph(
- planner_agent: "Agent[Any, Any]",
- knowledge_gap_agent: "Agent[Any, Any]",
- tool_selector_agent: "Agent[Any, Any]",
- thinking_agent: "Agent[Any, Any]",
- writer_agent: "Agent[Any, Any]",
- long_writer_agent: "Agent[Any, Any]",
-) -> ResearchGraph:
- """Create a graph for deep research flow.
-
- The graph structure: planner → store_plan → parallel_loops → collect_drafts → synthesizer
-
- Args:
- planner_agent: Agent for creating report plan
- knowledge_gap_agent: Agent for evaluating knowledge gaps (not used directly, but needed for iterative flows)
- tool_selector_agent: Agent for selecting tools (not used directly, but needed for iterative flows)
- thinking_agent: Agent for generating observations (not used directly, but needed for iterative flows)
- writer_agent: Agent for writing section reports (not used directly, but needed for iterative flows)
- long_writer_agent: Agent for synthesizing final report
-
- Returns:
- Constructed ResearchGraph for deep research
- """
- from src.utils.models import ReportPlan
-
- builder = GraphBuilder()
-
- # Add nodes
- # 1. Planner agent - creates report plan
- builder.add_agent_node("planner", planner_agent, "Create report plan with sections")
-
- # 2. State node - store report plan in workflow state
- def store_plan(state: "WorkflowState", plan: ReportPlan) -> "WorkflowState":
- """Store report plan in state for parallel loops to access."""
- # Store plan in a custom attribute (we'll need to extend WorkflowState or use a dict)
- # For now, we'll store it in the context's node_results
- # The actual storage will happen in the graph execution
- return state
-
- builder.add_state_node(
- "store_plan",
- state_updater=store_plan,
- description="Store report plan in state",
- )
-
- # 3. Parallel node - will execute iterative research flows for each section
- # The actual execution will be handled dynamically in _execute_parallel_node()
- # We use a special node ID that the executor will recognize
- builder.add_parallel_node(
- "parallel_loops",
- parallel_nodes=[], # Will be populated dynamically based on report plan
- description="Execute parallel iterative research loops for each section",
- aggregator=lambda results: results, # Collect all section drafts
- )
-
- # 4. State node - collect section drafts into ReportDraft
- def collect_drafts(state: "WorkflowState", section_drafts: list[str]) -> "WorkflowState":
- """Collect section drafts into state for synthesizer."""
- # Store drafts in state (will be accessed by synthesizer)
- return state
-
- builder.add_state_node(
- "collect_drafts",
- state_updater=collect_drafts,
- description="Collect section drafts for synthesis",
- )
-
- # 5. Synthesizer agent - creates final report from drafts
- builder.add_agent_node(
- "synthesizer", long_writer_agent, "Synthesize final report from section drafts"
- )
-
- # Add edges
- builder.connect_nodes("planner", "store_plan")
- builder.connect_nodes("store_plan", "parallel_loops")
- builder.connect_nodes("parallel_loops", "collect_drafts")
- builder.connect_nodes("collect_drafts", "synthesizer")
-
- # Set entry and exit
- builder.set_entry_node("planner")
- builder.set_exit_nodes(["synthesizer"])
-
- return builder.build()
-
-
-# No need to rebuild models since we're using Any types
-# The models will work correctly with arbitrary_types_allowed=True
diff --git a/src/agent_factory/judges.py b/src/agent_factory/judges.py
index 59ccdc08f3d8292be1161908a47b9451672121d2..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/src/agent_factory/judges.py
+++ b/src/agent_factory/judges.py
@@ -1,537 +0,0 @@
-"""Judge handler for evidence assessment using PydanticAI."""
-
-import asyncio
-import json
-from typing import Any, ClassVar
-
-import structlog
-from huggingface_hub import InferenceClient
-from pydantic_ai import Agent
-from pydantic_ai.models.anthropic import AnthropicModel
-from pydantic_ai.models.huggingface import HuggingFaceModel
-from pydantic_ai.models.openai import OpenAIChatModel as OpenAIModel
-from pydantic_ai.providers.anthropic import AnthropicProvider
-from pydantic_ai.providers.huggingface import HuggingFaceProvider
-from pydantic_ai.providers.openai import OpenAIProvider
-from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
-
-from src.prompts.judge import (
- SYSTEM_PROMPT,
- format_empty_evidence_prompt,
- format_user_prompt,
-)
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import AssessmentDetails, Evidence, JudgeAssessment
-
-logger = structlog.get_logger()
-
-
-def get_model(oauth_token: str | None = None) -> Any:
- """Get the LLM model based on configuration.
-
- Explicitly passes API keys from settings to avoid requiring
- users to export environment variables manually.
-
- Priority order:
- 1. HuggingFace (if OAuth token or API key available - preferred for free tier)
- 2. OpenAI (if API key available)
- 3. Anthropic (if API key available)
-
- If OAuth token is available, prefer HuggingFace (even if provider is set to OpenAI).
- This ensures users logged in via HuggingFace Spaces get the free tier.
-
- Args:
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured Pydantic AI model
-
- Raises:
- ConfigurationError: If no LLM provider is available
- """
- from src.utils.hf_error_handler import log_token_info, validate_hf_token
-
- # Priority: oauth_token > settings.hf_token > settings.huggingface_api_key
- effective_hf_token = oauth_token or settings.hf_token or settings.huggingface_api_key
-
- # Validate and log token information
- if effective_hf_token:
- log_token_info(effective_hf_token, context="get_model")
- is_valid, error_msg = validate_hf_token(effective_hf_token)
- if not is_valid:
- logger.warning(
- "Token validation failed",
- error=error_msg,
- has_oauth=bool(oauth_token),
- )
- # Continue anyway - let the API call fail with a clear error
-
- # Try HuggingFace first (preferred for free tier)
- if effective_hf_token:
- model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct"
- hf_provider = HuggingFaceProvider(api_key=effective_hf_token)
- logger.info(
- "using_huggingface_with_token",
- has_oauth=bool(oauth_token),
- has_settings_token=bool(settings.hf_token or settings.huggingface_api_key),
- model=model_name,
- )
- return HuggingFaceModel(model_name, provider=hf_provider)
-
- # Fallback to OpenAI if available
- if settings.has_openai_key:
- assert settings.openai_api_key is not None # Type narrowing
- model_name = settings.openai_model
- openai_provider = OpenAIProvider(api_key=settings.openai_api_key)
- logger.info("using_openai", model=model_name)
- return OpenAIModel(model_name, provider=openai_provider)
-
- # Fallback to Anthropic if available
- if settings.has_anthropic_key:
- assert settings.anthropic_api_key is not None # Type narrowing
- model_name = settings.anthropic_model
- anthropic_provider = AnthropicProvider(api_key=settings.anthropic_api_key)
- logger.info("using_anthropic", model=model_name)
- return AnthropicModel(model_name, provider=anthropic_provider)
-
- # No provider available
- raise ConfigurationError(
- "No LLM provider available. Please configure one of:\n"
- "1. HuggingFace: Log in via OAuth (recommended for Spaces) or set HF_TOKEN\n"
- "2. OpenAI: Set OPENAI_API_KEY environment variable\n"
- "3. Anthropic: Set ANTHROPIC_API_KEY environment variable"
- )
-
-
-class JudgeHandler:
- """
- Handles evidence assessment using an LLM with structured output.
-
- Uses PydanticAI to ensure responses match the JudgeAssessment schema.
- """
-
- def __init__(self, model: Any = None) -> None:
- """
- Initialize the JudgeHandler.
-
- Args:
- model: Optional PydanticAI model. If None, uses config default.
- """
- self.model = model or get_model()
- self.agent = Agent(
- model=self.model,
- output_type=JudgeAssessment,
- system_prompt=SYSTEM_PROMPT,
- retries=3,
- )
-
- async def assess(
- self,
- question: str,
- evidence: list[Evidence],
- ) -> JudgeAssessment:
- """
- Assess evidence and determine if it's sufficient.
-
- Args:
- question: The user's research question
- evidence: List of Evidence objects from search
-
- Returns:
- JudgeAssessment with evaluation results
-
- Raises:
- JudgeError: If assessment fails after retries
- """
- logger.info(
- "Starting evidence assessment",
- question=question[:100],
- evidence_count=len(evidence),
- )
-
- # Format the prompt based on whether we have evidence
- if evidence:
- user_prompt = format_user_prompt(question, evidence)
- else:
- user_prompt = format_empty_evidence_prompt(question)
-
- try:
- # Run the agent with structured output
- result = await self.agent.run(user_prompt)
- assessment = result.output
-
- logger.info(
- "Assessment complete",
- sufficient=assessment.sufficient,
- recommendation=assessment.recommendation,
- confidence=assessment.confidence,
- )
-
- return assessment
-
- except Exception as e:
- # Extract error details for better logging and handling
- from src.utils.hf_error_handler import (
- extract_error_details,
- get_user_friendly_error_message,
- )
-
- error_details = extract_error_details(e)
- logger.error(
- "Assessment failed",
- error=str(e),
- status_code=error_details.get("status_code"),
- model_name=error_details.get("model_name"),
- is_auth_error=error_details.get("is_auth_error"),
- is_model_error=error_details.get("is_model_error"),
- )
-
- # Log user-friendly message for debugging
- if error_details.get("is_auth_error") or error_details.get("is_model_error"):
- user_msg = get_user_friendly_error_message(e, error_details.get("model_name"))
- logger.warning("API error details", user_message=user_msg[:200])
-
- # Return a safe default assessment on failure
- return self._create_fallback_assessment(question, str(e))
-
- def _create_fallback_assessment(
- self,
- question: str,
- error: str,
- ) -> JudgeAssessment:
- """
- Create a fallback assessment when LLM fails.
-
- Args:
- question: The original question
- error: The error message
-
- Returns:
- Safe fallback JudgeAssessment
- """
- return JudgeAssessment(
- details=AssessmentDetails(
- mechanism_score=0,
- mechanism_reasoning="Assessment failed due to LLM error",
- clinical_evidence_score=0,
- clinical_reasoning="Assessment failed due to LLM error",
- drug_candidates=[],
- key_findings=[],
- ),
- sufficient=False,
- confidence=0.0,
- recommendation="continue",
- next_search_queries=[
- f"{question} mechanism",
- f"{question} clinical trials",
- f"{question} drug candidates",
- ],
- reasoning=f"Assessment failed: {error}. Recommend retrying with refined queries.",
- )
-
-
-class HFInferenceJudgeHandler:
- """
- JudgeHandler using HuggingFace Inference API for FREE LLM calls.
- Defaults to Llama-3.1-8B-Instruct (requires HF_TOKEN) or falls back to public models.
- """
-
- FALLBACK_MODELS: ClassVar[list[str]] = [
- "meta-llama/Llama-3.1-8B-Instruct", # Primary (Gated)
- "mistralai/Mistral-7B-Instruct-v0.3", # Secondary
- "HuggingFaceH4/zephyr-7b-beta", # Fallback (Ungated)
- ]
-
- def __init__(self, model_id: str | None = None, api_key: str | None = None) -> None:
- """
- Initialize with HF Inference client.
-
- Args:
- model_id: Optional specific model ID. If None, uses FALLBACK_MODELS chain.
- api_key: Optional HuggingFace API key/token. If None, uses HF_TOKEN from env.
- """
- self.model_id = model_id
- # Pass api_key to InferenceClient if provided, otherwise it will use HF_TOKEN from env
- self.client = InferenceClient(api_key=api_key) if api_key else InferenceClient()
- self.call_count = 0
- self.last_question: str | None = None
- self.last_evidence: list[Evidence] | None = None
-
- async def assess(
- self,
- question: str,
- evidence: list[Evidence],
- ) -> JudgeAssessment:
- """
- Assess evidence using HuggingFace Inference API.
- Attempts models in order until one succeeds.
- """
- self.call_count += 1
- self.last_question = question
- self.last_evidence = evidence
-
- # Format the user prompt
- if evidence:
- user_prompt = format_user_prompt(question, evidence)
- else:
- user_prompt = format_empty_evidence_prompt(question)
-
- models_to_try: list[str] = [self.model_id] if self.model_id else self.FALLBACK_MODELS
- last_error: Exception | None = None
-
- for model in models_to_try:
- try:
- return await self._call_with_retry(model, user_prompt, question)
- except Exception as e:
- logger.warning("Model failed", model=model, error=str(e))
- last_error = e
- continue
-
- # All models failed
- logger.error("All HF models failed", error=str(last_error))
- return self._create_fallback_assessment(question, str(last_error))
-
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=1, max=4),
- retry=retry_if_exception_type(Exception),
- reraise=True,
- )
- async def _call_with_retry(self, model: str, prompt: str, question: str) -> JudgeAssessment:
- """Make API call with retry logic using chat_completion."""
- loop = asyncio.get_running_loop()
-
- # Build messages for chat_completion (model-agnostic)
- messages = [
- {
- "role": "system",
- "content": f"""{SYSTEM_PROMPT}
-
-IMPORTANT: Respond with ONLY valid JSON matching this schema:
-{{
- "details": {{
- "mechanism_score": ,
- "mechanism_reasoning": "",
- "clinical_evidence_score": ,
- "clinical_reasoning": "",
- "drug_candidates": ["", ...],
- "key_findings": ["", ...]
- }},
- "sufficient": ,
- "confidence": ,
- "recommendation": "continue" | "synthesize",
- "next_search_queries": ["", ...],
- "reasoning": ""
-}}""",
- },
- {"role": "user", "content": prompt},
- ]
-
- # Use chat_completion (conversational task - supported by all models)
- response = await loop.run_in_executor(
- None,
- lambda: self.client.chat_completion(
- messages=messages,
- model=model,
- max_tokens=1024,
- temperature=0.1,
- ),
- )
-
- # Extract content from response
- content = response.choices[0].message.content
- if not content:
- raise ValueError("Empty response from model")
-
- # Extract and parse JSON
- json_data = self._extract_json(content)
- if not json_data:
- raise ValueError("No valid JSON found in response")
-
- return JudgeAssessment(**json_data)
-
- def _extract_json(self, text: str) -> dict[str, Any] | None:
- """
- Robust JSON extraction that handles markdown blocks and nested braces.
- """
- text = text.strip()
-
- # Remove markdown code blocks if present (with bounds checking)
- if "```json" in text:
- parts = text.split("```json", 1)
- if len(parts) > 1:
- inner_parts = parts[1].split("```", 1)
- text = inner_parts[0]
- elif "```" in text:
- parts = text.split("```", 1)
- if len(parts) > 1:
- inner_parts = parts[1].split("```", 1)
- text = inner_parts[0]
-
- text = text.strip()
-
- # Find first '{'
- start_idx = text.find("{")
- if start_idx == -1:
- return None
-
- # Stack-based parsing ignoring chars in strings
- count = 0
- in_string = False
- escape = False
-
- for i, char in enumerate(text[start_idx:], start=start_idx):
- if in_string:
- if escape:
- escape = False
- elif char == "\\":
- escape = True
- elif char == '"':
- in_string = False
- elif char == '"':
- in_string = True
- elif char == "{":
- count += 1
- elif char == "}":
- count -= 1
- if count == 0:
- try:
- result = json.loads(text[start_idx : i + 1])
- if isinstance(result, dict):
- return result
- return None
- except json.JSONDecodeError:
- return None
-
- return None
-
- def _create_fallback_assessment(
- self,
- question: str,
- error: str,
- ) -> JudgeAssessment:
- """Create a fallback assessment when inference fails."""
- return JudgeAssessment(
- details=AssessmentDetails(
- mechanism_score=0,
- mechanism_reasoning=f"Assessment failed: {error}",
- clinical_evidence_score=0,
- clinical_reasoning=f"Assessment failed: {error}",
- drug_candidates=[],
- key_findings=[],
- ),
- sufficient=False,
- confidence=0.0,
- recommendation="continue",
- next_search_queries=[
- f"{question} mechanism",
- f"{question} clinical trials",
- f"{question} drug candidates",
- ],
- reasoning=f"HF Inference failed: {error}. Recommend configuring OpenAI/Anthropic key.",
- )
-
-
-def create_judge_handler() -> JudgeHandler:
- """Create a judge handler based on configuration.
-
- Returns:
- Configured JudgeHandler instance
- """
- return JudgeHandler()
-
-
-class MockJudgeHandler:
- """
- Mock JudgeHandler for demo mode without LLM calls.
-
- Extracts meaningful information from real search results
- to provide a useful demo experience without requiring API keys.
- """
-
- def __init__(self, mock_response: JudgeAssessment | None = None) -> None:
- """
- Initialize with optional mock response.
-
- Args:
- mock_response: The assessment to return. If None, extracts from evidence.
- """
- self.mock_response = mock_response
- self.call_count = 0
- self.last_question: str | None = None
- self.last_evidence: list[Evidence] | None = None
-
- def _extract_key_findings(self, evidence: list[Evidence], max_findings: int = 5) -> list[str]:
- """Extract key findings from evidence titles."""
- findings = []
- for e in evidence[:max_findings]:
- # Use first 150 chars of title as a finding
- title = e.citation.title
- if len(title) > 150:
- title = title[:147] + "..."
- findings.append(title)
- return findings if findings else ["No specific findings extracted (demo mode)"]
-
- def _extract_drug_candidates(self, question: str, evidence: list[Evidence]) -> list[str]:
- """Extract drug candidates - demo mode returns honest message."""
- # Don't attempt heuristic extraction - it produces garbage like "Oral", "Kidney"
- # Real drug extraction requires LLM analysis
- return [
- "Drug identification requires AI analysis",
- "Enter API key above for full results",
- ]
-
- async def assess(
- self,
- question: str,
- evidence: list[Evidence],
- ) -> JudgeAssessment:
- """Return assessment based on actual evidence (demo mode)."""
- self.call_count += 1
- self.last_question = question
- self.last_evidence = evidence
-
- if self.mock_response:
- return self.mock_response
-
- min_evidence = 3
- evidence_count = len(evidence)
-
- # Extract meaningful data from actual evidence
- drug_candidates = self._extract_drug_candidates(question, evidence)
- key_findings = self._extract_key_findings(evidence)
-
- # Calculate scores based on evidence quantity
- mechanism_score = min(10, evidence_count * 2) if evidence_count > 0 else 0
- clinical_score = min(10, evidence_count) if evidence_count > 0 else 0
-
- return JudgeAssessment(
- details=AssessmentDetails(
- mechanism_score=mechanism_score,
- mechanism_reasoning=(
- f"Demo mode: Found {evidence_count} sources. "
- "Configure LLM API key for detailed mechanism analysis."
- ),
- clinical_evidence_score=clinical_score,
- clinical_reasoning=(
- f"Demo mode: {evidence_count} sources retrieved from PubMed, "
- "ClinicalTrials.gov, and Europe PMC. Full analysis requires LLM API key."
- ),
- drug_candidates=drug_candidates,
- key_findings=key_findings,
- ),
- sufficient=evidence_count >= min_evidence,
- confidence=min(0.5, evidence_count * 0.1) if evidence_count > 0 else 0.0,
- recommendation="synthesize" if evidence_count >= min_evidence else "continue",
- next_search_queries=(
- [f"{question} mechanism", f"{question} clinical trials"]
- if evidence_count < min_evidence
- else []
- ),
- reasoning=(
- f"Demo mode assessment based on {evidence_count} real search results. "
- "For AI-powered analysis with drug candidate identification and "
- "evidence synthesis, configure OPENAI_API_KEY or ANTHROPIC_API_KEY."
- ),
- )
diff --git a/src/agents/__init__.py b/src/agents/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/src/agents/analysis_agent.py b/src/agents/analysis_agent.py
deleted file mode 100644
index e8bbc7fec996050d96af1f9edd3c8605744fc830..0000000000000000000000000000000000000000
--- a/src/agents/analysis_agent.py
+++ /dev/null
@@ -1,145 +0,0 @@
-"""Analysis agent for statistical analysis using Modal code execution.
-
-This agent wraps StatisticalAnalyzer for use in magentic multi-agent mode.
-The core logic is in src/services/statistical_analyzer.py to avoid
-coupling agent_framework to the simple orchestrator.
-"""
-
-from collections.abc import AsyncIterable
-from typing import TYPE_CHECKING, Any
-
-from agent_framework import (
- AgentRunResponse,
- AgentRunResponseUpdate,
- AgentThread,
- BaseAgent,
- ChatMessage,
- Role,
-)
-
-from src.services.statistical_analyzer import (
- AnalysisResult,
- get_statistical_analyzer,
-)
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
-
-
-class AnalysisAgent(BaseAgent): # type: ignore[misc]
- """Wraps StatisticalAnalyzer for magentic multi-agent mode."""
-
- def __init__(
- self,
- evidence_store: dict[str, Any],
- embedding_service: "EmbeddingService | None" = None,
- ) -> None:
- super().__init__(
- name="AnalysisAgent",
- description="Performs statistical analysis using Modal sandbox",
- )
- self._evidence_store = evidence_store
- self._embeddings = embedding_service
- self._analyzer = get_statistical_analyzer()
-
- async def run(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AgentRunResponse:
- """Analyze evidence and return verdict."""
- query = self._extract_query(messages)
- hypotheses = self._evidence_store.get("hypotheses", [])
- evidence = self._evidence_store.get("current", [])
-
- if not evidence:
- return self._error_response("No evidence available.")
-
- # Get primary hypothesis if available
- hypothesis_dict = None
- if hypotheses:
- h = hypotheses[0]
- hypothesis_dict = {
- "drug": getattr(h, "drug", "Unknown"),
- "target": getattr(h, "target", "?"),
- "pathway": getattr(h, "pathway", "?"),
- "effect": getattr(h, "effect", "?"),
- "confidence": getattr(h, "confidence", 0.5),
- }
-
- # Delegate to StatisticalAnalyzer
- result = await self._analyzer.analyze(
- query=query,
- evidence=evidence,
- hypothesis=hypothesis_dict,
- )
-
- # Store in shared context
- self._evidence_store["analysis"] = result.model_dump()
-
- # Format response
- response_text = self._format_response(result)
-
- return AgentRunResponse(
- messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
- response_id=f"analysis-{result.verdict.lower()}",
- additional_properties={"analysis": result.model_dump()},
- )
-
- def _format_response(self, result: AnalysisResult) -> str:
- """Format analysis result as markdown."""
- lines = [
- "## Statistical Analysis Complete\n",
- f"### Verdict: **{result.verdict}**",
- f"**Confidence**: {result.confidence:.0%}\n",
- "### Key Findings",
- ]
- for finding in result.key_findings:
- lines.append(f"- {finding}")
-
- lines.extend(
- [
- "\n### Statistical Evidence",
- "```",
- result.statistical_evidence,
- "```",
- ]
- )
- return "\n".join(lines)
-
- def _error_response(self, message: str) -> AgentRunResponse:
- """Create error response."""
- return AgentRunResponse(
- messages=[ChatMessage(role=Role.ASSISTANT, text=f"**Error**: {message}")],
- response_id="analysis-error",
- )
-
- def _extract_query(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None,
- ) -> str:
- """Extract query from messages."""
- if isinstance(messages, str):
- return messages
- elif isinstance(messages, ChatMessage):
- return messages.text or ""
- elif isinstance(messages, list):
- for msg in reversed(messages):
- if isinstance(msg, ChatMessage) and msg.role == Role.USER:
- return msg.text or ""
- elif isinstance(msg, str):
- return msg
- return ""
-
- async def run_stream(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AsyncIterable[AgentRunResponseUpdate]:
- """Streaming wrapper."""
- result = await self.run(messages, thread=thread, **kwargs)
- yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id)
diff --git a/src/agents/audio_refiner.py b/src/agents/audio_refiner.py
deleted file mode 100644
index 257c6f5e6df42b16687f181a98ab36178dc9b26a..0000000000000000000000000000000000000000
--- a/src/agents/audio_refiner.py
+++ /dev/null
@@ -1,402 +0,0 @@
-"""Audio Refiner Agent - Cleans markdown reports for TTS audio clarity.
-
-This agent transforms markdown-formatted research reports into clean,
-audio-friendly plain text suitable for text-to-speech synthesis.
-"""
-
-import re
-
-import structlog
-from pydantic_ai import Agent
-
-from src.utils.llm_factory import get_pydantic_ai_model
-
-logger = structlog.get_logger(__name__)
-
-
-class AudioRefiner:
- """Refines markdown reports for optimal TTS audio output.
-
- Handles common formatting issues that make text difficult to listen to:
- - Markdown syntax (headers, bold, italic, links)
- - Citations and reference markers
- - Roman numerals in medical contexts
- - Multiple References sections
- - Special characters and formatting artifacts
- """
-
- # Roman numeral to integer mapping
- ROMAN_VALUES = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000}
-
- # Number to word mapping (1-20, common in medical literature)
- NUMBER_TO_WORD = {
- 1: "One",
- 2: "Two",
- 3: "Three",
- 4: "Four",
- 5: "Five",
- 6: "Six",
- 7: "Seven",
- 8: "Eight",
- 9: "Nine",
- 10: "Ten",
- 11: "Eleven",
- 12: "Twelve",
- 13: "Thirteen",
- 14: "Fourteen",
- 15: "Fifteen",
- 16: "Sixteen",
- 17: "Seventeen",
- 18: "Eighteen",
- 19: "Nineteen",
- 20: "Twenty",
- }
-
- async def refine_for_audio(self, markdown_text: str, use_llm_polish: bool = False) -> str:
- """Transform markdown report into audio-friendly plain text.
-
- Args:
- markdown_text: Markdown-formatted research report
- use_llm_polish: If True, apply LLM-based final polish (optional)
-
- Returns:
- Clean plain text optimized for TTS audio
- """
- logger.info("Refining report for audio output", use_llm_polish=use_llm_polish)
-
- text = markdown_text
-
- # Step 1: Remove References sections first (before other processing)
- text = self._remove_references_sections(text)
-
- # Step 2: Remove markdown formatting
- text = self._remove_markdown_syntax(text)
-
- # Step 3: Convert roman numerals to words
- text = self._convert_roman_numerals(text)
-
- # Step 4: Remove citations
- text = self._remove_citations(text)
-
- # Step 5: Clean up special characters and artifacts
- text = self._clean_special_characters(text)
-
- # Step 6: Normalize whitespace
- text = self._normalize_whitespace(text)
-
- # Step 7 (Optional): LLM polish for edge cases
- if use_llm_polish:
- text = await self._llm_polish(text)
-
- logger.info(
- "Audio refinement complete",
- original_length=len(markdown_text),
- refined_length=len(text),
- llm_polish_applied=use_llm_polish,
- )
-
- return text.strip()
-
- def _remove_references_sections(self, text: str) -> str:
- """Remove References sections while preserving other content.
-
- Removes the References section and its content until the next section
- heading or end of document. Handles multiple References sections.
-
- Matches various References heading formats:
- - # References
- - ## References
- - **References:**
- - **Additional References:**
- - References: (plain text)
- """
- # Pattern to match References section heading (case-insensitive)
- # Matches: markdown headers (# References), bold (**References:**), or plain text (References:)
- references_pattern = r"\n(?:#+\s*References?:?\s*\n|\*\*\s*(?:Additional\s+)?References?:?\s*\*\*\s*\n|References?:?\s*\n)"
-
- # Find all References sections
- while True:
- match = re.search(references_pattern, text, re.IGNORECASE)
- if not match:
- break
-
- # Find the start of the References section
- section_start = match.start()
-
- # Find the next section (markdown header or bold heading) or end of document
- # Match: "# Header", "## Header", or "**Header**"
- next_section_patterns = [
- r"\n#+\s+\w+", # Markdown headers (# Section, ## Section)
- r"\n\*\*[A-Z][^*]+\*\*", # Bold headings (**Section Name**)
- ]
-
- remaining_text = text[match.end() :]
- next_section_match = None
-
- # Try all patterns and find the earliest match
- earliest_match = None
- for pattern in next_section_patterns:
- m = re.search(pattern, remaining_text)
- if m and (earliest_match is None or m.start() < earliest_match.start()):
- earliest_match = m
-
- next_section_match = earliest_match
-
- if next_section_match:
- # Remove from References heading to next section
- section_end = match.end() + next_section_match.start()
- else:
- # No next section - remove to end of document
- section_end = len(text)
-
- # Remove the References section
- text = text[:section_start] + text[section_end:]
- logger.debug("Removed References section", removed_chars=section_end - section_start)
-
- return text
-
- def _remove_markdown_syntax(self, text: str) -> str:
- """Remove markdown formatting syntax."""
-
- # Headers (# ## ###)
- text = re.sub(r"^\s*#+\s+", "", text, flags=re.MULTILINE)
-
- # Bold (**text** or __text__)
- text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
- text = re.sub(r"__([^_]+)__", r"\1", text)
-
- # Italic (*text* or _text_)
- text = re.sub(r"\*([^*]+)\*", r"\1", text)
- text = re.sub(r"_([^_]+)_", r"\1", text)
-
- # Links [text](url) → text
- text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
-
- # Inline code `code` → code
- text = re.sub(r"`([^`]+)`", r"\1", text)
-
- # Strikethrough ~~text~~
- text = re.sub(r"~~([^~]+)~~", r"\1", text)
-
- # Blockquotes (> text)
- text = re.sub(r"^\s*>\s+", "", text, flags=re.MULTILINE)
-
- # Horizontal rules (---, ***, ___)
- text = re.sub(r"^\s*[-*_]{3,}\s*$", "", text, flags=re.MULTILINE)
-
- # List markers (-, *, 1., 2.)
- text = re.sub(r"^\s*[-*]\s+", "", text, flags=re.MULTILINE)
- text = re.sub(r"^\s*\d+\.\s+", "", text, flags=re.MULTILINE)
-
- return text
-
- def _roman_to_int(self, roman: str) -> int | None:
- """Convert roman numeral string to integer.
-
- Args:
- roman: Roman numeral string (e.g., 'IV', 'XII')
-
- Returns:
- Integer value, or None if invalid roman numeral
- """
- roman = roman.upper()
- result = 0
- prev_value = 0
-
- for char in reversed(roman):
- if char not in self.ROMAN_VALUES:
- return None
-
- value = self.ROMAN_VALUES[char]
-
- # Subtractive notation (IV = 4, IX = 9)
- if value < prev_value:
- result -= value
- else:
- result += value
-
- prev_value = value
-
- return result
-
- def _int_to_word(self, num: int) -> str:
- """Convert integer to word representation.
-
- Args:
- num: Integer to convert (1-20 supported)
-
- Returns:
- Word representation (e.g., 'One', 'Twelve')
- """
- if num in self.NUMBER_TO_WORD:
- return self.NUMBER_TO_WORD[num]
- else:
- # For numbers > 20, just return the digit
- return str(num)
-
- def _convert_roman_numerals(self, text: str) -> str:
- """Convert roman numerals to words for better TTS pronunciation.
-
- Handles patterns like:
- - Phase I, Phase II, Phase III
- - Trial I, Trial II
- - Type I, Type II
- - Stage I, Stage II
- - Standalone I, II, III (with word boundaries)
- """
-
- def replace_roman(match: re.Match[str]) -> str:
- """Callback to replace matched roman numeral."""
- prefix = match.group(1) # Word before roman numeral (if any)
- roman = match.group(2) # The roman numeral
-
- # Convert to integer
- num = self._roman_to_int(roman)
- if num is None:
- return match.group(0) # Return original if invalid
-
- # Convert to word
- word = self._int_to_word(num)
-
- # Return with prefix if present
- if prefix:
- return f"{prefix} {word}"
- else:
- return word
-
- # Pattern: Optional word + space + roman numeral
- # Matches: "Phase I", "Trial II", standalone "I", "II"
- # Uses word boundaries to avoid matching "I" in "INVALID"
- pattern = r"\b(Phase|Trial|Type|Stage|Class|Group|Arm|Cohort)?\s*([IVXLCDM]+)\b"
-
- text = re.sub(pattern, replace_roman, text)
-
- return text
-
- def _remove_citations(self, text: str) -> str:
- """Remove citation markers and references."""
-
- # Numbered citations [1], [2], [1,2], [1-3]
- text = re.sub(r"\[\d+(?:[-,]\d+)*\]", "", text)
-
- # Author citations (Smith et al., 2023) or (Smith et al. 2023)
- text = re.sub(r"\([A-Z][a-z]+\s+et\s+al\.?,?\s+\d{4}\)", "", text)
-
- # Simple year citations (2023)
- text = re.sub(r"\(\d{4}\)", "", text)
-
- # Author-year (Smith, 2023)
- text = re.sub(r"\([A-Z][a-z]+,?\s+\d{4}\)", "", text)
-
- # Footnote markers (¹, ², ³)
- text = re.sub(r"[¹²³⁴⁵⁶⁷⁸⁹⁰]+", "", text)
-
- return text
-
- def _clean_special_characters(self, text: str) -> str:
- """Clean up special characters and formatting artifacts."""
-
- # Replace em dashes with regular dashes
- text = text.replace("\u2014", "-") # em dash
- text = text.replace("\u2013", "-") # en dash
-
- # Replace smart quotes with regular quotes
- text = text.replace("\u201c", '"') # left double quote
- text = text.replace("\u201d", '"') # right double quote
- text = text.replace("\u2018", "'") # left single quote
- text = text.replace("\u2019", "'") # right single quote
-
- # Remove excessive punctuation (!!!, ???)
- text = re.sub(r"([!?]){2,}", r"\1", text)
-
- # Remove asterisks used for footnotes
- text = re.sub(r"\*+", "", text)
-
- # Remove hash symbols (from headers)
- text = text.replace("#", "")
-
- # Remove excessive dots (...)
- text = re.sub(r"\.{4,}", "...", text)
-
- return text
-
- def _normalize_whitespace(self, text: str) -> str:
- """Normalize whitespace for clean audio output."""
-
- # Replace multiple spaces with single space
- text = re.sub(r" {2,}", " ", text)
-
- # Replace multiple newlines with double newline (paragraph break)
- text = re.sub(r"\n{3,}", "\n\n", text)
-
- # Remove trailing/leading whitespace from lines
- text = "\n".join(line.strip() for line in text.split("\n"))
-
- # Remove empty lines at start/end
- text = text.strip()
-
- return text
-
- async def _llm_polish(self, text: str) -> str:
- """Apply LLM-based final polish to catch edge cases.
-
- This is a lightweight pass that removes any remaining formatting
- artifacts the rule-based methods might have missed.
-
- Args:
- text: Pre-cleaned text from rule-based methods
-
- Returns:
- Final polished text ready for TTS
- """
- try:
- # Create a simple agent for text cleanup
- model = get_pydantic_ai_model()
- polish_agent = Agent(
- model=model,
- system_prompt=(
- "You are a text cleanup assistant. Your ONLY job is to remove "
- "any remaining formatting artifacts (markdown, citations, special "
- "characters) that make text unsuitable for text-to-speech audio. "
- "DO NOT rewrite, improve, or change the content. "
- "DO NOT add explanations. "
- "ONLY output the cleaned text."
- ),
- )
-
- # Run asynchronously
- result = await polish_agent.run(
- f"Clean this text for audio (remove any formatting artifacts):\n\n{text}"
- )
-
- polished_text = result.output.strip()
-
- logger.info(
- "llm_polish_applied", original_length=len(text), polished_length=len(polished_text)
- )
-
- return polished_text
-
- except Exception as e:
- logger.warning(
- "llm_polish_failed", error=str(e), message="Falling back to rule-based output"
- )
- # Graceful fallback: return original text if LLM fails
- return text
-
-
-# Singleton instance for easy import
-audio_refiner = AudioRefiner()
-
-
-async def refine_text_for_audio(markdown_text: str, use_llm_polish: bool = False) -> str:
- """Convenience function to refine markdown text for audio.
-
- Args:
- markdown_text: Markdown-formatted text
- use_llm_polish: If True, apply LLM-based final polish (optional)
-
- Returns:
- Audio-friendly plain text
- """
- return await audio_refiner.refine_for_audio(markdown_text, use_llm_polish=use_llm_polish)
diff --git a/src/agents/code_executor_agent.py b/src/agents/code_executor_agent.py
deleted file mode 100644
index e86eab2225b9071e3421ed4f62173867d43ae1ea..0000000000000000000000000000000000000000
--- a/src/agents/code_executor_agent.py
+++ /dev/null
@@ -1,67 +0,0 @@
-"""Code execution agent using Modal."""
-
-import asyncio
-from typing import Any
-
-import structlog
-from agent_framework import ChatAgent, ai_function
-
-from src.tools.code_execution import get_code_executor
-from src.utils.llm_factory import get_chat_client_for_agent
-
-logger = structlog.get_logger()
-
-
-@ai_function # type: ignore[arg-type, misc]
-async def execute_python_code(code: str) -> str:
- """Execute Python code in a secure sandbox.
-
- Args:
- code: The Python code to execute.
-
- Returns:
- The standard output and standard error of the execution.
- """
- logger.info("Code execution starting", code_length=len(code))
- executor = get_code_executor()
- loop = asyncio.get_running_loop()
-
- # Run in executor to avoid blocking
- try:
- result = await loop.run_in_executor(None, lambda: executor.execute(code))
- if result["success"]:
- logger.info("Code execution succeeded")
- return f"Stdout:\n{result['stdout']}"
- else:
- logger.warning("Code execution failed", error=result.get("error"))
- return f"Error:\n{result['error']}\nStderr:\n{result['stderr']}"
- except Exception as e:
- logger.error("Code execution exception", error=str(e))
- return f"Execution failed: {e}"
-
-
-def create_code_executor_agent(chat_client: Any | None = None) -> ChatAgent:
- """Create a code executor agent.
-
- Args:
- chat_client: Optional custom chat client. If None, uses factory default
- (HuggingFace preferred, OpenAI fallback).
-
- Returns:
- ChatAgent configured for code execution.
- """
- client = chat_client or get_chat_client_for_agent()
-
- return ChatAgent(
- name="CodeExecutorAgent",
- description="Executes Python code for data analysis, calculation, and simulation.",
- instructions="""You are a code execution expert.
-When asked to analyze data or perform calculations, write Python code and execute it.
-Use libraries like pandas, numpy, scipy, matplotlib.
-
-Always output the code you want to execute using the `execute_python_code` tool.
-Check the output and interpret the results.""",
- chat_client=client,
- tools=[execute_python_code],
- temperature=0.0, # Strict code generation
- )
diff --git a/src/agents/hypothesis_agent.py b/src/agents/hypothesis_agent.py
deleted file mode 100644
index 6619fae1da264e80f296d7aa56528a17c1aa6815..0000000000000000000000000000000000000000
--- a/src/agents/hypothesis_agent.py
+++ /dev/null
@@ -1,144 +0,0 @@
-"""Hypothesis agent for mechanistic reasoning."""
-
-from collections.abc import AsyncIterable
-from typing import TYPE_CHECKING, Any
-
-from agent_framework import (
- AgentRunResponse,
- AgentRunResponseUpdate,
- AgentThread,
- BaseAgent,
- ChatMessage,
- Role,
-)
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.prompts.hypothesis import SYSTEM_PROMPT, format_hypothesis_prompt
-from src.utils.models import HypothesisAssessment
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
-
-
-class HypothesisAgent(BaseAgent): # type: ignore[misc]
- """Generates mechanistic hypotheses based on evidence."""
-
- def __init__(
- self,
- evidence_store: dict[str, Any],
- embedding_service: "EmbeddingService | None" = None, # NEW: for diverse selection
- ) -> None:
- super().__init__(
- name="HypothesisAgent",
- description="Generates scientific hypotheses about drug mechanisms to guide research",
- )
- self._evidence_store = evidence_store
- self._embeddings = embedding_service # Used for MMR evidence selection
- self._agent: Agent[None, HypothesisAssessment] | None = None # Lazy init
-
- def _get_agent(self) -> Agent[None, HypothesisAssessment]:
- """Lazy initialization of LLM agent to avoid requiring API keys at import."""
- if self._agent is None:
- self._agent = Agent(
- model=get_model(), # Uses configured LLM (OpenAI/Anthropic)
- output_type=HypothesisAssessment,
- system_prompt=SYSTEM_PROMPT,
- )
- return self._agent
-
- async def run(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AgentRunResponse:
- """Generate hypotheses based on current evidence."""
- # Extract query
- query = self._extract_query(messages)
-
- # Get current evidence
- evidence = self._evidence_store.get("current", [])
-
- if not evidence:
- return AgentRunResponse(
- messages=[
- ChatMessage(
- role=Role.ASSISTANT,
- text="No evidence available yet. Search for evidence first.",
- )
- ],
- response_id="hypothesis-no-evidence",
- )
-
- # Generate hypotheses with diverse evidence selection
- prompt = await format_hypothesis_prompt(query, evidence, embeddings=self._embeddings)
- result = await self._get_agent().run(prompt)
- assessment = result.output # pydantic-ai returns .output for structured output
-
- # Store hypotheses in shared context
- existing = self._evidence_store.get("hypotheses", [])
- self._evidence_store["hypotheses"] = existing + assessment.hypotheses
-
- # Format response
- response_text = self._format_response(assessment)
-
- return AgentRunResponse(
- messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
- response_id=f"hypothesis-{len(assessment.hypotheses)}",
- additional_properties={"assessment": assessment.model_dump()},
- )
-
- def _format_response(self, assessment: HypothesisAssessment) -> str:
- """Format hypothesis assessment as markdown."""
- lines = ["## Generated Hypotheses\n"]
-
- for i, h in enumerate(assessment.hypotheses, 1):
- lines.append(f"### Hypothesis {i} (Confidence: {h.confidence:.0%})")
- lines.append(f"**Mechanism**: {h.drug} -> {h.target} -> {h.pathway} -> {h.effect}")
- lines.append(f"**Suggested searches**: {', '.join(h.search_suggestions)}\n")
-
- if assessment.primary_hypothesis:
- lines.append("### Primary Hypothesis")
- h = assessment.primary_hypothesis
- lines.append(f"{h.drug} -> {h.target} -> {h.pathway} -> {h.effect}\n")
-
- if assessment.knowledge_gaps:
- lines.append("### Knowledge Gaps")
- for gap in assessment.knowledge_gaps:
- lines.append(f"- {gap}")
-
- if assessment.recommended_searches:
- lines.append("\n### Recommended Next Searches")
- for search in assessment.recommended_searches:
- lines.append(f"- `{search}`")
-
- return "\n".join(lines)
-
- def _extract_query(
- self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None
- ) -> str:
- """Extract query from messages."""
- if isinstance(messages, str):
- return messages
- elif isinstance(messages, ChatMessage):
- return messages.text or ""
- elif isinstance(messages, list):
- for msg in reversed(messages):
- if isinstance(msg, ChatMessage) and msg.role == Role.USER:
- return msg.text or ""
- elif isinstance(msg, str):
- return msg
- return ""
-
- async def run_stream(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AsyncIterable[AgentRunResponseUpdate]:
- """Streaming wrapper."""
- result = await self.run(messages, thread=thread, **kwargs)
- yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id)
diff --git a/src/agents/input_parser.py b/src/agents/input_parser.py
deleted file mode 100644
index bb6b920d415287e35d8a471a1c7aaf64f371ec6c..0000000000000000000000000000000000000000
--- a/src/agents/input_parser.py
+++ /dev/null
@@ -1,189 +0,0 @@
-"""Input parser agent for analyzing and improving user queries.
-
-Determines research mode (iterative vs deep) and extracts key information
-from user queries to improve research quality.
-"""
-
-from typing import TYPE_CHECKING, Any, Literal
-
-import structlog
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.utils.exceptions import ConfigurationError, JudgeError
-from src.utils.models import ParsedQuery
-
-if TYPE_CHECKING:
- pass
-
-logger = structlog.get_logger()
-
-# System prompt for the input parser agent
-SYSTEM_PROMPT = """
-You are an expert research query analyzer for a generalist deep research agent. Your job is to analyze user queries and determine:
-1. Whether the query requires iterative research (single focused question) or deep research (multiple sections/topics)
-2. Whether the query requires medical/biomedical knowledge sources (PubMed, ClinicalTrials.gov) or general knowledge sources (web search)
-3. Improve and refine the query for better research results
-4. Extract key entities (drugs, diseases, companies, technologies, concepts, etc.)
-5. Extract specific research questions
-
-Guidelines for determining research mode:
-- **Iterative mode**: Single focused question, straightforward research goal, can be answered with a focused search loop
- Examples: "What is the mechanism of metformin?", "How does quantum computing work?", "What are the latest AI models?"
-
-- **Deep mode**: Complex query requiring multiple sections, comprehensive report, multiple related topics
- Examples: "Write a comprehensive report on diabetes treatment", "Analyze the market for quantum computing", "Review the state of AI in healthcare"
- Indicators: words like "comprehensive", "report", "sections", "analyze", "market analysis", "overview"
-
-Guidelines for determining if medical knowledge is needed:
-- **Medical knowledge needed**: Queries about diseases, treatments, drugs, clinical trials, medical conditions, biomedical mechanisms, health outcomes, etc.
- Examples: "Alzheimer's treatment", "metformin mechanism", "cancer clinical trials", "diabetes research"
-
-- **General knowledge sufficient**: Queries about technology, business, science (non-medical), history, current events, etc.
- Examples: "quantum computing", "AI models", "market analysis", "historical events"
-
-Your output must be valid JSON matching the ParsedQuery schema. Always provide:
-- original_query: The exact input query
-- improved_query: A refined, clearer version of the query
-- research_mode: Either "iterative" or "deep"
-- key_entities: List of important entities (drugs, diseases, companies, technologies, etc.)
-- research_questions: List of specific questions to answer
-
-Only output JSON. Do not output anything else.
-"""
-
-
-class InputParserAgent:
- """
- Input parser agent that analyzes queries and determines research mode.
-
- Uses Pydantic AI to generate structured ParsedQuery output with research
- mode detection, query improvement, and entity extraction.
- """
-
- def __init__(self, model: Any | None = None) -> None:
- """
- Initialize the input parser agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses config default.
- """
- self.model = model or get_model()
- self.logger = logger
-
- # Initialize Pydantic AI Agent
- self.agent = Agent(
- model=self.model,
- output_type=ParsedQuery,
- system_prompt=SYSTEM_PROMPT,
- retries=3,
- )
-
- async def parse(self, query: str) -> ParsedQuery:
- """
- Parse and analyze a user query.
-
- Args:
- query: The user's research query
-
- Returns:
- ParsedQuery with research mode, improved query, entities, and questions
-
- Raises:
- JudgeError: If parsing fails after retries
- ConfigurationError: If agent configuration is invalid
- """
- self.logger.info("Parsing user query", query=query[:100])
-
- user_message = f"QUERY: {query}"
-
- try:
- # Run the agent
- result = await self.agent.run(user_message)
- parsed_query = result.output
-
- # Validate parsed query
- if not parsed_query.original_query:
- self.logger.warning("Parsed query missing original_query", query=query[:100])
- raise JudgeError("Parsed query must have original_query")
-
- if not parsed_query.improved_query:
- self.logger.warning("Parsed query missing improved_query", query=query[:100])
- # Use original as fallback
- parsed_query = ParsedQuery(
- original_query=parsed_query.original_query,
- improved_query=parsed_query.original_query,
- research_mode=parsed_query.research_mode,
- key_entities=parsed_query.key_entities,
- research_questions=parsed_query.research_questions,
- )
-
- self.logger.info(
- "Query parsed successfully",
- mode=parsed_query.research_mode,
- entities=len(parsed_query.key_entities),
- questions=len(parsed_query.research_questions),
- )
-
- return parsed_query
-
- except Exception as e:
- self.logger.error("Query parsing failed", error=str(e), query=query[:100])
-
- # Fallback: return basic parsed query with heuristic mode detection
- if isinstance(e, JudgeError | ConfigurationError):
- raise
-
- # Heuristic fallback
- query_lower = query.lower()
- research_mode: Literal["iterative", "deep"] = "iterative"
- if any(
- keyword in query_lower
- for keyword in [
- "comprehensive",
- "report",
- "sections",
- "analyze",
- "analysis",
- "overview",
- "market",
- ]
- ):
- research_mode = "deep"
-
- return ParsedQuery(
- original_query=query,
- improved_query=query,
- research_mode=research_mode,
- key_entities=[],
- research_questions=[],
- )
-
-
-def create_input_parser_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> InputParserAgent:
- """
- Factory function to create an input parser agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured InputParserAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- try:
- # Get model from settings if not provided
- if model is None:
- model = get_model(oauth_token=oauth_token)
-
- # Create and return input parser agent
- return InputParserAgent(model=model)
-
- except Exception as e:
- logger.error("Failed to create input parser agent", error=str(e))
- raise ConfigurationError(f"Failed to create input parser agent: {e}") from e
diff --git a/src/agents/judge_agent.py b/src/agents/judge_agent.py
deleted file mode 100644
index 9bab6dbbfcf9850f33aac8c0a8b267fc7a16fc58..0000000000000000000000000000000000000000
--- a/src/agents/judge_agent.py
+++ /dev/null
@@ -1,98 +0,0 @@
-"""Judge agent wrapper for Magentic integration."""
-
-from collections.abc import AsyncIterable
-from typing import Any
-
-from agent_framework import (
- AgentRunResponse,
- AgentRunResponseUpdate,
- AgentThread,
- BaseAgent,
- ChatMessage,
- Role,
-)
-
-from src.legacy_orchestrator import JudgeHandlerProtocol
-from src.utils.models import Evidence, JudgeAssessment
-
-
-class JudgeAgent(BaseAgent): # type: ignore[misc]
- """Wraps JudgeHandler as an AgentProtocol for Magentic."""
-
- def __init__(
- self,
- judge_handler: JudgeHandlerProtocol,
- evidence_store: dict[str, list[Evidence]],
- ) -> None:
- super().__init__(
- name="JudgeAgent",
- description="Evaluates evidence quality and determines if sufficient for synthesis",
- )
- self._handler = judge_handler
- self._evidence_store = evidence_store # Shared state for evidence
-
- async def run(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AgentRunResponse:
- """Assess evidence quality."""
- # Extract original question from messages
- question = ""
- if isinstance(messages, list):
- for msg in reversed(messages):
- if isinstance(msg, ChatMessage) and msg.role == Role.USER and msg.text:
- question = msg.text
- break
- elif isinstance(msg, str):
- question = msg
- break
- elif isinstance(messages, str):
- question = messages
- elif isinstance(messages, ChatMessage) and messages.text:
- question = messages.text
-
- # Get evidence from shared store
- evidence = self._evidence_store.get("current", [])
-
- # Assess
- assessment: JudgeAssessment = await self._handler.assess(question, evidence)
-
- # Format response
- response_text = f"""## Assessment
-
-**Sufficient**: {assessment.sufficient}
-**Confidence**: {assessment.confidence:.0%}
-**Recommendation**: {assessment.recommendation}
-
-### Scores
-- Mechanism: {assessment.details.mechanism_score}/10
-- Clinical: {assessment.details.clinical_evidence_score}/10
-
-### Reasoning
-{assessment.reasoning}
-"""
-
- if assessment.next_search_queries:
- response_text += "\n### Next Queries\n" + "\n".join(
- f"- {q}" for q in assessment.next_search_queries
- )
-
- return AgentRunResponse(
- messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
- response_id=f"judge-{assessment.recommendation}",
- additional_properties={"assessment": assessment.model_dump()},
- )
-
- async def run_stream(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AsyncIterable[AgentRunResponseUpdate]:
- """Streaming wrapper for judge."""
- result = await self.run(messages, thread=thread, **kwargs)
- yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id)
diff --git a/src/agents/judge_agent_llm.py b/src/agents/judge_agent_llm.py
deleted file mode 100644
index 12453e09489d9379fdf2e61c213591dcc27a5a27..0000000000000000000000000000000000000000
--- a/src/agents/judge_agent_llm.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""LLM Judge for sub-iterations."""
-
-from typing import Any
-
-import structlog
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.utils.models import JudgeAssessment
-
-logger = structlog.get_logger()
-
-
-class LLMSubIterationJudge:
- """Judge that uses an LLM to assess sub-iteration results."""
-
- def __init__(self) -> None:
- self.model = get_model()
- self.agent = Agent(
- model=self.model,
- output_type=JudgeAssessment,
- system_prompt="""You are a strict judge evaluating a research task.
-
-Evaluate if the result is sufficient to answer the task.
-Provide scores and detailed reasoning.
-If not sufficient, suggest next steps.""",
- retries=3,
- )
-
- async def assess(self, task: str, result: Any, history: list[Any]) -> JudgeAssessment:
- """Assess the result using LLM."""
- logger.info("LLM judge assessing result", task=task[:100], history_len=len(history))
-
- prompt = f"""Task: {task}
-
-Current Result:
-{str(result)[:4000]}
-
-History of previous attempts: {len(history)}
-
-Evaluate validity and sufficiency."""
-
- run_result = await self.agent.run(prompt)
- logger.info("LLM judge assessment complete", sufficient=run_result.output.sufficient)
- return run_result.output
diff --git a/src/agents/knowledge_gap.py b/src/agents/knowledge_gap.py
deleted file mode 100644
index 4ceab6cc2740eb22df06212ca3dc29bdb94dbc62..0000000000000000000000000000000000000000
--- a/src/agents/knowledge_gap.py
+++ /dev/null
@@ -1,169 +0,0 @@
-"""Knowledge gap agent for evaluating research completeness.
-
-Converts the folder/knowledge_gap_agent.py implementation to use Pydantic AI.
-"""
-
-from datetime import datetime
-from typing import Any
-
-import structlog
-from pydantic_ai import Agent
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-
-from src.agent_factory.judges import get_model
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import KnowledgeGapOutput
-
-logger = structlog.get_logger()
-
-
-# System prompt for the knowledge gap agent
-SYSTEM_PROMPT = f"""
-You are a Research State Evaluator. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
-Your job is to critically analyze the current state of a research report,
-identify what knowledge gaps still exist and determine the best next step to take.
-
-You will be given:
-1. The original user query and any relevant background context to the query
-2. A full history of the tasks, actions, findings and thoughts you've made up until this point in the research process
-
-Your task is to:
-1. Carefully review the findings and thoughts, particularly from the latest iteration, and assess their completeness in answering the original query
-2. Determine if the findings are sufficiently complete to end the research loop
-3. If not, identify up to 3 knowledge gaps that need to be addressed in sequence in order to continue with research - these should be relevant to the original query
-
-Be specific in the gaps you identify and include relevant information as this will be passed onto another agent to process without additional context.
-
-Only output JSON. Follow the JSON schema for KnowledgeGapOutput. Do not output anything else.
-"""
-
-
-class KnowledgeGapAgent:
- """
- Agent that evaluates research state and identifies knowledge gaps.
-
- Uses Pydantic AI to generate structured KnowledgeGapOutput indicating
- whether research is complete and what gaps remain.
- """
-
- def __init__(self, model: Any | None = None) -> None:
- """
- Initialize the knowledge gap agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses config default.
- """
- self.model = model or get_model()
- self.logger = logger
-
- # Initialize Pydantic AI Agent
- self.agent = Agent(
- model=self.model,
- output_type=KnowledgeGapOutput,
- system_prompt=SYSTEM_PROMPT,
- retries=3,
- )
-
- async def evaluate(
- self,
- query: str,
- background_context: str = "",
- conversation_history: str = "",
- message_history: list[ModelMessage] | None = None,
- iteration: int = 0,
- time_elapsed_minutes: float = 0.0,
- max_time_minutes: int = 10,
- ) -> KnowledgeGapOutput:
- """
- Evaluate research state and identify knowledge gaps.
-
- Args:
- query: The original research query
- background_context: Optional background context
- conversation_history: History of actions, findings, and thoughts (backward compat)
- message_history: Optional user conversation history (Pydantic AI format)
- iteration: Current iteration number
- time_elapsed_minutes: Time elapsed so far
- max_time_minutes: Maximum time allowed
-
- Returns:
- KnowledgeGapOutput with research completeness and outstanding gaps
-
- Raises:
- JudgeError: If evaluation fails after retries
- """
- self.logger.info(
- "Evaluating knowledge gaps",
- query=query[:100],
- iteration=iteration,
- )
-
- background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
-
- user_message = f"""
-Current Iteration Number: {iteration}
-Time Elapsed: {time_elapsed_minutes:.2f} minutes of maximum {max_time_minutes} minutes
-
-ORIGINAL QUERY:
-{query}
-
-{background}
-
-HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
-{conversation_history or "No previous actions, findings or thoughts available."}
-"""
-
- try:
- # Run the agent with message_history if provided
- if message_history:
- result = await self.agent.run(user_message, message_history=message_history)
- else:
- result = await self.agent.run(user_message)
- evaluation = result.output
-
- self.logger.info(
- "Knowledge gap evaluation complete",
- research_complete=evaluation.research_complete,
- gaps_count=len(evaluation.outstanding_gaps),
- )
-
- return evaluation
-
- except Exception as e:
- self.logger.error("Knowledge gap evaluation failed", error=str(e))
- # Return fallback: research not complete, suggest continuing
- return KnowledgeGapOutput(
- research_complete=False,
- outstanding_gaps=[f"Continue research on: {query}"],
- )
-
-
-def create_knowledge_gap_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> KnowledgeGapAgent:
- """
- Factory function to create a knowledge gap agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured KnowledgeGapAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- try:
- if model is None:
- model = get_model(oauth_token=oauth_token)
-
- return KnowledgeGapAgent(model=model)
-
- except Exception as e:
- logger.error("Failed to create knowledge gap agent", error=str(e))
- raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e
diff --git a/src/agents/long_writer.py b/src/agents/long_writer.py
deleted file mode 100644
index 5b0553aae60d4daf89f7c6879d024a44a1ea1a42..0000000000000000000000000000000000000000
--- a/src/agents/long_writer.py
+++ /dev/null
@@ -1,466 +0,0 @@
-"""Long writer agent for iteratively writing report sections.
-
-Converts the folder/long_writer_agent.py implementation to use Pydantic AI.
-"""
-
-import re
-from datetime import datetime
-from typing import Any
-
-import structlog
-from pydantic import BaseModel, Field
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import ReportDraft
-
-logger = structlog.get_logger()
-
-
-# LongWriterOutput model for structured output
-class LongWriterOutput(BaseModel):
- """Output from the long writer agent for a single section."""
-
- next_section_markdown: str = Field(
- description="The final draft of the next section in markdown format"
- )
- references: list[str] = Field(
- description="A list of URLs and their corresponding reference numbers for the section"
- )
-
- model_config = {"frozen": True}
-
-
-# System prompt for the long writer agent
-SYSTEM_PROMPT = f"""
-You are an expert report writer tasked with iteratively writing each section of a report.
-Today's date is {datetime.now().strftime("%Y-%m-%d")}.
-You will be provided with:
-1. The original research query
-2. A final draft of the report containing the table of contents and all sections written up until this point (in the first iteration there will be no sections written yet)
-3. A first draft of the next section of the report to be written
-
-OBJECTIVE:
-1. Write a final draft of the next section of the report with numbered citations in square brackets in the body of the report
-2. Produce a list of references to be appended to the end of the report
-
-CITATIONS/REFERENCES:
-The citations should be in numerical order, written in numbered square brackets in the body of the report.
-Separately, a list of all URLs and their corresponding reference numbers will be included at the end of the report.
-Follow the example below for formatting.
-
-LongWriterOutput(
- next_section_markdown="The company specializes in IT consulting [1]. It operates in the software services market which is expected to grow at 10% per year [2].",
- references=["[1] https://example.com/first-source-url", "[2] https://example.com/second-source-url"]
-)
-
-GUIDELINES:
-- You can reformat and reorganize the flow of the content and headings within a section to flow logically, but DO NOT remove details that were included in the first draft
-- Only remove text from the first draft if it is already mentioned earlier in the report, or if it should be covered in a later section per the table of contents
-- Ensure the heading for the section matches the table of contents
-- Format the final output and references section as markdown
-- Do not include a title for the reference section, just a list of numbered references
-
-Only output JSON. Follow the JSON schema for LongWriterOutput. Do not output anything else.
-"""
-
-
-class LongWriterAgent:
- """
- Agent that iteratively writes report sections with proper citations.
-
- Uses Pydantic AI to generate structured LongWriterOutput for each section.
- """
-
- def __init__(self, model: Any | None = None) -> None:
- """
- Initialize the long writer agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses config default.
- """
- self.model = model or get_model()
- self.logger = logger
-
- # Initialize Pydantic AI Agent
- self.agent = Agent(
- model=self.model,
- output_type=LongWriterOutput,
- system_prompt=SYSTEM_PROMPT,
- retries=3,
- )
-
- async def write_next_section(
- self,
- original_query: str,
- report_draft: str,
- next_section_title: str,
- next_section_draft: str,
- ) -> LongWriterOutput:
- """
- Write the next section of the report.
-
- Args:
- original_query: The original research query
- report_draft: Current report draft (all sections written so far)
- next_section_title: Title of the section to write
- next_section_draft: Draft content for the next section
-
- Returns:
- LongWriterOutput with formatted section and references
-
- Raises:
- ConfigurationError: If writing fails
- """
- # Input validation
- if not original_query or not original_query.strip():
- self.logger.warning("Empty query provided, using default")
- original_query = "Research query"
-
- if not next_section_title or not next_section_title.strip():
- self.logger.warning("Empty section title provided, using default")
- next_section_title = "Section"
-
- if next_section_draft is None:
- next_section_draft = ""
-
- if report_draft is None:
- report_draft = ""
-
- # Truncate very long inputs
- max_draft_length = 30000
- if len(report_draft) > max_draft_length:
- self.logger.warning(
- "Report draft too long, truncating",
- original_length=len(report_draft),
- )
- report_draft = report_draft[:max_draft_length] + "\n\n[Content truncated]"
-
- if len(next_section_draft) > max_draft_length:
- self.logger.warning(
- "Section draft too long, truncating",
- original_length=len(next_section_draft),
- )
- next_section_draft = next_section_draft[:max_draft_length] + "\n\n[Content truncated]"
-
- self.logger.info(
- "Writing next section",
- section_title=next_section_title,
- query=original_query[:100],
- )
-
- user_message = f"""
-
-{original_query}
-
-
-
-{report_draft or "No draft yet"}
-
-
-
-{next_section_title}
-
-
-
-{next_section_draft}
-
-"""
-
- # Retry logic for transient failures
- max_retries = 3
- last_exception: Exception | None = None
-
- for attempt in range(max_retries):
- try:
- # Run the agent
- result = await self.agent.run(user_message)
- output = result.output
-
- # Validate output
- if not output or not isinstance(output, LongWriterOutput):
- raise ValueError("Invalid output format")
-
- if not output.next_section_markdown or not output.next_section_markdown.strip():
- self.logger.warning("Empty section generated, using fallback")
- raise ValueError("Empty section generated")
-
- self.logger.info(
- "Section written",
- section_title=next_section_title,
- references_count=len(output.references),
- attempt=attempt + 1,
- )
-
- return output
-
- except (TimeoutError, ConnectionError) as e:
- # Transient errors - retry
- last_exception = e
- if attempt < max_retries - 1:
- self.logger.warning(
- "Transient error, retrying",
- error=str(e),
- attempt=attempt + 1,
- max_retries=max_retries,
- )
- continue
- else:
- self.logger.error("Max retries exceeded for transient error", error=str(e))
- break
-
- except Exception as e:
- # Non-transient errors - don't retry
- last_exception = e
- self.logger.error(
- "Section writing failed",
- error=str(e),
- error_type=type(e).__name__,
- )
- break
-
- # Return fallback section if all attempts failed
- self.logger.error(
- "Section writing failed after all attempts",
- error=str(last_exception) if last_exception else "Unknown error",
- )
-
- # Try to enhance fallback with evidence if available
- try:
- from src.middleware.state_machine import get_workflow_state
-
- state = get_workflow_state()
- if state and state.evidence:
- # Include evidence citations in fallback
- evidence_refs: list[str] = []
- for i, ev in enumerate(state.evidence[:10], 1): # Limit to 10
- authors = (
- ", ".join(ev.citation.authors[:2]) if ev.citation.authors else "Unknown"
- )
- evidence_refs.append(
- f"[{i}] {authors}. *{ev.citation.title}*. {ev.citation.url}"
- )
-
- enhanced_draft = f"## {next_section_title}\n\n{next_section_draft}"
- if evidence_refs:
- enhanced_draft += "\n\n### Sources\n\n" + "\n".join(evidence_refs)
-
- return LongWriterOutput(
- next_section_markdown=enhanced_draft,
- references=evidence_refs,
- )
- except Exception as e:
- self.logger.warning(
- "Failed to enhance fallback with evidence",
- error=str(e),
- )
-
- # Basic fallback
- return LongWriterOutput(
- next_section_markdown=f"## {next_section_title}\n\n{next_section_draft}",
- references=[],
- )
-
- async def write_report(
- self,
- original_query: str,
- report_title: str,
- report_draft: ReportDraft,
- ) -> str:
- """
- Write the final report by iteratively writing each section.
-
- Args:
- original_query: The original research query
- report_title: Title of the report
- report_draft: ReportDraft with all sections
-
- Returns:
- Complete markdown report string
-
- Raises:
- ConfigurationError: If writing fails
- """
- # Input validation
- if not original_query or not original_query.strip():
- self.logger.warning("Empty query provided, using default")
- original_query = "Research query"
-
- if not report_title or not report_title.strip():
- self.logger.warning("Empty report title provided, using default")
- report_title = "Research Report"
-
- if not report_draft or not report_draft.sections:
- self.logger.warning("Empty report draft provided, returning minimal report")
- return f"# {report_title}\n\n## Query\n{original_query}\n\n*No sections available.*"
-
- self.logger.info(
- "Writing full report",
- report_title=report_title,
- sections_count=len(report_draft.sections),
- )
-
- # Initialize the final draft with title and table of contents
- final_draft = (
- f"# {report_title}\n\n## Table of Contents\n\n"
- + "\n".join(
- [
- f"{i + 1}. {section.section_title}"
- for i, section in enumerate(report_draft.sections)
- ]
- )
- + "\n\n"
- )
- all_references: list[str] = []
-
- for section in report_draft.sections:
- # Write each section
- next_section_output = await self.write_next_section(
- original_query,
- final_draft,
- section.section_title,
- section.section_content,
- )
-
- # Reformat references and update section markdown
- section_markdown, all_references = self._reformat_references(
- next_section_output.next_section_markdown,
- next_section_output.references,
- all_references,
- )
-
- # Reformat section headings
- section_markdown = self._reformat_section_headings(section_markdown)
-
- # Add to final draft
- final_draft += section_markdown + "\n\n"
-
- # Add final references
- final_draft += "## References:\n\n" + " \n".join(all_references)
-
- self.logger.info("Full report written", length=len(final_draft))
-
- return final_draft
-
- def _reformat_references(
- self,
- section_markdown: str,
- section_references: list[str],
- all_references: list[str],
- ) -> tuple[str, list[str]]:
- """
- Reformat references: re-number, de-duplicate, and update markdown.
-
- Args:
- section_markdown: Markdown content with inline references [1], [2]
- section_references: List of references for this section
- all_references: Accumulated references from previous sections
-
- Returns:
- Tuple of (updated markdown, updated all_references)
- """
-
- # Convert reference lists to maps (URL -> ref_num)
- def convert_ref_list_to_map(ref_list: list[str]) -> dict[str, int]:
- ref_map: dict[str, int] = {}
- for ref in ref_list:
- try:
- # Parse "[1] https://example.com" format
- parts = ref.split("]", 1)
- if len(parts) == 2:
- ref_num = int(parts[0].strip("["))
- url = parts[1].strip()
- ref_map[url] = ref_num
- except (ValueError, IndexError):
- logger.warning("Invalid reference format", ref=ref)
- continue
- return ref_map
-
- section_ref_map = convert_ref_list_to_map(section_references)
- report_ref_map = convert_ref_list_to_map(all_references)
- section_to_report_ref_map: dict[int, int] = {}
-
- report_urls = set(report_ref_map.keys())
- ref_count = max(report_ref_map.values() or [0])
-
- # Map section references to report references
- for url, section_ref_num in section_ref_map.items():
- if url in report_urls:
- # URL already exists - reuse its reference number
- section_to_report_ref_map[section_ref_num] = report_ref_map[url]
- else:
- # New URL - assign next reference number
- ref_count += 1
- section_to_report_ref_map[section_ref_num] = ref_count
- all_references.append(f"[{ref_count}] {url}")
-
- # Replace reference numbers in markdown
- def replace_reference(match: re.Match[str]) -> str:
- ref_num = int(match.group(1))
- mapped_ref_num = section_to_report_ref_map.get(ref_num)
- if mapped_ref_num:
- return f"[{mapped_ref_num}]"
- return ""
-
- updated_markdown = re.sub(r"\[(\d+)\]", replace_reference, section_markdown)
-
- return updated_markdown, all_references
-
- def _reformat_section_headings(self, section_markdown: str) -> str:
- """
- Reformat section headings to be consistent (level-2 for main heading).
-
- Args:
- section_markdown: Markdown content with headings
-
- Returns:
- Updated markdown with adjusted heading levels
- """
- if not section_markdown.strip():
- return section_markdown
-
- # Find first heading level
- first_heading_match = re.search(r"^(#+)\s", section_markdown, re.MULTILINE)
- if not first_heading_match:
- return section_markdown
-
- # Calculate level adjustment needed (target is level 2)
- first_heading_level = len(first_heading_match.group(1))
- level_adjustment = 2 - first_heading_level
-
- def adjust_heading_level(match: re.Match[str]) -> str:
- hashes = match.group(1)
- content = match.group(2)
- new_level = max(2, len(hashes) + level_adjustment)
- return "#" * new_level + " " + content
-
- # Apply heading adjustment
- return re.sub(r"^(#+)\s(.+)$", adjust_heading_level, section_markdown, flags=re.MULTILINE)
-
-
-def create_long_writer_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> LongWriterAgent:
- """
- Factory function to create a long writer agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured LongWriterAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- try:
- if model is None:
- model = get_model(oauth_token=oauth_token)
-
- return LongWriterAgent(model=model)
-
- except Exception as e:
- logger.error("Failed to create long writer agent", error=str(e))
- raise ConfigurationError(f"Failed to create long writer agent: {e}") from e
diff --git a/src/agents/magentic_agents.py b/src/agents/magentic_agents.py
deleted file mode 100644
index 926e84a9b3c4bf00cf36bfa99d643b637ece3d1f..0000000000000000000000000000000000000000
--- a/src/agents/magentic_agents.py
+++ /dev/null
@@ -1,177 +0,0 @@
-"""Magentic-compatible agents using ChatAgent pattern."""
-
-from typing import Any
-
-from agent_framework import ChatAgent
-
-from src.agents.tools import (
- get_bibliography,
- search_clinical_trials,
- search_preprints,
- search_pubmed,
-)
-from src.utils.llm_factory import get_chat_client_for_agent
-
-
-def create_search_agent(chat_client: Any | None = None) -> ChatAgent:
- """Create a search agent with internal LLM and search tools.
-
- Args:
- chat_client: Optional custom chat client. If None, uses factory default
- (HuggingFace preferred, OpenAI fallback).
-
- Returns:
- ChatAgent configured for biomedical search
- """
- client = chat_client or get_chat_client_for_agent()
-
- return ChatAgent(
- name="SearchAgent",
- description=(
- "Searches biomedical databases (PubMed, ClinicalTrials.gov, Europe PMC) "
- "for research evidence"
- ),
- instructions="""You are a biomedical search specialist. When asked to find evidence:
-
-1. Analyze the request to determine what to search for
-2. Extract key search terms (drug names, disease names, mechanisms)
-3. Use the appropriate search tools:
- - search_pubmed for peer-reviewed papers
- - search_clinical_trials for clinical studies
- - search_preprints for cutting-edge findings
-4. Summarize what you found and highlight key evidence
-
-Be thorough - search multiple databases when appropriate.
-Focus on finding: mechanisms of action, clinical evidence, and specific drug candidates.""",
- chat_client=client,
- tools=[search_pubmed, search_clinical_trials, search_preprints],
- temperature=0.3, # More deterministic for tool use
- )
-
-
-def create_judge_agent(chat_client: Any | None = None) -> ChatAgent:
- """Create a judge agent that evaluates evidence quality.
-
- Args:
- chat_client: Optional custom chat client. If None, uses factory default
- (HuggingFace preferred, OpenAI fallback).
-
- Returns:
- ChatAgent configured for evidence assessment
- """
- client = chat_client or get_chat_client_for_agent()
-
- return ChatAgent(
- name="JudgeAgent",
- description="Evaluates evidence quality and determines if sufficient for synthesis",
- instructions="""You are an evidence quality assessor. When asked to evaluate:
-
-1. Review all evidence presented in the conversation
-2. Score on two dimensions (0-10 each):
- - Mechanism Score: How well is the biological mechanism explained?
- - Clinical Score: How strong is the clinical/preclinical evidence?
-3. Determine if evidence is SUFFICIENT for a final report:
- - Sufficient: Clear mechanism + supporting clinical data
- - Insufficient: Gaps in mechanism OR weak clinical evidence
-4. If insufficient, suggest specific search queries to fill gaps
-
-Be rigorous but fair. Look for:
-- Molecular targets and pathways
-- Animal model studies
-- Human clinical trials
-- Safety data
-- Drug-drug interactions""",
- chat_client=client,
- temperature=0.2, # Consistent judgments
- )
-
-
-def create_hypothesis_agent(chat_client: Any | None = None) -> ChatAgent:
- """Create a hypothesis generation agent.
-
- Args:
- chat_client: Optional custom chat client. If None, uses factory default
- (HuggingFace preferred, OpenAI fallback).
-
- Returns:
- ChatAgent configured for hypothesis generation
- """
- client = chat_client or get_chat_client_for_agent()
-
- return ChatAgent(
- name="HypothesisAgent",
- description="Generates mechanistic hypotheses for research investigation",
- instructions="""You are a biomedical hypothesis generator. Based on evidence:
-
-1. Identify the key molecular targets involved
-2. Map the biological pathways affected
-3. Generate testable hypotheses in this format:
-
- DRUG -> TARGET -> PATHWAY -> THERAPEUTIC EFFECT
-
- Example:
- Metformin -> AMPK activation -> mTOR inhibition -> Reduced tau phosphorylation
-
-4. Explain the rationale for each hypothesis
-5. Suggest what additional evidence would support or refute it
-
-Focus on mechanistic plausibility and existing evidence.""",
- chat_client=client,
- temperature=0.5, # Some creativity for hypothesis generation
- )
-
-
-def create_report_agent(chat_client: Any | None = None) -> ChatAgent:
- """Create a report synthesis agent.
-
- Args:
- chat_client: Optional custom chat client. If None, uses factory default
- (HuggingFace preferred, OpenAI fallback).
-
- Returns:
- ChatAgent configured for report generation
- """
- client = chat_client or get_chat_client_for_agent()
-
- return ChatAgent(
- name="ReportAgent",
- description="Synthesizes research findings into structured reports",
- instructions="""You are a scientific report writer. When asked to synthesize:
-
-Generate a structured report with these sections:
-
-## Executive Summary
-Brief overview of findings and recommendation
-
-## Methodology
-Databases searched, queries used, evidence reviewed
-
-## Key Findings
-### Mechanism of Action
-- Molecular targets
-- Biological pathways
-- Proposed mechanism
-
-### Clinical Evidence
-- Preclinical studies
-- Clinical trials
-- Safety profile
-
-## Drug Candidates
-List specific drugs with repurposing potential
-
-## Limitations
-Gaps in evidence, conflicting data, caveats
-
-## Conclusion
-Final recommendation with confidence level
-
-## References
-Use the 'get_bibliography' tool to fetch the complete list of citations.
-Format them as a numbered list.
-
-Be comprehensive but concise. Cite evidence for all claims.""",
- chat_client=client,
- tools=[get_bibliography],
- temperature=0.3,
- )
diff --git a/src/agents/proofreader.py b/src/agents/proofreader.py
deleted file mode 100644
index 7a209c365ea04313dceee645e23ff0dd8b846817..0000000000000000000000000000000000000000
--- a/src/agents/proofreader.py
+++ /dev/null
@@ -1,208 +0,0 @@
-"""Proofreader agent for finalizing report drafts.
-
-Converts the folder/proofreader_agent.py implementation to use Pydantic AI.
-"""
-
-from datetime import datetime
-from typing import Any
-
-import structlog
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import ReportDraft
-
-logger = structlog.get_logger()
-
-
-# System prompt for the proofreader agent
-SYSTEM_PROMPT = f"""
-You are a research expert who proofreads and edits research reports.
-Today's date is {datetime.now().strftime("%Y-%m-%d")}.
-
-You are given:
-1. The original query topic for the report
-2. A first draft of the report in ReportDraft format containing each section in sequence
-
-Your task is to:
-1. **Combine sections:** Concatenate the sections into a single string
-2. **Add section titles:** Add the section titles to the beginning of each section in markdown format, as well as a main title for the report
-3. **De-duplicate:** Remove duplicate content across sections to avoid repetition
-4. **Remove irrelevant sections:** If any sections or sub-sections are completely irrelevant to the query, remove them
-5. **Refine wording:** Edit the wording of the report to be polished, concise and punchy, but **without eliminating any detail** or large chunks of text
-6. **Add a summary:** Add a short report summary / outline to the beginning of the report to provide an overview of the sections and what is discussed
-7. **Preserve sources:** Preserve all sources / references - move the long list of references to the end of the report
-8. **Update reference numbers:** Continue to include reference numbers in square brackets ([1], [2], [3], etc.) in the main body of the report, but update the numbering to match the new order of references at the end of the report
-9. **Output final report:** Output the final report in markdown format (do not wrap it in a code block)
-
-Guidelines:
-- Do not add any new facts or data to the report
-- Do not remove any content from the report unless it is very clearly wrong, contradictory or irrelevant
-- Remove or reformat any redundant or excessive headings, and ensure that the final nesting of heading levels is correct
-- Ensure that the final report flows well and has a logical structure
-- Include all sources and references that are present in the final report
-"""
-
-
-class ProofreaderAgent:
- """
- Agent that proofreads and finalizes report drafts.
-
- Uses Pydantic AI to generate polished markdown reports from draft sections.
- """
-
- def __init__(self, model: Any | None = None) -> None:
- """
- Initialize the proofreader agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses config default.
- """
- self.model = model or get_model()
- self.logger = logger
-
- # Initialize Pydantic AI Agent (no structured output - returns markdown text)
- self.agent = Agent(
- model=self.model,
- system_prompt=SYSTEM_PROMPT,
- retries=3,
- )
-
- async def proofread(
- self,
- query: str,
- report_draft: ReportDraft,
- ) -> str:
- """
- Proofread and finalize a report draft.
-
- Args:
- query: The original research query
- report_draft: ReportDraft with all sections
-
- Returns:
- Final polished markdown report string
-
- Raises:
- ConfigurationError: If proofreading fails
- """
- # Input validation
- if not query or not query.strip():
- self.logger.warning("Empty query provided, using default")
- query = "Research query"
-
- if not report_draft or not report_draft.sections:
- self.logger.warning("Empty report draft provided, returning minimal report")
- return f"# Research Report\n\n## Query\n{query}\n\n*No sections available.*"
-
- # Validate section structure
- valid_sections = []
- for section in report_draft.sections:
- if section.section_title and section.section_title.strip():
- valid_sections.append(section)
- else:
- self.logger.warning("Skipping section with empty title")
-
- if not valid_sections:
- self.logger.warning("No valid sections in draft, returning minimal report")
- return f"# Research Report\n\n## Query\n{query}\n\n*No valid sections available.*"
-
- self.logger.info(
- "Proofreading report",
- query=query[:100],
- sections_count=len(valid_sections),
- )
-
- # Create validated draft
- validated_draft = ReportDraft(sections=valid_sections)
-
- user_message = f"""
-QUERY:
-{query}
-
-REPORT DRAFT:
-{validated_draft.model_dump_json()}
-"""
-
- # Retry logic for transient failures
- max_retries = 3
- last_exception: Exception | None = None
-
- for attempt in range(max_retries):
- try:
- # Run the agent
- result = await self.agent.run(user_message)
- final_report = result.output
-
- # Validate output
- if not final_report or not final_report.strip():
- self.logger.warning("Empty report generated, using fallback")
- raise ValueError("Empty report generated")
-
- self.logger.info("Report proofread", length=len(final_report), attempt=attempt + 1)
-
- return final_report
-
- except (TimeoutError, ConnectionError) as e:
- # Transient errors - retry
- last_exception = e
- if attempt < max_retries - 1:
- self.logger.warning(
- "Transient error, retrying",
- error=str(e),
- attempt=attempt + 1,
- max_retries=max_retries,
- )
- continue
- else:
- self.logger.error("Max retries exceeded for transient error", error=str(e))
- break
-
- except Exception as e:
- # Non-transient errors - don't retry
- last_exception = e
- self.logger.error(
- "Proofreading failed",
- error=str(e),
- error_type=type(e).__name__,
- )
- break
-
- # Return fallback: combine sections manually
- self.logger.error(
- "Proofreading failed after all attempts",
- error=str(last_exception) if last_exception else "Unknown error",
- )
- sections = [
- f"## {section.section_title}\n\n{section.section_content or 'Content unavailable.'}"
- for section in valid_sections
- ]
- return f"# Research Report\n\n## Query\n{query}\n\n" + "\n\n".join(sections)
-
-
-def create_proofreader_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> ProofreaderAgent:
- """
- Factory function to create a proofreader agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured ProofreaderAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- try:
- if model is None:
- model = get_model(oauth_token=oauth_token)
-
- return ProofreaderAgent(model=model)
-
- except Exception as e:
- logger.error("Failed to create proofreader agent", error=str(e))
- raise ConfigurationError(f"Failed to create proofreader agent: {e}") from e
diff --git a/src/agents/report_agent.py b/src/agents/report_agent.py
deleted file mode 100644
index 5454fed54f67356dca0674985c6c90a7f34517c6..0000000000000000000000000000000000000000
--- a/src/agents/report_agent.py
+++ /dev/null
@@ -1,140 +0,0 @@
-"""Report agent for generating structured research reports."""
-
-from collections.abc import AsyncIterable
-from typing import TYPE_CHECKING, Any
-
-from agent_framework import (
- AgentRunResponse,
- AgentRunResponseUpdate,
- AgentThread,
- BaseAgent,
- ChatMessage,
- Role,
-)
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.prompts.report import SYSTEM_PROMPT, format_report_prompt
-from src.utils.citation_validator import validate_references
-from src.utils.models import Evidence, ResearchReport
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
-
-
-class ReportAgent(BaseAgent): # type: ignore[misc]
- """Generates structured scientific reports from evidence and hypotheses."""
-
- def __init__(
- self,
- evidence_store: dict[str, Any],
- embedding_service: "EmbeddingService | None" = None, # For diverse selection
- ) -> None:
- super().__init__(
- name="ReportAgent",
- description="Generates structured scientific research reports with citations",
- )
- self._evidence_store = evidence_store
- self._embeddings = embedding_service
- self._agent: Agent[None, ResearchReport] | None = None # Lazy init
-
- def _get_agent(self) -> Agent[None, ResearchReport]:
- """Lazy initialization of LLM agent to avoid requiring API keys at import."""
- if self._agent is None:
- self._agent = Agent(
- model=get_model(),
- output_type=ResearchReport,
- system_prompt=SYSTEM_PROMPT,
- )
- return self._agent
-
- async def run(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AgentRunResponse:
- """Generate research report."""
- query = self._extract_query(messages)
-
- # Gather all context
- evidence: list[Evidence] = self._evidence_store.get("current", [])
- hypotheses = self._evidence_store.get("hypotheses", [])
- assessment = self._evidence_store.get("last_assessment", {})
-
- if not evidence:
- return AgentRunResponse(
- messages=[
- ChatMessage(
- role=Role.ASSISTANT,
- text="Cannot generate report: No evidence collected.",
- )
- ],
- response_id="report-no-evidence",
- )
-
- # Build metadata
- metadata = {
- "sources": list(set(e.citation.source for e in evidence)),
- "iterations": self._evidence_store.get("iteration_count", 0),
- }
-
- # Generate report (format_report_prompt is now async)
- prompt = await format_report_prompt(
- query=query,
- evidence=evidence,
- hypotheses=hypotheses,
- assessment=assessment,
- metadata=metadata,
- embeddings=self._embeddings,
- )
-
- result = await self._get_agent().run(prompt)
- report = result.output
-
- # ═══════════════════════════════════════════════════════════════════
- # 🚨 CRITICAL: Validate citations to prevent hallucination
- # ═══════════════════════════════════════════════════════════════════
- report = validate_references(report, evidence)
-
- # Store validated report
- self._evidence_store["final_report"] = report
-
- # Return markdown version
- return AgentRunResponse(
- messages=[ChatMessage(role=Role.ASSISTANT, text=report.to_markdown())],
- response_id="report-complete",
- additional_properties={"report": report.model_dump()},
- )
-
- def _extract_query(
- self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None
- ) -> str:
- """Extract query from messages."""
- if isinstance(messages, str):
- return messages
- elif isinstance(messages, ChatMessage):
- return messages.text or ""
- elif isinstance(messages, list):
- for msg in reversed(messages):
- if isinstance(msg, ChatMessage) and msg.role == Role.USER:
- return msg.text or ""
- elif isinstance(msg, str):
- return msg
- return ""
-
- async def run_stream(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AsyncIterable[AgentRunResponseUpdate]:
- """Streaming wrapper."""
- result = await self.run(messages, thread=thread, **kwargs)
- yield AgentRunResponseUpdate(
- messages=result.messages,
- response_id=result.response_id,
- additional_properties=result.additional_properties,
- )
diff --git a/src/agents/retrieval_agent.py b/src/agents/retrieval_agent.py
deleted file mode 100644
index 63a3a6bff9c86342e8f932a97cad51e8b9c5f58b..0000000000000000000000000000000000000000
--- a/src/agents/retrieval_agent.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""Retrieval agent for web search and context management."""
-
-from typing import Any
-
-import structlog
-from agent_framework import ChatAgent, ai_function
-
-from src.agents.state import get_magentic_state
-from src.tools.web_search import WebSearchTool
-from src.utils.llm_factory import get_chat_client_for_agent
-
-logger = structlog.get_logger()
-
-_web_search = WebSearchTool()
-
-
-@ai_function # type: ignore[arg-type, misc]
-async def search_web(query: str, max_results: int = 10) -> str:
- """Search the web using DuckDuckGo.
-
- Args:
- query: Search keywords.
- max_results: Maximum results to return (default 10).
-
- Returns:
- Formatted search results.
- """
- logger.info("Web search starting", query=query, max_results=max_results)
- state = get_magentic_state()
-
- evidence = await _web_search.search(query, max_results)
- if not evidence:
- logger.info("Web search returned no results", query=query)
- return f"No web results found for: {query}"
-
- # Update state
- # We add *all* found results to state
- new_count = state.add_evidence(evidence)
- logger.info(
- "Web search complete",
- query=query,
- results_found=len(evidence),
- new_evidence=new_count,
- )
-
- # Use embedding service for deduplication/indexing if available
- if state.embedding_service:
- # This method also adds to vector DB as a side effect for unique items
- await state.embedding_service.deduplicate(evidence)
-
- output = [f"Found {len(evidence)} web results ({new_count} new stored):\n"]
- for i, r in enumerate(evidence[:max_results], 1):
- output.append(f"{i}. **{r.citation.title}**")
- output.append(f" Source: {r.citation.url}")
- output.append(f" {r.content[:300]}...\n")
-
- return "\n".join(output)
-
-
-def create_retrieval_agent(chat_client: Any | None = None) -> ChatAgent:
- """Create a retrieval agent.
-
- Args:
- chat_client: Optional custom chat client. If None, uses factory default
- (HuggingFace preferred, OpenAI fallback).
-
- Returns:
- ChatAgent configured for retrieval.
- """
- client = chat_client or get_chat_client_for_agent()
-
- return ChatAgent(
- name="RetrievalAgent",
- description="Searches the web and manages context/evidence.",
- instructions="""You are a retrieval specialist.
-Use `search_web` to find information on the internet.
-Your goal is to gather relevant evidence for the research task.
-Always summarize what you found.""",
- chat_client=client,
- tools=[search_web],
- )
diff --git a/src/agents/search_agent.py b/src/agents/search_agent.py
deleted file mode 100644
index 9c28242458109bebc7824d122b54f9b4b1d1ea0a..0000000000000000000000000000000000000000
--- a/src/agents/search_agent.py
+++ /dev/null
@@ -1,154 +0,0 @@
-from collections.abc import AsyncIterable
-from typing import TYPE_CHECKING, Any
-
-from agent_framework import (
- AgentRunResponse,
- AgentRunResponseUpdate,
- AgentThread,
- BaseAgent,
- ChatMessage,
- Role,
-)
-
-from src.legacy_orchestrator import SearchHandlerProtocol
-from src.utils.models import Citation, Evidence, SearchResult
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
-
-
-class SearchAgent(BaseAgent): # type: ignore[misc]
- """Wraps SearchHandler as an AgentProtocol for Magentic."""
-
- def __init__(
- self,
- search_handler: SearchHandlerProtocol,
- evidence_store: dict[str, list[Evidence]],
- embedding_service: "EmbeddingService | None" = None,
- ) -> None:
- super().__init__(
- name="SearchAgent",
- description="Searches PubMed for biomedical research evidence",
- )
- self._handler = search_handler
- self._evidence_store = evidence_store
- self._embeddings = embedding_service
-
- async def run(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AgentRunResponse:
- """Execute search based on the last user message."""
- # Extract query from messages
- query = ""
- if isinstance(messages, list):
- for msg in reversed(messages):
- if isinstance(msg, ChatMessage) and msg.role == Role.USER and msg.text:
- query = msg.text
- break
- elif isinstance(msg, str):
- query = msg
- break
- elif isinstance(messages, str):
- query = messages
- elif isinstance(messages, ChatMessage) and messages.text:
- query = messages.text
-
- if not query:
- return AgentRunResponse(
- messages=[ChatMessage(role=Role.ASSISTANT, text="No query provided")],
- response_id="search-no-query",
- )
-
- # Execute search
- result: SearchResult = await self._handler.execute(query, max_results_per_tool=10)
-
- # Track what to show in response (initialized to search results as default)
- evidence_to_show: list[Evidence] = result.evidence
- total_new = 0
-
- # Update shared evidence store
- if self._embeddings:
- # Deduplicate by semantic similarity (async-safe)
- unique_evidence = await self._embeddings.deduplicate(result.evidence)
-
- # Also search for semantically related evidence (async-safe)
- related = await self._embeddings.search_similar(query, n_results=5)
-
- # Merge related evidence not already in results
- existing_urls = {e.citation.url for e in unique_evidence}
-
- # Reconstruct Evidence objects from stored vector DB data
- related_evidence: list[Evidence] = []
- for item in related:
- if item["id"] not in existing_urls:
- meta = item.get("metadata", {})
- # Parse authors (stored as comma-separated string)
- authors_str = meta.get("authors", "")
- authors = [a.strip() for a in authors_str.split(",") if a.strip()]
-
- ev = Evidence(
- content=item["content"],
- citation=Citation(
- title=meta.get("title", "Related Evidence"),
- url=item["id"],
- source="pubmed",
- date=meta.get("date", "n.d."),
- authors=authors,
- ),
- # Convert distance to relevance (lower distance = higher relevance)
- relevance=max(0.0, 1.0 - item.get("distance", 0.5)),
- )
- related_evidence.append(ev)
-
- # Combine unique from search + related from vector DB
- final_new_evidence = unique_evidence + related_evidence
-
- # Add to global store (deduping against global store)
- global_urls = {e.citation.url for e in self._evidence_store["current"]}
- really_new = [e for e in final_new_evidence if e.citation.url not in global_urls]
- self._evidence_store["current"].extend(really_new)
-
- total_new = len(really_new)
- evidence_to_show = unique_evidence + related_evidence
-
- else:
- # Fallback to URL-based deduplication (no embeddings)
- existing_urls = {e.citation.url for e in self._evidence_store["current"]}
- new_unique = [e for e in result.evidence if e.citation.url not in existing_urls]
- self._evidence_store["current"].extend(new_unique)
- total_new = len(new_unique)
- evidence_to_show = result.evidence
-
- evidence_text = "\n".join(
- [
- f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..."
- for e in evidence_to_show[:5]
- ]
- )
-
- response_text = (
- f"Found {result.total_found} sources ({total_new} new added to context):\n\n"
- f"{evidence_text}"
- )
-
- return AgentRunResponse(
- messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
- response_id=f"search-{result.total_found}",
- additional_properties={"evidence": [e.model_dump() for e in evidence_to_show]},
- )
-
- async def run_stream(
- self,
- messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
- *,
- thread: AgentThread | None = None,
- **kwargs: Any,
- ) -> AsyncIterable[AgentRunResponseUpdate]:
- """Streaming wrapper for search (search itself isn't streaming)."""
- result = await self.run(messages, thread=thread, **kwargs)
- # Yield single update with full result
- yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id)
diff --git a/src/agents/state.py b/src/agents/state.py
deleted file mode 100644
index 8bab5f7b9563b5c263c376780c2856c8a058c2f0..0000000000000000000000000000000000000000
--- a/src/agents/state.py
+++ /dev/null
@@ -1,112 +0,0 @@
-"""Thread-safe state management for Magentic agents.
-
-DEPRECATED: This module is deprecated. Use src.middleware.state_machine instead.
-
-This file is kept for backward compatibility and will be removed in a future version.
-"""
-
-import warnings
-from contextvars import ContextVar
-from typing import TYPE_CHECKING, Any
-
-from pydantic import BaseModel, Field
-
-from src.utils.models import Citation, Evidence
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
-
-
-def _deprecation_warning() -> None:
- """Emit deprecation warning for this module."""
- warnings.warn(
- "src.agents.state is deprecated. Use src.middleware.state_machine instead.",
- DeprecationWarning,
- stacklevel=3,
- )
-
-
-class MagenticState(BaseModel):
- """Mutable state for a Magentic workflow session.
-
- DEPRECATED: Use WorkflowState from src.middleware.state_machine instead.
- """
-
- evidence: list[Evidence] = Field(default_factory=list)
- # Type as Any to avoid circular imports/runtime resolution issues
- # The actual object injected will be an EmbeddingService instance
- embedding_service: Any = None
-
- model_config = {"arbitrary_types_allowed": True}
-
- def add_evidence(self, new_evidence: list[Evidence]) -> int:
- """Add new evidence, deduplicating by URL.
-
- Returns:
- Number of *new* items added.
- """
- existing_urls = {e.citation.url for e in self.evidence}
- count = 0
- for item in new_evidence:
- if item.citation.url not in existing_urls:
- self.evidence.append(item)
- existing_urls.add(item.citation.url)
- count += 1
- return count
-
- async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]:
- """Search for semantically related evidence using the embedding service."""
- if not self.embedding_service:
- return []
-
- results = await self.embedding_service.search_similar(query, n_results=n_results)
-
- # Convert dict results back to Evidence objects
- evidence_list = []
- for item in results:
- meta = item.get("metadata", {})
- authors_str = meta.get("authors", "")
- authors = [a.strip() for a in authors_str.split(",") if a.strip()]
-
- ev = Evidence(
- content=item["content"],
- citation=Citation(
- title=meta.get("title", "Related Evidence"),
- url=item["id"],
- source="pubmed", # Defaulting to pubmed if unknown
- date=meta.get("date", "n.d."),
- authors=authors,
- ),
- relevance=max(0.0, 1.0 - item.get("distance", 0.5)),
- )
- evidence_list.append(ev)
-
- return evidence_list
-
-
-# The ContextVar holds the MagenticState for the current execution context
-_magentic_state_var: ContextVar[MagenticState | None] = ContextVar("magentic_state", default=None)
-
-
-def init_magentic_state(embedding_service: "EmbeddingService | None" = None) -> MagenticState:
- """Initialize a new state for the current context.
-
- DEPRECATED: Use init_workflow_state from src.middleware.state_machine instead.
- """
- _deprecation_warning()
- state = MagenticState(embedding_service=embedding_service)
- _magentic_state_var.set(state)
- return state
-
-
-def get_magentic_state() -> MagenticState:
- """Get the current state. Raises RuntimeError if not initialized.
-
- DEPRECATED: Use get_workflow_state from src.middleware.state_machine instead.
- """
- _deprecation_warning()
- state = _magentic_state_var.get()
- if state is None:
- # Auto-initialize if missing (e.g. during tests or simple scripts)
- return init_magentic_state()
- return state
diff --git a/src/agents/thinking.py b/src/agents/thinking.py
deleted file mode 100644
index eff08e5111e7700ea2d4940d730ee1dfd3225324..0000000000000000000000000000000000000000
--- a/src/agents/thinking.py
+++ /dev/null
@@ -1,161 +0,0 @@
-"""Thinking agent for generating observations and reflections.
-
-Converts the folder/thinking_agent.py implementation to use Pydantic AI.
-"""
-
-from datetime import datetime
-from typing import Any
-
-import structlog
-from pydantic_ai import Agent
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-
-from src.agent_factory.judges import get_model
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger()
-
-
-# System prompt for the thinking agent
-SYSTEM_PROMPT = f"""
-You are a research expert who is managing a research process in iterations. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
-
-You are given:
-1. The original research query along with some supporting background context
-2. A history of the tasks, actions, findings and thoughts you've made up until this point in the research process (on iteration 1 you will be at the start of the research process, so this will be empty)
-
-Your objective is to reflect on the research process so far and share your latest thoughts.
-
-Specifically, your thoughts should include reflections on questions such as:
-- What have you learned from the last iteration?
-- What new areas would you like to explore next, or existing topics you'd like to go deeper into?
-- Were you able to retrieve the information you were looking for in the last iteration?
-- If not, should we change our approach or move to the next topic?
-- Is there any info that is contradictory or conflicting?
-
-Guidelines:
-- Share your stream of consciousness on the above questions as raw text
-- Keep your response concise and informal
-- Focus most of your thoughts on the most recent iteration and how that influences this next iteration
-- Our aim is to do very deep and thorough research - bear this in mind when reflecting on the research process
-- DO NOT produce a draft of the final report. This is not your job.
-- If this is the first iteration (i.e. no data from prior iterations), provide thoughts on what info we need to gather in the first iteration to get started
-"""
-
-
-class ThinkingAgent:
- """
- Agent that generates observations and reflections on the research process.
-
- Uses Pydantic AI to generate unstructured text observations about
- the current state of research and next steps.
- """
-
- def __init__(self, model: Any | None = None) -> None:
- """
- Initialize the thinking agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses config default.
- """
- self.model = model or get_model()
- self.logger = logger
-
- # Initialize Pydantic AI Agent (no structured output - returns text)
- self.agent = Agent(
- model=self.model,
- system_prompt=SYSTEM_PROMPT,
- retries=3,
- )
-
- async def generate_observations(
- self,
- query: str,
- background_context: str = "",
- conversation_history: str = "",
- message_history: list[ModelMessage] | None = None,
- iteration: int = 1,
- ) -> str:
- """
- Generate observations about the research process.
-
- Args:
- query: The original research query
- background_context: Optional background context
- conversation_history: History of actions, findings, and thoughts (backward compat)
- message_history: Optional user conversation history (Pydantic AI format)
- iteration: Current iteration number
-
- Returns:
- String containing observations and reflections
-
- Raises:
- ConfigurationError: If generation fails
- """
- self.logger.info(
- "Generating observations",
- query=query[:100],
- iteration=iteration,
- )
-
- background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
-
- user_message = f"""
-You are starting iteration {iteration} of your research process.
-
-ORIGINAL QUERY:
-{query}
-
-{background}
-
-HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
-{conversation_history or "No previous actions, findings or thoughts available."}
-"""
-
- try:
- # Run the agent with message_history if provided
- if message_history:
- result = await self.agent.run(user_message, message_history=message_history)
- else:
- result = await self.agent.run(user_message)
- observations = result.output
-
- self.logger.info("Observations generated", length=len(observations))
-
- return observations
-
- except Exception as e:
- self.logger.error("Observation generation failed", error=str(e))
- # Return fallback observations
- return f"Starting iteration {iteration}. Need to gather information about: {query}"
-
-
-def create_thinking_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> ThinkingAgent:
- """
- Factory function to create a thinking agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured ThinkingAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- try:
- if model is None:
- model = get_model(oauth_token=oauth_token)
-
- return ThinkingAgent(model=model)
-
- except Exception as e:
- logger.error("Failed to create thinking agent", error=str(e))
- raise ConfigurationError(f"Failed to create thinking agent: {e}") from e
diff --git a/src/agents/tool_selector.py b/src/agents/tool_selector.py
deleted file mode 100644
index 0da35f43696306bac1093e1722ccc2fc981e598b..0000000000000000000000000000000000000000
--- a/src/agents/tool_selector.py
+++ /dev/null
@@ -1,181 +0,0 @@
-"""Tool selector agent for choosing which tools to use for knowledge gaps.
-
-Converts the folder/tool_selector_agent.py implementation to use Pydantic AI.
-"""
-
-from datetime import datetime
-from typing import Any
-
-import structlog
-from pydantic_ai import Agent
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-
-from src.agent_factory.judges import get_model
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import AgentSelectionPlan
-
-logger = structlog.get_logger()
-
-
-# System prompt for the tool selector agent
-SYSTEM_PROMPT = f"""
-You are a Tool Selector responsible for determining which specialized agents should address a knowledge gap in a research project.
-Today's date is {datetime.now().strftime("%Y-%m-%d")}.
-
-You will be given:
-1. The original user query
-2. A knowledge gap identified in the research
-3. A full history of the tasks, actions, findings and thoughts you've made up until this point in the research process
-
-Your task is to decide:
-1. Which specialized agents are best suited to address the gap
-2. What specific queries should be given to the agents (keep this short - 3-6 words)
-
-Available specialized agents:
-- WebSearchAgent: General web search for broad topics (can be called multiple times with different queries)
-- SiteCrawlerAgent: Crawl the pages of a specific website to retrieve information about it - use this if you want to find out something about a particular company, entity or product
-- RAGAgent: Semantic search within previously collected evidence - use when you need to find information from evidence already gathered in this research session. Best for finding connections, summarizing collected evidence, or retrieving specific details from earlier findings.
-
-Guidelines:
-- Aim to call at most 3 agents at a time in your final output
-- You can list the WebSearchAgent multiple times with different queries if needed to cover the full scope of the knowledge gap
-- Be specific and concise (3-6 words) with the agent queries - they should target exactly what information is needed
-- If you know the website or domain name of an entity being researched, always include it in the query
-- Use RAGAgent when: (1) You need to search within evidence already collected, (2) You want to find connections between different findings, (3) You need to retrieve specific details from earlier research iterations
-- Use WebSearchAgent or SiteCrawlerAgent when: (1) You need fresh information from the web, (2) You're starting a new research direction, (3) You need information not yet in the collected evidence
-- If a gap doesn't clearly match any agent's capability, default to the WebSearchAgent
-- Use the history of actions / tool calls as a guide - try not to repeat yourself if an approach didn't work previously
-
-Only output JSON. Follow the JSON schema for AgentSelectionPlan. Do not output anything else.
-"""
-
-
-class ToolSelectorAgent:
- """
- Agent that selects appropriate tools to address knowledge gaps.
-
- Uses Pydantic AI to generate structured AgentSelectionPlan with
- specific tasks for web search and crawl agents.
- """
-
- def __init__(self, model: Any | None = None) -> None:
- """
- Initialize the tool selector agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses config default.
- """
- self.model = model or get_model()
- self.logger = logger
-
- # Initialize Pydantic AI Agent
- self.agent = Agent(
- model=self.model,
- output_type=AgentSelectionPlan,
- system_prompt=SYSTEM_PROMPT,
- retries=3,
- )
-
- async def select_tools(
- self,
- gap: str,
- query: str,
- background_context: str = "",
- conversation_history: str = "",
- message_history: list[ModelMessage] | None = None,
- ) -> AgentSelectionPlan:
- """
- Select tools to address a knowledge gap.
-
- Args:
- gap: The knowledge gap to address
- query: The original research query
- background_context: Optional background context
- conversation_history: History of actions, findings, and thoughts (backward compat)
- message_history: Optional user conversation history (Pydantic AI format)
-
- Returns:
- AgentSelectionPlan with tasks for selected agents
-
- Raises:
- ConfigurationError: If selection fails
- """
- self.logger.info("Selecting tools for gap", gap=gap[:100], query=query[:100])
-
- background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
-
- user_message = f"""
-ORIGINAL QUERY:
-{query}
-
-KNOWLEDGE GAP TO ADDRESS:
-{gap}
-
-{background}
-
-HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
-{conversation_history or "No previous actions, findings or thoughts available."}
-"""
-
- try:
- # Run the agent with message_history if provided
- if message_history:
- result = await self.agent.run(user_message, message_history=message_history)
- else:
- result = await self.agent.run(user_message)
- selection_plan = result.output
-
- self.logger.info(
- "Tool selection complete",
- tasks_count=len(selection_plan.tasks),
- agents=[task.agent for task in selection_plan.tasks],
- )
-
- return selection_plan
-
- except Exception as e:
- self.logger.error("Tool selection failed", error=str(e))
- # Return fallback: use web search
- from src.utils.models import AgentTask
-
- return AgentSelectionPlan(
- tasks=[
- AgentTask(
- gap=gap,
- agent="WebSearchAgent",
- query=gap[:50], # Use gap as query
- entity_website=None,
- )
- ]
- )
-
-
-def create_tool_selector_agent(
- model: Any | None = None, oauth_token: str | None = None
-) -> ToolSelectorAgent:
- """
- Factory function to create a tool selector agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured ToolSelectorAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- try:
- if model is None:
- model = get_model(oauth_token=oauth_token)
-
- return ToolSelectorAgent(model=model)
-
- except Exception as e:
- logger.error("Failed to create tool selector agent", error=str(e))
- raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e
diff --git a/src/agents/tools.py b/src/agents/tools.py
deleted file mode 100644
index 9ce9f908eba4d07e539b971f00dab8a420697956..0000000000000000000000000000000000000000
--- a/src/agents/tools.py
+++ /dev/null
@@ -1,175 +0,0 @@
-"""Tool functions for Magentic agents.
-
-These functions are decorated with @ai_function to be callable by the ChatAgent's internal LLM.
-They also interact with the thread-safe MagenticState to persist evidence.
-"""
-
-from agent_framework import ai_function
-
-from src.agents.state import get_magentic_state
-from src.tools.clinicaltrials import ClinicalTrialsTool
-from src.tools.europepmc import EuropePMCTool
-from src.tools.pubmed import PubMedTool
-
-# Singleton tool instances (stateless wrappers)
-_pubmed = PubMedTool()
-_clinicaltrials = ClinicalTrialsTool()
-_europepmc = EuropePMCTool()
-
-
-@ai_function # type: ignore[arg-type, misc]
-async def search_pubmed(query: str, max_results: int = 10) -> str:
- """Search PubMed for biomedical research papers.
-
- Use this tool to find peer-reviewed scientific literature about
- drugs, diseases, mechanisms of action, and clinical studies.
-
- Args:
- query: Search keywords (e.g., "metformin alzheimer mechanism")
- max_results: Maximum results to return (default 10)
-
- Returns:
- Formatted list of papers with titles, abstracts, and citations
- """
- state = get_magentic_state()
-
- # 1. Execute raw search
- results = await _pubmed.search(query, max_results)
- if not results:
- return f"No PubMed results found for: {query}"
-
- # 2. Semantic Deduplication & Expansion (The "Digital Twin" Brain)
- display_results = results
- if state.embedding_service:
- # Deduplicate against what we just found vs what's in the DB
- unique_results = await state.embedding_service.deduplicate(results)
-
- # Search for related context in the vector DB (previous searches)
- related = await state.search_related(query, n_results=3)
-
- # Combine unique new results + relevant historical results
- display_results = unique_results + related
-
- # 3. Update State (Persist for ReportAgent)
- # We add *all* found results to state, not just the displayed ones
- new_count = state.add_evidence(results)
-
- # 4. Format Output for LLM
- output = [f"Found {len(results)} results ({new_count} new stored):\n"]
-
- # Limit display to avoid context window overflow, but state has everything
- limit = min(len(display_results), max_results)
-
- for i, r in enumerate(display_results[:limit], 1):
- title = r.citation.title
- date = r.citation.date
- source = r.citation.source
- content_clean = r.content[:300].replace("\n", " ")
- url = r.citation.url
-
- output.append(f"{i}. **{title}** ({date})")
- output.append(f" Source: {source} | {url}")
- output.append(f" {content_clean}...")
- output.append("")
-
- return "\n".join(output)
-
-
-@ai_function # type: ignore[arg-type, misc]
-async def search_clinical_trials(query: str, max_results: int = 10) -> str:
- """Search ClinicalTrials.gov for clinical studies.
-
- Use this tool to find ongoing and completed clinical trials
- for research investigation.
-
- Args:
- query: Search terms (e.g., "metformin cancer phase 3")
- max_results: Maximum results to return (default 10)
-
- Returns:
- Formatted list of clinical trials with status and details
- """
- state = get_magentic_state()
-
- results = await _clinicaltrials.search(query, max_results)
- if not results:
- return f"No clinical trials found for: {query}"
-
- # Update state
- new_count = state.add_evidence(results)
-
- output = [f"Found {len(results)} clinical trials ({new_count} new stored):\n"]
- for i, r in enumerate(results[:max_results], 1):
- title = r.citation.title
- date = r.citation.date
- source = r.citation.source
- content_clean = r.content[:300].replace("\n", " ")
- url = r.citation.url
-
- output.append(f"{i}. **{title}**")
- output.append(f" Status: {source} | Date: {date}")
- output.append(f" {content_clean}...")
- output.append(f" URL: {url}\n")
-
- return "\n".join(output)
-
-
-@ai_function # type: ignore[arg-type, misc]
-async def search_preprints(query: str, max_results: int = 10) -> str:
- """Search Europe PMC for preprints and papers.
-
- Use this tool to find the latest research including preprints
- from bioRxiv, medRxiv, and peer-reviewed papers.
-
- Args:
- query: Search terms (e.g., "long covid treatment")
- max_results: Maximum results to return (default 10)
-
- Returns:
- Formatted list of papers with abstracts and links
- """
- state = get_magentic_state()
-
- results = await _europepmc.search(query, max_results)
- if not results:
- return f"No papers found for: {query}"
-
- # Update state
- new_count = state.add_evidence(results)
-
- output = [f"Found {len(results)} papers ({new_count} new stored):\n"]
- for i, r in enumerate(results[:max_results], 1):
- title = r.citation.title
- date = r.citation.date
- source = r.citation.source
- content_clean = r.content[:300].replace("\n", " ")
- url = r.citation.url
-
- output.append(f"{i}. **{title}**")
- output.append(f" Source: {source} | Date: {date}")
- output.append(f" {content_clean}...")
- output.append(f" URL: {url}\n")
-
- return "\n".join(output)
-
-
-@ai_function # type: ignore[arg-type, misc]
-async def get_bibliography() -> str:
- """Get the full list of collected evidence for the bibliography.
-
- Use this tool when generating the final report to get the complete
- list of references.
-
- Returns:
- Formatted bibliography string.
- """
- state = get_magentic_state()
- if not state.evidence:
- return "No evidence collected."
-
- output = ["## References"]
- for i, ev in enumerate(state.evidence, 1):
- output.append(f"{i}. {ev.citation.formatted}")
- output.append(f" URL: {ev.citation.url}")
-
- return "\n".join(output)
diff --git a/src/agents/writer.py b/src/agents/writer.py
deleted file mode 100644
index 9fc2edffc55ebf5420c920827745ba878ffe7842..0000000000000000000000000000000000000000
--- a/src/agents/writer.py
+++ /dev/null
@@ -1,234 +0,0 @@
-"""Writer agent for generating final reports from findings.
-
-Converts the folder/writer_agent.py implementation to use Pydantic AI.
-"""
-
-from datetime import datetime
-from typing import Any
-
-import structlog
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger()
-
-
-# System prompt for the writer agent
-SYSTEM_PROMPT = f"""
-You are a senior researcher tasked with comprehensively answering a research query.
-Today's date is {datetime.now().strftime("%Y-%m-%d")}.
-You will be provided with the original query along with research findings put together by a research assistant.
-Your objective is to generate the final response in markdown format.
-The response should be as lengthy and detailed as possible with the information provided, focusing on answering the original query.
-In your final output, include references to the source URLs for all information and data gathered.
-This should be formatted in the form of a numbered square bracket next to the relevant information,
-followed by a list of URLs at the end of the response, per the example below.
-
-EXAMPLE REFERENCE FORMAT:
-The company has XYZ products [1]. It operates in the software services market which is expected to grow at 10% per year [2].
-
-References:
-[1] https://example.com/first-source-url
-[2] https://example.com/second-source-url
-
-GUIDELINES:
-* Answer the query directly, do not include unrelated or tangential information.
-* Adhere to any instructions on the length of your final response if provided in the user prompt.
-* If any additional guidelines are provided in the user prompt, follow them exactly and give them precedence over these system instructions.
-"""
-
-
-class WriterAgent:
- """
- Agent that generates final reports from research findings.
-
- Uses Pydantic AI to generate markdown reports with citations.
- """
-
- def __init__(self, model: Any | None = None) -> None:
- """
- Initialize the writer agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses config default.
- """
- self.model = model or get_model()
- self.logger = logger
-
- # Initialize Pydantic AI Agent (no structured output - returns markdown text)
- self.agent = Agent(
- model=self.model,
- system_prompt=SYSTEM_PROMPT,
- retries=3,
- )
-
- async def write_report(
- self,
- query: str,
- findings: str,
- output_length: str = "",
- output_instructions: str = "",
- ) -> str:
- """
- Write a final report from findings.
-
- Args:
- query: The original research query
- findings: All findings collected during research
- output_length: Optional description of desired output length
- output_instructions: Optional additional instructions
-
- Returns:
- Markdown formatted report string
-
- Raises:
- ConfigurationError: If writing fails
- """
- # Input validation
- if not query or not query.strip():
- self.logger.warning("Empty query provided, using default")
- query = "Research query"
-
- if findings is None:
- self.logger.warning("None findings provided, using empty string")
- findings = "No findings available."
-
- # Truncate very long inputs to prevent context overflow
- max_findings_length = 50000 # ~12k tokens
- if len(findings) > max_findings_length:
- self.logger.warning(
- "Findings too long, truncating",
- original_length=len(findings),
- truncated_length=max_findings_length,
- )
- findings = findings[:max_findings_length] + "\n\n[Content truncated due to length]"
-
- self.logger.info("Writing final report", query=query[:100], findings_length=len(findings))
-
- length_str = (
- f"* The full response should be approximately {output_length}.\n"
- if output_length
- else ""
- )
- instructions_str = f"* {output_instructions}" if output_instructions else ""
- guidelines_str = (
- ("\n\nGUIDELINES:\n" + length_str + instructions_str).strip("\n")
- if length_str or instructions_str
- else ""
- )
-
- user_message = f"""
-Provide a response based on the query and findings below with as much detail as possible. {guidelines_str}
-
-QUERY: {query}
-
-FINDINGS:
-{findings}
-"""
-
- # Retry logic for transient failures
- max_retries = 3
- last_exception: Exception | None = None
-
- for attempt in range(max_retries):
- try:
- # Run the agent
- result = await self.agent.run(user_message)
- report = result.output
-
- # Validate output
- if not report or not report.strip():
- self.logger.warning("Empty report generated, using fallback")
- raise ValueError("Empty report generated")
-
- self.logger.info("Report written", length=len(report), attempt=attempt + 1)
-
- return report
-
- except (TimeoutError, ConnectionError) as e:
- # Transient errors - retry
- last_exception = e
- if attempt < max_retries - 1:
- self.logger.warning(
- "Transient error, retrying",
- error=str(e),
- attempt=attempt + 1,
- max_retries=max_retries,
- )
- continue
- else:
- self.logger.error("Max retries exceeded for transient error", error=str(e))
- break
-
- except Exception as e:
- # Non-transient errors - don't retry
- last_exception = e
- self.logger.error(
- "Report writing failed", error=str(e), error_type=type(e).__name__
- )
- break
-
- # Return fallback report if all attempts failed
- self.logger.error(
- "Report writing failed after all attempts",
- error=str(last_exception) if last_exception else "Unknown error",
- )
-
- # Try to use evidence-based report generator for better fallback
- try:
- from src.middleware.state_machine import get_workflow_state
- from src.utils.report_generator import generate_report_from_evidence
-
- state = get_workflow_state()
- if state and state.evidence:
- self.logger.info(
- "Using evidence-based report generator for fallback",
- evidence_count=len(state.evidence),
- )
- return generate_report_from_evidence(
- query=query,
- evidence=state.evidence,
- findings=findings,
- )
- except Exception as e:
- self.logger.warning(
- "Failed to use evidence-based report generator",
- error=str(e),
- )
-
- # Fallback to simple report if evidence generator fails
- # Truncate findings in fallback if too long
- fallback_findings = findings[:500] + "..." if len(findings) > 500 else findings
- return (
- f"# Research Report\n\n"
- f"## Query\n{query}\n\n"
- f"## Findings\n{fallback_findings}\n\n"
- f"*Note: Report generation encountered an error. This is a fallback report.*"
- )
-
-
-def create_writer_agent(model: Any | None = None, oauth_token: str | None = None) -> WriterAgent:
- """
- Factory function to create a writer agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured WriterAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- try:
- if model is None:
- model = get_model(oauth_token=oauth_token)
-
- return WriterAgent(model=model)
-
- except Exception as e:
- logger.error("Failed to create writer agent", error=str(e))
- raise ConfigurationError(f"Failed to create writer agent: {e}") from e
diff --git a/src/app.py b/src/app.py
index ba6795ef78ae066b31a6a96a44c6895802f93544..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/src/app.py
+++ b/src/app.py
@@ -1,1323 +0,0 @@
-"""Main Gradio application for DeepCritical research agent.
-
-This module provides the Gradio interface with:
-- OAuth authentication via HuggingFace
-- Multimodal input support (text, images, audio)
-- Research agent orchestration
-- Real-time event streaming
-- MCP server integration
-"""
-
-import os
-from collections.abc import AsyncGenerator
-from typing import Any
-
-import gradio as gr
-import numpy as np
-import structlog
-
-from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler
-from src.orchestrator_factory import create_orchestrator
-from src.services.multimodal_processing import get_multimodal_service
-from src.utils.config import settings
-from src.utils.models import AgentEvent, OrchestratorConfig
-
-# Import ModelMessage from pydantic_ai with fallback
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- from typing import Any
-
- ModelMessage = Any # type: ignore[assignment, misc]
-
-# Type alias for Gradio multimodal input
-MultimodalPostprocess = dict[str, Any] | str
-
-# Import HuggingFace components with graceful fallback
-try:
- from pydantic_ai.models.huggingface import HuggingFaceModel
- from pydantic_ai.providers.huggingface import HuggingFaceProvider
-
- _HUGGINGFACE_AVAILABLE = True
-except ImportError:
- _HUGGINGFACE_AVAILABLE = False
- HuggingFaceModel = None # type: ignore[assignment, misc]
- HuggingFaceProvider = None # type: ignore[assignment, misc]
-
-try:
- from huggingface_hub import AsyncInferenceClient
-
- _ASYNC_INFERENCE_AVAILABLE = True
-except ImportError:
- _ASYNC_INFERENCE_AVAILABLE = False
- AsyncInferenceClient = None # type: ignore[assignment, misc]
-
-logger = structlog.get_logger()
-
-
-def configure_orchestrator(
- use_mock: bool = False,
- mode: str = "simple",
- oauth_token: str | None = None,
- hf_model: str | None = None,
- hf_provider: str | None = None,
- graph_mode: str | None = None,
- use_graph: bool = True,
- web_search_provider: str | None = None,
-) -> tuple[Any, str]:
- """
- Configure and create the research orchestrator.
-
- Args:
- use_mock: Force mock judge handler (for testing)
- mode: Orchestrator mode ("simple", "iterative", "deep", "auto", "advanced")
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
- hf_model: Optional HuggingFace model ID (overrides settings)
- hf_provider: Optional inference provider (currently not used by HuggingFaceProvider)
- graph_mode: Optional graph execution mode
- use_graph: Whether to use graph execution
- web_search_provider: Optional web search provider ("auto", "serper", "duckduckgo")
-
- Returns:
- Tuple of (orchestrator, backend_info_string)
- """
- from src.tools.clinicaltrials import ClinicalTrialsTool
- from src.tools.europepmc import EuropePMCTool
- from src.tools.neo4j_search import Neo4jSearchTool
- from src.tools.pubmed import PubMedTool
- from src.tools.search_handler import SearchHandler
- from src.tools.web_search_factory import create_web_search_tool
-
- # Create search handler with tools
- tools = []
-
- # Add biomedical search tools (always available, no API keys required)
- tools.append(PubMedTool())
- logger.info("PubMed tool added to search handler")
-
- tools.append(ClinicalTrialsTool())
- logger.info("ClinicalTrials tool added to search handler")
-
- tools.append(EuropePMCTool())
- logger.info("EuropePMC tool added to search handler")
-
- # Add Neo4j knowledge graph search tool (if Neo4j is configured)
- neo4j_tool = Neo4jSearchTool()
- tools.append(neo4j_tool)
- logger.info("Neo4j search tool added to search handler")
-
- # Add web search tool
- web_search_tool = create_web_search_tool(provider=web_search_provider or "auto")
- if web_search_tool:
- tools.append(web_search_tool)
- logger.info("Web search tool added to search handler", provider=web_search_tool.name)
-
- # Create config if not provided
- config = OrchestratorConfig()
-
- search_handler = SearchHandler(
- tools=tools,
- timeout=config.search_timeout,
- include_rag=True,
- auto_ingest_to_rag=True,
- oauth_token=oauth_token,
- )
-
- # Create judge (mock, real, or free tier)
- judge_handler: JudgeHandler | MockJudgeHandler | HFInferenceJudgeHandler
- backend_info = "Unknown"
-
- # 1. Forced Mock (Unit Testing)
- if use_mock:
- judge_handler = MockJudgeHandler()
- backend_info = "Mock (Testing)"
-
- # 2. API Key (OAuth or Env) - HuggingFace only (OAuth provides HF token)
- # Priority: oauth_token > env vars
- # On HuggingFace Spaces, OAuth token is available via request.oauth_token
- #
- # OAuth Scope Requirements:
- # - 'inference-api': Required for HuggingFace Inference API access
- # This scope grants access to:
- # * HuggingFace's own Inference API
- # * All third-party inference providers (nebius, together, scaleway, hyperbolic, novita, nscale, sambanova, ovh, fireworks, etc.)
- # * All models available through the Inference Providers API
- # See: https://huggingface.co/docs/hub/oauth#currently-supported-scopes
- #
- # Note: The hf_provider parameter is accepted but not used here because HuggingFaceProvider
- # from pydantic-ai doesn't support provider selection. Provider selection happens at the
- # InferenceClient level (used in HuggingFaceChatClient for advanced mode).
- effective_api_key = oauth_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY")
-
- # Log which authentication source is being used
- if effective_api_key:
- auth_source = (
- "OAuth token"
- if oauth_token
- else ("HF_TOKEN env var" if os.getenv("HF_TOKEN") else "HUGGINGFACE_API_KEY env var")
- )
- logger.info(
- "Using HuggingFace authentication",
- source=auth_source,
- has_token=bool(effective_api_key),
- )
-
- if effective_api_key:
- # We have an API key (OAuth or env) - use pydantic-ai with JudgeHandler
- # This uses HuggingFace Inference API, which includes access to all third-party providers
- # via the Inference Providers API (router.huggingface.co)
- model: Any | None = None
- # Use selected model or fall back to env var/settings
- model_name = (
- hf_model
- or os.getenv("HF_MODEL")
- or settings.huggingface_model
- or "Qwen/Qwen3-Next-80B-A3B-Thinking"
- )
- if not _HUGGINGFACE_AVAILABLE:
- raise ImportError(
- "HuggingFace models are not available in this version of pydantic-ai. "
- "Please install with: uv add 'pydantic-ai[huggingface]' to use HuggingFace inference providers."
- )
- # Inference API - uses HuggingFace Inference API
- # Per https://ai.pydantic.dev/models/huggingface/#configure-the-provider
- # HuggingFaceProvider accepts api_key parameter directly
- # This is consistent with usage in src/utils/llm_factory.py and src/agent_factory/judges.py
- # The OAuth token with 'inference-api' scope provides access to all inference providers
- provider = HuggingFaceProvider(api_key=effective_api_key) # type: ignore[misc]
- model = HuggingFaceModel(model_name, provider=provider) # type: ignore[misc]
- backend_info = "API (HuggingFace OAuth)" if oauth_token else "API (Env Config)"
-
- judge_handler = JudgeHandler(model=model)
-
- # 3. Free Tier (HuggingFace Inference) - NO API KEY AVAILABLE
- else:
- # No API key available - use HFInferenceJudgeHandler with public models
- # HFInferenceJudgeHandler will use HF_TOKEN from env if available, otherwise public models
- # Note: OAuth token should have been caught in effective_api_key check above
- # If we reach here, we truly have no API key, so use public models
- judge_handler = HFInferenceJudgeHandler(
- model_id=hf_model if hf_model else None,
- api_key=None, # Will use HF_TOKEN from env if available, otherwise public models
- )
- model_display = hf_model.split("/")[-1] if hf_model else "Default (Public Models)"
- backend_info = f"Free Tier ({model_display} - Public Models Only)"
-
- # Determine effective mode
- # If mode is already iterative/deep/auto, use it directly
- # If mode is "graph" or "simple", use graph_mode if provided
- effective_mode = mode
- if mode in ("graph", "simple") and graph_mode:
- effective_mode = graph_mode
- elif mode == "graph" and not graph_mode:
- effective_mode = "auto" # Default to auto if graph mode but no graph_mode specified
-
- orchestrator = create_orchestrator(
- search_handler=search_handler,
- judge_handler=judge_handler,
- config=config,
- mode=effective_mode, # type: ignore
- oauth_token=oauth_token,
- )
-
- return orchestrator, backend_info
-
-
-def _is_file_path(text: str) -> bool:
- """Check if text appears to be a file path.
-
- Args:
- text: Text to check
-
- Returns:
- True if text looks like a file path
- """
- return ("/" in text or "\\" in text) and (
- "." in text.split("/")[-1] or "." in text.split("\\")[-1]
- )
-
-
-def event_to_chat_message(event: AgentEvent) -> dict[str, Any]:
- """Convert AgentEvent to Gradio chat message format.
-
- Args:
- event: AgentEvent to convert
-
- Returns:
- Dictionary with 'role' and 'content' keys for Gradio Chatbot
- """
- result: dict[str, Any] = {
- "role": "assistant",
- "content": event.to_markdown(),
- }
-
- # Add metadata if available
- if event.data:
- metadata: dict[str, Any] = {}
-
- # Extract file path if present
- if isinstance(event.data, dict):
- file_path = event.data.get("file_path")
- if file_path:
- metadata["file_path"] = file_path
-
- if metadata:
- result["metadata"] = metadata
- return result
-
-
-def extract_oauth_info(request: gr.Request | None) -> tuple[str | None, str | None]:
- """
- Extract OAuth token and username from Gradio request.
-
- Args:
- request: Gradio request object containing OAuth information
-
- Returns:
- Tuple of (oauth_token, oauth_username)
- """
- oauth_token: str | None = None
- oauth_username: str | None = None
-
- if request is None:
- return oauth_token, oauth_username
-
- # Try multiple ways to access OAuth token (Gradio API may vary)
- # Pattern 1: request.oauth_token.token
- if hasattr(request, "oauth_token") and request.oauth_token is not None:
- if hasattr(request.oauth_token, "token"):
- oauth_token = request.oauth_token.token
- elif isinstance(request.oauth_token, str):
- oauth_token = request.oauth_token
- # Pattern 2: request.headers (fallback)
- elif hasattr(request, "headers"):
- # OAuth token might be in headers
- auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
- if auth_header and auth_header.startswith("Bearer "):
- oauth_token = auth_header.replace("Bearer ", "")
-
- # Access username from request
- if hasattr(request, "username") and request.username:
- oauth_username = request.username
- # Also try accessing via oauth_profile if available
- elif hasattr(request, "oauth_profile") and request.oauth_profile is not None:
- if hasattr(request.oauth_profile, "username") and request.oauth_profile.username:
- oauth_username = request.oauth_profile.username
- elif hasattr(request.oauth_profile, "name") and request.oauth_profile.name:
- oauth_username = request.oauth_profile.name
-
- return oauth_token, oauth_username
-
-
-async def yield_auth_messages(
- oauth_username: str | None,
- oauth_token: str | None,
- has_huggingface: bool,
- mode: str,
-) -> AsyncGenerator[dict[str, Any], None]:
- """
- Yield authentication status messages.
-
- Args:
- oauth_username: OAuth username if available
- oauth_token: OAuth token if available
- has_huggingface: Whether HuggingFace authentication is available
- mode: Research mode
-
- Yields:
- Chat message dictionaries
- """
- if oauth_username:
- yield {
- "role": "assistant",
- "content": f"👋 **Welcome, {oauth_username}!**\n\nAuthenticated via HuggingFace OAuth.",
- }
-
- if oauth_token:
- yield {
- "role": "assistant",
- "content": (
- "🔐 **Authentication Status**: ✅ Authenticated\n\n"
- "Your OAuth token has been validated. You can now use all AI models and research tools."
- ),
- }
- elif has_huggingface:
- yield {
- "role": "assistant",
- "content": (
- "🔐 **Authentication Status**: ✅ Using environment token\n\n"
- "Using HF_TOKEN from environment variables."
- ),
- }
- else:
- yield {
- "role": "assistant",
- "content": (
- "⚠️ **Authentication Status**: ❌ No authentication\n\n"
- "Please sign in with HuggingFace or set HF_TOKEN environment variable."
- ),
- }
-
- yield {
- "role": "assistant",
- "content": f"🚀 **Mode**: {mode.upper()}\n\nStarting research agent...",
- }
-
-
-def _extract_oauth_token(oauth_token: gr.OAuthToken | None) -> str | None:
- """Extract token value from OAuth token object."""
- if oauth_token is None:
- return None
-
- if hasattr(oauth_token, "token"):
- token_value: str | None = getattr(oauth_token, "token", None) # type: ignore[assignment]
- if token_value is None:
- return None
- logger.debug("OAuth token extracted from oauth_token.token attribute")
-
- # Validate token format
- from src.utils.hf_error_handler import log_token_info, validate_hf_token
-
- log_token_info(token_value, context="research_agent")
- is_valid, error_msg = validate_hf_token(token_value)
- if not is_valid:
- logger.warning(
- "OAuth token validation failed",
- error=error_msg,
- oauth_token_type=type(oauth_token).__name__,
- )
- return token_value
-
- if isinstance(oauth_token, str):
- logger.debug("OAuth token extracted as string")
-
- # Validate token format
- from src.utils.hf_error_handler import log_token_info, validate_hf_token
-
- log_token_info(oauth_token, context="research_agent")
- return oauth_token
-
- logger.warning(
- "OAuth token object present but token extraction failed",
- oauth_token_type=type(oauth_token).__name__,
- )
- return None
-
-
-def _extract_username(oauth_profile: gr.OAuthProfile | None) -> str | None:
- """Extract username from OAuth profile."""
- if oauth_profile is None:
- return None
-
- username: str | None = None
- if hasattr(oauth_profile, "username") and oauth_profile.username:
- username = str(oauth_profile.username)
- elif hasattr(oauth_profile, "name") and oauth_profile.name:
- username = str(oauth_profile.name)
-
- if username:
- logger.info("OAuth user authenticated", username=username)
- return username
-
-
-async def _process_multimodal_input(
- message: str | MultimodalPostprocess,
- enable_image_input: bool,
- enable_audio_input: bool,
- token_value: str | None,
-) -> tuple[str, tuple[int, np.ndarray[Any, Any]] | None]: # type: ignore[type-arg]
- """Process multimodal input and return processed text and audio data."""
- processed_text = ""
- audio_input_data: tuple[int, np.ndarray[Any, Any]] | None = None # type: ignore[type-arg]
-
- if isinstance(message, dict):
- processed_text = message.get("text", "") or ""
- files = message.get("files", []) or []
- audio_input_data = message.get("audio") or None
-
- if (files and enable_image_input) or (audio_input_data is not None and enable_audio_input):
- try:
- multimodal_service = get_multimodal_service()
- processed_text = await multimodal_service.process_multimodal_input(
- processed_text,
- files=files if enable_image_input else [],
- audio_input=audio_input_data if enable_audio_input else None,
- hf_token=token_value,
- prepend_multimodal=True,
- )
- except Exception as e:
- logger.warning("multimodal_processing_failed", error=str(e))
- else:
- processed_text = str(message) if message else ""
-
- return processed_text, audio_input_data
-
-
-async def research_agent(
- message: str | MultimodalPostprocess,
- history: list[dict[str, Any]],
- mode: str = "simple",
- hf_model: str | None = None,
- hf_provider: str | None = None,
- graph_mode: str = "auto",
- use_graph: bool = True,
- enable_image_input: bool = True,
- enable_audio_input: bool = True,
- web_search_provider: str = "auto",
- oauth_token: gr.OAuthToken | None = None,
- oauth_profile: gr.OAuthProfile | None = None,
-) -> AsyncGenerator[dict[str, Any], None]:
- """
- Main research agent function that processes queries and streams results.
-
- Args:
- message: User message (text, image, or audio)
- history: Conversation history
- mode: Orchestrator mode
- hf_model: Optional HuggingFace model ID
- hf_provider: Optional inference provider
- graph_mode: Graph execution mode
- use_graph: Whether to use graph execution
- enable_image_input: Whether to process image inputs
- enable_audio_input: Whether to process audio inputs
- web_search_provider: Web search provider selection
- oauth_token: Gradio OAuth token (None if user not logged in)
- oauth_profile: Gradio OAuth profile (None if user not logged in)
-
- Yields:
- Chat message dictionaries
- """
- # Extract OAuth token and username
- token_value = _extract_oauth_token(oauth_token)
- username = _extract_username(oauth_profile)
-
- # Check if user is logged in (OAuth token or env var)
- # Fallback to env vars for local development or Spaces with HF_TOKEN secret
- has_authentication = bool(
- token_value or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY")
- )
-
- if not has_authentication:
- yield {
- "role": "assistant",
- "content": (
- "🔐 **Authentication Required**\n\n"
- "Please **sign in with HuggingFace** using the login button at the top of the page "
- "before using this application.\n\n"
- "The login button is required to access the AI models and research tools."
- ),
- }
- return
-
- # Process multimodal input
- processed_text, audio_input_data = await _process_multimodal_input(
- message, enable_image_input, enable_audio_input, token_value
- )
-
- if not processed_text.strip():
- yield {
- "role": "assistant",
- "content": "Please enter a research question or provide an image/audio input.",
- }
- return
-
- # Check available keys (use token_value instead of oauth_token)
- has_huggingface = bool(os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") or token_value)
-
- # Adjust mode if needed
- effective_mode = mode
- if mode == "advanced":
- effective_mode = "simple"
-
- # Yield authentication and mode status messages
- async for msg in yield_auth_messages(username, token_value, has_huggingface, mode):
- yield msg
-
- # Run the agent and stream events
- try:
- # use_mock=False - let configure_orchestrator decide based on available keys
- # It will use: OAuth token > Env vars > HF Inference (free tier)
- # Convert empty strings from Textbox to None for defaults
- model_id = hf_model if hf_model and hf_model.strip() else None
- provider_name = hf_provider if hf_provider and hf_provider.strip() else None
-
- # Log authentication source for debugging
- auth_source = (
- "OAuth"
- if token_value
- else (
- "Env (HF_TOKEN)"
- if os.getenv("HF_TOKEN")
- else ("Env (HUGGINGFACE_API_KEY)" if os.getenv("HUGGINGFACE_API_KEY") else "None")
- )
- )
- logger.info(
- "Configuring orchestrator",
- mode=effective_mode,
- auth_source=auth_source,
- has_oauth_token=bool(token_value),
- model=model_id or "default",
- provider=provider_name or "auto",
- )
-
- # Convert empty string to None for web_search_provider
- web_search_provider_value = (
- web_search_provider if web_search_provider and web_search_provider.strip() else None
- )
-
- orchestrator, backend_name = configure_orchestrator(
- use_mock=False, # Never use mock in production - HF Inference is the free fallback
- mode=effective_mode,
- oauth_token=token_value, # Use extracted token value - passed to all agents and services
- hf_model=model_id, # None will use defaults in configure_orchestrator
- hf_provider=provider_name, # None will use defaults in configure_orchestrator
- graph_mode=graph_mode if graph_mode else None,
- use_graph=use_graph,
- web_search_provider=web_search_provider_value, # None will use settings default
- )
-
- yield {
- "role": "assistant",
- "content": f"🔧 **Backend**: {backend_name}\n\nProcessing your query...",
- }
-
- # Convert history to ModelMessage format if needed
- message_history: list[ModelMessage] = []
- if history:
- for msg in history:
- role = msg.get("role", "user")
- content = msg.get("content", "")
- if isinstance(content, str) and content.strip():
- message_history.append(ModelMessage(role=role, content=content)) # type: ignore[operator]
-
- # Run orchestrator and stream events
- async for event in orchestrator.run(
- processed_text, message_history=message_history if message_history else None
- ):
- chat_msg = event_to_chat_message(event)
- yield chat_msg
-
- # Note: Audio output is now handled via on-demand TTS button
- # Users click "Generate Audio" button to create TTS for the last response
-
- except Exception as e:
- # Return error message without metadata to avoid issues during example caching
- # Metadata can cause validation errors when Gradio caches examples
- # Gradio Chatbot requires plain text - remove all markdown and special characters
- error_msg = str(e).replace("**", "").replace("*", "").replace("`", "")
- # Ensure content is a simple string without any special formatting
- yield {
- "role": "assistant",
- "content": f"Error: {error_msg}. Please check your configuration and try again.",
- }
-
-
-async def update_model_provider_dropdowns(
- oauth_token: gr.OAuthToken | None = None,
- oauth_profile: gr.OAuthProfile | None = None,
-) -> tuple[dict[str, Any], dict[str, Any], str]:
- """Update model and provider dropdowns based on OAuth token.
-
- This function is called when OAuth token/profile changes (user logs in/out).
- It queries HuggingFace API to get available models and providers.
-
- Args:
- oauth_token: Gradio OAuth token
- oauth_profile: Gradio OAuth profile
-
- Returns:
- Tuple of (model_dropdown_update, provider_dropdown_update, status_message)
- """
- from src.utils.hf_model_validator import (
- get_available_models,
- get_available_providers,
- validate_oauth_token,
- )
-
- # Extract token value
- token_value: str | None = None
- if oauth_token is not None:
- if hasattr(oauth_token, "token"):
- token_value = oauth_token.token
- elif isinstance(oauth_token, str):
- token_value = oauth_token
-
- # Default values (empty = use default)
- default_models = [""]
- default_providers = [""]
- status_msg = "⚠️ Not authenticated - using default models"
-
- if not token_value:
- # No token - return defaults
- return (
- gr.update(choices=default_models, value=""),
- gr.update(choices=default_providers, value=""),
- status_msg,
- )
-
- try:
- # Validate token and get available resources
- validation_result = await validate_oauth_token(token_value)
-
- if not validation_result["is_valid"]:
- status_msg = (
- f"❌ Token validation failed: {validation_result.get('error', 'Unknown error')}"
- )
- return (
- gr.update(choices=default_models, value=""),
- gr.update(choices=default_providers, value=""),
- status_msg,
- )
-
- # Get available models and providers
- models = await get_available_models(token=token_value, limit=50)
- providers = await get_available_providers(token=token_value)
-
- # Combine with defaults
- model_choices = ["", *models[:49]] # Keep first 49 + empty option
- provider_choices = providers # Already includes "auto"
-
- username = validation_result.get("username", "User")
-
- # Build status message with warning if scope is missing
- scope_warning = ""
- if not validation_result["has_inference_api_scope"]:
- scope_warning = (
- "⚠️ Token may not have 'inference-api' scope - some models may not work\n\n"
- )
-
- status_msg = (
- f"{scope_warning}✅ Authenticated as {username}\n\n"
- f"📊 Found {len(models)} available models\n"
- f"🔧 Found {len(providers)} available providers"
- )
-
- logger.info(
- "Updated model/provider dropdowns",
- model_count=len(model_choices),
- provider_count=len(provider_choices),
- username=username,
- )
-
- return (
- gr.update(choices=model_choices, value=""),
- gr.update(choices=provider_choices, value=""),
- status_msg,
- )
-
- except Exception as e:
- logger.error("Failed to update dropdowns", error=str(e))
- status_msg = f"⚠️ Failed to load models: {e!s}"
- return (
- gr.update(choices=default_models, value=""),
- gr.update(choices=default_providers, value=""),
- status_msg,
- )
-
-
-def create_demo() -> gr.Blocks:
- """
- Create the Gradio demo interface with MCP support and OAuth login.
-
- Returns:
- Configured Gradio Blocks interface with MCP server and OAuth enabled
- """
- with gr.Blocks(title="🔬 The DETERMINATOR", fill_height=True) as demo:
- # Add sidebar with login button and information
- # Reference: Working implementation pattern from Gradio docs
- with gr.Sidebar():
- gr.Markdown("# 🔐 Authentication")
- gr.Markdown(
- "**Sign in with Hugging Face** to access AI models and research tools.\n\n"
- "This application requires authentication to use the inference API."
- )
- gr.LoginButton("Sign in with Hugging Face")
- gr.Markdown("---")
-
- # About Section - Collapsible with details
- with gr.Accordion("ℹ️ About", open=False):
- gr.Markdown(
- "**The DETERMINATOR** - Generalist Deep Research Agent\n\n"
- "Stops at nothing until finding precise answers to complex questions.\n\n"
- "**How It Works**:\n"
- "- 🔍 Multi-source search (Web, PubMed, ClinicalTrials.gov, Europe PMC, RAG)\n"
- "- 🧠 Automatic medical knowledge detection\n"
- "- 🔄 Iterative refinement with search-judge loops\n"
- "- ⏹️ Continues until budget/time/iteration limits\n"
- "- 📊 Evidence synthesis with citations\n\n"
- "**Multimodal Input**:\n"
- "- 📷 **Images**: Click image icon in textbox (OCR)\n"
- "- 🎤 **Audio**: Click microphone icon (speech-to-text)\n"
- "- 📄 **Files**: Drag & drop or click to upload\n\n"
- "**MCP Server**: Connect Claude Desktop to `/gradio_api/mcp/`\n\n"
- "⚠️ **Research tool only** - Synthesizes evidence but cannot provide medical advice."
- )
-
- gr.Markdown("---")
-
- # Settings Section - Organized in Accordions
- gr.Markdown("## ⚙️ Settings")
-
- # Research Configuration Accordion
- with gr.Accordion("🔬 Research Configuration", open=True):
- mode_radio = gr.Radio(
- choices=["simple", "advanced", "iterative", "deep", "auto"],
- value="simple",
- label="Orchestrator Mode",
- info=(
- "Simple: Linear search-judge loop | "
- "Advanced: Multi-agent (OpenAI) | "
- "Iterative: Knowledge-gap driven | "
- "Deep: Parallel sections | "
- "Auto: Smart routing"
- ),
- )
-
- graph_mode_radio = gr.Radio(
- choices=["iterative", "deep", "auto"],
- value="auto",
- label="Graph Research Mode",
- info="Iterative: Single loop | Deep: Parallel sections | Auto: Detect from query",
- )
-
- use_graph_checkbox = gr.Checkbox(
- value=True,
- label="Use Graph Execution",
- info="Enable graph-based workflow execution",
- )
-
- # Model and Provider selection
- gr.Markdown("### 🤖 Model & Provider")
-
- # Status message for model/provider loading
- model_provider_status = gr.Markdown(
- value="⚠️ Sign in to see available models and providers",
- visible=True,
- )
-
- # Popular models list (will be updated by validator)
- popular_models = [
- "", # Empty = use default
- "Qwen/Qwen3-Next-80B-A3B-Thinking",
- "Qwen/Qwen3-235B-A22B-Instruct-2507",
- "zai-org/GLM-4.5-Air",
- "meta-llama/Llama-3.1-8B-Instruct",
- "meta-llama/Llama-3.1-70B-Instruct",
- "mistralai/Mistral-7B-Instruct-v0.2",
- "google/gemma-2-9b-it",
- ]
-
- hf_model_dropdown = gr.Dropdown(
- choices=popular_models,
- value="", # Empty string - will be converted to None in research_agent
- label="Reasoning Model",
- info="Select a HuggingFace model (leave empty for default). Sign in to see all available models.",
- allow_custom_value=True, # Allow users to type custom model IDs
- )
-
- # Provider list from README (will be updated by validator)
- providers = [
- "", # Empty string = auto-select
- "nebius",
- "together",
- "scaleway",
- "hyperbolic",
- "novita",
- "nscale",
- "sambanova",
- "ovh",
- "fireworks",
- ]
-
- hf_provider_dropdown = gr.Dropdown(
- choices=providers,
- value="", # Empty string - will be converted to None in research_agent
- label="Inference Provider",
- info="Select inference provider (leave empty for auto-select). Sign in to see all available providers.",
- )
-
- # Refresh button for updating models/providers after login
- def refresh_models_and_providers(
- request: gr.Request,
- ) -> tuple[dict[str, Any], dict[str, Any], str]:
- """Handle refresh button click and update dropdowns."""
- import asyncio
-
- # Extract OAuth token and profile from request
- oauth_token: gr.OAuthToken | None = None
- oauth_profile: gr.OAuthProfile | None = None
-
- if request is not None:
- # Try to get OAuth token from request
- if hasattr(request, "oauth_token"):
- oauth_token = request.oauth_token
- if hasattr(request, "oauth_profile"):
- oauth_profile = request.oauth_profile
-
- # Run async function in sync context
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- result = loop.run_until_complete(
- update_model_provider_dropdowns(oauth_token, oauth_profile)
- )
- return result
- finally:
- loop.close()
-
- refresh_models_btn = gr.Button(
- value="🔄 Refresh Available Models",
- visible=True,
- size="sm",
- )
-
- # Pass request to get OAuth token from Gradio context
- refresh_models_btn.click(
- fn=refresh_models_and_providers,
- inputs=[], # Request is automatically available in Gradio context
- outputs=[hf_model_dropdown, hf_provider_dropdown, model_provider_status],
- )
-
- # Web Search Provider selection
- gr.Markdown("### 🔍 Web Search Provider")
-
- # Available providers with labels indicating availability
- # Format: (display_label, value) - Gradio Dropdown supports tuples
- web_search_provider_options = [
- ("Auto-detect (Recommended)", "auto"),
- ("Serper (Google Search + Full Content)", "serper"),
- ("DuckDuckGo (Free, Snippets Only)", "duckduckgo"),
- ("SearchXNG (Self-hosted) - Coming Soon", "searchxng"), # Not fully implemented
- ("Brave - Coming Soon", "brave"), # Not implemented
- ("Tavily - Coming Soon", "tavily"), # Not implemented
- ]
-
- # Create Dropdown with label-value pairs
- # Gradio will display labels but return values
- # Disabled options are marked with "Coming Soon" in the label
- # The factory will handle "not implemented" cases gracefully
- web_search_provider_dropdown = gr.Dropdown(
- choices=web_search_provider_options,
- value="auto",
- label="Web Search Provider",
- info="Select web search provider. 'Auto' detects best available.",
- )
-
- # Multimodal Input Configuration
- gr.Markdown("### 📷🎤 Multimodal Input")
-
- enable_image_input_checkbox = gr.Checkbox(
- value=settings.enable_image_input,
- label="Enable Image Input (OCR)",
- info="Process uploaded images with OCR",
- )
-
- enable_audio_input_checkbox = gr.Checkbox(
- value=settings.enable_audio_input,
- label="Enable Audio Input (STT)",
- info="Process uploaded/recorded audio with speech-to-text",
- )
-
- # Audio Output Configuration - Collapsible
- with gr.Accordion("🔊 Audio Output (TTS)", open=False):
- gr.Markdown(
- "**Generate audio for research responses on-demand.**\n\n"
- "Enter Modal keys below or set `MODAL_TOKEN_ID`/`MODAL_TOKEN_SECRET` in `.env` for local development."
- )
-
- with gr.Accordion("🔑 Modal Credentials (Optional)", open=False):
- modal_token_id_input = gr.Textbox(
- label="Modal Token ID",
- placeholder="ak-... (leave empty to use .env)",
- type="password",
- value="",
- )
-
- modal_token_secret_input = gr.Textbox(
- label="Modal Token Secret",
- placeholder="as-... (leave empty to use .env)",
- type="password",
- value="",
- )
-
- with gr.Accordion("🎚️ Voice & Quality Settings", open=False):
- tts_voice_dropdown = gr.Dropdown(
- choices=[
- "af_heart",
- "af_bella",
- "af_sarah",
- "af_sky",
- "af_nova",
- "af_shimmer",
- "af_echo",
- "af_fable",
- "af_onyx",
- "af_angel",
- "af_asteria",
- "af_jessica",
- "af_elli",
- "af_domi",
- "af_gigi",
- "af_freya",
- "af_glinda",
- "af_cora",
- "af_serena",
- "af_liv",
- "af_naomi",
- "af_rachel",
- "af_antoni",
- "af_thomas",
- "af_charlie",
- "af_emily",
- "af_george",
- "af_arnold",
- "af_adam",
- "af_sam",
- "af_paul",
- "af_josh",
- "af_daniel",
- "af_liam",
- "af_dave",
- "af_fin",
- "af_sarah",
- "af_glinda",
- "af_grace",
- "af_dorothy",
- "af_michael",
- "af_james",
- "af_joseph",
- "af_jeremy",
- "af_ryan",
- "af_oliver",
- "af_harry",
- "af_kyle",
- "af_leo",
- "af_otto",
- "af_owen",
- "af_pepper",
- "af_phil",
- "af_raven",
- "af_rocky",
- "af_rusty",
- "af_serena",
- "af_sky",
- "af_spark",
- "af_stella",
- "af_storm",
- "af_taylor",
- "af_vera",
- "af_will",
- "af_aria",
- "af_ash",
- "af_ballad",
- "af_bella",
- "af_breeze",
- "af_cove",
- "af_dusk",
- "af_ember",
- "af_flash",
- "af_flow",
- "af_glow",
- "af_harmony",
- "af_journey",
- "af_lullaby",
- "af_lyra",
- "af_melody",
- "af_midnight",
- "af_moon",
- "af_muse",
- "af_music",
- "af_narrator",
- "af_nightingale",
- "af_poet",
- "af_rain",
- "af_redwood",
- "af_rewind",
- "af_river",
- "af_sage",
- "af_seashore",
- "af_shadow",
- "af_silver",
- "af_song",
- "af_starshine",
- "af_story",
- "af_summer",
- "af_sun",
- "af_thunder",
- "af_tide",
- "af_time",
- "af_valentino",
- "af_verdant",
- "af_verse",
- "af_vibrant",
- "af_vivid",
- "af_warmth",
- "af_whisper",
- "af_wilderness",
- "af_willow",
- "af_winter",
- "af_wit",
- "af_witness",
- "af_wren",
- "af_writer",
- "af_zara",
- "af_zeus",
- "af_ziggy",
- "af_zoom",
- "af_river",
- "am_michael",
- "am_fenrir",
- "am_puck",
- "am_echo",
- "am_eric",
- "am_liam",
- "am_onyx",
- "am_santa",
- "am_adam",
- ],
- value=settings.tts_voice,
- label="TTS Voice",
- info="Select TTS voice (American English voices: af_*, am_*)",
- )
-
- tts_speed_slider = gr.Slider(
- minimum=0.5,
- maximum=2.0,
- value=settings.tts_speed,
- step=0.1,
- label="TTS Speech Speed",
- info="Adjust TTS speech speed (0.5x to 2.0x)",
- )
-
- gr.Dropdown(
- choices=["T4", "A10", "A100", "L4", "L40S"],
- value=settings.tts_gpu or "T4",
- label="TTS GPU Type",
- info="Modal GPU type for TTS (T4 is cheapest, A100 is fastest). Note: GPU changes require app restart.",
- visible=settings.modal_available,
- interactive=False, # GPU type set at function definition time, requires restart
- )
-
- tts_use_llm_polish_checkbox = gr.Checkbox(
- value=settings.tts_use_llm_polish,
- label="Use LLM Polish for Audio",
- info="Apply LLM-based final polish to remove remaining formatting artifacts (costs API calls)",
- )
-
- tts_generate_button = gr.Button(
- "🎵 Generate Audio for Last Response",
- variant="primary",
- size="lg",
- )
-
- tts_status_text = gr.Markdown(
- "Click the button above to generate audio for the last research response.",
- elem_classes="tts-status",
- )
-
- # Audio output component (for TTS response)
- audio_output = gr.Audio(
- label="🔊 Audio Output",
- visible=True,
- )
-
- # TTS on-demand generation handler
- async def handle_tts_generation(
- history: list[dict[str, Any]],
- modal_token_id: str,
- modal_token_secret: str,
- voice: str,
- speed: float,
- use_llm_polish: bool,
- ) -> tuple[Any | None, str]:
- """Generate audio on-demand for the last response.
-
- Args:
- history: Chat history
- modal_token_id: Modal token ID from UI
- modal_token_secret: Modal token secret from UI
- voice: TTS voice selection
- speed: TTS speed
- use_llm_polish: Enable LLM polish
-
- Returns:
- Tuple of (audio_output, status_message)
- """
- from src.services.tts_modal import generate_audio_on_demand
-
- # Get last assistant message from history
- # History is a list of tuples: [(user_msg, assistant_msg), ...]
- if not history:
- logger.warning("tts_no_history", history=history)
- return None, "❌ No messages in history to generate audio for"
-
- # Debug: Log history format
- logger.info(
- "tts_history_debug",
- history_type=type(history).__name__,
- history_length=len(history) if isinstance(history, list) else 0,
- first_entry_type=type(history[0]).__name__
- if isinstance(history, list) and len(history) > 0
- else None,
- first_entry_sample=str(history[0])[:200]
- if isinstance(history, list) and len(history) > 0
- else None,
- )
-
- # Get the last assistant message (second element of last tuple)
- last_message = None
- if isinstance(history, list) and len(history) > 0:
- last_entry = history[-1]
- # ChatInterface format: (user_message, assistant_message)
- if isinstance(last_entry, (tuple, list)) and len(last_entry) >= 2:
- last_message = last_entry[1]
- logger.info(
- "tts_extracted_from_tuple", message_type=type(last_message).__name__
- )
- # Dict format: {"role": "assistant", "content": "..."}
- elif isinstance(last_entry, dict):
- if last_entry.get("role") == "assistant":
- content = last_entry.get("content", "")
- # Content might be a list (multimodal) or string
- if isinstance(content, list):
- # Extract text from multimodal content list
- last_message = " ".join(str(item) for item in content if item)
- else:
- last_message = content
- logger.info(
- "tts_extracted_from_dict",
- message_type=type(content).__name__,
- message_length=len(last_message)
- if isinstance(last_message, str)
- else 0,
- )
- else:
- logger.warning(
- "tts_unknown_format",
- entry_type=type(last_entry).__name__,
- entry=str(last_entry)[:200],
- )
-
- # Also handle if last_message itself is a list
- if isinstance(last_message, list):
- last_message = " ".join(str(item) for item in last_message if item)
-
- if not last_message or not isinstance(last_message, str) or not last_message.strip():
- logger.error(
- "tts_no_message_found",
- last_message_type=type(last_message).__name__ if last_message else None,
- last_message_value=str(last_message)[:100] if last_message else None,
- )
- return None, "❌ No assistant response found in history"
-
- # Generate audio
- audio_output, status_message = await generate_audio_on_demand(
- text=last_message,
- modal_token_id=modal_token_id,
- modal_token_secret=modal_token_secret,
- voice=voice,
- speed=speed,
- use_llm_polish=use_llm_polish,
- )
-
- return audio_output, status_message
-
- # Chat interface with multimodal support
- # Examples are provided but will NOT run at startup (cache_examples=False)
- # Users must log in first before using examples or submitting queries
- chat_interface = gr.ChatInterface(
- fn=research_agent,
- multimodal=True, # Enable multimodal input (text + images + audio)
- title="🔬 The DETERMINATOR",
- description=(
- "*Generalist Deep Research Agent — stops at nothing until finding precise answers*\n\n"
- "💡 **Quick Start**: Type your research question below. Use 📷 for images, 🎤 for audio.\n\n"
- "⚠️ **Sign in with HuggingFace** (sidebar) before starting."
- ),
- examples=[
- # When additional_inputs are provided, examples must be lists of lists
- # Each inner list: [message, mode, hf_model, hf_provider, graph_mode, multimodal_enabled]
- # Using actual model IDs and provider names from inference_models.py
- # Note: Provider is optional - if empty, HF will auto-select
- # These examples will NOT run at startup - users must click them after logging in
- # All examples require deep iterative search and information retrieval across multiple sources
- [
- # Medical research example (only one medical example)
- "Create a comprehensive report on Long COVID treatments including clinical trials, mechanisms, and safety.",
- "deep",
- "zai-org/GLM-4.5-Air",
- "nebius",
- "deep",
- True,
- ],
- [
- # Technical/Engineering example requiring deep research
- "Analyze the current state of quantum computing architectures: compare different qubit technologies, error correction methods, and scalability challenges across major platforms including IBM, Google, and IonQ.",
- "deep",
- "Qwen/Qwen3-Next-80B-A3B-Thinking",
- "nebius",
- "deep",
- True,
- ],
- [
- # Historical/Social Science example
- "Research and synthesize information about the economic impact of the Industrial Revolution on European social structures, including changes in class dynamics, urbanization patterns, and labor movements from 1750-1900.",
- "deep",
- "meta-llama/Llama-3.1-70B-Instruct",
- "together",
- "deep",
- True,
- ],
- [
- # Scientific/Physics example
- "Investigate the latest developments in fusion energy research: compare ITER, SPARC, and other major projects, analyze recent breakthroughs in plasma confinement, and assess the timeline to commercial fusion power.",
- "deep",
- "Qwen/Qwen3-235B-A22B-Instruct-2507",
- "hyperbolic",
- "deep",
- True,
- ],
- [
- # Technology/Business example
- "Research the competitive landscape of AI chip manufacturers: analyze NVIDIA, AMD, Intel, and emerging players, compare architectures (GPU vs. TPU vs. NPU), and assess market positioning and future trends.",
- "deep",
- "zai-org/GLM-4.5-Air",
- "fireworks",
- "deep",
- True,
- ],
- ],
- additional_inputs=[
- mode_radio,
- hf_model_dropdown,
- hf_provider_dropdown,
- graph_mode_radio,
- use_graph_checkbox,
- enable_image_input_checkbox,
- enable_audio_input_checkbox,
- web_search_provider_dropdown,
- # Note: gr.OAuthToken and gr.OAuthProfile are automatically passed as function parameters
- ],
- cache_examples=False, # Don't cache examples - requires authentication
- )
-
- # Wire up TTS generation button
- tts_generate_button.click(
- fn=handle_tts_generation,
- inputs=[
- chat_interface.chatbot, # Get chat history from ChatInterface
- modal_token_id_input,
- modal_token_secret_input,
- tts_voice_dropdown,
- tts_speed_slider,
- tts_use_llm_polish_checkbox,
- ],
- outputs=[audio_output, tts_status_text],
- )
-
- return demo # type: ignore[no-any-return]
-
-
-if __name__ == "__main__":
- demo = create_demo()
- demo.launch(server_name="0.0.0.0", server_port=7860)
diff --git a/src/legacy_orchestrator.py b/src/legacy_orchestrator.py
deleted file mode 100644
index b41ba8aad566bee202c37ddbc142e677a7a9fce4..0000000000000000000000000000000000000000
--- a/src/legacy_orchestrator.py
+++ /dev/null
@@ -1,453 +0,0 @@
-"""Orchestrator - the agent loop connecting Search and Judge."""
-
-import asyncio
-from collections.abc import AsyncGenerator
-from typing import Any, Protocol
-
-import structlog
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-
-from src.utils.config import settings
-from src.utils.models import (
- AgentEvent,
- Evidence,
- JudgeAssessment,
- OrchestratorConfig,
- SearchResult,
-)
-
-logger = structlog.get_logger()
-
-
-class SearchHandlerProtocol(Protocol):
- """Protocol for search handler."""
-
- async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult: ...
-
-
-class JudgeHandlerProtocol(Protocol):
- """Protocol for judge handler."""
-
- async def assess(self, question: str, evidence: list[Evidence]) -> JudgeAssessment: ...
-
-
-class Orchestrator:
- """
- The agent orchestrator - runs the Search -> Judge -> Loop cycle.
-
- This is a generator-based design that yields events for real-time UI updates.
- """
-
- def __init__(
- self,
- search_handler: SearchHandlerProtocol,
- judge_handler: JudgeHandlerProtocol,
- config: OrchestratorConfig | None = None,
- enable_analysis: bool = False,
- enable_embeddings: bool = True,
- ):
- """
- Initialize the orchestrator.
-
- Args:
- search_handler: Handler for executing searches
- judge_handler: Handler for assessing evidence
- config: Optional configuration (uses defaults if not provided)
- enable_analysis: Whether to perform statistical analysis (if Modal available)
- enable_embeddings: Whether to use semantic search for ranking/dedup
- """
- self.search = search_handler
- self.judge = judge_handler
- self.config = config or OrchestratorConfig()
- self.history: list[dict[str, Any]] = []
- self._enable_analysis = enable_analysis and settings.modal_available
- self._enable_embeddings = enable_embeddings
-
- # Lazy-load services
- self._analyzer: Any = None
- self._embeddings: Any = None
-
- def _get_analyzer(self) -> Any:
- """Lazy initialization of StatisticalAnalyzer.
-
- Note: This imports from src.services, NOT src.agents,
- so it works without the magentic optional dependency.
- """
- if self._analyzer is None:
- from src.services.statistical_analyzer import get_statistical_analyzer
-
- self._analyzer = get_statistical_analyzer()
- return self._analyzer
-
- def _get_embeddings(self) -> Any:
- """Lazy initialization of EmbeddingService.
-
- Uses local sentence-transformers - NO API key required.
- """
- if self._embeddings is None and self._enable_embeddings:
- try:
- from src.services.embeddings import get_embedding_service
-
- self._embeddings = get_embedding_service()
- logger.info("Embedding service enabled for semantic ranking")
- except Exception as e:
- logger.warning("Embeddings unavailable, using basic ranking", error=str(e))
- self._enable_embeddings = False
- return self._embeddings
-
- async def _deduplicate_and_rank(self, evidence: list[Evidence], query: str) -> list[Evidence]:
- """Use embeddings to deduplicate and rank evidence by relevance."""
- embeddings = self._get_embeddings()
- if not embeddings or not evidence:
- return evidence
-
- try:
- # Deduplicate using semantic similarity
- unique_evidence: list[Evidence] = await embeddings.deduplicate(evidence, threshold=0.85)
- logger.info(
- "Deduplicated evidence",
- before=len(evidence),
- after=len(unique_evidence),
- )
- return unique_evidence
- except Exception as e:
- logger.warning("Deduplication failed, using original", error=str(e))
- return evidence
-
- async def _run_analysis_phase(
- self, query: str, evidence: list[Evidence], iteration: int
- ) -> AsyncGenerator[AgentEvent, None]:
- """Run the optional analysis phase."""
- if not self._enable_analysis:
- return
-
- yield AgentEvent(
- type="analyzing",
- message="Running statistical analysis in Modal sandbox...",
- data={},
- iteration=iteration,
- )
-
- try:
- analyzer = self._get_analyzer()
-
- # Run Modal analysis (no agent_framework needed!)
- analysis_result = await analyzer.analyze(
- query=query,
- evidence=evidence,
- hypothesis=None, # Could add hypothesis generation later
- )
-
- yield AgentEvent(
- type="analysis_complete",
- message=f"Analysis verdict: {analysis_result.verdict}",
- data=analysis_result.model_dump(),
- iteration=iteration,
- )
-
- except Exception as e:
- logger.error("Modal analysis failed", error=str(e))
- yield AgentEvent(
- type="error",
- message=f"Modal analysis failed: {e}",
- data={"error": str(e)},
- iteration=iteration,
- )
-
- async def run(
- self, query: str, message_history: list[ModelMessage] | None = None
- ) -> AsyncGenerator[AgentEvent, None]: # noqa: PLR0915
- """
- Run the agent loop for a query.
-
- Yields AgentEvent objects for each step, allowing real-time UI updates.
-
- Args:
- query: The user's research question
- message_history: Optional user conversation history (for compatibility)
-
- Yields:
- AgentEvent objects for each step of the process
- """
- logger.info(
- "Starting orchestrator",
- query=query,
- has_history=bool(message_history),
- )
-
- yield AgentEvent(
- type="started",
- message=f"Starting research for: {query}",
- iteration=0,
- )
-
- all_evidence: list[Evidence] = []
- current_queries = [query]
- iteration = 0
-
- while iteration < self.config.max_iterations:
- iteration += 1
- logger.info("Iteration", iteration=iteration, queries=current_queries)
-
- # === SEARCH PHASE ===
- yield AgentEvent(
- type="searching",
- message=f"Searching for: {', '.join(current_queries[:3])}...",
- iteration=iteration,
- )
-
- try:
- # Execute searches for all current queries
- search_tasks = [
- self.search.execute(q, self.config.max_results_per_tool)
- for q in current_queries[:3] # Limit to 3 queries per iteration
- ]
- search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
-
- # Collect evidence from successful searches
- new_evidence: list[Evidence] = []
- errors: list[str] = []
-
- for q, result in zip(current_queries[:3], search_results, strict=False):
- if isinstance(result, Exception):
- errors.append(f"Search for '{q}' failed: {result!s}")
- elif isinstance(result, SearchResult):
- new_evidence.extend(result.evidence)
- errors.extend(result.errors)
- else:
- # Should not happen with return_exceptions=True but safe fallback
- errors.append(f"Unknown result type for '{q}': {type(result)}")
-
- # Deduplicate evidence by URL (fast, basic)
- seen_urls = {e.citation.url for e in all_evidence}
- unique_new = [e for e in new_evidence if e.citation.url not in seen_urls]
- all_evidence.extend(unique_new)
-
- # Semantic deduplication and ranking (if embeddings available)
- all_evidence = await self._deduplicate_and_rank(all_evidence, query)
-
- yield AgentEvent(
- type="search_complete",
- message=f"Found {len(unique_new)} new sources ({len(all_evidence)} total)",
- data={
- "new_count": len(unique_new),
- "total_count": len(all_evidence),
- },
- iteration=iteration,
- )
-
- if errors:
- logger.warning("Search errors", errors=errors)
-
- except Exception as e:
- logger.error("Search phase failed", error=str(e))
- yield AgentEvent(
- type="error",
- message=f"Search failed: {e!s}",
- iteration=iteration,
- )
- continue
-
- # === JUDGE PHASE ===
- yield AgentEvent(
- type="judging",
- message=f"Evaluating {len(all_evidence)} sources...",
- iteration=iteration,
- )
-
- try:
- assessment = await self.judge.assess(query, all_evidence)
-
- yield AgentEvent(
- type="judge_complete",
- message=(
- f"Assessment: {assessment.recommendation} "
- f"(confidence: {assessment.confidence:.0%})"
- ),
- data={
- "sufficient": assessment.sufficient,
- "confidence": assessment.confidence,
- "mechanism_score": assessment.details.mechanism_score,
- "clinical_score": assessment.details.clinical_evidence_score,
- },
- iteration=iteration,
- )
-
- # Record this iteration in history
- self.history.append(
- {
- "iteration": iteration,
- "queries": current_queries,
- "evidence_count": len(all_evidence),
- "assessment": assessment.model_dump(),
- }
- )
-
- # === DECISION PHASE ===
- if assessment.sufficient and assessment.recommendation == "synthesize":
- # Optional Analysis Phase
- async for event in self._run_analysis_phase(query, all_evidence, iteration):
- yield event
-
- yield AgentEvent(
- type="synthesizing",
- message="Evidence sufficient! Preparing synthesis...",
- iteration=iteration,
- )
-
- # Generate final response
- final_response = self._generate_synthesis(query, all_evidence, assessment)
-
- yield AgentEvent(
- type="complete",
- message=final_response,
- data={
- "evidence_count": len(all_evidence),
- "iterations": iteration,
- "drug_candidates": assessment.details.drug_candidates,
- "key_findings": assessment.details.key_findings,
- },
- iteration=iteration,
- )
- return
-
- else:
- # Need more evidence - prepare next queries
- current_queries = assessment.next_search_queries or [
- f"{query} mechanism of action",
- f"{query} clinical evidence",
- ]
-
- yield AgentEvent(
- type="looping",
- message=(
- f"Need more evidence. "
- f"Next searches: {', '.join(current_queries[:2])}..."
- ),
- data={"next_queries": current_queries},
- iteration=iteration,
- )
-
- except Exception as e:
- logger.error("Judge phase failed", error=str(e))
- yield AgentEvent(
- type="error",
- message=f"Assessment failed: {e!s}",
- iteration=iteration,
- )
- continue
-
- # Max iterations reached
- yield AgentEvent(
- type="complete",
- message=self._generate_partial_synthesis(query, all_evidence),
- data={
- "evidence_count": len(all_evidence),
- "iterations": iteration,
- "max_reached": True,
- },
- iteration=iteration,
- )
-
- def _generate_synthesis(
- self,
- query: str,
- evidence: list[Evidence],
- assessment: JudgeAssessment,
- ) -> str:
- """
- Generate the final synthesis response.
-
- Args:
- query: The original question
- evidence: All collected evidence
- assessment: The final assessment
-
- Returns:
- Formatted synthesis as markdown
- """
- drug_list = (
- "\n".join([f"- **{d}**" for d in assessment.details.drug_candidates])
- or "- No specific candidates identified"
- )
- findings_list = (
- "\n".join([f"- {f}" for f in assessment.details.key_findings]) or "- See evidence below"
- )
-
- citations = "\n".join(
- [
- f"{i + 1}. [{e.citation.title}]({e.citation.url}) "
- f"({e.citation.source.upper()}, {e.citation.date})"
- for i, e in enumerate(evidence[:10]) # Limit to 10 citations
- ]
- )
-
- return f"""## Research Analysis
-
-### Question
-{query}
-
-### Drug Candidates
-{drug_list}
-
-### Key Findings
-{findings_list}
-
-### Assessment
-- **Mechanism Score**: {assessment.details.mechanism_score}/10
-- **Clinical Evidence Score**: {assessment.details.clinical_evidence_score}/10
-- **Confidence**: {assessment.confidence:.0%}
-
-### Reasoning
-{assessment.reasoning}
-
-### Citations ({len(evidence)} sources)
-{citations}
-
----
-*Analysis based on {len(evidence)} sources across {len(self.history)} iterations.*
-"""
-
- def _generate_partial_synthesis(
- self,
- query: str,
- evidence: list[Evidence],
- ) -> str:
- """
- Generate a partial synthesis when max iterations reached.
-
- Args:
- query: The original question
- evidence: All collected evidence
-
- Returns:
- Formatted partial synthesis as markdown
- """
- citations = "\n".join(
- [
- f"{i + 1}. [{e.citation.title}]({e.citation.url}) ({e.citation.source.upper()})"
- for i, e in enumerate(evidence[:10])
- ]
- )
-
- return f"""## Partial Analysis (Max Iterations Reached)
-
-### Question
-{query}
-
-### Status
-Maximum search iterations reached. The evidence gathered may be incomplete.
-
-### Evidence Collected
-Found {len(evidence)} sources. Consider refining your query for more specific results.
-
-### Citations
-{citations}
-
----
-*Consider searching with more specific terms or drug names.*
-"""
diff --git a/src/mcp_tools.py b/src/mcp_tools.py
deleted file mode 100644
index 7b59f9be0c47a3c9ae4ebd29f2d633202d21c8ef..0000000000000000000000000000000000000000
--- a/src/mcp_tools.py
+++ /dev/null
@@ -1,301 +0,0 @@
-"""MCP tool wrappers for The DETERMINATOR search tools.
-
-These functions expose our search tools via MCP protocol.
-Each function follows the MCP tool contract:
-- Full type hints
-- Google-style docstrings with Args section
-- Formatted string returns
-"""
-
-from src.tools.clinicaltrials import ClinicalTrialsTool
-from src.tools.europepmc import EuropePMCTool
-from src.tools.pubmed import PubMedTool
-
-# Singleton instances (avoid recreating on each call)
-_pubmed = PubMedTool()
-_trials = ClinicalTrialsTool()
-_europepmc = EuropePMCTool()
-
-
-async def search_pubmed(query: str, max_results: int = 10) -> str:
- """Search PubMed for peer-reviewed biomedical literature.
-
- Searches NCBI PubMed database for scientific papers matching your query.
- Returns titles, authors, abstracts, and citation information.
-
- Args:
- query: Search query (e.g., "metformin alzheimer", "cancer treatment mechanisms")
- max_results: Maximum results to return (1-50, default 10)
-
- Returns:
- Formatted search results with paper titles, authors, dates, and abstracts
- """
- max_results = max(1, min(50, max_results)) # Clamp to valid range
-
- results = await _pubmed.search(query, max_results)
-
- if not results:
- return f"No PubMed results found for: {query}"
-
- formatted = [f"## PubMed Results for: {query}\n"]
- for i, evidence in enumerate(results, 1):
- formatted.append(f"### {i}. {evidence.citation.title}")
- formatted.append(f"**Authors**: {', '.join(evidence.citation.authors[:3])}")
- formatted.append(f"**Date**: {evidence.citation.date}")
- formatted.append(f"**URL**: {evidence.citation.url}")
- formatted.append(f"\n{evidence.content}\n")
-
- return "\n".join(formatted)
-
-
-async def search_clinical_trials(query: str, max_results: int = 10) -> str:
- """Search ClinicalTrials.gov for clinical trial data.
-
- Searches the ClinicalTrials.gov database for trials matching your query.
- Returns trial titles, phases, status, conditions, and interventions.
-
- Args:
- query: Search query (e.g., "metformin alzheimer", "diabetes phase 3")
- max_results: Maximum results to return (1-50, default 10)
-
- Returns:
- Formatted clinical trial information with NCT IDs, phases, and status
- """
- max_results = max(1, min(50, max_results))
-
- results = await _trials.search(query, max_results)
-
- if not results:
- return f"No clinical trials found for: {query}"
-
- formatted = [f"## Clinical Trials for: {query}\n"]
- for i, evidence in enumerate(results, 1):
- formatted.append(f"### {i}. {evidence.citation.title}")
- formatted.append(f"**URL**: {evidence.citation.url}")
- formatted.append(f"**Date**: {evidence.citation.date}")
- formatted.append(f"\n{evidence.content}\n")
-
- return "\n".join(formatted)
-
-
-async def search_europepmc(query: str, max_results: int = 10) -> str:
- """Search Europe PMC for preprints and papers.
-
- Searches Europe PMC, which includes bioRxiv, medRxiv, and peer-reviewed content.
- Useful for finding cutting-edge preprints and open access papers.
-
- Args:
- query: Search query (e.g., "metformin neuroprotection", "long covid treatment")
- max_results: Maximum results to return (1-50, default 10)
-
- Returns:
- Formatted results with titles, authors, and abstracts
- """
- max_results = max(1, min(50, max_results))
-
- results = await _europepmc.search(query, max_results)
-
- if not results:
- return f"No Europe PMC results found for: {query}"
-
- formatted = [f"## Europe PMC Results for: {query}\n"]
- for i, evidence in enumerate(results, 1):
- formatted.append(f"### {i}. {evidence.citation.title}")
- formatted.append(f"**Authors**: {', '.join(evidence.citation.authors[:3])}")
- formatted.append(f"**Date**: {evidence.citation.date}")
- formatted.append(f"**URL**: {evidence.citation.url}")
- formatted.append(f"\n{evidence.content}\n")
-
- return "\n".join(formatted)
-
-
-async def search_all_sources(query: str, max_per_source: int = 5) -> str:
- """Search all biomedical sources simultaneously.
-
- Performs parallel search across PubMed, ClinicalTrials.gov, and Europe PMC.
- This is the most comprehensive search option for deep medical research inquiry.
-
- Args:
- query: Search query (e.g., "metformin alzheimer", "aspirin cancer prevention")
- max_per_source: Maximum results per source (1-20, default 5)
-
- Returns:
- Combined results from all sources with source labels
- """
- import asyncio
-
- max_per_source = max(1, min(20, max_per_source))
-
- # Run all searches in parallel
- pubmed_task = search_pubmed(query, max_per_source)
- trials_task = search_clinical_trials(query, max_per_source)
- europepmc_task = search_europepmc(query, max_per_source)
-
- pubmed_results, trials_results, europepmc_results = await asyncio.gather(
- pubmed_task, trials_task, europepmc_task, return_exceptions=True
- )
-
- formatted = [f"# Comprehensive Search: {query}\n"]
-
- # Add each result section (handle exceptions gracefully)
- if isinstance(pubmed_results, str):
- formatted.append(pubmed_results)
- else:
- formatted.append(f"## PubMed\n*Error: {pubmed_results}*\n")
-
- if isinstance(trials_results, str):
- formatted.append(trials_results)
- else:
- formatted.append(f"## Clinical Trials\n*Error: {trials_results}*\n")
-
- if isinstance(europepmc_results, str):
- formatted.append(europepmc_results)
- else:
- formatted.append(f"## Europe PMC\n*Error: {europepmc_results}*\n")
-
- return "\n---\n".join(formatted)
-
-
-async def analyze_hypothesis(
- drug: str,
- condition: str,
- evidence_summary: str,
-) -> str:
- """Perform statistical analysis of research hypothesis using Modal.
-
- Executes AI-generated Python code in a secure Modal sandbox to analyze
- the statistical evidence for a research hypothesis.
-
- Args:
- drug: The drug being evaluated (e.g., "metformin")
- condition: The target condition (e.g., "Alzheimer's disease")
- evidence_summary: Summary of evidence to analyze
-
- Returns:
- Analysis result with verdict (SUPPORTED/REFUTED/INCONCLUSIVE) and statistics
- """
- from src.services.statistical_analyzer import get_statistical_analyzer
- from src.utils.config import settings
- from src.utils.models import Citation, Evidence
-
- if not settings.modal_available:
- return "Error: Modal credentials not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET."
-
- # Create evidence from summary
- evidence = [
- Evidence(
- content=evidence_summary,
- citation=Citation(
- source="pubmed",
- title=f"Evidence for {drug} in {condition}",
- url="https://example.com",
- date="2024-01-01",
- authors=["User Provided"],
- ),
- relevance=0.9,
- )
- ]
-
- analyzer = get_statistical_analyzer()
- result = await analyzer.analyze(
- query=f"Can {drug} treat {condition}?",
- evidence=evidence,
- hypothesis={"drug": drug, "target": "unknown", "pathway": "unknown", "effect": condition},
- )
-
- return f"""## Statistical Analysis: {drug} for {condition}
-
-### Verdict: **{result.verdict}**
-**Confidence**: {result.confidence:.0%}
-
-### Key Findings
-{chr(10).join(f"- {f}" for f in result.key_findings) or "- No specific findings extracted"}
-
-### Execution Output
-```
-{result.execution_output}
-```
-
-### Generated Code
-```python
-{result.code_generated}
-```
-
-**Executed in Modal Sandbox** - Isolated, secure, reproducible.
-"""
-
-
-async def extract_text_from_image(
- image_path: str, model: str | None = None, hf_token: str | None = None
-) -> str:
- """Extract text from an image using OCR.
-
- Uses the Multimodal-OCR3 Gradio Space to extract text from images.
- Supports various image formats (PNG, JPG, etc.) and can extract text
- from scanned documents, screenshots, and other image types.
-
- Args:
- image_path: Path to image file
- model: Optional model selection (default: None, uses API default)
-
- Returns:
- Extracted text from the image
- """
- from src.services.image_ocr import get_image_ocr_service
- from src.utils.config import settings
-
- try:
- ocr_service = get_image_ocr_service()
- # Use provided token or fallback to env vars
- token = hf_token or settings.hf_token or settings.huggingface_api_key
- extracted_text = await ocr_service.extract_text(image_path, model=model, hf_token=token)
-
- if not extracted_text:
- return f"No text found in image: {image_path}"
-
- return f"## Extracted Text from Image\n\n{extracted_text}"
-
- except Exception as e:
- return f"Error extracting text from image: {e}"
-
-
-async def transcribe_audio_file(
- audio_path: str,
- source_lang: str | None = None,
- target_lang: str | None = None,
- hf_token: str | None = None,
-) -> str:
- """Transcribe audio file to text using speech-to-text.
-
- Uses the NVIDIA Canary Gradio Space to transcribe audio files.
- Supports various audio formats (WAV, MP3, etc.) and multiple languages.
-
- Args:
- audio_path: Path to audio file
- source_lang: Source language (default: "English")
- target_lang: Target language (default: "English")
-
- Returns:
- Transcribed text from the audio file
- """
- from src.services.stt_gradio import get_stt_service
- from src.utils.config import settings
-
- try:
- stt_service = get_stt_service()
- # Use provided token or fallback to env vars
- token = hf_token or settings.hf_token or settings.huggingface_api_key
- transcribed_text = await stt_service.transcribe_file(
- audio_path,
- source_lang=source_lang,
- target_lang=target_lang,
- hf_token=token,
- )
-
- if not transcribed_text:
- return f"No transcription found in audio: {audio_path}"
-
- return f"## Audio Transcription\n\n{transcribed_text}"
-
- except Exception as e:
- return f"Error transcribing audio: {e}"
diff --git a/src/middleware/__init__.py b/src/middleware/__init__.py
deleted file mode 100644
index 2d296a27b28bd09955d82289397d4d889ece5b62..0000000000000000000000000000000000000000
--- a/src/middleware/__init__.py
+++ /dev/null
@@ -1,30 +0,0 @@
-"""Middleware for workflow state management, parallel loop coordination, and budget tracking.
-
-This module provides:
-- WorkflowState: Thread-safe state management using ContextVar
-- WorkflowManager: Coordination of parallel research loops
-- BudgetTracker: Token, time, and iteration budget tracking
-"""
-
-from src.middleware.budget_tracker import BudgetStatus, BudgetTracker
-from src.middleware.state_machine import (
- WorkflowState,
- get_workflow_state,
- init_workflow_state,
-)
-from src.middleware.workflow_manager import (
- LoopStatus,
- ResearchLoop,
- WorkflowManager,
-)
-
-__all__ = [
- "BudgetStatus",
- "BudgetTracker",
- "LoopStatus",
- "ResearchLoop",
- "WorkflowManager",
- "WorkflowState",
- "get_workflow_state",
- "init_workflow_state",
-]
diff --git a/src/middleware/budget_tracker.py b/src/middleware/budget_tracker.py
deleted file mode 100644
index 7f7c89ece57b73c3c1a526a43238480f5e65fed5..0000000000000000000000000000000000000000
--- a/src/middleware/budget_tracker.py
+++ /dev/null
@@ -1,390 +0,0 @@
-"""Budget tracking for research loops.
-
-Tracks token usage, time elapsed, and iteration counts per loop and globally.
-Enforces budget constraints to prevent infinite loops and excessive resource usage.
-"""
-
-import time
-
-import structlog
-from pydantic import BaseModel, Field
-
-logger = structlog.get_logger()
-
-
-class BudgetStatus(BaseModel):
- """Status of a budget (tokens, time, iterations)."""
-
- tokens_used: int = Field(default=0, description="Total tokens used")
- tokens_limit: int = Field(default=100000, description="Token budget limit", ge=0)
- time_elapsed_seconds: float = Field(default=0.0, description="Time elapsed", ge=0.0)
- time_limit_seconds: float = Field(
- default=600.0, description="Time budget limit (10 min default)", ge=0.0
- )
- iterations: int = Field(default=0, description="Number of iterations completed", ge=0)
- iterations_limit: int = Field(default=10, description="Maximum iterations", ge=1)
- iteration_tokens: dict[int, int] = Field(
- default_factory=dict,
- description="Tokens used per iteration (iteration number -> token count)",
- )
-
- def is_exceeded(self) -> bool:
- """Check if any budget limit has been exceeded.
-
- Returns:
- True if any limit is exceeded, False otherwise.
- """
- return (
- self.tokens_used >= self.tokens_limit
- or self.time_elapsed_seconds >= self.time_limit_seconds
- or self.iterations >= self.iterations_limit
- )
-
- def remaining_tokens(self) -> int:
- """Get remaining token budget.
-
- Returns:
- Remaining tokens (may be negative if exceeded).
- """
- return self.tokens_limit - self.tokens_used
-
- def remaining_time_seconds(self) -> float:
- """Get remaining time budget.
-
- Returns:
- Remaining time in seconds (may be negative if exceeded).
- """
- return self.time_limit_seconds - self.time_elapsed_seconds
-
- def remaining_iterations(self) -> int:
- """Get remaining iteration budget.
-
- Returns:
- Remaining iterations (may be negative if exceeded).
- """
- return self.iterations_limit - self.iterations
-
- def add_iteration_tokens(self, iteration: int, tokens: int) -> None:
- """Add tokens for a specific iteration.
-
- Args:
- iteration: Iteration number (1-indexed).
- tokens: Number of tokens to add.
- """
- if iteration not in self.iteration_tokens:
- self.iteration_tokens[iteration] = 0
- self.iteration_tokens[iteration] += tokens
- # Also add to total tokens
- self.tokens_used += tokens
-
- def get_iteration_tokens(self, iteration: int) -> int:
- """Get tokens used for a specific iteration.
-
- Args:
- iteration: Iteration number.
-
- Returns:
- Token count for the iteration, or 0 if not found.
- """
- return self.iteration_tokens.get(iteration, 0)
-
-
-class BudgetTracker:
- """Tracks budgets per loop and globally."""
-
- def __init__(self) -> None:
- """Initialize the budget tracker."""
- self._budgets: dict[str, BudgetStatus] = {}
- self._start_times: dict[str, float] = {}
- self._global_budget: BudgetStatus | None = None
-
- def create_budget(
- self,
- loop_id: str,
- tokens_limit: int = 100000,
- time_limit_seconds: float = 600.0,
- iterations_limit: int = 10,
- ) -> BudgetStatus:
- """Create a budget for a specific loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- tokens_limit: Maximum tokens allowed.
- time_limit_seconds: Maximum time allowed in seconds.
- iterations_limit: Maximum iterations allowed.
-
- Returns:
- The created BudgetStatus instance.
- """
- budget = BudgetStatus(
- tokens_limit=tokens_limit,
- time_limit_seconds=time_limit_seconds,
- iterations_limit=iterations_limit,
- )
- self._budgets[loop_id] = budget
- logger.debug(
- "Budget created",
- loop_id=loop_id,
- tokens_limit=tokens_limit,
- time_limit=time_limit_seconds,
- iterations_limit=iterations_limit,
- )
- return budget
-
- def get_budget(self, loop_id: str) -> BudgetStatus | None:
- """Get the budget for a specific loop.
-
- Args:
- loop_id: Unique identifier for the loop.
-
- Returns:
- The BudgetStatus instance, or None if not found.
- """
- return self._budgets.get(loop_id)
-
- def add_tokens(self, loop_id: str, tokens: int) -> None:
- """Add tokens to a loop's budget.
-
- Args:
- loop_id: Unique identifier for the loop.
- tokens: Number of tokens to add (can be negative).
- """
- if loop_id not in self._budgets:
- logger.warning("Budget not found for loop", loop_id=loop_id)
- return
- self._budgets[loop_id].tokens_used += tokens
- logger.debug("Tokens added", loop_id=loop_id, tokens=tokens)
-
- def add_iteration_tokens(self, loop_id: str, iteration: int, tokens: int) -> None:
- """Add tokens for a specific iteration.
-
- Args:
- loop_id: Loop identifier.
- iteration: Iteration number (1-indexed).
- tokens: Number of tokens to add.
- """
- if loop_id not in self._budgets:
- logger.warning("Budget not found for loop", loop_id=loop_id)
- return
-
- budget = self._budgets[loop_id]
- budget.add_iteration_tokens(iteration, tokens)
-
- logger.debug(
- "Iteration tokens added",
- loop_id=loop_id,
- iteration=iteration,
- tokens=tokens,
- total_iteration=budget.get_iteration_tokens(iteration),
- )
-
- def get_iteration_tokens(self, loop_id: str, iteration: int) -> int:
- """Get tokens used for a specific iteration.
-
- Args:
- loop_id: Loop identifier.
- iteration: Iteration number.
-
- Returns:
- Token count for the iteration, or 0 if not found.
- """
- if loop_id not in self._budgets:
- return 0
-
- return self._budgets[loop_id].get_iteration_tokens(iteration)
-
- def start_timer(self, loop_id: str) -> None:
- """Start the timer for a loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- """
- self._start_times[loop_id] = time.time()
- logger.debug("Timer started", loop_id=loop_id)
-
- def update_timer(self, loop_id: str) -> None:
- """Update the elapsed time for a loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- """
- if loop_id not in self._start_times:
- logger.warning("Timer not started for loop", loop_id=loop_id)
- return
- if loop_id not in self._budgets:
- logger.warning("Budget not found for loop", loop_id=loop_id)
- return
-
- elapsed = time.time() - self._start_times[loop_id]
- self._budgets[loop_id].time_elapsed_seconds = elapsed
- logger.debug("Timer updated", loop_id=loop_id, elapsed=elapsed)
-
- def increment_iteration(self, loop_id: str) -> None:
- """Increment the iteration count for a loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- """
- if loop_id not in self._budgets:
- logger.warning("Budget not found for loop", loop_id=loop_id)
- return
- self._budgets[loop_id].iterations += 1
- logger.debug(
- "Iteration incremented",
- loop_id=loop_id,
- iterations=self._budgets[loop_id].iterations,
- )
-
- def check_budget(self, loop_id: str) -> tuple[bool, str]:
- """Check if a loop's budget has been exceeded.
-
- Args:
- loop_id: Unique identifier for the loop.
-
- Returns:
- Tuple of (exceeded: bool, reason: str). Reason is empty if not exceeded.
- """
- if loop_id not in self._budgets:
- return False, ""
-
- budget = self._budgets[loop_id]
- self.update_timer(loop_id) # Update time before checking
-
- if budget.is_exceeded():
- reasons = []
- if budget.tokens_used >= budget.tokens_limit:
- reasons.append("tokens")
- if budget.time_elapsed_seconds >= budget.time_limit_seconds:
- reasons.append("time")
- if budget.iterations >= budget.iterations_limit:
- reasons.append("iterations")
- reason = f"Budget exceeded: {', '.join(reasons)}"
- logger.warning("Budget exceeded", loop_id=loop_id, reason=reason)
- return True, reason
-
- return False, ""
-
- def can_continue(self, loop_id: str) -> bool:
- """Check if a loop can continue based on budget.
-
- Args:
- loop_id: Unique identifier for the loop.
-
- Returns:
- True if the loop can continue, False if budget is exceeded.
- """
- exceeded, _ = self.check_budget(loop_id)
- return not exceeded
-
- def get_budget_summary(self, loop_id: str) -> str:
- """Get a formatted summary of a loop's budget status.
-
- Args:
- loop_id: Unique identifier for the loop.
-
- Returns:
- Formatted string summary.
- """
- if loop_id not in self._budgets:
- return f"Budget not found for loop: {loop_id}"
-
- budget = self._budgets[loop_id]
- self.update_timer(loop_id)
-
- return (
- f"Loop {loop_id}: "
- f"Tokens: {budget.tokens_used}/{budget.tokens_limit} "
- f"({budget.remaining_tokens()} remaining), "
- f"Time: {budget.time_elapsed_seconds:.1f}/{budget.time_limit_seconds:.1f}s "
- f"({budget.remaining_time_seconds():.1f}s remaining), "
- f"Iterations: {budget.iterations}/{budget.iterations_limit} "
- f"({budget.remaining_iterations()} remaining)"
- )
-
- def reset_budget(self, loop_id: str) -> None:
- """Reset the budget for a loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- """
- if loop_id in self._budgets:
- old_budget = self._budgets[loop_id]
- # Preserve iteration_tokens when resetting
- old_iteration_tokens = old_budget.iteration_tokens
- self._budgets[loop_id] = BudgetStatus(
- tokens_limit=old_budget.tokens_limit,
- time_limit_seconds=old_budget.time_limit_seconds,
- iterations_limit=old_budget.iterations_limit,
- iteration_tokens=old_iteration_tokens, # Restore old iteration tokens
- )
- if loop_id in self._start_times:
- self._start_times[loop_id] = time.time()
- logger.debug("Budget reset", loop_id=loop_id)
-
- def set_global_budget(
- self,
- tokens_limit: int = 100000,
- time_limit_seconds: float = 600.0,
- iterations_limit: int = 10,
- ) -> None:
- """Set a global budget that applies to all loops.
-
- Args:
- tokens_limit: Maximum tokens allowed globally.
- time_limit_seconds: Maximum time allowed in seconds.
- iterations_limit: Maximum iterations allowed globally.
- """
- self._global_budget = BudgetStatus(
- tokens_limit=tokens_limit,
- time_limit_seconds=time_limit_seconds,
- iterations_limit=iterations_limit,
- )
- logger.debug(
- "Global budget set",
- tokens_limit=tokens_limit,
- time_limit=time_limit_seconds,
- iterations_limit=iterations_limit,
- )
-
- def get_global_budget(self) -> BudgetStatus | None:
- """Get the global budget.
-
- Returns:
- The global BudgetStatus instance, or None if not set.
- """
- return self._global_budget
-
- def add_global_tokens(self, tokens: int) -> None:
- """Add tokens to the global budget.
-
- Args:
- tokens: Number of tokens to add (can be negative).
- """
- if self._global_budget is None:
- logger.warning("Global budget not set")
- return
- self._global_budget.tokens_used += tokens
- logger.debug("Global tokens added", tokens=tokens)
-
- def estimate_tokens(self, text: str) -> int:
- """Estimate token count from text (rough estimate: ~4 chars per token).
-
- Args:
- text: Text to estimate tokens for.
-
- Returns:
- Estimated token count.
- """
- return len(text) // 4
-
- def estimate_llm_call_tokens(self, prompt: str, response: str) -> int:
- """Estimate token count for an LLM call.
-
- Args:
- prompt: The prompt text.
- response: The response text.
-
- Returns:
- Estimated total token count (prompt + response).
- """
- return self.estimate_tokens(prompt) + self.estimate_tokens(response)
diff --git a/src/middleware/state_machine.py b/src/middleware/state_machine.py
deleted file mode 100644
index 8fbdf793b5ebc2d64a98501f543de4a3979a8132..0000000000000000000000000000000000000000
--- a/src/middleware/state_machine.py
+++ /dev/null
@@ -1,171 +0,0 @@
-"""Thread-safe state management for workflow agents.
-
-Uses contextvars to ensure isolation between concurrent requests (e.g., multiple users
-searching simultaneously via Gradio). Refactored from MagenticState to support both
-iterative and deep research patterns.
-"""
-
-from contextvars import ContextVar
-from typing import TYPE_CHECKING, Any
-
-import structlog
-from pydantic import BaseModel, Field
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-
-from src.utils.models import Citation, Conversation, Evidence
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
-
-logger = structlog.get_logger()
-
-
-class WorkflowState(BaseModel):
- """Mutable state for a workflow session.
-
- Supports both iterative and deep research patterns by tracking evidence,
- conversation history, and providing semantic search capabilities.
- """
-
- evidence: list[Evidence] = Field(default_factory=list)
- conversation: Conversation = Field(default_factory=Conversation)
- user_message_history: list[ModelMessage] = Field(
- default_factory=list,
- description="User conversation history (multi-turn interactions)",
- )
- # Type as Any to avoid circular imports/runtime resolution issues
- # The actual object injected will be an EmbeddingService instance
- embedding_service: Any = Field(default=None)
-
- model_config = {"arbitrary_types_allowed": True}
-
- def add_evidence(self, new_evidence: list[Evidence]) -> int:
- """Add new evidence, deduplicating by URL.
-
- Args:
- new_evidence: List of Evidence objects to add.
-
- Returns:
- Number of *new* items added (excluding duplicates).
- """
- existing_urls = {e.citation.url for e in self.evidence}
- count = 0
- for item in new_evidence:
- if item.citation.url not in existing_urls:
- self.evidence.append(item)
- existing_urls.add(item.citation.url)
- count += 1
- return count
-
- async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]:
- """Search for semantically related evidence using the embedding service.
-
- Args:
- query: Search query string.
- n_results: Maximum number of results to return.
-
- Returns:
- List of Evidence objects, ordered by relevance.
- """
- if not self.embedding_service:
- logger.warning("Embedding service not available, returning empty results")
- return []
-
- results = await self.embedding_service.search_similar(query, n_results=n_results)
-
- # Convert dict results back to Evidence objects
- evidence_list = []
- for item in results:
- meta = item.get("metadata", {})
- authors_str = meta.get("authors", "")
- authors = [a.strip() for a in authors_str.split(",") if a.strip()]
-
- ev = Evidence(
- content=item["content"],
- citation=Citation(
- title=meta.get("title", "Related Evidence"),
- url=item["id"],
- source="pubmed", # Defaulting to pubmed if unknown
- date=meta.get("date", "n.d."),
- authors=authors,
- ),
- relevance=max(0.0, 1.0 - item.get("distance", 0.5)),
- )
- evidence_list.append(ev)
-
- return evidence_list
-
- def add_user_message(self, message: ModelMessage) -> None:
- """Add a user message to conversation history.
-
- Args:
- message: Message to add
- """
- self.user_message_history.append(message)
-
- def get_user_history(self, max_messages: int | None = None) -> list[ModelMessage]:
- """Get user conversation history.
-
- Args:
- max_messages: Maximum messages to return (None for all)
-
- Returns:
- List of messages
- """
- if max_messages is None:
- return self.user_message_history.copy()
- return (
- self.user_message_history[-max_messages:]
- if len(self.user_message_history) > max_messages
- else self.user_message_history.copy()
- )
-
-
-# The ContextVar holds the WorkflowState for the current execution context
-_workflow_state_var: ContextVar[WorkflowState | None] = ContextVar("workflow_state", default=None)
-
-
-def init_workflow_state(
- embedding_service: "EmbeddingService | None" = None,
- message_history: list[ModelMessage] | None = None,
-) -> WorkflowState:
- """Initialize a new state for the current context.
-
- Args:
- embedding_service: Optional embedding service for semantic search.
- message_history: Optional user conversation history.
-
- Returns:
- The initialized WorkflowState instance.
- """
- state = WorkflowState(embedding_service=embedding_service)
- if message_history:
- state.user_message_history = message_history.copy()
- _workflow_state_var.set(state)
- logger.debug(
- "Workflow state initialized",
- has_embeddings=embedding_service is not None,
- has_history=bool(message_history),
- )
- return state
-
-
-def get_workflow_state() -> WorkflowState:
- """Get the current state. Auto-initializes if not set.
-
- Returns:
- The current WorkflowState instance.
-
- Raises:
- RuntimeError: If state is not initialized and auto-initialization fails.
- """
- state = _workflow_state_var.get()
- if state is None:
- # Auto-initialize if missing (e.g. during tests or simple scripts)
- logger.debug("Workflow state not found, auto-initializing")
- return init_workflow_state()
- return state
\ No newline at end of file
diff --git a/src/middleware/sub_iteration.py b/src/middleware/sub_iteration.py
deleted file mode 100644
index 801a3686a6d023c39615d01548766e4c24098c66..0000000000000000000000000000000000000000
--- a/src/middleware/sub_iteration.py
+++ /dev/null
@@ -1,135 +0,0 @@
-"""Middleware for orchestrating sub-iterations with research teams and judges."""
-
-from typing import Any, Protocol
-
-import structlog
-
-from src.utils.models import AgentEvent, JudgeAssessment
-
-logger = structlog.get_logger()
-
-
-class SubIterationTeam(Protocol):
- """Protocol for a research team that executes a sub-task."""
-
- async def execute(self, task: str) -> Any:
- """Execute the sub-task and return a result."""
- ...
-
-
-class SubIterationJudge(Protocol):
- """Protocol for a judge that evaluates the sub-task result."""
-
- async def assess(self, task: str, result: Any, history: list[Any]) -> JudgeAssessment:
- """Assess the quality of the result."""
- ...
-
-
-class SubIterationMiddleware:
- """
- Middleware that manages a sub-iteration loop:
- 1. Orchestrator delegates to a Research Team.
- 2. Research Team produces a result.
- 3. Judge evaluates the result.
- 4. Loop continues until Judge approves or max iterations reached.
- """
-
- def __init__(
- self,
- team: SubIterationTeam,
- judge: SubIterationJudge,
- max_iterations: int = 3,
- ):
- self.team = team
- self.judge = judge
- self.max_iterations = max_iterations
-
- async def run(
- self,
- task: str,
- event_callback: Any = None, # Optional callback for streaming events
- ) -> tuple[Any, JudgeAssessment | None]:
- """
- Run the sub-iteration loop.
-
- Args:
- task: The research task or question.
- event_callback: Async callable to report events (e.g. to UI).
-
- Returns:
- Tuple of (best_result, final_assessment).
- """
- history: list[Any] = []
- best_result: Any = None
- final_assessment: JudgeAssessment | None = None
-
- for i in range(1, self.max_iterations + 1):
- logger.info("Sub-iteration starting", iteration=i, task=task)
-
- if event_callback:
- await event_callback(
- AgentEvent(
- type="looping",
- message=f"Sub-iteration {i}: Executing task...",
- iteration=i,
- )
- )
-
- # 1. Team Execution
- try:
- result = await self.team.execute(task)
- history.append(result)
- best_result = result # Assume latest is best for now
- except Exception as e:
- logger.error("Sub-iteration execution failed", error=str(e))
- if event_callback:
- await event_callback(
- AgentEvent(
- type="error",
- message=f"Sub-iteration execution failed: {e}",
- iteration=i,
- )
- )
- return best_result, final_assessment
-
- # 2. Judge Assessment
- try:
- assessment = await self.judge.assess(task, result, history)
- final_assessment = assessment
- except Exception as e:
- logger.error("Sub-iteration judge failed", error=str(e))
- if event_callback:
- await event_callback(
- AgentEvent(
- type="error",
- message=f"Sub-iteration judge failed: {e}",
- iteration=i,
- )
- )
- return best_result, final_assessment
-
- # 3. Decision
- if assessment.sufficient:
- logger.info("Sub-iteration sufficient", iteration=i)
- return best_result, assessment
-
- # If not sufficient, we might refine the task for the next iteration
- # For this implementation, we assume the team is smart enough or the task stays same
- # but we could append feedback to the task.
-
- feedback = assessment.reasoning
- logger.info("Sub-iteration insufficient", feedback=feedback)
-
- if event_callback:
- await event_callback(
- AgentEvent(
- type="looping",
- message=(
- f"Sub-iteration {i} result insufficient. Feedback: {feedback[:100]}..."
- ),
- iteration=i,
- )
- )
-
- logger.warning("Sub-iteration max iterations reached", task=task)
- return best_result, final_assessment
diff --git a/src/middleware/workflow_manager.py b/src/middleware/workflow_manager.py
deleted file mode 100644
index b817cbf27a49fd19b4ad682416eae61cf4d90a1d..0000000000000000000000000000000000000000
--- a/src/middleware/workflow_manager.py
+++ /dev/null
@@ -1,322 +0,0 @@
-"""Workflow manager for coordinating parallel research loops.
-
-Manages multiple research loops running in parallel, tracks their status,
-and synchronizes evidence between loops and the global state.
-"""
-
-import asyncio
-from collections.abc import Callable
-from typing import Any, Literal
-
-import structlog
-from pydantic import BaseModel, Field
-
-from src.middleware.state_machine import get_workflow_state
-from src.utils.models import Evidence
-
-logger = structlog.get_logger()
-
-LoopStatus = Literal["pending", "running", "completed", "failed", "cancelled"]
-
-
-class ResearchLoop(BaseModel):
- """Represents a single research loop."""
-
- loop_id: str = Field(description="Unique identifier for the loop")
- query: str = Field(description="The research query for this loop")
- status: LoopStatus = Field(default="pending")
- evidence: list[Evidence] = Field(default_factory=list)
- iteration_count: int = Field(default=0, ge=0)
- error: str | None = Field(default=None)
-
- model_config = {"frozen": False} # Mutable for status updates
-
-
-class WorkflowManager:
- """Manages parallel research loops and state synchronization."""
-
- def __init__(self) -> None:
- """Initialize the workflow manager."""
- self._loops: dict[str, ResearchLoop] = {}
-
- async def add_loop(self, loop_id: str, query: str) -> ResearchLoop:
- """Add a new research loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- query: The research query for this loop.
-
- Returns:
- The created ResearchLoop instance.
- """
- loop = ResearchLoop(loop_id=loop_id, query=query, status="pending")
- self._loops[loop_id] = loop
- logger.info("Loop added", loop_id=loop_id, query=query)
- return loop
-
- async def get_loop(self, loop_id: str) -> ResearchLoop | None:
- """Get a research loop by ID.
-
- Args:
- loop_id: Unique identifier for the loop.
-
- Returns:
- The ResearchLoop instance, or None if not found.
- """
- return self._loops.get(loop_id)
-
- async def update_loop_status(
- self, loop_id: str, status: LoopStatus, error: str | None = None
- ) -> None:
- """Update the status of a research loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- status: New status for the loop.
- error: Optional error message if status is "failed".
- """
- if loop_id not in self._loops:
- logger.warning("Loop not found", loop_id=loop_id)
- return
-
- self._loops[loop_id].status = status
- if error:
- self._loops[loop_id].error = error
- logger.info("Loop status updated", loop_id=loop_id, status=status)
-
- async def add_loop_evidence(self, loop_id: str, evidence: list[Evidence]) -> None:
- """Add evidence to a research loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- evidence: List of Evidence objects to add.
- """
- if loop_id not in self._loops:
- logger.warning("Loop not found", loop_id=loop_id)
- return
-
- self._loops[loop_id].evidence.extend(evidence)
- logger.debug(
- "Evidence added to loop",
- loop_id=loop_id,
- evidence_count=len(evidence),
- )
-
- async def increment_loop_iteration(self, loop_id: str) -> None:
- """Increment the iteration count for a research loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- """
- if loop_id not in self._loops:
- logger.warning("Loop not found", loop_id=loop_id)
- return
-
- self._loops[loop_id].iteration_count += 1
- logger.debug(
- "Iteration incremented",
- loop_id=loop_id,
- iteration=self._loops[loop_id].iteration_count,
- )
-
- async def run_loops_parallel(
- self,
- loop_configs: list[dict[str, Any]],
- loop_func: Callable[[dict[str, Any]], Any],
- judge_handler: Any | None = None,
- budget_tracker: Any | None = None,
- ) -> list[Any]:
- """Run multiple research loops in parallel.
-
- Args:
- loop_configs: List of configuration dicts, each must contain 'loop_id' and 'query'.
- loop_func: Async function that takes a config dict and returns loop results.
- judge_handler: Optional JudgeHandler for early termination based on evidence sufficiency.
- budget_tracker: Optional BudgetTracker for budget enforcement.
-
- Returns:
- List of results from each loop (in order of completion, not original order).
- """
- logger.info("Starting parallel loops", loop_count=len(loop_configs))
-
- # Create loops
- for config in loop_configs:
- loop_id = config.get("loop_id")
- query = config.get("query", "")
- if loop_id:
- await self.add_loop(loop_id, query)
- await self.update_loop_status(loop_id, "running")
-
- # Run loops in parallel
- async def run_single_loop(config: dict[str, Any]) -> Any:
- loop_id = config.get("loop_id", "unknown")
- query = config.get("query", "")
- try:
- # Check budget before starting
- if budget_tracker:
- exceeded, reason = budget_tracker.check_budget(loop_id)
- if exceeded:
- await self.update_loop_status(loop_id, "cancelled", error=reason)
- logger.warning(
- "Loop cancelled due to budget", loop_id=loop_id, reason=reason
- )
- return None
-
- # If loop_func supports periodic checkpoints, we could check judge here
- # For now, the loop_func itself handles judge checks internally
- result = await loop_func(config)
-
- # Final check with judge if available
- if judge_handler and query:
- should_complete, reason = await self.check_loop_completion(
- loop_id, query, judge_handler
- )
- if should_complete:
- logger.info(
- "Loop completed early based on judge assessment",
- loop_id=loop_id,
- reason=reason,
- )
-
- await self.update_loop_status(loop_id, "completed")
- return result
- except Exception as e:
- error_msg = str(e)
- await self.update_loop_status(loop_id, "failed", error=error_msg)
- logger.error("Loop failed", loop_id=loop_id, error=error_msg)
- raise
-
- results = await asyncio.gather(
- *(run_single_loop(config) for config in loop_configs),
- return_exceptions=True,
- )
-
- # Log completion
- completed = sum(1 for r in results if not isinstance(r, Exception))
- failed = len(results) - completed
- logger.info(
- "Parallel loops completed",
- total=len(loop_configs),
- completed=completed,
- failed=failed,
- )
-
- return results
-
- async def wait_for_loops(
- self, loop_ids: list[str], timeout: float | None = None
- ) -> list[ResearchLoop]:
- """Wait for loops to complete.
-
- Args:
- loop_ids: List of loop IDs to wait for.
- timeout: Optional timeout in seconds.
-
- Returns:
- List of ResearchLoop instances (may be incomplete if timeout occurs).
- """
- start_time = asyncio.get_event_loop().time()
-
- while True:
- loops = [self._loops.get(loop_id) for loop_id in loop_ids]
- all_complete = all(
- loop and loop.status in ("completed", "failed", "cancelled") for loop in loops
- )
-
- if all_complete:
- return [loop for loop in loops if loop is not None]
-
- if timeout is not None:
- elapsed = asyncio.get_event_loop().time() - start_time
- if elapsed >= timeout:
- logger.warning("Timeout waiting for loops", timeout=timeout)
- return [loop for loop in loops if loop is not None]
-
- await asyncio.sleep(0.1) # Small delay to avoid busy waiting
-
- async def cancel_loop(self, loop_id: str) -> None:
- """Cancel a research loop.
-
- Args:
- loop_id: Unique identifier for the loop.
- """
- await self.update_loop_status(loop_id, "cancelled")
- logger.info("Loop cancelled", loop_id=loop_id)
-
- async def get_all_loops(self) -> list[ResearchLoop]:
- """Get all research loops.
-
- Returns:
- List of all ResearchLoop instances.
- """
- return list(self._loops.values())
-
- async def sync_loop_evidence_to_state(self, loop_id: str) -> None:
- """Synchronize evidence from a loop to the global state.
-
- Args:
- loop_id: Unique identifier for the loop.
- """
- if loop_id not in self._loops:
- logger.warning("Loop not found", loop_id=loop_id)
- return
-
- loop = self._loops[loop_id]
- state = get_workflow_state()
- added_count = state.add_evidence(loop.evidence)
- logger.debug(
- "Loop evidence synced to state",
- loop_id=loop_id,
- evidence_count=len(loop.evidence),
- added_count=added_count,
- )
-
- async def get_shared_evidence(self) -> list[Evidence]:
- """Get evidence from the global state.
-
- Returns:
- List of Evidence objects from the global state.
- """
- state = get_workflow_state()
- return state.evidence
-
- async def get_loop_evidence(self, loop_id: str) -> list[Evidence]:
- """Get evidence collected by a specific loop.
-
- Args:
- loop_id: Loop identifier.
-
- Returns:
- List of Evidence objects from the loop.
- """
- if loop_id not in self._loops:
- return []
-
- return self._loops[loop_id].evidence
-
- async def check_loop_completion(
- self, loop_id: str, query: str, judge_handler: Any
- ) -> tuple[bool, str]:
- """Check if a loop should complete using judge assessment.
-
- Args:
- loop_id: Loop identifier.
- query: Research query.
- judge_handler: JudgeHandler instance.
-
- Returns:
- Tuple of (should_complete: bool, reason: str).
- """
- evidence = await self.get_loop_evidence(loop_id)
-
- if not evidence:
- return False, "No evidence collected yet"
-
- try:
- assessment = await judge_handler.assess(query, evidence)
- if assessment.sufficient:
- return True, f"Judge assessment: {assessment.reasoning}"
- return False, f"Judge assessment: {assessment.reasoning}"
- except Exception as e:
- logger.error("Judge assessment failed", error=str(e), loop_id=loop_id)
- return False, f"Judge assessment failed: {e!s}"
diff --git a/src/__init__.py b/src/orchestrator.py
similarity index 100%
rename from src/__init__.py
rename to src/orchestrator.py
diff --git a/src/orchestrator/__init__.py b/src/orchestrator/__init__.py
deleted file mode 100644
index e71f737ab54ccc3391d909d170f6938ac0a9836b..0000000000000000000000000000000000000000
--- a/src/orchestrator/__init__.py
+++ /dev/null
@@ -1,48 +0,0 @@
-"""Orchestrator module for research flows and planner agent.
-
-This module provides:
-- PlannerAgent: Creates report plans with sections
-- IterativeResearchFlow: Single research loop pattern
-- DeepResearchFlow: Parallel research loops pattern
-- GraphOrchestrator: Stub for Phase 4 (uses agent chains for now)
-- Protocols: SearchHandlerProtocol, JudgeHandlerProtocol (re-exported from legacy_orchestrator)
-- Orchestrator: Legacy orchestrator class (re-exported from legacy_orchestrator)
-"""
-
-from typing import TYPE_CHECKING
-
-# Re-export protocols and Orchestrator from legacy_orchestrator for backward compatibility
-from src.legacy_orchestrator import (
- JudgeHandlerProtocol,
- Orchestrator,
- SearchHandlerProtocol,
-)
-
-# Lazy imports to avoid circular dependencies
-if TYPE_CHECKING:
- from src.orchestrator.graph_orchestrator import GraphOrchestrator
- from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent
- from src.orchestrator.research_flow import (
- DeepResearchFlow,
- IterativeResearchFlow,
- )
-
-# Public exports
-from src.orchestrator.graph_orchestrator import (
- GraphOrchestrator,
- create_graph_orchestrator,
-)
-from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent
-from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
-
-__all__ = [
- "DeepResearchFlow",
- "GraphOrchestrator",
- "IterativeResearchFlow",
- "JudgeHandlerProtocol",
- "Orchestrator",
- "PlannerAgent",
- "SearchHandlerProtocol",
- "create_graph_orchestrator",
- "create_planner_agent",
-]
diff --git a/src/orchestrator/graph_orchestrator.py b/src/orchestrator/graph_orchestrator.py
deleted file mode 100644
index d71a337dde1929301a70217f99e560c5413e34fc..0000000000000000000000000000000000000000
--- a/src/orchestrator/graph_orchestrator.py
+++ /dev/null
@@ -1,1751 +0,0 @@
-"""Graph orchestrator for Phase 4.
-
-Implements graph-based orchestration using Pydantic AI agents as nodes.
-Supports both iterative and deep research patterns with parallel execution.
-"""
-
-import asyncio
-from collections.abc import AsyncGenerator, Callable
-from typing import TYPE_CHECKING, Any, Literal
-
-import structlog
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-
-from src.agent_factory.agents import (
- create_input_parser_agent,
- create_knowledge_gap_agent,
- create_long_writer_agent,
- create_planner_agent,
- create_thinking_agent,
- create_tool_selector_agent,
- create_writer_agent,
-)
-from src.agent_factory.graph_builder import (
- AgentNode,
- DecisionNode,
- ParallelNode,
- ResearchGraph,
- StateNode,
- create_deep_graph,
- create_iterative_graph,
-)
-from src.legacy_orchestrator import JudgeHandlerProtocol, SearchHandlerProtocol
-from src.middleware.budget_tracker import BudgetTracker
-from src.middleware.state_machine import WorkflowState, init_workflow_state
-from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
-from src.services.report_file_service import ReportFileService, get_report_file_service
-from src.utils.models import AgentEvent
-
-if TYPE_CHECKING:
- pass
-
-logger = structlog.get_logger()
-
-
-class GraphExecutionContext:
- """Context for managing graph execution state."""
-
- def __init__(
- self,
- state: WorkflowState,
- budget_tracker: BudgetTracker,
- message_history: list[ModelMessage] | None = None,
- ) -> None:
- """Initialize execution context.
-
- Args:
- state: Current workflow state
- budget_tracker: Budget tracker instance
- message_history: Optional user conversation history
- """
- self.current_node: str = ""
- self.visited_nodes: set[str] = set()
- self.node_results: dict[str, Any] = {}
- self.state = state
- self.budget_tracker = budget_tracker
- self.iteration_count = 0
- self.message_history: list[ModelMessage] = message_history or []
-
- def set_node_result(self, node_id: str, result: Any) -> None:
- """Store result from node execution.
-
- Args:
- node_id: The node ID
- result: The execution result
- """
- self.node_results[node_id] = result
-
- def get_node_result(self, node_id: str) -> Any:
- """Get result from node execution.
-
- Args:
- node_id: The node ID
-
- Returns:
- The stored result, or None if not found
- """
- return self.node_results.get(node_id)
-
- def has_visited(self, node_id: str) -> bool:
- """Check if node was visited.
-
- Args:
- node_id: The node ID
-
- Returns:
- True if visited, False otherwise
- """
- return node_id in self.visited_nodes
-
- def mark_visited(self, node_id: str) -> None:
- """Mark node as visited.
-
- Args:
- node_id: The node ID
- """
- self.visited_nodes.add(node_id)
-
- def update_state(
- self, updater: Callable[[WorkflowState, Any], WorkflowState], data: Any
- ) -> None:
- """Update workflow state.
-
- Args:
- updater: Function to update state
- data: Data to pass to updater
- """
- self.state = updater(self.state, data)
-
- def add_message(self, message: ModelMessage) -> None:
- """Add a message to the history.
-
- Args:
- message: Message to add
- """
- self.message_history.append(message)
-
- def get_message_history(self, max_messages: int | None = None) -> list[ModelMessage]:
- """Get message history, optionally truncated.
-
- Args:
- max_messages: Maximum messages to return (None for all)
-
- Returns:
- List of messages
- """
- if max_messages is None:
- return self.message_history.copy()
- return (
- self.message_history[-max_messages:]
- if len(self.message_history) > max_messages
- else self.message_history.copy()
- )
-
-
-class GraphOrchestrator:
- """
- Graph orchestrator using Pydantic AI Graphs.
-
- Executes research workflows as graphs with nodes (agents) and edges (transitions).
- Supports parallel execution, conditional routing, and state management.
- """
-
- def __init__(
- self,
- mode: Literal["iterative", "deep", "auto"] = "auto",
- max_iterations: int = 5,
- max_time_minutes: int = 10,
- use_graph: bool = True,
- search_handler: SearchHandlerProtocol | None = None,
- judge_handler: JudgeHandlerProtocol | None = None,
- oauth_token: str | None = None,
- ) -> None:
- """
- Initialize graph orchestrator.
-
- Args:
- mode: Research mode ("iterative", "deep", or "auto" to detect)
- max_iterations: Maximum iterations per loop
- max_time_minutes: Maximum time per loop
- use_graph: Whether to use graph execution (True) or agent chains (False)
- search_handler: Optional search handler for tool execution
- judge_handler: Optional judge handler for evidence assessment
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
- """
- self.mode = mode
- self.max_iterations = max_iterations
- self.max_time_minutes = max_time_minutes
- self.use_graph = use_graph
- self.search_handler = search_handler
- self.judge_handler = judge_handler
- self.oauth_token = oauth_token
- self.logger = logger
-
- # Initialize file service (lazy if not provided)
- self._file_service: ReportFileService | None = None
-
- # Initialize flows (for backward compatibility)
- self._iterative_flow: IterativeResearchFlow | None = None
- self._deep_flow: DeepResearchFlow | None = None
-
- # Graph execution components (lazy initialization)
- self._graph: ResearchGraph | None = None
- self._budget_tracker: BudgetTracker | None = None
-
- def _get_file_service(self) -> ReportFileService | None:
- """
- Get file service instance (lazy initialization).
-
- Returns:
- ReportFileService instance or None if disabled
- """
- if self._file_service is None:
- try:
- self._file_service = get_report_file_service()
- except Exception as e:
- self.logger.warning("Failed to initialize file service", error=str(e))
- return None
- return self._file_service
-
- async def run(
- self, query: str, message_history: list[ModelMessage] | None = None
- ) -> AsyncGenerator[AgentEvent, None]:
- """
- Run the research workflow.
-
- Args:
- query: The user's research query
- message_history: Optional user conversation history
-
- Yields:
- AgentEvent objects for real-time UI updates
- """
- self.logger.info(
- "Starting graph orchestrator",
- query=query[:100],
- mode=self.mode,
- use_graph=self.use_graph,
- has_history=bool(message_history),
- )
-
- yield AgentEvent(
- type="started",
- message=f"Starting research ({self.mode} mode): {query}",
- iteration=0,
- )
-
- try:
- # Determine research mode
- research_mode = self.mode
- if research_mode == "auto":
- research_mode = await self._detect_research_mode(query)
-
- # Use graph execution if enabled, otherwise fall back to agent chains
- if self.use_graph:
- async for event in self._run_with_graph(query, research_mode, message_history):
- yield event
- else:
- async for event in self._run_with_chains(query, research_mode, message_history):
- yield event
-
- except Exception as e:
- self.logger.error("Graph orchestrator failed", error=str(e), exc_info=True)
- yield AgentEvent(
- type="error",
- message=f"Research failed: {e!s}",
- iteration=0,
- )
-
- async def _run_with_graph(
- self,
- query: str,
- research_mode: Literal["iterative", "deep"],
- message_history: list[ModelMessage] | None = None,
- ) -> AsyncGenerator[AgentEvent, None]:
- """Run workflow using graph execution.
-
- Args:
- query: The research query
- research_mode: The research mode
- message_history: Optional user conversation history
-
- Yields:
- AgentEvent objects
- """
- # Initialize state and budget tracker
- from src.services.embeddings import get_embedding_service
-
- embedding_service = get_embedding_service()
- state = init_workflow_state(
- embedding_service=embedding_service,
- message_history=message_history,
- )
- budget_tracker = BudgetTracker()
- budget_tracker.create_budget(
- loop_id="graph_execution",
- tokens_limit=100000,
- time_limit_seconds=self.max_time_minutes * 60,
- iterations_limit=self.max_iterations,
- )
- budget_tracker.start_timer("graph_execution")
-
- context = GraphExecutionContext(
- state,
- budget_tracker,
- message_history=message_history or [],
- )
-
- # Build graph
- self._graph = await self._build_graph(research_mode)
-
- # Execute graph
- async for event in self._execute_graph(query, context):
- yield event
-
- async def _run_with_chains(
- self,
- query: str,
- research_mode: Literal["iterative", "deep"],
- message_history: list[ModelMessage] | None = None,
- ) -> AsyncGenerator[AgentEvent, None]:
- """Run workflow using agent chains (backward compatibility).
-
- Args:
- query: The research query
- research_mode: The research mode
- message_history: Optional user conversation history
-
- Yields:
- AgentEvent objects
- """
- if research_mode == "iterative":
- yield AgentEvent(
- type="searching",
- message="Running iterative research flow...",
- iteration=1,
- )
-
- if self._iterative_flow is None:
- self._iterative_flow = IterativeResearchFlow(
- max_iterations=self.max_iterations,
- max_time_minutes=self.max_time_minutes,
- judge_handler=self.judge_handler,
- oauth_token=self.oauth_token,
- )
-
- try:
- final_report = await self._iterative_flow.run(
- query, message_history=message_history
- )
- except Exception as e:
- self.logger.error("Iterative flow failed", error=str(e), exc_info=True)
- # Yield error event - outer handler will also catch and yield error event
- yield AgentEvent(
- type="error",
- message=f"Iterative research failed: {e!s}",
- iteration=1,
- )
- # Re-raise so outer handler can also yield error event for consistency
- raise
-
- yield AgentEvent(
- type="complete",
- message=final_report,
- data={"mode": "iterative"},
- iteration=1,
- )
-
- elif research_mode == "deep":
- yield AgentEvent(
- type="searching",
- message="Running deep research flow...",
- iteration=1,
- )
-
- if self._deep_flow is None:
- # DeepResearchFlow creates its own judge_handler internally
- # The judge_handler is passed to IterativeResearchFlow in parallel loops
- self._deep_flow = DeepResearchFlow(
- max_iterations=self.max_iterations,
- max_time_minutes=self.max_time_minutes,
- oauth_token=self.oauth_token,
- )
-
- try:
- final_report = await self._deep_flow.run(query, message_history=message_history)
- except Exception as e:
- self.logger.error("Deep flow failed", error=str(e), exc_info=True)
- # Yield error event before re-raising so test can capture it
- yield AgentEvent(
- type="error",
- message=f"Deep research failed: {e!s}",
- iteration=1,
- )
- raise
-
- yield AgentEvent(
- type="complete",
- message=final_report,
- data={"mode": "deep"},
- iteration=1,
- )
-
- async def _build_graph(self, mode: Literal["iterative", "deep"]) -> ResearchGraph:
- """Build graph for the specified mode.
-
- Args:
- mode: Research mode
-
- Returns:
- Constructed ResearchGraph
- """
- if mode == "iterative":
- # Get agents - pass OAuth token for HuggingFace authentication
- knowledge_gap_agent = create_knowledge_gap_agent(oauth_token=self.oauth_token)
- tool_selector_agent = create_tool_selector_agent(oauth_token=self.oauth_token)
- thinking_agent = create_thinking_agent(oauth_token=self.oauth_token)
- writer_agent = create_writer_agent(oauth_token=self.oauth_token)
-
- # Create graph
- graph = create_iterative_graph(
- knowledge_gap_agent=knowledge_gap_agent.agent,
- tool_selector_agent=tool_selector_agent.agent,
- thinking_agent=thinking_agent.agent,
- writer_agent=writer_agent.agent,
- )
- else: # deep
- # Get agents - pass OAuth token for HuggingFace authentication
- planner_agent = create_planner_agent(oauth_token=self.oauth_token)
- knowledge_gap_agent = create_knowledge_gap_agent(oauth_token=self.oauth_token)
- tool_selector_agent = create_tool_selector_agent(oauth_token=self.oauth_token)
- thinking_agent = create_thinking_agent(oauth_token=self.oauth_token)
- writer_agent = create_writer_agent(oauth_token=self.oauth_token)
- long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token)
-
- # Create graph
- graph = create_deep_graph(
- planner_agent=planner_agent.agent,
- knowledge_gap_agent=knowledge_gap_agent.agent,
- tool_selector_agent=tool_selector_agent.agent,
- thinking_agent=thinking_agent.agent,
- writer_agent=writer_agent.agent,
- long_writer_agent=long_writer_agent.agent,
- )
-
- return graph
-
- def _emit_start_event(
- self, node: Any, current_node_id: str, iteration: int, context: GraphExecutionContext
- ) -> AgentEvent:
- """Emit start event for a node.
-
- Args:
- node: The node being executed
- current_node_id: Current node ID
- iteration: Current iteration number
- context: Execution context
-
- Returns:
- AgentEvent for the start of node execution
- """
- if node and node.node_id == "planner":
- return AgentEvent(
- type="searching",
- message="Creating report plan...",
- iteration=iteration,
- )
- elif node and node.node_id == "parallel_loops":
- # Get report plan to show section count
- report_plan = context.get_node_result("planner")
- if report_plan and hasattr(report_plan, "report_outline"):
- section_count = len(report_plan.report_outline)
- return AgentEvent(
- type="looping",
- message=f"Running parallel research loops for {section_count} sections...",
- iteration=iteration,
- data={"sections": section_count},
- )
- return AgentEvent(
- type="looping",
- message="Running parallel research loops...",
- iteration=iteration,
- )
- elif node and node.node_id == "synthesizer":
- return AgentEvent(
- type="synthesizing",
- message="Synthesizing final report from section drafts...",
- iteration=iteration,
- )
- return AgentEvent(
- type="looping",
- message=f"Executing node: {current_node_id}",
- iteration=iteration,
- )
-
- def _emit_completion_event(
- self, node: Any, current_node_id: str, result: Any, iteration: int
- ) -> AgentEvent:
- """Emit completion event for a node.
-
- Args:
- node: The node that was executed
- current_node_id: Current node ID
- result: Node execution result
- iteration: Current iteration number
-
- Returns:
- AgentEvent for the completion of node execution
- """
- if not node:
- return AgentEvent(
- type="looping",
- message=f"Completed node: {current_node_id}",
- iteration=iteration,
- )
-
- if node.node_id == "planner":
- if isinstance(result, dict) and "report_outline" in result:
- section_count = len(result["report_outline"])
- return AgentEvent(
- type="search_complete",
- message=f"Report plan created with {section_count} sections",
- iteration=iteration,
- data={"sections": section_count},
- )
- return AgentEvent(
- type="search_complete",
- message="Report plan created",
- iteration=iteration,
- )
- elif node.node_id == "parallel_loops":
- if isinstance(result, list):
- return AgentEvent(
- type="search_complete",
- message=f"Completed parallel research for {len(result)} sections",
- iteration=iteration,
- data={"sections_completed": len(result)},
- )
- return AgentEvent(
- type="search_complete",
- message="Parallel research loops completed",
- iteration=iteration,
- )
- elif node.node_id == "synthesizer":
- return AgentEvent(
- type="synthesizing",
- message="Final report synthesis completed",
- iteration=iteration,
- )
- return AgentEvent(
- type="searching" if node.node_type == "agent" else "looping",
- message=f"Completed {node.node_type} node: {current_node_id}",
- iteration=iteration,
- )
-
- def _get_final_result_from_exit_nodes(
- self, context: GraphExecutionContext, current_node_id: str | None
- ) -> tuple[Any, str | None]:
- """Get final result from exit nodes, prioritizing synthesizer/writer."""
- if not self._graph:
- return None, current_node_id
-
- final_result = None
- result_node_id = current_node_id
-
- # First try to get result from current node (if it's an exit node)
- if current_node_id and current_node_id in self._graph.exit_nodes:
- final_result = context.get_node_result(current_node_id)
- self.logger.debug(
- "Final result from current exit node",
- node_id=current_node_id,
- has_result=final_result is not None,
- result_type=type(final_result).__name__ if final_result else None,
- )
-
- # If no result from current node, check all exit nodes for results
- # Prioritize synthesizer (deep research) or writer (iterative research)
- if not final_result:
- exit_node_priority = ["synthesizer", "writer"]
- for exit_node_id in exit_node_priority:
- if exit_node_id in self._graph.exit_nodes:
- result = context.get_node_result(exit_node_id)
- if result:
- final_result = result
- result_node_id = exit_node_id
- self.logger.debug(
- "Final result from priority exit node",
- node_id=exit_node_id,
- result_type=type(final_result).__name__,
- )
- break
-
- # If still no result, check all exit nodes
- if not final_result:
- for exit_node_id in self._graph.exit_nodes:
- result = context.get_node_result(exit_node_id)
- if result:
- final_result = result
- result_node_id = exit_node_id
- self.logger.debug(
- "Final result from any exit node",
- node_id=exit_node_id,
- result_type=type(final_result).__name__,
- )
- break
-
- # Log warning if no result found
- if not final_result:
- self.logger.warning(
- "No final result found in exit nodes",
- exit_nodes=list(self._graph.exit_nodes),
- visited_nodes=list(context.visited_nodes),
- all_node_results=list(context.node_results.keys()),
- )
-
- return final_result, result_node_id
-
- def _extract_final_message_and_files(self, final_result: Any) -> tuple[str, dict[str, Any]]:
- """Extract message and file information from final result."""
- event_data: dict[str, Any] = {"mode": self.mode}
- message: str = "Research completed"
-
- if isinstance(final_result, str):
- message = final_result
- self.logger.debug("Final message extracted from string result", length=len(message))
- elif isinstance(final_result, dict):
- # First check for message key (most important)
- if "message" in final_result:
- message = final_result["message"]
- self.logger.debug(
- "Final message extracted from dict 'message' key",
- length=len(message) if isinstance(message, str) else 0,
- )
-
- # Then check for file paths
- if "file" in final_result:
- file_path = final_result["file"]
- if isinstance(file_path, str):
- event_data["file"] = file_path
- # Only override message if not already set from "message" key
- if "message" not in final_result:
- message = "Report generated. Download available."
- self.logger.debug("File path added to event data", file_path=file_path)
-
- # Check for multiple files
- if "files" in final_result:
- files = final_result["files"]
- if isinstance(files, list):
- event_data["files"] = files
- self.logger.debug("Multiple files added to event data", count=len(files))
-
- return message, event_data
-
- async def _execute_graph(
- self, query: str, context: GraphExecutionContext
- ) -> AsyncGenerator[AgentEvent, None]:
- """Execute the graph from entry node.
-
- Args:
- query: The research query
- context: Execution context
-
- Yields:
- AgentEvent objects
- """
- if not self._graph:
- raise ValueError("Graph not built")
-
- current_node_id = self._graph.entry_node
- iteration = 0
-
- # Execute nodes until we reach an exit node
- while current_node_id:
- # Check budget
- if not context.budget_tracker.can_continue("graph_execution"):
- self.logger.warning("Budget exceeded, exiting graph execution")
- break
-
- # Execute current node
- iteration += 1
- context.current_node = current_node_id
- node = self._graph.get_node(current_node_id)
-
- # Emit start event
- yield self._emit_start_event(node, current_node_id, iteration, context)
-
- try:
- result = await self._execute_node(current_node_id, query, context)
- context.set_node_result(current_node_id, result)
- context.mark_visited(current_node_id)
-
- # Yield completion event
- yield self._emit_completion_event(node, current_node_id, result, iteration)
-
- except Exception as e:
- self.logger.error("Node execution failed", node_id=current_node_id, error=str(e))
- yield AgentEvent(
- type="error",
- message=f"Node {current_node_id} failed: {e!s}",
- iteration=iteration,
- )
- break
-
- # Check if current node is an exit node - if so, we're done
- if current_node_id in self._graph.exit_nodes:
- break
-
- # Get next node(s)
- next_nodes = self._get_next_node(current_node_id, context)
-
- if not next_nodes:
- # No more nodes, we've reached a dead end
- self.logger.warning("Reached dead end in graph", node_id=current_node_id)
- break
-
- current_node_id = next_nodes[0] # For now, take first next node (handle parallel later)
-
- # Final event - get result from exit nodes (prioritize synthesizer/writer nodes)
- final_result, result_node_id = self._get_final_result_from_exit_nodes(
- context, current_node_id
- )
-
- # Check if final result contains file information
- event_data: dict[str, Any] = {"mode": self.mode, "iterations": iteration}
- message, file_event_data = self._extract_final_message_and_files(final_result)
- event_data.update(file_event_data)
-
- yield AgentEvent(
- type="complete",
- message=message,
- data=event_data,
- iteration=iteration,
- )
-
- async def _execute_node(self, node_id: str, query: str, context: GraphExecutionContext) -> Any:
- """Execute a single node.
-
- Args:
- node_id: The node ID
- query: The research query
- context: Execution context
-
- Returns:
- Node execution result
- """
- if not self._graph:
- raise ValueError("Graph not built")
-
- node = self._graph.get_node(node_id)
- if not node:
- raise ValueError(f"Node {node_id} not found")
-
- if isinstance(node, AgentNode):
- return await self._execute_agent_node(node, query, context)
- elif isinstance(node, StateNode):
- return await self._execute_state_node(node, query, context)
- elif isinstance(node, DecisionNode):
- return await self._execute_decision_node(node, query, context)
- elif isinstance(node, ParallelNode):
- return await self._execute_parallel_node(node, query, context)
- else:
- raise ValueError(f"Unknown node type: {type(node)}")
-
- async def _execute_synthesizer_node(self, query: str, context: GraphExecutionContext) -> Any:
- """Execute synthesizer node for deep research."""
- from src.agent_factory.agents import create_long_writer_agent
- from src.utils.models import ReportDraft, ReportDraftSection, ReportPlan
-
- report_plan = context.get_node_result("planner")
- section_drafts = context.get_node_result("parallel_loops") or []
-
- if not isinstance(report_plan, ReportPlan):
- raise ValueError("ReportPlan not found for synthesizer")
-
- if not section_drafts:
- raise ValueError("Section drafts not found for synthesizer")
-
- # Create ReportDraft from section drafts
- report_draft = ReportDraft(
- sections=[
- ReportDraftSection(
- section_title=section.title,
- section_content=draft,
- )
- for section, draft in zip(report_plan.report_outline, section_drafts, strict=False)
- ]
- )
-
- # Get LongWriterAgent instance and call write_report directly
- long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token)
- final_report = await long_writer_agent.write_report(
- original_query=query,
- report_title=report_plan.report_title,
- report_draft=report_draft,
- )
-
- # Estimate tokens (rough estimate)
- estimated_tokens = len(final_report) // 4 # Rough token estimate
- context.budget_tracker.add_tokens("graph_execution", estimated_tokens)
-
- # Save report to file if enabled (may generate multiple formats)
- return self._save_report_and_return_result(final_report, query)
-
- def _save_report_and_return_result(self, final_report: str, query: str) -> dict[str, Any] | str:
- """Save report to file and return result with file paths if available."""
- file_path: str | None = None
- pdf_path: str | None = None
- try:
- file_service = self._get_file_service()
- if file_service:
- # Use save_report_multiple_formats to get both MD and PDF if enabled
- saved_files = file_service.save_report_multiple_formats(
- report_content=final_report,
- query=query,
- )
- file_path = saved_files.get("md")
- pdf_path = saved_files.get("pdf")
- self.logger.info(
- "Report saved to file",
- md_path=file_path,
- pdf_path=pdf_path,
- )
- except Exception as e:
- # Don't fail the entire operation if file saving fails
- self.logger.warning("Failed to save report to file", error=str(e))
- file_path = None
- pdf_path = None
-
- # Return dict with file paths if available, otherwise return string (backward compatible)
- if file_path:
- result: dict[str, Any] = {
- "message": final_report,
- "file": file_path,
- }
- # Add PDF path if generated
- if pdf_path:
- result["files"] = [file_path, pdf_path]
- return result
- return final_report
-
- async def _execute_writer_node(self, query: str, context: GraphExecutionContext) -> Any:
- """Execute writer node for iterative research."""
- from src.agent_factory.agents import create_writer_agent
-
- # Get all evidence from workflow state and convert to findings string
- evidence = context.state.evidence
- if evidence:
- # Convert evidence to findings format (similar to conversation.get_all_findings())
- findings_parts: list[str] = []
- for ev in evidence:
- finding = f"**{ev.citation.title}**\n{ev.content}"
- if ev.citation.url:
- finding += f"\nSource: {ev.citation.url}"
- findings_parts.append(finding)
- all_findings = "\n\n".join(findings_parts)
- else:
- all_findings = "No findings available yet."
-
- # Get WriterAgent instance and call write_report directly
- writer_agent = create_writer_agent(oauth_token=self.oauth_token)
- final_report = await writer_agent.write_report(
- query=query,
- findings=all_findings,
- output_length="",
- output_instructions="",
- )
-
- # Estimate tokens (rough estimate)
- estimated_tokens = len(final_report) // 4 # Rough token estimate
- context.budget_tracker.add_tokens("graph_execution", estimated_tokens)
-
- # Save report to file if enabled (may generate multiple formats)
- return self._save_report_and_return_result(final_report, query)
-
- def _prepare_agent_input(
- self, node: AgentNode, query: str, context: GraphExecutionContext
- ) -> Any:
- """Prepare input data for agent execution."""
- if node.node_id == "planner":
- # Planner takes the original query
- input_data = query
- else:
- # Standard: use previous node result or query
- prev_result = context.get_node_result(context.current_node)
- input_data = prev_result if prev_result is not None else query
-
- # Apply input transformer if provided
- if node.input_transformer:
- input_data = node.input_transformer(input_data)
-
- return input_data
-
- async def _execute_standard_agent(
- self, node: AgentNode, input_data: Any, query: str, context: GraphExecutionContext
- ) -> Any:
- """Execute standard agent with error handling and fallback models."""
- # Get message history from context (limit to most recent 10 messages for token efficiency)
- message_history = context.get_message_history(max_messages=10)
-
- # Try with the original agent first
- try:
- # Pass message_history if available (Pydantic AI agents support this)
- if message_history:
- result = await node.agent.run(input_data, message_history=message_history)
- else:
- result = await node.agent.run(input_data)
-
- # Accumulate new messages from agent result if available
- if hasattr(result, "new_messages"):
- try:
- new_messages = result.new_messages()
- for msg in new_messages:
- context.add_message(msg)
- except Exception as e:
- # Don't fail if message accumulation fails
- self.logger.debug(
- "Failed to accumulate messages from agent result", error=str(e)
- )
- return result
- except Exception as e:
- # Check if we should retry with fallback models
- from src.utils.hf_error_handler import (
- extract_error_details,
- should_retry_with_fallback,
- )
-
- error_details = extract_error_details(e)
- should_retry = should_retry_with_fallback(e)
-
- # Handle validation errors and API errors for planner node (with fallback)
- if node.node_id == "planner":
- if should_retry:
- self.logger.warning(
- "Planner failed, trying fallback models",
- original_error=str(e),
- status_code=error_details.get("status_code"),
- )
- # Try fallback models for planner
- fallback_result = await self._try_fallback_models(
- node, input_data, message_history, query, context, e
- )
- if fallback_result is not None:
- return fallback_result
- # If fallback failed or not applicable, use fallback plan
- return self._create_fallback_plan(query, input_data)
-
- # For other nodes, try fallback models if applicable
- if should_retry:
- self.logger.warning(
- "Agent node failed, trying fallback models",
- node_id=node.node_id,
- original_error=str(e),
- status_code=error_details.get("status_code"),
- )
- fallback_result = await self._try_fallback_models(
- node, input_data, message_history, query, context, e
- )
- if fallback_result is not None:
- return fallback_result
-
- # If fallback didn't work or wasn't applicable, re-raise the exception
- raise
-
- async def _try_fallback_models(
- self,
- node: AgentNode,
- input_data: Any,
- message_history: list[Any],
- query: str,
- context: GraphExecutionContext,
- original_error: Exception,
- ) -> Any | None:
- """Try executing agent with fallback models.
-
- Args:
- node: The agent node that failed
- input_data: Input data for the agent
- message_history: Message history for the agent
- query: The research query
- context: Execution context
- original_error: The original error that triggered fallback
-
- Returns:
- Agent result if successful, None if all fallbacks failed
- """
- from src.utils.hf_error_handler import extract_error_details, get_fallback_models
-
- error_details = extract_error_details(original_error)
- original_model = error_details.get("model_name")
- fallback_models = get_fallback_models(original_model)
-
- # Also try models from settings fallback list
- from src.utils.config import settings
-
- settings_fallbacks = settings.get_hf_fallback_models_list()
- for model in settings_fallbacks:
- if model not in fallback_models:
- fallback_models.append(model)
-
- self.logger.info(
- "Trying fallback models",
- node_id=node.node_id,
- original_model=original_model,
- fallback_count=len(fallback_models),
- )
-
- # Try each fallback model
- for fallback_model in fallback_models:
- try:
- # Recreate agent with fallback model
- fallback_agent = self._recreate_agent_with_model(node.node_id, fallback_model)
- if fallback_agent is None:
- continue
-
- # Try running with fallback agent
- if message_history:
- result = await fallback_agent.run(input_data, message_history=message_history)
- else:
- result = await fallback_agent.run(input_data)
-
- self.logger.info(
- "Fallback model succeeded",
- node_id=node.node_id,
- fallback_model=fallback_model,
- )
-
- # Accumulate new messages from agent result if available
- if hasattr(result, "new_messages"):
- try:
- new_messages = result.new_messages()
- for msg in new_messages:
- context.add_message(msg)
- except Exception as e:
- self.logger.debug(
- "Failed to accumulate messages from fallback agent result", error=str(e)
- )
-
- return result
-
- except Exception as e:
- self.logger.warning(
- "Fallback model failed",
- node_id=node.node_id,
- fallback_model=fallback_model,
- error=str(e),
- )
- continue
-
- # All fallback models failed
- self.logger.error(
- "All fallback models failed",
- node_id=node.node_id,
- fallback_count=len(fallback_models),
- )
- return None
-
- def _recreate_agent_with_model(self, node_id: str, model_name: str) -> Any | None:
- """Recreate an agent with a specific model.
-
- Args:
- node_id: The node ID (e.g., "thinking", "knowledge_gap")
- model_name: The model name to use
-
- Returns:
- Agent instance or None if recreation failed
- """
- try:
- from pydantic_ai.models.huggingface import HuggingFaceModel
- from pydantic_ai.providers.huggingface import HuggingFaceProvider
-
- # Create model with fallback model name
- hf_provider = HuggingFaceProvider(api_key=self.oauth_token)
- model = HuggingFaceModel(model_name, provider=hf_provider)
-
- # Recreate agent based on node_id
- if node_id == "thinking":
- from src.agent_factory.agents import create_thinking_agent
-
- agent_wrapper = create_thinking_agent(model=model, oauth_token=self.oauth_token)
- return agent_wrapper.agent
- elif node_id == "knowledge_gap":
- from src.agent_factory.agents import create_knowledge_gap_agent
-
- agent_wrapper = create_knowledge_gap_agent( # type: ignore[assignment]
- model=model, oauth_token=self.oauth_token
- )
- return agent_wrapper.agent
- elif node_id == "tool_selector":
- from src.agent_factory.agents import create_tool_selector_agent
-
- agent_wrapper = create_tool_selector_agent( # type: ignore[assignment]
- model=model, oauth_token=self.oauth_token
- )
- return agent_wrapper.agent
- elif node_id == "planner":
- from src.agent_factory.agents import create_planner_agent
-
- agent_wrapper = create_planner_agent(model=model, oauth_token=self.oauth_token) # type: ignore[assignment]
- return agent_wrapper.agent
- elif node_id == "writer":
- from src.agent_factory.agents import create_writer_agent
-
- agent_wrapper = create_writer_agent(model=model, oauth_token=self.oauth_token) # type: ignore[assignment]
- return agent_wrapper.agent
- else:
- self.logger.warning("Unknown node_id for agent recreation", node_id=node_id)
- return None
-
- except Exception as e:
- self.logger.error(
- "Failed to recreate agent with fallback model",
- node_id=node_id,
- model_name=model_name,
- error=str(e),
- )
- return None
-
- def _create_fallback_plan(self, query: str, input_data: Any) -> Any:
- """Create fallback ReportPlan when planner fails."""
- from src.utils.models import ReportPlan, ReportPlanSection
-
- self.logger.error(
- "Planner agent execution failed, using fallback plan",
- error_type=type(input_data).__name__,
- )
-
- # Extract query from input_data if possible
- fallback_query = query
- if isinstance(input_data, str):
- # Try to extract query from input string
- if "QUERY:" in input_data:
- fallback_query = input_data.split("QUERY:")[-1].strip()
-
- return ReportPlan(
- background_context="",
- report_outline=[
- ReportPlanSection(
- title="Research Findings",
- key_question=fallback_query,
- )
- ],
- report_title=f"Research Report: {fallback_query[:50]}",
- )
-
- def _extract_agent_output(self, node: AgentNode, result: Any) -> Any:
- """Extract and transform output from agent result."""
- # Defensively extract output - handle various result formats
- output = result.output if hasattr(result, "output") else result
-
- # Handle case where output might be a tuple (from pydantic-ai validation errors)
- if isinstance(output, tuple):
- output = self._handle_tuple_output(node, output, result)
- return output
-
- def _handle_tuple_output(self, node: AgentNode, output: tuple[Any, ...], result: Any) -> Any:
- """Handle tuple output from agent (validation errors)."""
- # If tuple contains a dict-like structure, try to reconstruct the object
- if len(output) == 2 and isinstance(output[0], str) and output[0] == "research_complete":
- # This is likely a validation error format: ('research_complete', False)
- # Try to get the actual output from result
- self.logger.warning(
- "Agent result output is a tuple, attempting to extract actual output",
- node_id=node.node_id,
- tuple_value=output,
- )
- # Try to get output from result attributes
- if hasattr(result, "data"):
- return result.data
- if hasattr(result, "response"):
- return result.response
- # Last resort: try to reconstruct from tuple
- # This shouldn't happen, but handle gracefully
- from src.utils.models import KnowledgeGapOutput
-
- if node.node_id == "knowledge_gap":
- # Reconstruct KnowledgeGapOutput from validation error tuple
- reconstructed = KnowledgeGapOutput(
- research_complete=output[1] if len(output) > 1 else False,
- outstanding_gaps=[],
- )
- self.logger.info(
- "Reconstructed KnowledgeGapOutput from validation error tuple",
- node_id=node.node_id,
- research_complete=reconstructed.research_complete,
- )
- return reconstructed
-
- # For other nodes, try to extract meaningful output or use fallback
- self.logger.warning(
- "Agent node output is tuple format, attempting extraction",
- node_id=node.node_id,
- tuple_value=output,
- )
- # Try to extract first meaningful element
- if len(output) > 0:
- # If first element is a string or dict, might be the actual output
- if isinstance(output[0], str | dict):
- return output[0]
- # Last resort: use first element
- return output[0]
- # Empty tuple - use None and let downstream handle it
- return None
-
- async def _execute_agent_node(
- self, node: AgentNode, query: str, context: GraphExecutionContext
- ) -> Any:
- """Execute an agent node.
-
- Special handling for deep research nodes:
- - "planner": Takes query string, returns ReportPlan
- - "synthesizer": Takes query + ReportPlan + section drafts, returns final report
-
- Args:
- node: The agent node
- query: The research query
- context: Execution context
-
- Returns:
- Agent execution result
- """
- # Special handling for synthesizer node (deep research)
- if node.node_id == "synthesizer":
- return await self._execute_synthesizer_node(query, context)
-
- # Special handling for writer node (iterative research)
- if node.node_id == "writer":
- return await self._execute_writer_node(query, context)
-
- # Standard agent execution
- input_data = self._prepare_agent_input(node, query, context)
- result = await self._execute_standard_agent(node, input_data, query, context)
- output = self._extract_agent_output(node, result)
-
- if node.output_transformer:
- output = node.output_transformer(output)
-
- # Estimate and track tokens
- if hasattr(result, "usage") and result.usage:
- tokens = result.usage.total_tokens if hasattr(result.usage, "total_tokens") else 0
- context.budget_tracker.add_tokens("graph_execution", tokens)
-
- # Special handling for knowledge_gap node: optionally call judge_handler
- if node.node_id == "knowledge_gap" and self.judge_handler:
- # Get evidence from workflow state
- evidence = context.state.evidence
- if evidence:
- try:
- from src.utils.models import JudgeAssessment
-
- # Call judge handler to assess evidence
- judge_assessment: JudgeAssessment = await self.judge_handler.assess(
- question=query, evidence=evidence
- )
- # Store assessment in context for decision node to use
- context.set_node_result("judge_assessment", judge_assessment)
- self.logger.info(
- "Judge assessment completed",
- sufficient=judge_assessment.sufficient,
- confidence=judge_assessment.confidence,
- recommendation=judge_assessment.recommendation,
- )
- except Exception as e:
- self.logger.warning(
- "Judge handler assessment failed",
- error=str(e),
- node_id=node.node_id,
- )
- # Continue without judge assessment
-
- return output
-
- async def _execute_state_node(
- self, node: StateNode, query: str, context: GraphExecutionContext
- ) -> Any:
- """Execute a state node.
-
- Special handling for deep research state nodes:
- - "store_plan": Stores ReportPlan in context for parallel loops
- - "collect_drafts": Stores section drafts in context for synthesizer
- - "execute_tools": Executes search using search_handler
-
- Args:
- node: The state node
- query: The research query
- context: Execution context
-
- Returns:
- State update result
- """
- # Special handling for execute_tools node
- if node.node_id == "execute_tools":
- # Get AgentSelectionPlan from tool_selector node result
- tool_selector_result = context.get_node_result("tool_selector")
- from src.utils.models import AgentSelectionPlan, SearchResult
-
- # Extract query from context or use original query
- search_query = query
- if tool_selector_result and isinstance(tool_selector_result, AgentSelectionPlan):
- # Use the gap or query from the selection plan
- if tool_selector_result.tasks:
- # Use the first task's query if available
- first_task = tool_selector_result.tasks[0]
- if hasattr(first_task, "query") and first_task.query:
- search_query = first_task.query
- elif hasattr(first_task, "tool_input") and isinstance(
- first_task.tool_input, str
- ):
- search_query = first_task.tool_input
-
- # Execute search using search_handler
- if self.search_handler:
- try:
- search_result: SearchResult = await self.search_handler.execute(
- query=search_query, max_results_per_tool=10
- )
- # Add evidence to workflow state (add_evidence expects a list)
- context.state.add_evidence(search_result.evidence)
- # Store evidence list in context for next nodes
- context.set_node_result(node.node_id, search_result.evidence)
- self.logger.info(
- "Tools executed via search_handler",
- query=search_query[:100],
- evidence_count=len(search_result.evidence),
- )
- return search_result.evidence
- except Exception as e:
- self.logger.error(
- "Search handler execution failed",
- error=str(e),
- query=search_query[:100],
- )
- # Return empty list on error to allow graph to continue
- return []
- else:
- # Fallback: log warning and return empty list
- self.logger.warning(
- "Search handler not available for execute_tools node",
- node_id=node.node_id,
- )
- return []
-
- # Get previous result for state update
- # For "store_plan", get from planner node
- # For "collect_drafts", get from parallel_loops node
- if node.node_id == "store_plan":
- prev_result = context.get_node_result("planner")
- elif node.node_id == "collect_drafts":
- prev_result = context.get_node_result("parallel_loops")
- else:
- prev_result = context.get_node_result(context.current_node)
-
- # Update state
- updated_state = node.state_updater(context.state, prev_result)
- context.state = updated_state
-
- # Store result in context for next nodes to access
- context.set_node_result(node.node_id, prev_result)
-
- # Read state if needed
- if node.state_reader:
- return node.state_reader(context.state)
-
- return prev_result # Return the stored result for next nodes
-
- async def _execute_decision_node(
- self, node: DecisionNode, query: str, context: GraphExecutionContext
- ) -> str:
- """Execute a decision node.
-
- Args:
- node: The decision node
- query: The research query
- context: Execution context
-
- Returns:
- Next node ID
- """
- # Get previous result for decision
- # The decision node needs the result from the node that connects to it
- # Find the previous node by searching edges
- prev_node_id: str | None = None
- if self._graph:
- # Find which node connects to this decision node
- for from_node, edge_list in self._graph.edges.items():
- for edge in edge_list:
- if edge.to_node == node.node_id:
- prev_node_id = from_node
- break
- if prev_node_id:
- break
-
- # Fallback: For continue_decision, it always comes from knowledge_gap
- if not prev_node_id and node.node_id == "continue_decision":
- prev_node_id = "knowledge_gap"
-
- # Get result from previous node (or current node if no previous found)
- if prev_node_id:
- prev_result = context.get_node_result(prev_node_id)
- else:
- # Fallback: try to get from visited nodes (last visited before current)
- visited_list = list(context.visited_nodes)
- if len(visited_list) > 0:
- prev_node_id = visited_list[-1]
- prev_result = context.get_node_result(prev_node_id)
- else:
- prev_result = context.get_node_result(context.current_node)
-
- # Handle case where result might be a tuple (from pydantic-ai validation errors)
- # Extract the actual result object if it's a tuple
- if isinstance(prev_result, tuple) and len(prev_result) > 0:
- # Check if first element is a KnowledgeGapOutput-like object
- if hasattr(prev_result[0], "research_complete"):
- prev_result = prev_result[0]
- elif len(prev_result) > 1 and hasattr(prev_result[1], "research_complete"):
- prev_result = prev_result[1]
- elif (
- len(prev_result) == 2
- and isinstance(prev_result[0], str)
- and prev_result[0] == "research_complete"
- ):
- # Handle validation error format: ('research_complete', False)
- # Reconstruct KnowledgeGapOutput from tuple
- from src.utils.models import KnowledgeGapOutput
-
- self.logger.warning(
- "Decision node received validation error tuple, reconstructing KnowledgeGapOutput",
- node_id=node.node_id,
- tuple_value=prev_result,
- )
- prev_result = KnowledgeGapOutput(
- research_complete=prev_result[1] if len(prev_result) > 1 else False,
- outstanding_gaps=[],
- )
- else:
- # If tuple doesn't contain the object, try to reconstruct or use fallback
- self.logger.warning(
- "Decision node received unexpected tuple format, attempting reconstruction",
- node_id=node.node_id,
- tuple_length=len(prev_result),
- tuple_types=[type(x).__name__ for x in prev_result],
- )
- # Try to reconstruct KnowledgeGapOutput if this is from knowledge_gap node
- if prev_node_id == "knowledge_gap":
- from src.utils.models import KnowledgeGapOutput
-
- # Try to extract research_complete from tuple
- research_complete = False
- for item in prev_result:
- if isinstance(item, bool):
- research_complete = item
- break
- elif isinstance(item, dict) and "research_complete" in item:
- research_complete = item["research_complete"]
- break
- prev_result = KnowledgeGapOutput(
- research_complete=research_complete,
- outstanding_gaps=[],
- )
- else:
- # For other nodes, use first element as fallback
- prev_result = prev_result[0]
-
- # Make decision
- try:
- next_node_id = node.decision_function(prev_result)
- except Exception as e:
- self.logger.error(
- "Decision function failed",
- node_id=node.node_id,
- error=str(e),
- prev_result_type=type(prev_result).__name__,
- )
- # Default to first option on error
- next_node_id = node.options[0]
-
- # Validate decision
- if next_node_id not in node.options:
- self.logger.warning(
- "Decision function returned invalid node",
- node_id=node.node_id,
- returned=next_node_id,
- options=node.options,
- )
- # Default to first option
- next_node_id = node.options[0]
-
- return next_node_id
-
- async def _execute_parallel_node(
- self, node: ParallelNode, query: str, context: GraphExecutionContext
- ) -> list[Any]:
- """Execute a parallel node.
-
- Special handling for deep research "parallel_loops" node:
- - Extracts report plan from previous node result
- - Creates IterativeResearchFlow instances for each section
- - Executes them in parallel
- - Returns section drafts
-
- Args:
- node: The parallel node
- query: The research query
- context: Execution context
-
- Returns:
- List of results from parallel nodes
- """
- # Special handling for deep research parallel_loops node
- if node.node_id == "parallel_loops":
- return await self._execute_deep_research_parallel_loops(node, query, context)
-
- # Standard parallel node execution
- # Execute all parallel nodes concurrently
- tasks = [
- self._execute_node(parallel_node_id, query, context)
- for parallel_node_id in node.parallel_nodes
- ]
-
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- # Handle exceptions
- for i, result in enumerate(results):
- if isinstance(result, Exception):
- self.logger.error(
- "Parallel node execution failed",
- node_id=node.parallel_nodes[i] if i < len(node.parallel_nodes) else "unknown",
- error=str(result),
- )
- results[i] = None
-
- # Aggregate if needed
- if node.aggregator:
- aggregated = node.aggregator(results)
- # Type cast: aggregator returns Any, but we expect list[Any]
- return list(aggregated) if isinstance(aggregated, list) else [aggregated]
-
- return results
-
- async def _execute_deep_research_parallel_loops(
- self, node: ParallelNode, query: str, context: GraphExecutionContext
- ) -> list[str]:
- """Execute parallel iterative research loops for deep research.
-
- Args:
- node: The parallel node (should be "parallel_loops")
- query: The research query
- context: Execution context
-
- Returns:
- List of section draft strings
- """
- from src.agent_factory.judges import create_judge_handler
- from src.orchestrator.research_flow import IterativeResearchFlow
- from src.utils.models import ReportPlan
-
- # Get report plan from previous node (store_plan)
- # The plan should be stored in context.node_results from the planner node
- planner_result = context.get_node_result("planner")
- if not isinstance(planner_result, ReportPlan):
- self.logger.error(
- "Planner result is not a ReportPlan",
- type=type(planner_result),
- )
- raise ValueError("Planner must return ReportPlan for deep research")
-
- report_plan: ReportPlan = planner_result
- self.logger.info(
- "Executing parallel loops for deep research",
- sections=len(report_plan.report_outline),
- )
-
- # Use judge handler from GraphOrchestrator if available, otherwise create new one
- judge_handler = self.judge_handler
- if judge_handler is None:
- judge_handler = create_judge_handler()
-
- # Create and execute iterative research flows for each section
- async def run_section_research(section_index: int) -> str:
- """Run iterative research for a single section."""
- section = report_plan.report_outline[section_index]
-
- try:
- # Create iterative research flow
- flow = IterativeResearchFlow(
- max_iterations=self.max_iterations,
- max_time_minutes=self.max_time_minutes,
- verbose=False, # Less verbose in parallel execution
- use_graph=False, # Use agent chains for section research
- judge_handler=self.judge_handler or judge_handler,
- oauth_token=self.oauth_token,
- )
-
- # Run research for this section
- section_draft = await flow.run(
- query=section.key_question,
- background_context=report_plan.background_context,
- )
-
- self.logger.info(
- "Section research completed",
- section_index=section_index,
- section_title=section.title,
- draft_length=len(section_draft),
- )
-
- return section_draft
-
- except Exception as e:
- self.logger.error(
- "Section research failed",
- section_index=section_index,
- section_title=section.title,
- error=str(e),
- )
- # Return empty string for failed sections
- return f"# {section.title}\n\n[Research failed: {e!s}]"
-
- # Execute all sections in parallel
- section_drafts = await asyncio.gather(
- *(run_section_research(i) for i in range(len(report_plan.report_outline))),
- return_exceptions=True,
- )
-
- # Handle exceptions and filter None results
- filtered_drafts: list[str] = []
- for i, draft in enumerate(section_drafts):
- if isinstance(draft, Exception):
- self.logger.error(
- "Section research exception",
- section_index=i,
- error=str(draft),
- )
- filtered_drafts.append(
- f"# {report_plan.report_outline[i].title}\n\n[Research failed: {draft!s}]"
- )
- elif draft is not None:
- # Type narrowing: after Exception check, draft is str | None
- assert isinstance(draft, str), "Expected str after Exception check"
- filtered_drafts.append(draft)
-
- self.logger.info(
- "Parallel loops completed",
- sections=len(filtered_drafts),
- total_sections=len(report_plan.report_outline),
- )
-
- return filtered_drafts
-
- def _get_next_node(self, node_id: str, context: GraphExecutionContext) -> list[str]:
- """Get next node(s) from current node.
-
- Args:
- node_id: Current node ID
- context: Execution context
-
- Returns:
- List of next node IDs
- """
- if not self._graph:
- return []
-
- # Get node result for condition evaluation
- node_result = context.get_node_result(node_id)
-
- # Get next nodes
- next_nodes = self._graph.get_next_nodes(node_id, context=node_result)
-
- # If this was a decision node, use its result
- node = self._graph.get_node(node_id)
- if isinstance(node, DecisionNode):
- decision_result = node_result
- if isinstance(decision_result, str):
- return [decision_result]
-
- # Return next node IDs
- return [next_node_id for next_node_id, _ in next_nodes]
-
- async def _detect_research_mode(self, query: str) -> Literal["iterative", "deep"]:
- """
- Detect research mode from query using input parser agent.
-
- Uses input parser agent to analyze query and determine research mode.
- Falls back to heuristic if parser fails.
-
- Args:
- query: The research query
-
- Returns:
- Detected research mode
- """
- try:
- # Use input parser agent for intelligent mode detection
- input_parser = create_input_parser_agent(oauth_token=self.oauth_token)
- parsed_query = await input_parser.parse(query)
- self.logger.info(
- "Research mode detected by input parser",
- mode=parsed_query.research_mode,
- query=query[:100],
- )
- return parsed_query.research_mode
- except Exception as e:
- # Fallback to heuristic if parser fails
- self.logger.warning(
- "Input parser failed, using heuristic",
- error=str(e),
- query=query[:100],
- )
- query_lower = query.lower()
- if any(
- keyword in query_lower
- for keyword in [
- "section",
- "sections",
- "report",
- "outline",
- "structure",
- "comprehensive",
- "analyze",
- "analysis",
- ]
- ):
- return "deep"
- return "iterative"
-
-
-def create_graph_orchestrator(
- mode: Literal["iterative", "deep", "auto"] = "auto",
- max_iterations: int = 5,
- max_time_minutes: int = 10,
- use_graph: bool = True,
- search_handler: SearchHandlerProtocol | None = None,
- judge_handler: JudgeHandlerProtocol | None = None,
- oauth_token: str | None = None,
-) -> GraphOrchestrator:
- """
- Factory function to create a graph orchestrator.
-
- Args:
- mode: Research mode
- max_iterations: Maximum iterations per loop
- max_time_minutes: Maximum time per loop
- use_graph: Whether to use graph execution (True) or agent chains (False)
- search_handler: Optional search handler for tool execution
- judge_handler: Optional judge handler for evidence assessment
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured GraphOrchestrator instance
- """
- return GraphOrchestrator(
- mode=mode,
- max_iterations=max_iterations,
- max_time_minutes=max_time_minutes,
- use_graph=use_graph,
- search_handler=search_handler,
- judge_handler=judge_handler,
- oauth_token=oauth_token,
- )
diff --git a/src/orchestrator/planner_agent.py b/src/orchestrator/planner_agent.py
deleted file mode 100644
index 410c19ea682d150d17738c15605768e2484b24e6..0000000000000000000000000000000000000000
--- a/src/orchestrator/planner_agent.py
+++ /dev/null
@@ -1,185 +0,0 @@
-"""Planner agent for creating report plans with sections and background context.
-
-Converts the folder/planner_agent.py implementation to use Pydantic AI.
-"""
-
-from datetime import datetime
-from typing import Any
-
-import structlog
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.tools.crawl_adapter import crawl_website
-from src.tools.web_search_adapter import web_search
-from src.utils.exceptions import ConfigurationError, JudgeError
-from src.utils.models import ReportPlan, ReportPlanSection
-
-logger = structlog.get_logger()
-
-
-# System prompt for the planner agent
-SYSTEM_PROMPT = f"""
-You are a research manager, managing a team of research agents. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
-Given a research query, your job is to produce an initial outline of the report (section titles and key questions),
-as well as some background context. Each section will be assigned to a different researcher in your team who will then
-carry out research on the section.
-
-You will be given:
-- An initial research query
-
-Your task is to:
-1. Produce 1-2 paragraphs of initial background context (if needed) on the query by running web searches or crawling websites
-2. Produce an outline of the report that includes a list of section titles and the key question to be addressed in each section
-3. Provide a title for the report that will be used as the main heading
-
-Guidelines:
-- Each section should cover a single topic/question that is independent of other sections
-- The key question for each section should include both the NAME and DOMAIN NAME / WEBSITE (if available and applicable) if it is related to a company, product or similar
-- The background_context should not be more than 2 paragraphs
-- The background_context should be very specific to the query and include any information that is relevant for researchers across all sections of the report
-- The background_context should be drawn only from web search or crawl results rather than prior knowledge (i.e. it should only be included if you have called tools)
-- For example, if the query is about a company, the background context should include some basic information about what the company does
-- DO NOT do more than 2 tool calls
-
-Only output JSON. Follow the JSON schema for ReportPlan. Do not output anything else.
-"""
-
-
-class PlannerAgent:
- """
- Planner agent that creates report plans with sections and background context.
-
- Uses Pydantic AI to generate structured ReportPlan output with optional
- web search and crawl tool usage for background context.
- """
-
- def __init__(
- self,
- model: Any | None = None,
- web_search_tool: Any | None = None,
- crawl_tool: Any | None = None,
- ) -> None:
- """
- Initialize the planner agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses config default.
- web_search_tool: Optional web search tool function. If None, uses default.
- crawl_tool: Optional crawl tool function. If None, uses default.
- """
- self.model = model or get_model()
- self.web_search_tool = web_search_tool or web_search
- self.crawl_tool = crawl_tool or crawl_website
- self.logger = logger
-
- # Validate tools are callable
- if not callable(self.web_search_tool):
- raise ConfigurationError("web_search_tool must be callable")
- if not callable(self.crawl_tool):
- raise ConfigurationError("crawl_tool must be callable")
-
- # Initialize Pydantic AI Agent
- self.agent = Agent(
- model=self.model,
- output_type=ReportPlan,
- system_prompt=SYSTEM_PROMPT,
- tools=[self.web_search_tool, self.crawl_tool],
- retries=3,
- )
-
- async def run(self, query: str) -> ReportPlan:
- """
- Run the planner agent to generate a report plan.
-
- Args:
- query: The user's research query
-
- Returns:
- ReportPlan with sections, background context, and report title
-
- Raises:
- JudgeError: If planning fails after retries
- ConfigurationError: If agent configuration is invalid
- """
- self.logger.info("Starting report planning", query=query[:100])
-
- user_message = f"QUERY: {query}"
-
- try:
- # Run the agent
- result = await self.agent.run(user_message)
- report_plan = result.output
-
- # Validate report plan
- if not report_plan.report_outline:
- self.logger.warning("Report plan has no sections", query=query[:100])
- # Return fallback plan instead of raising error
- return ReportPlan(
- background_context=report_plan.background_context or "",
- report_outline=[
- ReportPlanSection(
- title="Overview",
- key_question=query,
- )
- ],
- report_title=report_plan.report_title or f"Research Report: {query[:50]}",
- )
-
- if not report_plan.report_title:
- self.logger.warning("Report plan has no title", query=query[:100])
- raise JudgeError("Report plan must have a title")
-
- self.logger.info(
- "Report plan created",
- sections=len(report_plan.report_outline),
- has_background=bool(report_plan.background_context),
- )
-
- return report_plan
-
- except Exception as e:
- self.logger.error("Planning failed", error=str(e), query=query[:100])
-
- # Fallback: return minimal report plan
- if isinstance(e, JudgeError | ConfigurationError):
- raise
-
- # For other errors, return a minimal plan
- return ReportPlan(
- background_context="",
- report_outline=[
- ReportPlanSection(
- title="Research Findings",
- key_question=query,
- )
- ],
- report_title=f"Research Report: {query[:50]}",
- )
-
-
-def create_planner_agent(model: Any | None = None, oauth_token: str | None = None) -> PlannerAgent:
- """
- Factory function to create a planner agent.
-
- Args:
- model: Optional Pydantic AI model. If None, uses settings default.
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured PlannerAgent instance
-
- Raises:
- ConfigurationError: If required API keys are missing
- """
- try:
- # Get model from settings if not provided
- if model is None:
- model = get_model(oauth_token=oauth_token)
-
- # Create and return planner agent
- return PlannerAgent(model=model)
-
- except Exception as e:
- logger.error("Failed to create planner agent", error=str(e))
- raise ConfigurationError(f"Failed to create planner agent: {e}") from e
diff --git a/src/orchestrator/research_flow.py b/src/orchestrator/research_flow.py
deleted file mode 100644
index 3ae58be138e570aa16bd10c04f39ec380a18e288..0000000000000000000000000000000000000000
--- a/src/orchestrator/research_flow.py
+++ /dev/null
@@ -1,1111 +0,0 @@
-"""Research flow implementations for iterative and deep research patterns.
-
-Converts the folder/iterative_research.py and folder/deep_research.py
-implementations to use Pydantic AI agents.
-"""
-
-import asyncio
-import time
-from typing import Any
-
-import structlog
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-
-from src.agent_factory.agents import (
- create_graph_orchestrator,
- create_knowledge_gap_agent,
- create_long_writer_agent,
- create_planner_agent,
- create_proofreader_agent,
- create_thinking_agent,
- create_tool_selector_agent,
- create_writer_agent,
-)
-from src.agent_factory.judges import create_judge_handler
-from src.middleware.budget_tracker import BudgetTracker
-from src.middleware.state_machine import get_workflow_state, init_workflow_state
-from src.middleware.workflow_manager import WorkflowManager
-from src.services.llamaindex_rag import LlamaIndexRAGService, get_rag_service
-from src.services.report_file_service import ReportFileService, get_report_file_service
-from src.tools.tool_executor import execute_tool_tasks
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import (
- AgentSelectionPlan,
- AgentTask,
- Citation,
- Conversation,
- Evidence,
- JudgeAssessment,
- KnowledgeGapOutput,
- ReportDraft,
- ReportDraftSection,
- ReportPlan,
- SourceName,
- ToolAgentOutput,
-)
-
-logger = structlog.get_logger()
-
-
-class IterativeResearchFlow:
- """
- Iterative research flow that runs a single research loop.
-
- Pattern: Generate observations → Evaluate gaps → Select tools → Execute → Repeat
- until research is complete or constraints are met.
- """
-
- def __init__(
- self,
- max_iterations: int = 5,
- max_time_minutes: int = 10,
- verbose: bool = True,
- use_graph: bool = False,
- judge_handler: Any | None = None,
- oauth_token: str | None = None,
- ) -> None:
- """
- Initialize iterative research flow.
-
- Args:
- max_iterations: Maximum number of iterations
- max_time_minutes: Maximum time in minutes
- verbose: Whether to log progress
- use_graph: Whether to use graph-based execution (True) or agent chains (False)
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
- """
- self.max_iterations = max_iterations
- self.max_time_minutes = max_time_minutes
- self.verbose = verbose
- self.use_graph = use_graph
- self.oauth_token = oauth_token
- self.logger = logger
-
- # Initialize agents (only needed for agent chain execution)
- if not use_graph:
- self.knowledge_gap_agent = create_knowledge_gap_agent(oauth_token=self.oauth_token)
- self.tool_selector_agent = create_tool_selector_agent(oauth_token=self.oauth_token)
- self.thinking_agent = create_thinking_agent(oauth_token=self.oauth_token)
- self.writer_agent = create_writer_agent(oauth_token=self.oauth_token)
- # Initialize judge handler (use provided or create new)
- self.judge_handler = judge_handler or create_judge_handler()
-
- # Initialize state (only needed for agent chain execution)
- if not use_graph:
- self.conversation = Conversation()
- self.iteration = 0
- self.start_time: float | None = None
- self.should_continue = True
-
- # Initialize budget tracker
- self.budget_tracker = BudgetTracker()
- self.loop_id = "iterative_flow"
- self.budget_tracker.create_budget(
- loop_id=self.loop_id,
- tokens_limit=100000,
- time_limit_seconds=max_time_minutes * 60,
- iterations_limit=max_iterations,
- )
- self.budget_tracker.start_timer(self.loop_id)
-
- # Initialize RAG service (lazy, may be None if unavailable)
- self._rag_service: LlamaIndexRAGService | None = None
-
- # Graph orchestrator (lazy initialization)
- self._graph_orchestrator: Any = None
-
- # File service (lazy initialization)
- self._file_service: ReportFileService | None = None
-
- def _get_file_service(self) -> ReportFileService | None:
- """
- Get file service instance (lazy initialization).
-
- Returns:
- ReportFileService instance or None if disabled
- """
- if self._file_service is None:
- try:
- self._file_service = get_report_file_service()
- except Exception as e:
- self.logger.warning("Failed to initialize file service", error=str(e))
- return None
- return self._file_service
-
- async def run(
- self,
- query: str,
- background_context: str = "",
- output_length: str = "",
- output_instructions: str = "",
- message_history: list[ModelMessage] | None = None,
- ) -> str:
- """
- Run the iterative research flow.
-
- Args:
- query: The research query
- background_context: Optional background context
- output_length: Optional description of desired output length
- output_instructions: Optional additional instructions
- message_history: Optional user conversation history
-
- Returns:
- Final report string
- """
- if self.use_graph:
- return await self._run_with_graph(
- query, background_context, output_length, output_instructions, message_history
- )
- else:
- return await self._run_with_chains(
- query, background_context, output_length, output_instructions, message_history
- )
-
- async def _run_with_chains(
- self,
- query: str,
- background_context: str = "",
- output_length: str = "",
- output_instructions: str = "",
- message_history: list[ModelMessage] | None = None,
- ) -> str:
- """
- Run the iterative research flow using agent chains.
-
- Args:
- query: The research query
- background_context: Optional background context
- output_length: Optional description of desired output length
- output_instructions: Optional additional instructions
- message_history: Optional user conversation history
-
- Returns:
- Final report string
- """
- self.start_time = time.time()
- self.logger.info("Starting iterative research (agent chains)", query=query[:100])
-
- # Initialize conversation with first iteration
- self.conversation.add_iteration()
-
- # Main research loop
- while self.should_continue and self._check_constraints():
- self.iteration += 1
- self.logger.info("Starting iteration", iteration=self.iteration)
-
- # Add new iteration to conversation
- self.conversation.add_iteration()
-
- # 1. Generate observations
- await self._generate_observations(query, background_context, message_history)
-
- # 2. Evaluate gaps
- evaluation = await self._evaluate_gaps(query, background_context, message_history)
-
- # 3. Assess with judge (after tools execute, we'll assess again)
- # For now, check knowledge gap evaluation
- # After tool execution, we'll do a full judge assessment
-
- # Check if research is complete (knowledge gap agent says complete)
- if evaluation.research_complete:
- self.should_continue = False
- self.logger.info("Research marked as complete by knowledge gap agent")
- break
-
- # 4. Select tools for next gap
- next_gap = evaluation.outstanding_gaps[0] if evaluation.outstanding_gaps else query
- selection_plan = await self._select_agents(
- next_gap, query, background_context, message_history
- )
-
- # 5. Execute tools
- await self._execute_tools(selection_plan.tasks)
-
- # 6. Assess evidence sufficiency with judge
- judge_assessment = await self._assess_with_judge(query)
-
- # Check if judge says evidence is sufficient
- if judge_assessment.sufficient:
- self.should_continue = False
- self.logger.info(
- "Research marked as complete by judge",
- confidence=judge_assessment.confidence,
- reasoning=judge_assessment.reasoning[:100],
- )
- break
-
- # Update budget tracker
- self.budget_tracker.increment_iteration(self.loop_id)
- self.budget_tracker.update_timer(self.loop_id)
-
- # Create final report
- report = await self._create_final_report(query, output_length, output_instructions)
-
- elapsed = time.time() - (self.start_time or time.time())
- self.logger.info(
- "Iterative research completed",
- iterations=self.iteration,
- elapsed_minutes=elapsed / 60,
- )
-
- return report
-
- async def _run_with_graph(
- self,
- query: str,
- background_context: str = "",
- output_length: str = "",
- output_instructions: str = "",
- message_history: list[ModelMessage] | None = None,
- ) -> str:
- """
- Run the iterative research flow using graph execution.
-
- Args:
- query: The research query
- background_context: Optional background context (currently ignored in graph execution)
- output_length: Optional description of desired output length (currently ignored in graph execution)
- output_instructions: Optional additional instructions (currently ignored in graph execution)
-
- Returns:
- Final report string
- """
- self.logger.info("Starting iterative research (graph execution)", query=query[:100])
-
- # Create graph orchestrator (lazy initialization)
- if self._graph_orchestrator is None:
- self._graph_orchestrator = create_graph_orchestrator(
- mode="iterative",
- max_iterations=self.max_iterations,
- max_time_minutes=self.max_time_minutes,
- use_graph=True,
- )
-
- # Run orchestrator and collect events
- final_report = ""
- async for event in self._graph_orchestrator.run(query):
- if event.type == "complete":
- final_report = event.message
- break
- elif event.type == "error":
- self.logger.error("Graph execution error", error=event.message)
- raise RuntimeError(f"Graph execution failed: {event.message}")
-
- if not final_report:
- self.logger.warning("No complete event received from graph orchestrator")
- final_report = "Research completed but no report was generated."
-
- self.logger.info("Iterative research completed (graph execution)")
-
- return final_report
-
- def _check_constraints(self) -> bool:
- """Check if we've exceeded constraints."""
- if self.iteration >= self.max_iterations:
- self.logger.info("Max iterations reached", max=self.max_iterations)
- return False
-
- if self.start_time:
- elapsed_minutes = (time.time() - self.start_time) / 60
- if elapsed_minutes >= self.max_time_minutes:
- self.logger.info("Max time reached", max=self.max_time_minutes)
- return False
-
- # Check budget tracker
- self.budget_tracker.update_timer(self.loop_id)
- exceeded, reason = self.budget_tracker.check_budget(self.loop_id)
- if exceeded:
- self.logger.info("Budget exceeded", reason=reason)
- return False
-
- return True
-
- async def _generate_observations(
- self,
- query: str,
- background_context: str = "",
- message_history: list[ModelMessage] | None = None,
- ) -> str:
- """Generate observations from current research state."""
- # Build input prompt for token estimation
- conversation_history = self.conversation.compile_conversation_history()
- # Build background context section separately to avoid backslash in f-string
- background_section = (
- f"BACKGROUND CONTEXT:\n{background_context}\n\n" if background_context else ""
- )
- input_prompt = f"""
-You are starting iteration {self.iteration} of your research process.
-
-ORIGINAL QUERY:
-{query}
-
-{background_section}HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
-{conversation_history or "No previous actions, findings or thoughts available."}
-"""
-
- observations = await self.thinking_agent.generate_observations(
- query=query,
- background_context=background_context,
- conversation_history=conversation_history,
- message_history=message_history,
- iteration=self.iteration,
- )
-
- # Track tokens for this iteration
- estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, observations)
- self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
- self.logger.debug(
- "Tokens tracked for thinking agent",
- iteration=self.iteration,
- tokens=estimated_tokens,
- )
-
- self.conversation.set_latest_thought(observations)
- return observations
-
- async def _evaluate_gaps(
- self,
- query: str,
- background_context: str = "",
- message_history: list[ModelMessage] | None = None,
- ) -> KnowledgeGapOutput:
- """Evaluate knowledge gaps in current research."""
- if self.start_time:
- elapsed_minutes = (time.time() - self.start_time) / 60
- else:
- elapsed_minutes = 0.0
-
- # Build input prompt for token estimation
- conversation_history = self.conversation.compile_conversation_history()
- background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
- input_prompt = f"""
-Current Iteration Number: {self.iteration}
-Time Elapsed: {elapsed_minutes:.2f} minutes of maximum {self.max_time_minutes} minutes
-
-ORIGINAL QUERY:
-{query}
-
-{background}
-
-HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
-{conversation_history or "No previous actions, findings or thoughts available."}
-"""
-
- evaluation = await self.knowledge_gap_agent.evaluate(
- query=query,
- background_context=background_context,
- conversation_history=conversation_history,
- message_history=message_history,
- iteration=self.iteration,
- time_elapsed_minutes=elapsed_minutes,
- max_time_minutes=self.max_time_minutes,
- )
-
- # Track tokens for this iteration
- evaluation_text = f"research_complete={evaluation.research_complete}, gaps={len(evaluation.outstanding_gaps)}"
- estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
- input_prompt, evaluation_text
- )
- self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
- self.logger.debug(
- "Tokens tracked for knowledge gap agent",
- iteration=self.iteration,
- tokens=estimated_tokens,
- )
-
- if not evaluation.research_complete and evaluation.outstanding_gaps:
- self.conversation.set_latest_gap(evaluation.outstanding_gaps[0])
-
- return evaluation
-
- async def _assess_with_judge(self, query: str) -> JudgeAssessment:
- """Assess evidence sufficiency using JudgeHandler.
-
- Args:
- query: The research query
-
- Returns:
- JudgeAssessment with sufficiency evaluation
- """
- state = get_workflow_state()
- evidence = state.evidence # Get all collected evidence
-
- self.logger.info(
- "Assessing evidence with judge",
- query=query[:100],
- evidence_count=len(evidence),
- )
-
- assessment = await self.judge_handler.assess(query, evidence)
-
- # Track tokens for judge call
- # Estimate tokens from query + evidence + assessment
- evidence_text = "\n".join([e.content[:500] for e in evidence[:10]]) # Sample
- estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
- query + evidence_text, str(assessment.reasoning)
- )
- self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
-
- self.logger.info(
- "Judge assessment complete",
- sufficient=assessment.sufficient,
- confidence=assessment.confidence,
- recommendation=assessment.recommendation,
- )
-
- return assessment
-
- async def _select_agents(
- self,
- gap: str,
- query: str,
- background_context: str = "",
- message_history: list[ModelMessage] | None = None,
- ) -> AgentSelectionPlan:
- """Select tools to address knowledge gap."""
- # Build input prompt for token estimation
- conversation_history = self.conversation.compile_conversation_history()
- background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
- input_prompt = f"""
-ORIGINAL QUERY:
-{query}
-
-KNOWLEDGE GAP TO ADDRESS:
-{gap}
-
-{background}
-
-HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
-{conversation_history or "No previous actions, findings or thoughts available."}
-"""
-
- selection_plan = await self.tool_selector_agent.select_tools(
- gap=gap,
- query=query,
- background_context=background_context,
- conversation_history=conversation_history,
- message_history=message_history,
- )
-
- # Track tokens for this iteration
- selection_text = f"tasks={len(selection_plan.tasks)}, agents={[task.agent for task in selection_plan.tasks]}"
- estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
- input_prompt, selection_text
- )
- self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
- self.logger.debug(
- "Tokens tracked for tool selector agent",
- iteration=self.iteration,
- tokens=estimated_tokens,
- )
-
- # Store tool calls in conversation
- tool_calls = [
- f"[Agent] {task.agent} [Query] {task.query} [Entity] {task.entity_website or 'null'}"
- for task in selection_plan.tasks
- ]
- self.conversation.set_latest_tool_calls(tool_calls)
-
- return selection_plan
-
- def _get_rag_service(self) -> LlamaIndexRAGService | None:
- """
- Get or create RAG service instance.
-
- Returns:
- RAG service instance, or None if unavailable
- """
- if self._rag_service is None:
- try:
- self._rag_service = get_rag_service(oauth_token=self.oauth_token)
- self.logger.info("RAG service initialized for research flow")
- except (ConfigurationError, ImportError) as e:
- self.logger.warning(
- "RAG service unavailable", error=str(e), hint="OPENAI_API_KEY required"
- )
- return None
- return self._rag_service
-
- async def _execute_tools(self, tasks: list[AgentTask]) -> dict[str, ToolAgentOutput]:
- """Execute selected tools concurrently."""
- try:
- results = await execute_tool_tasks(tasks)
- except Exception as e:
- # Handle tool execution errors gracefully
- self.logger.error(
- "Tool execution failed",
- error=str(e),
- task_count=len(tasks),
- exc_info=True,
- )
- # Return empty results to allow research flow to continue
- # The flow can still generate a report based on previous iterations
- results = {}
-
- # Store findings in conversation (only if we have results)
- evidence_list: list[Evidence] = []
- if results:
- findings = [result.output for result in results.values()]
- self.conversation.set_latest_findings(findings)
-
- # Convert tool outputs to Evidence objects and store in workflow state
- evidence_list = self._convert_tool_outputs_to_evidence(results)
-
- if evidence_list:
- state = get_workflow_state()
- added_count = state.add_evidence(evidence_list)
- self.logger.info(
- "Evidence added to workflow state",
- count=added_count,
- total_evidence=len(state.evidence),
- )
-
- # Ingest evidence into RAG if available (Phase 6 requirement)
- rag_service = self._get_rag_service()
- if rag_service is not None:
- try:
- # ingest_evidence is synchronous, run in executor to avoid blocking
- loop = asyncio.get_event_loop()
- await loop.run_in_executor(None, rag_service.ingest_evidence, evidence_list)
- self.logger.info(
- "Evidence ingested into RAG",
- count=len(evidence_list),
- )
- except Exception as e:
- # Don't fail the research loop if RAG ingestion fails
- self.logger.warning(
- "Failed to ingest evidence into RAG",
- error=str(e),
- count=len(evidence_list),
- )
-
- return results
-
- def _convert_tool_outputs_to_evidence(
- self, tool_results: dict[str, ToolAgentOutput]
- ) -> list[Evidence]:
- """Convert ToolAgentOutput to Evidence objects.
-
- Args:
- tool_results: Dictionary of tool execution results
-
- Returns:
- List of Evidence objects
- """
- evidence_list = []
- for key, result in tool_results.items():
- # Extract URLs from sources
- if result.sources:
- # Create one Evidence object per source URL
- for url in result.sources:
- # Determine source type from URL or tool name
- # Default to "web" for unknown web sources
- source_type: SourceName = "web"
- if "pubmed" in url.lower() or "ncbi" in url.lower():
- source_type = "pubmed"
- elif "clinicaltrials" in url.lower():
- source_type = "clinicaltrials"
- elif "europepmc" in url.lower():
- source_type = "europepmc"
- elif "biorxiv" in url.lower():
- source_type = "biorxiv"
- elif "arxiv" in url.lower() or "preprint" in url.lower():
- source_type = "preprint"
- # Note: "web" is now a valid SourceName for general web sources
-
- citation = Citation(
- title=f"Tool Result: {key}",
- url=url,
- source=source_type,
- date="n.d.",
- authors=[],
- )
- # Truncate content to reasonable length for judge (1500 chars)
- content = result.output[:1500]
- if len(result.output) > 1500:
- content += "... [truncated]"
-
- evidence = Evidence(
- content=content,
- citation=citation,
- relevance=0.5, # Default relevance
- )
- evidence_list.append(evidence)
- else:
- # No URLs, create a single Evidence object with tool output
- # Use a placeholder URL based on the tool name
- # Determine source type from tool name
- tool_source_type: SourceName = "web" # Default for unknown sources
- if "RAG" in key:
- tool_source_type = "rag"
- elif "WebSearch" in key or "SiteCrawler" in key:
- tool_source_type = "web"
- # "web" is now a valid SourceName for general web sources
-
- citation = Citation(
- title=f"Tool Result: {key}",
- url=f"tool://{key}",
- source=tool_source_type,
- date="n.d.",
- authors=[],
- )
- content = result.output[:1500]
- if len(result.output) > 1500:
- content += "... [truncated]"
-
- evidence = Evidence(
- content=content,
- citation=citation,
- relevance=0.5,
- )
- evidence_list.append(evidence)
-
- return evidence_list
-
- async def _create_final_report(
- self, query: str, length: str = "", instructions: str = ""
- ) -> str:
- """Create final report from all findings."""
- all_findings = "\n\n".join(self.conversation.get_all_findings())
- if not all_findings:
- all_findings = "No findings available yet."
-
- # Build input prompt for token estimation
- length_str = f"* The full response should be approximately {length}.\n" if length else ""
- instructions_str = f"* {instructions}" if instructions else ""
- guidelines_str = (
- ("\n\nGUIDELINES:\n" + length_str + instructions_str).strip("\n")
- if length or instructions
- else ""
- )
- input_prompt = f"""
-Provide a response based on the query and findings below with as much detail as possible. {guidelines_str}
-
-QUERY: {query}
-
-FINDINGS:
-{all_findings}
-"""
-
- report = await self.writer_agent.write_report(
- query=query,
- findings=all_findings,
- output_length=length,
- output_instructions=instructions,
- )
-
- # Track tokens for final report (not per iteration, just total)
- estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, report)
- self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
- self.logger.debug(
- "Tokens tracked for writer agent (final report)",
- tokens=estimated_tokens,
- )
-
- # Save report to file if enabled
- try:
- file_service = self._get_file_service()
- if file_service:
- file_path = file_service.save_report(
- report_content=report,
- query=query,
- )
- self.logger.info("Report saved to file", file_path=file_path)
- except Exception as e:
- # Don't fail the entire operation if file saving fails
- self.logger.warning("Failed to save report to file", error=str(e))
-
- # Note: Citation validation for markdown reports would require Evidence objects
- # Currently, findings are strings, not Evidence objects. For full validation,
- # consider using ResearchReport format or passing Evidence objects separately.
- # See src/utils/citation_validator.py for markdown citation validation utilities.
-
- return report
-
-
-class DeepResearchFlow:
- """
- Deep research flow that runs parallel iterative loops per section.
-
- Pattern: Plan → Parallel Iterative Loops (one per section) → Synthesis
- """
-
- def __init__(
- self,
- max_iterations: int = 5,
- max_time_minutes: int = 10,
- verbose: bool = True,
- use_long_writer: bool = True,
- use_graph: bool = False,
- oauth_token: str | None = None,
- ) -> None:
- """
- Initialize deep research flow.
-
- Args:
- max_iterations: Maximum iterations per section
- max_time_minutes: Maximum time per section
- verbose: Whether to log progress
- use_long_writer: Whether to use long writer (True) or proofreader (False)
- use_graph: Whether to use graph-based execution (True) or agent chains (False)
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
- """
- self.max_iterations = max_iterations
- self.max_time_minutes = max_time_minutes
- self.verbose = verbose
- self.use_long_writer = use_long_writer
- self.use_graph = use_graph
- self.oauth_token = oauth_token
- self.logger = logger
-
- # Initialize agents (only needed for agent chain execution)
- if not use_graph:
- self.planner_agent = create_planner_agent(oauth_token=self.oauth_token)
- self.long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token)
- self.proofreader_agent = create_proofreader_agent(oauth_token=self.oauth_token)
- # Initialize judge handler for section loop completion
- self.judge_handler = create_judge_handler()
- # Initialize budget tracker for token tracking
- self.budget_tracker = BudgetTracker()
- self.loop_id = "deep_research_flow"
- self.budget_tracker.create_budget(
- loop_id=self.loop_id,
- tokens_limit=200000, # Higher limit for deep research
- time_limit_seconds=max_time_minutes
- * 60
- * 2, # Allow more time for parallel sections
- iterations_limit=max_iterations * 10, # Allow for multiple sections
- )
- self.budget_tracker.start_timer(self.loop_id)
-
- # Graph orchestrator (lazy initialization)
- self._graph_orchestrator: Any = None
-
- # File service (lazy initialization)
- self._file_service: ReportFileService | None = None
-
- def _get_file_service(self) -> ReportFileService | None:
- """
- Get file service instance (lazy initialization).
-
- Returns:
- ReportFileService instance or None if disabled
- """
- if self._file_service is None:
- try:
- self._file_service = get_report_file_service()
- except Exception as e:
- self.logger.warning("Failed to initialize file service", error=str(e))
- return None
- return self._file_service
-
- async def run(self, query: str, message_history: list[ModelMessage] | None = None) -> str:
- """
- Run the deep research flow.
-
- Args:
- query: The research query
- message_history: Optional user conversation history
-
- Returns:
- Final report string
- """
- if self.use_graph:
- return await self._run_with_graph(query, message_history)
- else:
- return await self._run_with_chains(query, message_history)
-
- async def _run_with_chains(
- self, query: str, message_history: list[ModelMessage] | None = None
- ) -> str:
- """
- Run the deep research flow using agent chains.
-
- Args:
- query: The research query
- message_history: Optional user conversation history
-
- Returns:
- Final report string
- """
- self.logger.info("Starting deep research (agent chains)", query=query[:100])
-
- # Initialize workflow state for deep research
- try:
- from src.services.embeddings import get_embedding_service
-
- embedding_service = get_embedding_service()
- except (ImportError, Exception):
- # If embedding service is unavailable, initialize without it
- embedding_service = None
- self.logger.debug("Embedding service unavailable, initializing state without it")
-
- init_workflow_state(embedding_service=embedding_service, message_history=message_history)
- self.logger.debug("Workflow state initialized for deep research")
-
- # 1. Build report plan
- report_plan = await self._build_report_plan(query, message_history)
- self.logger.info(
- "Report plan created",
- sections=len(report_plan.report_outline),
- title=report_plan.report_title,
- )
-
- # 2. Run parallel research loops with state synchronization
- section_drafts = await self._run_research_loops(report_plan, message_history)
-
- # Verify state synchronization - log evidence count
- state = get_workflow_state()
- self.logger.info(
- "State synchronization complete",
- total_evidence=len(state.evidence),
- sections_completed=len(section_drafts),
- )
-
- # 3. Create final report
- final_report = await self._create_final_report(query, report_plan, section_drafts)
-
- self.logger.info(
- "Deep research completed",
- sections=len(section_drafts),
- final_report_length=len(final_report),
- )
-
- return final_report
-
- async def _run_with_graph(
- self, query: str, message_history: list[ModelMessage] | None = None
- ) -> str:
- """
- Run the deep research flow using graph execution.
-
- Args:
- query: The research query
- message_history: Optional user conversation history
-
- Returns:
- Final report string
- """
- self.logger.info("Starting deep research (graph execution)", query=query[:100])
-
- # Create graph orchestrator (lazy initialization)
- if self._graph_orchestrator is None:
- self._graph_orchestrator = create_graph_orchestrator(
- mode="deep",
- max_iterations=self.max_iterations,
- max_time_minutes=self.max_time_minutes,
- use_graph=True,
- )
-
- # Run orchestrator and collect events
- final_report = ""
- async for event in self._graph_orchestrator.run(query, message_history=message_history):
- if event.type == "complete":
- final_report = event.message
- break
- elif event.type == "error":
- self.logger.error("Graph execution error", error=event.message)
- raise RuntimeError(f"Graph execution failed: {event.message}")
-
- if not final_report:
- self.logger.warning("No complete event received from graph orchestrator")
- final_report = "Research completed but no report was generated."
-
- self.logger.info("Deep research completed (graph execution)")
-
- return final_report
-
- async def _build_report_plan(
- self, query: str, message_history: list[ModelMessage] | None = None
- ) -> ReportPlan:
- """Build the initial report plan."""
- self.logger.info("Building report plan")
-
- # Build input prompt for token estimation
- input_prompt = f"QUERY: {query}"
-
- # Planner agent may not support message_history yet, so we'll pass it if available
- # For now, just use the standard run() call
- report_plan = await self.planner_agent.run(query)
-
- # Track tokens for planner agent
- if not self.use_graph and hasattr(self, "budget_tracker"):
- plan_text = (
- f"title={report_plan.report_title}, sections={len(report_plan.report_outline)}"
- )
- estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, plan_text)
- self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
- self.logger.debug(
- "Tokens tracked for planner agent",
- tokens=estimated_tokens,
- )
-
- self.logger.info(
- "Report plan created",
- sections=len(report_plan.report_outline),
- has_background=bool(report_plan.background_context),
- )
-
- return report_plan
-
- async def _run_research_loops(
- self, report_plan: ReportPlan, message_history: list[ModelMessage] | None = None
- ) -> list[str]:
- """Run parallel iterative research loops for each section."""
- self.logger.info("Running research loops", sections=len(report_plan.report_outline))
-
- # Create workflow manager for parallel execution
- workflow_manager = WorkflowManager()
-
- # Create loop configurations
- loop_configs = [
- {
- "loop_id": f"section_{i}",
- "query": section.key_question,
- "section_title": section.title,
- "background_context": report_plan.background_context,
- }
- for i, section in enumerate(report_plan.report_outline)
- ]
-
- async def run_research_for_section(config: dict[str, Any]) -> str:
- """Run iterative research for a single section."""
- loop_id = config.get("loop_id", "unknown")
- query = config.get("query", "")
- background_context = config.get("background_context", "")
-
- try:
- # Update loop status
- await workflow_manager.update_loop_status(loop_id, "running")
-
- # Create iterative research flow
- flow = IterativeResearchFlow(
- max_iterations=self.max_iterations,
- max_time_minutes=self.max_time_minutes,
- verbose=self.verbose,
- use_graph=self.use_graph,
- judge_handler=self.judge_handler if not self.use_graph else None,
- )
-
- # Run research with message_history
- result = await flow.run(
- query=query,
- background_context=background_context,
- message_history=message_history,
- )
-
- # Sync evidence from flow to loop
- state = get_workflow_state()
- if state.evidence:
- await workflow_manager.add_loop_evidence(loop_id, state.evidence)
-
- # Update loop status
- await workflow_manager.update_loop_status(loop_id, "completed")
-
- return result
-
- except Exception as e:
- error_msg = str(e)
- await workflow_manager.update_loop_status(loop_id, "failed", error=error_msg)
- self.logger.error(
- "Section research failed",
- loop_id=loop_id,
- error=error_msg,
- )
- raise
-
- # Run all sections in parallel using workflow manager
- section_drafts = await workflow_manager.run_loops_parallel(
- loop_configs=loop_configs,
- loop_func=run_research_for_section,
- judge_handler=self.judge_handler if not self.use_graph else None,
- budget_tracker=self.budget_tracker if not self.use_graph else None,
- )
-
- # Sync evidence from all loops to global state
- for config in loop_configs:
- loop_id = config.get("loop_id")
- if loop_id:
- await workflow_manager.sync_loop_evidence_to_state(loop_id)
-
- # Filter out None results (failed loops)
- section_drafts = [draft for draft in section_drafts if draft is not None]
-
- self.logger.info(
- "Research loops completed",
- drafts=len(section_drafts),
- total_sections=len(report_plan.report_outline),
- )
-
- return section_drafts
-
- async def _create_final_report(
- self, query: str, report_plan: ReportPlan, section_drafts: list[str]
- ) -> str:
- """Create final report from section drafts."""
- self.logger.info("Creating final report")
-
- # Create ReportDraft from section drafts
- report_draft = ReportDraft(
- sections=[
- ReportDraftSection(
- section_title=section.title,
- section_content=draft,
- )
- for section, draft in zip(report_plan.report_outline, section_drafts, strict=False)
- ]
- )
-
- # Build input prompt for token estimation
- draft_text = "\n".join(
- [s.section_content[:500] for s in report_draft.sections[:5]]
- ) # Sample
- input_prompt = f"QUERY: {query}\nTITLE: {report_plan.report_title}\nDRAFT: {draft_text}"
-
- if self.use_long_writer:
- # Use long writer agent
- final_report = await self.long_writer_agent.write_report(
- original_query=query,
- report_title=report_plan.report_title,
- report_draft=report_draft,
- )
- else:
- # Use proofreader agent
- final_report = await self.proofreader_agent.proofread(
- query=query,
- report_draft=report_draft,
- )
-
- # Track tokens for final report synthesis
- if not self.use_graph and hasattr(self, "budget_tracker"):
- estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
- input_prompt, final_report
- )
- self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
- self.logger.debug(
- "Tokens tracked for final report synthesis",
- tokens=estimated_tokens,
- agent="long_writer" if self.use_long_writer else "proofreader",
- )
-
- # Save report to file if enabled
- try:
- file_service = self._get_file_service()
- if file_service:
- file_path = file_service.save_report(
- report_content=final_report,
- query=query,
- )
- self.logger.info("Report saved to file", file_path=file_path)
- except Exception as e:
- # Don't fail the entire operation if file saving fails
- self.logger.warning("Failed to save report to file", error=str(e))
-
- self.logger.info("Final report created", length=len(final_report))
-
- return final_report
diff --git a/src/orchestrator_factory.py b/src/orchestrator_factory.py
deleted file mode 100644
index 44f8b93a5cd75a600a040258ab3da24f0ea9d2f6..0000000000000000000000000000000000000000
--- a/src/orchestrator_factory.py
+++ /dev/null
@@ -1,115 +0,0 @@
-"""Factory for creating orchestrators."""
-
-from typing import Any, Literal
-
-import structlog
-
-from src.legacy_orchestrator import (
- JudgeHandlerProtocol,
- Orchestrator,
- SearchHandlerProtocol,
-)
-from src.utils.config import settings
-from src.utils.models import OrchestratorConfig
-
-logger = structlog.get_logger()
-
-
-def _get_magentic_orchestrator_class() -> Any:
- """Import MagenticOrchestrator lazily to avoid hard dependency."""
- try:
- from src.orchestrator_magentic import MagenticOrchestrator
-
- return MagenticOrchestrator
- except ImportError as e:
- logger.error("Failed to import MagenticOrchestrator", error=str(e))
- raise ValueError(
- "Advanced mode requires agent-framework-core. Please install it or use mode='simple'."
- ) from e
-
-
-def _get_graph_orchestrator_factory() -> Any:
- """Import create_graph_orchestrator lazily to avoid circular dependencies."""
- try:
- from src.orchestrator.graph_orchestrator import create_graph_orchestrator
-
- return create_graph_orchestrator
- except ImportError as e:
- logger.error("Failed to import create_graph_orchestrator", error=str(e))
- raise ValueError(
- "Graph orchestrators require Pydantic Graph. Please check dependencies."
- ) from e
-
-
-def create_orchestrator(
- search_handler: SearchHandlerProtocol | None = None,
- judge_handler: JudgeHandlerProtocol | None = None,
- config: OrchestratorConfig | None = None,
- mode: Literal["simple", "magentic", "advanced", "iterative", "deep", "auto"] | None = None,
- oauth_token: str | None = None,
-) -> Any:
- """
- Create an orchestrator instance.
-
- Args:
- search_handler: The search handler (required for simple mode)
- judge_handler: The judge handler (required for simple mode)
- config: Optional configuration
- mode: Orchestrator mode - "simple", "advanced", "iterative", "deep", "auto", or None (auto-detect)
- - "simple": Linear search-judge loop (Free Tier)
- - "advanced": Multi-agent coordination (Requires OpenAI)
- - "iterative": Knowledge-gap-driven research (Free Tier)
- - "deep": Parallel section-based research (Free Tier)
- - "auto": Intelligent mode detection (Free Tier)
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Orchestrator instance
- """
- effective_mode = _determine_mode(mode)
- logger.info("Creating orchestrator", mode=effective_mode)
-
- if effective_mode == "advanced":
- orchestrator_cls = _get_magentic_orchestrator_class()
- return orchestrator_cls(
- max_rounds=config.max_iterations if config else 10,
- )
-
- # Graph-based orchestrators (iterative, deep, auto)
- if effective_mode in ("iterative", "deep", "auto"):
- create_graph_orchestrator = _get_graph_orchestrator_factory()
- return create_graph_orchestrator(
- mode=effective_mode, # type: ignore[arg-type]
- max_iterations=config.max_iterations if config else 5,
- max_time_minutes=10,
- use_graph=True,
- search_handler=search_handler,
- judge_handler=judge_handler,
- oauth_token=oauth_token,
- )
-
- # Simple mode requires handlers
- if search_handler is None or judge_handler is None:
- raise ValueError("Simple mode requires search_handler and judge_handler")
-
- return Orchestrator(
- search_handler=search_handler,
- judge_handler=judge_handler,
- config=config,
- )
-
-
-def _determine_mode(explicit_mode: str | None) -> str:
- """Determine which mode to use."""
- if explicit_mode:
- if explicit_mode in ("magentic", "advanced"):
- return "advanced"
- if explicit_mode in ("iterative", "deep", "auto"):
- return explicit_mode
- return "simple"
-
- # Auto-detect: advanced if paid API key available, otherwise simple
- if settings.has_openai_key:
- return "advanced"
-
- return "simple"
diff --git a/src/orchestrator_hierarchical.py b/src/orchestrator_hierarchical.py
deleted file mode 100644
index a7bfb85adc874c60bd7ab84c49e0990cc2f1a620..0000000000000000000000000000000000000000
--- a/src/orchestrator_hierarchical.py
+++ /dev/null
@@ -1,109 +0,0 @@
-"""Hierarchical orchestrator using middleware and sub-teams."""
-
-import asyncio
-from collections.abc import AsyncGenerator
-from typing import Any
-
-import structlog
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-
-from src.agents.judge_agent_llm import LLMSubIterationJudge
-from src.agents.magentic_agents import create_search_agent
-from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam
-from src.services.embeddings import get_embedding_service
-from src.state import init_magentic_state
-from src.utils.models import AgentEvent
-
-logger = structlog.get_logger()
-
-
-class ResearchTeam(SubIterationTeam):
- """Adapts Magentic ChatAgent to SubIterationTeam protocol."""
-
- def __init__(self) -> None:
- self.agent = create_search_agent()
-
- async def execute(self, task: str) -> str:
- response = await self.agent.run(task)
- if response.messages:
- for msg in reversed(response.messages):
- if msg.role == "assistant" and msg.text:
- return str(msg.text)
- return "No response from agent."
-
-
-class HierarchicalOrchestrator:
- """Orchestrator that uses hierarchical teams and sub-iterations."""
-
- def __init__(self) -> None:
- self.team = ResearchTeam()
- self.judge = LLMSubIterationJudge()
- self.middleware = SubIterationMiddleware(self.team, self.judge, max_iterations=5)
-
- async def run(
- self, query: str, message_history: list[ModelMessage] | None = None
- ) -> AsyncGenerator[AgentEvent, None]:
- logger.info(
- "Starting hierarchical orchestrator",
- query=query,
- has_history=bool(message_history),
- )
-
- try:
- service = get_embedding_service()
- init_magentic_state(service)
- except Exception as e:
- logger.warning(
- "Embedding service initialization failed, using default state",
- error=str(e),
- )
- init_magentic_state()
-
- yield AgentEvent(type="started", message=f"Starting research: {query}")
-
- queue: asyncio.Queue[AgentEvent | None] = asyncio.Queue()
-
- async def event_callback(event: AgentEvent) -> None:
- await queue.put(event)
-
- # Note: middleware.run() may not support message_history yet
- # Pass query for now, message_history can be added to middleware later if needed
- task_future = asyncio.create_task(self.middleware.run(query, event_callback))
-
- while not task_future.done():
- get_event = asyncio.create_task(queue.get())
- done, _ = await asyncio.wait(
- {task_future, get_event}, return_when=asyncio.FIRST_COMPLETED
- )
-
- if get_event in done:
- event = get_event.result()
- if event:
- yield event
- else:
- get_event.cancel()
-
- # Process remaining events
- while not queue.empty():
- ev = queue.get_nowait()
- if ev:
- yield ev
-
- try:
- result, assessment = await task_future
-
- assessment_text = assessment.reasoning if assessment else "None"
- yield AgentEvent(
- type="complete",
- message=(
- f"Research complete.\n\nResult:\n{result}\n\nAssessment:\n{assessment_text}"
- ),
- data={"assessment": assessment.model_dump() if assessment else None},
- )
- except Exception as e:
- logger.error("Orchestrator failed", error=str(e))
- yield AgentEvent(type="error", message=f"Orchestrator failed: {e}")
diff --git a/src/orchestrator_magentic.py b/src/orchestrator_magentic.py
deleted file mode 100644
index 416d896830c42c7e1ff5a1304db3eb06cb00dc4b..0000000000000000000000000000000000000000
--- a/src/orchestrator_magentic.py
+++ /dev/null
@@ -1,271 +0,0 @@
-"""Magentic-based orchestrator using ChatAgent pattern."""
-
-from collections.abc import AsyncGenerator
-from typing import TYPE_CHECKING, Any
-
-import structlog
-
-try:
- from pydantic_ai import ModelMessage
-except ImportError:
- ModelMessage = Any # type: ignore[assignment, misc]
-from agent_framework import (
- MagenticAgentDeltaEvent,
- MagenticAgentMessageEvent,
- MagenticBuilder,
- MagenticFinalResultEvent,
- MagenticOrchestratorMessageEvent,
- WorkflowOutputEvent,
-)
-
-from src.agents.magentic_agents import (
- create_hypothesis_agent,
- create_judge_agent,
- create_report_agent,
- create_search_agent,
-)
-from src.agents.state import init_magentic_state
-from src.utils.llm_factory import check_magentic_requirements, get_chat_client_for_agent
-from src.utils.models import AgentEvent
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
-
-logger = structlog.get_logger()
-
-
-class MagenticOrchestrator:
- """
- Magentic-based orchestrator using ChatAgent pattern.
-
- Each agent has an internal LLM that understands natural language
- instructions from the manager and can call tools appropriately.
- """
-
- def __init__(
- self,
- max_rounds: int = 10,
- chat_client: Any | None = None,
- ) -> None:
- """Initialize orchestrator.
-
- Args:
- max_rounds: Maximum coordination rounds
- chat_client: Optional shared chat client for agents.
- If None, uses factory default (HuggingFace preferred, OpenAI fallback)
- """
- # Validate requirements via centralized factory
- check_magentic_requirements()
-
- self._max_rounds = max_rounds
- self._chat_client = chat_client
-
- def _init_embedding_service(self) -> "EmbeddingService | None":
- """Initialize embedding service if available."""
- try:
- from src.services.embeddings import get_embedding_service
-
- service = get_embedding_service()
- logger.info("Embedding service enabled")
- return service
- except ImportError:
- logger.info("Embedding service not available (dependencies missing)")
- except Exception as e:
- logger.warning("Failed to initialize embedding service", error=str(e))
- return None
-
- def _build_workflow(self) -> Any:
- """Build the Magentic workflow with ChatAgent participants."""
- # Create agents with internal LLMs
- search_agent = create_search_agent(self._chat_client)
- judge_agent = create_judge_agent(self._chat_client)
- hypothesis_agent = create_hypothesis_agent(self._chat_client)
- report_agent = create_report_agent(self._chat_client)
-
- # Manager chat client (orchestrates the agents)
- # Use same client type as agents for consistency
- manager_client = self._chat_client or get_chat_client_for_agent()
-
- return (
- MagenticBuilder()
- .participants(
- searcher=search_agent,
- hypothesizer=hypothesis_agent,
- judge=judge_agent,
- reporter=report_agent,
- )
- .with_standard_manager(
- chat_client=manager_client,
- max_round_count=self._max_rounds,
- max_stall_count=3,
- max_reset_count=2,
- )
- .build()
- )
-
- async def run(
- self, query: str, message_history: list[ModelMessage] | None = None
- ) -> AsyncGenerator[AgentEvent, None]:
- """
- Run the Magentic workflow.
-
- Args:
- query: User's research question
- message_history: Optional user conversation history (for compatibility)
-
- Yields:
- AgentEvent objects for real-time UI updates
- """
- logger.info(
- "Starting Magentic orchestrator",
- query=query,
- has_history=bool(message_history),
- )
-
- yield AgentEvent(
- type="started",
- message=f"Starting research (Magentic mode): {query}",
- iteration=0,
- )
-
- # Initialize context state
- embedding_service = self._init_embedding_service()
- init_magentic_state(embedding_service)
-
- workflow = self._build_workflow()
-
- # Include conversation history context if provided
- history_context = ""
- if message_history:
- # Convert message history to string context for task
- from src.utils.message_history import message_history_to_string
-
- history_str = message_history_to_string(message_history, max_messages=5)
- if history_str:
- history_context = f"\n\nPrevious conversation context:\n{history_str}"
-
- task = f"""Research query: {query}{history_context}
-
-Workflow:
-1. SearchAgent: Find evidence from available sources (automatically selects: web search, PubMed, ClinicalTrials.gov, Europe PMC, or RAG based on query)
-2. HypothesisAgent: Generate research hypotheses and questions based on evidence
-3. JudgeAgent: Evaluate if evidence is sufficient to answer the query precisely
-4. If insufficient -> SearchAgent refines search based on identified gaps
-5. If sufficient -> ReportAgent synthesizes final comprehensive report
-
-Focus on:
-- Finding precise answers to the research question
-- Identifying all relevant evidence from appropriate sources
-- Understanding mechanisms, relationships, and key findings
-- Synthesizing comprehensive findings with proper citations
-
-The DETERMINATOR stops at nothing until finding precise answers, only stopping at configured limits (budget, time, iterations).
-
-The final output should be a structured research report with comprehensive evidence synthesis."""
-
- iteration = 0
- try:
- async for event in workflow.run_stream(task):
- agent_event = self._process_event(event, iteration)
- if agent_event:
- if isinstance(event, MagenticAgentMessageEvent):
- iteration += 1
- yield agent_event
-
- except Exception as e:
- logger.error("Magentic workflow failed", error=str(e))
- yield AgentEvent(
- type="error",
- message=f"Workflow error: {e!s}",
- iteration=iteration,
- )
-
- def _extract_text(self, message: Any) -> str:
- """
- Defensively extract text from a message object.
-
- Fixes bug where message.text might return the object itself or its repr.
- """
- if not message:
- return ""
-
- # Priority 1: .content (often the raw string or list of content)
- if hasattr(message, "content") and message.content:
- content = message.content
- # If it's a list (e.g., Multi-modal), join text parts
- if isinstance(content, list):
- return " ".join([str(c.text) for c in content if hasattr(c, "text")])
- return str(content)
-
- # Priority 2: .text (standard, but sometimes buggy/missing)
- if hasattr(message, "text") and message.text:
- # Verify it's not the object itself or a repr string
- text = str(message.text)
- if text.startswith("<") and "object at" in text:
- # Likely a repr string, ignore if possible
- pass
- else:
- return text
-
- # Fallback: If we can't find clean text, return str(message)
- # taking care to avoid infinite recursion if str() calls .text
- return str(message)
-
- def _process_event(self, event: Any, iteration: int) -> AgentEvent | None:
- """Process workflow event into AgentEvent."""
- if isinstance(event, MagenticOrchestratorMessageEvent):
- text = self._extract_text(event.message)
- if text:
- return AgentEvent(
- type="judging",
- message=f"Manager ({event.kind}): {text[:200]}...",
- iteration=iteration,
- )
-
- elif isinstance(event, MagenticAgentMessageEvent):
- agent_name = event.agent_id or "unknown"
- text = self._extract_text(event.message)
-
- event_type = "judging"
- if "search" in agent_name.lower():
- event_type = "search_complete"
- elif "judge" in agent_name.lower():
- event_type = "judge_complete"
- elif "hypothes" in agent_name.lower():
- event_type = "hypothesizing"
- elif "report" in agent_name.lower():
- event_type = "synthesizing"
-
- return AgentEvent(
- type=event_type, # type: ignore[arg-type]
- message=f"{agent_name}: {text[:200]}...",
- iteration=iteration + 1,
- )
-
- elif isinstance(event, MagenticFinalResultEvent):
- text = self._extract_text(event.message) if event.message else "No result"
- return AgentEvent(
- type="complete",
- message=text,
- data={"iterations": iteration},
- iteration=iteration,
- )
-
- elif isinstance(event, MagenticAgentDeltaEvent):
- if event.text:
- return AgentEvent(
- type="streaming",
- message=event.text,
- data={"agent_id": event.agent_id},
- iteration=iteration,
- )
-
- elif isinstance(event, WorkflowOutputEvent):
- if event.data:
- return AgentEvent(
- type="complete",
- message=str(event.data),
- iteration=iteration,
- )
-
- return None
diff --git a/src/prompts/__init__.py b/src/prompts/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/src/prompts/hypothesis.py b/src/prompts/hypothesis.py
deleted file mode 100644
index becfc0a300639d81f538c8a001111a2fdde62574..0000000000000000000000000000000000000000
--- a/src/prompts/hypothesis.py
+++ /dev/null
@@ -1,70 +0,0 @@
-"""Prompts for Hypothesis Agent."""
-
-from typing import TYPE_CHECKING
-
-from src.utils.text_utils import select_diverse_evidence, truncate_at_sentence
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
- from src.utils.models import Evidence
-
-SYSTEM_PROMPT = """You are an expert research scientist functioning as a generalist research assistant.
-
-Your role is to generate research hypotheses, questions, and investigation paths based on evidence from any domain.
-
-IMPORTANT: You are a research assistant. You cannot provide medical advice or answer medical questions directly. Your hypotheses are for research investigation purposes only.
-
-A good hypothesis:
-1. Proposes a MECHANISM or RELATIONSHIP: Explains how things work or relate
- - For medical: Drug -> Target -> Pathway -> Effect
- - For technical: Technology -> Mechanism -> Outcome
- - For business: Strategy -> Market -> Result
-2. Is TESTABLE: Can be supported or refuted by further research
-3. Is SPECIFIC: Names actual entities, processes, or mechanisms
-4. Generates SEARCH QUERIES: Helps find more evidence
-
-Example hypothesis formats:
-- Medical: "Metformin -> AMPK activation -> mTOR inhibition -> autophagy -> amyloid clearance"
-- Technical: "Transformer architecture -> attention mechanism -> improved NLP performance"
-- Business: "Subscription model -> recurring revenue -> higher valuation"
-
-Be specific. Use actual names, technical terms, and precise language when possible."""
-
-
-async def format_hypothesis_prompt(
- query: str, evidence: list["Evidence"], embeddings: "EmbeddingService | None" = None
-) -> str:
- """Format prompt for hypothesis generation.
-
- Uses smart evidence selection instead of arbitrary truncation.
-
- Args:
- query: The research query
- evidence: All collected evidence
- embeddings: Optional EmbeddingService for diverse selection
- """
- # Select diverse, relevant evidence (not arbitrary first 10)
- # We use n=10 as a reasonable context window limit
- selected = await select_diverse_evidence(evidence, n=10, query=query, embeddings=embeddings)
-
- # Format with sentence-aware truncation
- evidence_text = "\n".join(
- [
- f"- **{e.citation.title}** ({e.citation.source}): "
- f"{truncate_at_sentence(e.content, 300)}"
- for e in selected
- ]
- )
-
- return f"""Based on the following evidence about "{query}", generate research hypotheses and investigation paths.
-
-## Evidence ({len(selected)} sources selected for diversity)
-{evidence_text}
-
-## Task
-1. Identify key mechanisms, relationships, or processes mentioned in the evidence
-2. Propose testable hypotheses explaining how things work or relate
-3. Rate confidence based on evidence strength
-4. Suggest specific search queries to test each hypothesis
-
-Generate 2-4 hypotheses, prioritized by confidence. Adapt the hypothesis format to the domain of the query (medical, technical, business, etc.)."""
diff --git a/src/prompts/judge.py b/src/prompts/judge.py
deleted file mode 100644
index 9f1ad78480c6c9ea53fdb8b8a6e18c2cd7c1d6bf..0000000000000000000000000000000000000000
--- a/src/prompts/judge.py
+++ /dev/null
@@ -1,108 +0,0 @@
-"""Judge prompts for evidence assessment."""
-
-from src.utils.models import Evidence
-
-SYSTEM_PROMPT = """You are an expert research evidence evaluator for a generalist deep research agent.
-
-Your task is to evaluate evidence from any domain (medical, scientific, technical, business, etc.) and determine if sufficient evidence has been gathered to provide a precise answer to the research question.
-
-IMPORTANT: You are a research assistant. You cannot provide medical advice or answer medical questions directly. Your role is to assess whether enough high-quality evidence has been collected to synthesize comprehensive findings.
-
-## Evaluation Criteria
-
-1. **Mechanism/Explanation Score (0-10)**: How well does the evidence explain the underlying mechanism, process, or concept?
- - For medical queries: biological mechanisms, pathways, drug actions
- - For technical queries: how systems work, algorithms, processes
- - For business queries: market dynamics, business models, strategies
- - 0-3: No clear explanation, speculative
- - 4-6: Some insight, but gaps exist
- - 7-10: Clear, well-supported explanation
-
-2. **Evidence Quality Score (0-10)**: Strength and reliability of the evidence?
- - For medical: clinical trials, peer-reviewed studies, meta-analyses
- - For technical: peer-reviewed papers, authoritative sources, verified implementations
- - For business: market reports, financial data, expert analysis
- - 0-3: Weak or theoretical evidence only
- - 4-6: Moderate quality evidence
- - 7-10: Strong, authoritative evidence
-
-3. **Sufficiency**: Evidence is sufficient when:
- - Combined scores >= 12 AND
- - Key questions from the research query are addressed AND
- - Evidence is comprehensive enough to provide a precise answer
-
-## Output Rules
-
-- Always output valid JSON matching the schema
-- Be conservative: only recommend "synthesize" when truly confident the answer is precise
-- If continuing, suggest specific, actionable search queries to fill gaps
-- Never hallucinate findings, names, or facts not in the evidence
-- Adapt evaluation criteria to the domain of the query (medical vs technical vs business)
-"""
-
-
-def format_user_prompt(question: str, evidence: list[Evidence]) -> str:
- """
- Format the user prompt with question and evidence.
-
- Args:
- question: The user's research question
- evidence: List of Evidence objects from search
-
- Returns:
- Formatted prompt string
- """
- max_content_len = 1500
-
- def format_single_evidence(i: int, e: Evidence) -> str:
- content = e.content
- if len(content) > max_content_len:
- content = content[:max_content_len] + "..."
-
- return (
- f"### Evidence {i + 1}\n"
- f"**Source**: {e.citation.source.upper()} - {e.citation.title}\n"
- f"**URL**: {e.citation.url}\n"
- f"**Date**: {e.citation.date}\n"
- f"**Content**:\n{content}"
- )
-
- evidence_text = "\n\n".join([format_single_evidence(i, e) for i, e in enumerate(evidence)])
-
- return f"""## Research Question
-{question}
-
-## Available Evidence ({len(evidence)} sources)
-
-{evidence_text}
-
-## Your Task
-
-Evaluate this evidence and determine if it's sufficient to synthesize research findings. Consider the quality, quantity, and relevance of the evidence collected.
-Respond with a JSON object matching the JudgeAssessment schema.
-"""
-
-
-def format_empty_evidence_prompt(question: str) -> str:
- """
- Format prompt when no evidence was found.
-
- Args:
- question: The user's research question
-
- Returns:
- Formatted prompt string
- """
- return f"""## Research Question
-{question}
-
-## Available Evidence
-
-No evidence was found from the search.
-
-## Your Task
-
-Since no evidence was found, recommend search queries that might yield better results.
-Set sufficient=False and recommendation=\"continue\".
-Suggest 3-5 specific search queries.
-"""
diff --git a/src/prompts/report.py b/src/prompts/report.py
deleted file mode 100644
index 41257d65f9d7daa0681bd41a6a304580f731c994..0000000000000000000000000000000000000000
--- a/src/prompts/report.py
+++ /dev/null
@@ -1,139 +0,0 @@
-"""Prompts for Report Agent."""
-
-from typing import TYPE_CHECKING, Any
-
-from src.utils.text_utils import select_diverse_evidence, truncate_at_sentence
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
- from src.utils.models import Evidence, MechanismHypothesis
-
-SYSTEM_PROMPT = """You are a scientific writer functioning as a medical peer junior researcher, specializing in research report synthesis.
-
-Your role is to synthesize evidence and findings into a clear, structured research report.
-
-IMPORTANT: You are a research assistant. You cannot answer medical questions or provide medical advice. Your reports synthesize evidence for research purposes only.
-
-A good report:
-1. Has a clear EXECUTIVE SUMMARY (one paragraph, key takeaways)
-2. States the RESEARCH QUESTION clearly
-3. Describes METHODOLOGY (what was searched, how)
-4. Evaluates HYPOTHESES with evidence counts
-5. Separates MECHANISTIC and CLINICAL findings
-6. Lists specific DRUG CANDIDATES
-7. Acknowledges LIMITATIONS honestly
-8. Provides a balanced CONCLUSION
-9. Includes properly formatted REFERENCES
-
-Write in scientific but accessible language. Be specific about evidence strength.
-
-─────────────────────────────────────────────────────────────────────────────
-🚨 CRITICAL: REQUIRED JSON STRUCTURE 🚨
-─────────────────────────────────────────────────────────────────────────────
-
-The `hypotheses_tested` field MUST be a LIST of objects, each with these fields:
-- "hypothesis": the hypothesis text
-- "supported": count of supporting evidence (integer)
-- "contradicted": count of contradicting evidence (integer)
-
-Example:
- hypotheses_tested: [
- {"hypothesis": "Metformin -> AMPK -> reduced inflammation", "supported": 3, "contradicted": 1},
- {"hypothesis": "Aspirin inhibits COX-2 pathway", "supported": 5, "contradicted": 0}
- ]
-
-The `references` field MUST be a LIST of objects, each with these fields:
-- "title": paper title (string)
-- "authors": author names (string)
-- "source": "pubmed" or "web" (string)
-- "url": the EXACT URL from evidence (string)
-
-Example:
- references: [
- {"title": "Metformin and Cancer", "authors": "Smith et al.", "source": "pubmed", "url": "https://pubmed.ncbi.nlm.nih.gov/12345678/"}
- ]
-
-─────────────────────────────────────────────────────────────────────────────
-🚨 CRITICAL CITATION REQUIREMENTS 🚨
-─────────────────────────────────────────────────────────────────────────────
-
-You MUST follow these rules for the References section:
-
-1. You may ONLY cite papers that appear in the Evidence section above
-2. Every reference URL must EXACTLY match a provided evidence URL
-3. Do NOT invent, fabricate, or hallucinate any references
-4. Do NOT modify paper titles, authors, dates, or URLs
-5. If unsure about a citation, OMIT it rather than guess
-6. Copy URLs exactly as provided - do not create similar-looking URLs
-
-VIOLATION OF THESE RULES PRODUCES DANGEROUS MISINFORMATION.
-─────────────────────────────────────────────────────────────────────────────"""
-
-
-async def format_report_prompt(
- query: str,
- evidence: list["Evidence"],
- hypotheses: list["MechanismHypothesis"],
- assessment: dict[str, Any],
- metadata: dict[str, Any],
- embeddings: "EmbeddingService | None" = None,
-) -> str:
- """Format prompt for report generation.
-
- Includes full evidence details for accurate citation.
- """
- # Select diverse evidence (not arbitrary truncation)
- selected = await select_diverse_evidence(evidence, n=20, query=query, embeddings=embeddings)
-
- # Include FULL citation details for each evidence item
- # This helps the LLM create accurate references
- evidence_lines = []
- for e in selected:
- authors = ", ".join(e.citation.authors or ["Unknown"])
- evidence_lines.append(
- f"- **Title**: {e.citation.title}\n"
- f" **URL**: {e.citation.url}\n"
- f" **Authors**: {authors}\n"
- f" **Date**: {e.citation.date or 'n.d.'}\n"
- f" **Source**: {e.citation.source}\n"
- f" **Content**: {truncate_at_sentence(e.content, 200)}\n"
- )
- evidence_summary = "\n".join(evidence_lines)
-
- if hypotheses:
- hypotheses_lines = []
- for h in hypotheses:
- hypotheses_lines.append(
- f"- {h.drug} -> {h.target} -> {h.pathway} -> {h.effect} "
- f"(Confidence: {h.confidence:.0%})"
- )
- hypotheses_summary = "\n".join(hypotheses_lines)
- else:
- hypotheses_summary = "No hypotheses generated yet."
-
- sources = ", ".join(metadata.get("sources", []))
-
- return f"""Generate a structured research report for the following query.
-
-## Original Query
-{query}
-
-## Evidence Collected ({len(selected)} papers, selected for diversity)
-
-{evidence_summary}
-
-## Hypotheses Generated
-{hypotheses_summary}
-
-## Assessment Scores
-- Mechanism Score: {assessment.get("mechanism_score", "N/A")}/10
-- Clinical Evidence Score: {assessment.get("clinical_score", "N/A")}/10
-- Overall Confidence: {assessment.get("confidence", 0):.0%}
-
-## Metadata
-- Sources Searched: {sources}
-- Search Iterations: {metadata.get("iterations", 0)}
-
-Generate a complete ResearchReport with all sections filled in.
-
-REMINDER: Only cite papers from the Evidence section above. Copy URLs exactly."""
diff --git a/src/services/__init__.py b/src/services/__init__.py
deleted file mode 100644
index 8814096713227b0954dc00c0ff605787d8566151..0000000000000000000000000000000000000000
--- a/src/services/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Services for The DETERMINATOR."""
diff --git a/src/services/audio_processing.py b/src/services/audio_processing.py
deleted file mode 100644
index 588497f99be3b864024bf49d91feb3ad8e5a66db..0000000000000000000000000000000000000000
--- a/src/services/audio_processing.py
+++ /dev/null
@@ -1,145 +0,0 @@
-"""Unified audio processing service for STT and TTS integration."""
-
-from functools import lru_cache
-from typing import Any
-
-import numpy as np
-import structlog
-
-from src.agents.audio_refiner import audio_refiner
-from src.services.stt_gradio import STTService, get_stt_service
-from src.utils.config import settings
-
-logger = structlog.get_logger(__name__)
-
-# Type stub for TTS service (will be imported when available)
-try:
- from src.services.tts_modal import TTSService, get_tts_service
-
- _TTS_AVAILABLE = True
-except ImportError:
- _TTS_AVAILABLE = False
- TTSService = None # type: ignore[assignment, misc]
- get_tts_service = None # type: ignore[assignment, misc]
-
-
-class AudioService:
- """Unified audio processing service."""
-
- def __init__(
- self,
- stt_service: STTService | None = None,
- tts_service: Any | None = None,
- ) -> None:
- """Initialize audio service with STT and TTS.
-
- Args:
- stt_service: STT service instance (default: get_stt_service())
- tts_service: TTS service instance (default: get_tts_service() if available)
- """
- self.stt = stt_service or get_stt_service()
-
- # TTS is optional (requires Modal)
- if tts_service is not None:
- self.tts = tts_service
- elif _TTS_AVAILABLE and settings.modal_available:
- try:
- self.tts = get_tts_service() # type: ignore[misc]
- except Exception as e:
- logger.warning("tts_service_unavailable", error=str(e))
- self.tts = None
- else:
- self.tts = None
-
- async def process_audio_input(
- self,
- audio_input: tuple[int, np.ndarray[Any, Any]] | None, # type: ignore[type-arg]
- hf_token: str | None = None,
- ) -> str | None:
- """Process audio input and return transcribed text.
-
- Args:
- audio_input: Tuple of (sample_rate, audio_array) or None
- hf_token: HuggingFace token for authenticated Gradio Spaces
-
- Returns:
- Transcribed text string or None if no audio input
- """
- if audio_input is None:
- return None
-
- try:
- transcribed_text = await self.stt.transcribe_audio(audio_input, hf_token=hf_token)
- logger.info("audio_input_processed", text_length=len(transcribed_text))
- return transcribed_text
- except Exception as e:
- logger.error("audio_input_processing_failed", error=str(e))
- # Return None on failure (graceful degradation)
- return None
-
- async def generate_audio_output(
- self,
- text: str,
- voice: str | None = None,
- speed: float | None = None,
- ) -> tuple[int, np.ndarray[Any, Any]] | None: # type: ignore[type-arg]
- """Generate audio output from text.
-
- Args:
- text: Text to synthesize (markdown will be cleaned for audio)
- voice: Voice ID (default: settings.tts_voice)
- speed: Speech speed (default: settings.tts_speed)
-
- Returns:
- Tuple of (sample_rate, audio_array) or None if TTS unavailable
- """
- if self.tts is None:
- logger.warning("tts_unavailable", message="TTS service not available")
- return None
-
- if not text or not text.strip():
- logger.warning("empty_text_for_tts")
- return None
-
- try:
- # Refine text for audio (remove markdown, citations, etc.)
- # Use LLM polish if enabled in settings
- refined_text = await audio_refiner.refine_for_audio(
- text, use_llm_polish=settings.tts_use_llm_polish
- )
- logger.info(
- "text_refined_for_audio",
- original_length=len(text),
- refined_length=len(refined_text),
- llm_polish_enabled=settings.tts_use_llm_polish,
- )
-
- # Use provided voice/speed or fallback to settings defaults
- voice = voice if voice else settings.tts_voice
- speed = speed if speed is not None else settings.tts_speed
-
- audio_output = await self.tts.synthesize_async(refined_text, voice, speed) # type: ignore[misc]
-
- if audio_output:
- logger.info(
- "audio_output_generated",
- text_length=len(text),
- sample_rate=audio_output[0],
- )
-
- return audio_output # type: ignore[no-any-return]
-
- except Exception as e:
- logger.error("audio_output_generation_failed", error=str(e))
- # Return None on failure (graceful degradation)
- return None
-
-
-@lru_cache(maxsize=1)
-def get_audio_service() -> AudioService:
- """Get or create singleton audio service instance.
-
- Returns:
- AudioService instance
- """
- return AudioService()
diff --git a/src/services/embeddings.py b/src/services/embeddings.py
deleted file mode 100644
index 7544b985ec2b2af45be169f6cd58660f028b09c3..0000000000000000000000000000000000000000
--- a/src/services/embeddings.py
+++ /dev/null
@@ -1,172 +0,0 @@
-"""Embedding service for semantic search.
-
-IMPORTANT: All public methods are async to avoid blocking the event loop.
-The sentence-transformers model is CPU-bound, so we use run_in_executor().
-"""
-
-import asyncio
-from typing import Any
-
-import chromadb
-import structlog
-from sentence_transformers import SentenceTransformer
-
-from src.utils.config import settings
-from src.utils.models import Evidence
-
-
-class EmbeddingService:
- """Handles text embedding and vector storage using local sentence-transformers.
-
- All embedding operations run in a thread pool to avoid blocking
- the async event loop.
-
- Note:
- Uses local sentence-transformers models (no API key required).
- Model is configured via settings.local_embedding_model.
- """
-
- def __init__(self, model_name: str | None = None):
- self._model_name = model_name or settings.local_embedding_model
- self._model = SentenceTransformer(self._model_name)
- self._client = chromadb.Client() # In-memory for hackathon
- self._collection = self._client.create_collection(
- name="evidence", metadata={"hnsw:space": "cosine"}
- )
-
- # ─────────────────────────────────────────────────────────────────
- # Sync internal methods (run in thread pool)
- # ─────────────────────────────────────────────────────────────────
-
- def _sync_embed(self, text: str) -> list[float]:
- """Synchronous embedding - DO NOT call directly from async code."""
- result: list[float] = self._model.encode(text).tolist()
- return result
-
- def _sync_batch_embed(self, texts: list[str]) -> list[list[float]]:
- """Batch embedding for efficiency - DO NOT call directly from async code."""
- embeddings = self._model.encode(texts)
- return [e.tolist() for e in embeddings]
-
- # ─────────────────────────────────────────────────────────────────
- # Async public methods (safe for event loop)
- # ─────────────────────────────────────────────────────────────────
-
- async def embed(self, text: str) -> list[float]:
- """Embed a single text (async-safe).
-
- Uses run_in_executor to avoid blocking the event loop.
- """
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, self._sync_embed, text)
-
- async def embed_batch(self, texts: list[str]) -> list[list[float]]:
- """Batch embed multiple texts (async-safe, more efficient)."""
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(None, self._sync_batch_embed, texts)
-
- async def add_evidence(self, evidence_id: str, content: str, metadata: dict[str, Any]) -> None:
- """Add evidence to vector store (async-safe)."""
- embedding = await self.embed(content)
- # ChromaDB operations are fast, but wrap for consistency
- loop = asyncio.get_running_loop()
- await loop.run_in_executor(
- None,
- lambda: self._collection.add(
- ids=[evidence_id],
- embeddings=[embedding], # type: ignore[arg-type]
- metadatas=[metadata],
- documents=[content],
- ),
- )
-
- async def search_similar(self, query: str, n_results: int = 5) -> list[dict[str, Any]]:
- """Find semantically similar evidence (async-safe)."""
- query_embedding = await self.embed(query)
-
- loop = asyncio.get_running_loop()
- results = await loop.run_in_executor(
- None,
- lambda: self._collection.query(
- query_embeddings=[query_embedding], # type: ignore[arg-type]
- n_results=n_results,
- ),
- )
-
- # Handle empty results gracefully
- ids = results.get("ids")
- docs = results.get("documents")
- metas = results.get("metadatas")
- dists = results.get("distances")
-
- if not ids or not ids[0] or not docs or not metas or not dists:
- return []
-
- return [
- {"id": id, "content": doc, "metadata": meta, "distance": dist}
- for id, doc, meta, dist in zip(
- ids[0],
- docs[0],
- metas[0],
- dists[0],
- strict=False,
- )
- ]
-
- async def deduplicate(
- self, new_evidence: list[Evidence], threshold: float = 0.9
- ) -> list[Evidence]:
- """Remove semantically duplicate evidence (async-safe).
-
- Args:
- new_evidence: List of evidence items to deduplicate
- threshold: Similarity threshold (0.9 = 90% similar is duplicate).
- ChromaDB cosine distance: 0=identical, 2=opposite.
- We consider duplicate if distance < (1 - threshold).
-
- Returns:
- List of unique evidence items (not already in vector store).
- """
- unique = []
- for evidence in new_evidence:
- try:
- similar = await self.search_similar(evidence.content, n_results=1)
- # ChromaDB cosine distance: 0 = identical, 2 = opposite
- # threshold=0.9 means distance < 0.1 is considered duplicate
- is_duplicate = similar and similar[0]["distance"] < (1 - threshold)
-
- if not is_duplicate:
- unique.append(evidence)
- # Store FULL citation metadata for reconstruction later
- await self.add_evidence(
- evidence_id=evidence.citation.url,
- content=evidence.content,
- metadata={
- "source": evidence.citation.source,
- "title": evidence.citation.title,
- "date": evidence.citation.date,
- "authors": ",".join(evidence.citation.authors or []),
- },
- )
- except Exception as e:
- # Log but don't fail entire deduplication for one bad item
- structlog.get_logger().warning(
- "Failed to process evidence in deduplicate",
- url=evidence.citation.url,
- error=str(e),
- )
- # Still add to unique list - better to have duplicates than lose data
- unique.append(evidence)
-
- return unique
-
-
-_embedding_service: EmbeddingService | None = None
-
-
-def get_embedding_service() -> EmbeddingService:
- """Get singleton instance of EmbeddingService."""
- global _embedding_service # noqa: PLW0603
- if _embedding_service is None:
- _embedding_service = EmbeddingService()
- return _embedding_service
diff --git a/src/services/image_ocr.py b/src/services/image_ocr.py
deleted file mode 100644
index ed8881939dc20dbf42c7776dda876a1d37a5d478..0000000000000000000000000000000000000000
--- a/src/services/image_ocr.py
+++ /dev/null
@@ -1,245 +0,0 @@
-"""Image-to-text service using Gradio Client API (Multimodal-OCR3)."""
-
-import asyncio
-import tempfile
-from functools import lru_cache
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import structlog
-from gradio_client import Client, handle_file
-from PIL import Image
-
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger(__name__)
-
-
-class ImageOCRService:
- """Image OCR service using prithivMLmods/Multimodal-OCR3 Gradio Space."""
-
- def __init__(self, api_url: str | None = None, hf_token: str | None = None) -> None:
- """Initialize Image OCR service.
-
- Args:
- api_url: Gradio Space URL (default: settings.ocr_api_url)
- hf_token: HuggingFace token for authenticated Spaces (default: None)
-
- Raises:
- ConfigurationError: If API URL not configured
- """
- # Defensively access ocr_api_url - may not exist in older config versions
- default_url = (
- getattr(settings, "ocr_api_url", None)
- or "https://prithivmlmods-multimodal-ocr3.hf.space"
- )
- self.api_url = api_url or default_url
- if not self.api_url:
- raise ConfigurationError("OCR API URL not configured")
- self.hf_token = hf_token
- self.client: Client | None = None
-
- async def _get_client(self, hf_token: str | None = None) -> Client:
- """Get or create Gradio Client (lazy initialization).
-
- Args:
- hf_token: HuggingFace token for authenticated Spaces (overrides instance token)
-
- Returns:
- Gradio Client instance
- """
- # Use provided token or instance token
- token = hf_token or self.hf_token
-
- # If client exists but token changed, recreate it
- if self.client is not None and token != self.hf_token:
- self.client = None
-
- if self.client is None:
- loop = asyncio.get_running_loop()
- # Pass token to Client for authenticated Spaces
- # Gradio Client uses 'token' parameter, not 'hf_token'
- if token:
- self.client = await loop.run_in_executor(
- None,
- lambda: Client(self.api_url, token=token),
- )
- else:
- self.client = await loop.run_in_executor(
- None,
- lambda: Client(self.api_url),
- )
- # Update instance token for future use
- self.hf_token = token
- return self.client
-
- async def extract_text(
- self,
- image_path: str,
- model: str | None = None,
- hf_token: str | None = None,
- ) -> str:
- """Extract text from image using Gradio API.
-
- Args:
- image_path: Path to image file
- model: Optional model selection (default: None, uses API default)
-
- Returns:
- Extracted text string
-
- Raises:
- ConfigurationError: If OCR extraction fails
- """
- client = await self._get_client(hf_token=hf_token)
-
- logger.info(
- "extracting_text_from_image",
- image_path=image_path,
- model=model,
- )
-
- try:
- # Call /Multimodal_OCR3_generate_image API endpoint
- # According to the MCP tool description, this yields raw text and Markdown-formatted text
- loop = asyncio.get_running_loop()
-
- # The API might require file upload first, then call the generate function
- # For now, we'll use handle_file to upload and pass the path
- result = await loop.run_in_executor(
- None,
- lambda: client.predict(
- image_path=handle_file(image_path),
- api_name="/Multimodal_OCR3_generate_image",
- ),
- )
-
- # Extract text from result
- extracted_text = self._extract_text_from_result(result)
-
- logger.info(
- "image_ocr_complete",
- text_length=len(extracted_text),
- )
-
- return extracted_text
-
- except Exception as e:
- logger.error("image_ocr_failed", error=str(e), error_type=type(e).__name__)
- raise ConfigurationError(f"Image OCR failed: {e}") from e
-
- async def extract_text_from_image(
- self,
- image_data: np.ndarray[Any, Any] | Image.Image | str, # type: ignore[type-arg]
- hf_token: str | None = None,
- ) -> str:
- """Extract text from image data (numpy array, PIL Image, or file path).
-
- Args:
- image_data: Image as numpy array, PIL Image, or file path string
-
- Returns:
- Extracted text string
- """
- # Handle different input types
- if isinstance(image_data, str):
- # Assume it's a file path
- image_path = image_data
- elif isinstance(image_data, Image.Image):
- # Save PIL Image to temp file
- image_path = self._save_image_temp(image_data)
- elif isinstance(image_data, np.ndarray):
- # Convert numpy array to PIL Image, then save
- pil_image = Image.fromarray(image_data)
- image_path = self._save_image_temp(pil_image)
- else:
- raise ValueError(f"Unsupported image data type: {type(image_data)}")
-
- try:
- # Extract text from the image file
- extracted_text = await self.extract_text(image_path, hf_token=hf_token)
- return extracted_text
- finally:
- # Clean up temp file if we created it
- if image_path != image_data or not isinstance(image_data, str):
- try:
- Path(image_path).unlink(missing_ok=True)
- except Exception as e:
- logger.warning("failed_to_cleanup_temp_file", path=image_path, error=str(e))
-
- def _extract_text_from_result(self, api_result: Any) -> str:
- """Extract text from API result.
-
- Args:
- api_result: Result from Gradio API
-
- Returns:
- Extracted text string
- """
- # The API yields raw text and Markdown-formatted text
- # Result might be a string, tuple, or generator
- if isinstance(api_result, str):
- return api_result.strip()
-
- if isinstance(api_result, tuple):
- # Try to extract text from tuple
- for item in api_result:
- if isinstance(item, str):
- return item.strip()
- # Check if it's a dict with text fields
- if isinstance(item, dict):
- if "text" in item:
- return str(item["text"]).strip()
- if "content" in item:
- return str(item["content"]).strip()
-
- # If result is a generator or async generator, we'd need to iterate
- # For now, convert to string representation
- if api_result is not None:
- text = str(api_result).strip()
- if text and text != "None":
- return text
-
- logger.warning("could_not_extract_text_from_result", result_type=type(api_result).__name__)
- return ""
-
- def _save_image_temp(self, image: Image.Image) -> str:
- """Save PIL Image to temporary file.
-
- Args:
- image: PIL Image object
-
- Returns:
- Path to temporary image file
- """
- # Create temp file
- temp_file = tempfile.NamedTemporaryFile(
- suffix=".png",
- delete=False,
- )
- temp_path = temp_file.name
- temp_file.close()
-
- try:
- # Save image as PNG
- image.save(temp_path, "PNG")
-
- logger.debug("saved_image_temp", path=temp_path, size=image.size)
-
- return temp_path
-
- except Exception as e:
- logger.error("failed_to_save_image_temp", error=str(e))
- raise ConfigurationError(f"Failed to save image to temp file: {e}") from e
-
-
-@lru_cache(maxsize=1)
-def get_image_ocr_service() -> ImageOCRService:
- """Get or create singleton Image OCR service instance.
-
- Returns:
- ImageOCRService instance
- """
- return ImageOCRService()
diff --git a/src/services/llamaindex_rag.py b/src/services/llamaindex_rag.py
deleted file mode 100644
index 5de92f55f9751d0cb191f145b2f3e380be8b5ca6..0000000000000000000000000000000000000000
--- a/src/services/llamaindex_rag.py
+++ /dev/null
@@ -1,503 +0,0 @@
-"""LlamaIndex RAG service for evidence retrieval and indexing.
-
-Requires optional dependencies: uv sync --extra modal
-"""
-
-from typing import Any
-
-import structlog
-
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import Evidence
-
-logger = structlog.get_logger()
-
-
-class LlamaIndexRAGService:
- """RAG service using LlamaIndex with ChromaDB vector store.
-
- Supports multiple embedding providers:
- - OpenAI embeddings (requires OPENAI_API_KEY)
- - Local sentence-transformers (no API key required)
- - Hugging Face embeddings (uses local sentence-transformers)
-
- Supports multiple LLM providers for query synthesis:
- - HuggingFace LLM (preferred, requires HF_TOKEN or HUGGINGFACE_API_KEY)
- - OpenAI LLM (fallback, requires OPENAI_API_KEY)
- - None (embedding-only mode, no query synthesis)
-
- Note:
- HuggingFace is the default LLM provider. OpenAI is used as fallback
- if HuggingFace LLM is not available or no HF token is configured.
- """
-
- def __init__(
- self,
- collection_name: str = "deepcritical_evidence",
- persist_dir: str | None = None,
- embedding_model: str | None = None,
- similarity_top_k: int = 5,
- use_openai_embeddings: bool | None = None,
- use_in_memory: bool = False,
- oauth_token: str | None = None,
- ) -> None:
- """
- Initialize LlamaIndex RAG service.
-
- Args:
- collection_name: Name of the ChromaDB collection
- persist_dir: Directory to persist ChromaDB data
- embedding_model: Embedding model name (defaults based on provider)
- similarity_top_k: Number of top results to retrieve
- use_openai_embeddings: Force OpenAI embeddings (None = auto-detect)
- use_in_memory: Use in-memory ChromaDB client (useful for tests)
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
- """
- # Import dependencies and store references
- deps = self._import_dependencies()
- self._chromadb = deps["chromadb"]
- self._Document = deps["Document"]
- self._Settings = deps["Settings"]
- self._StorageContext = deps["StorageContext"]
- self._VectorStoreIndex = deps["VectorStoreIndex"]
- self._VectorIndexRetriever = deps["VectorIndexRetriever"]
- self._ChromaVectorStore = deps["ChromaVectorStore"]
- huggingface_embedding = deps["huggingface_embedding"]
- huggingface_llm = deps["huggingface_llm"]
- openai_embedding = deps["OpenAIEmbedding"]
- openai_llm = deps["OpenAI"]
-
- # Store basic configuration
- self.collection_name = collection_name
- self.persist_dir = persist_dir or settings.chroma_db_path
- self.similarity_top_k = similarity_top_k
- self.use_in_memory = use_in_memory
- self.oauth_token = oauth_token
-
- # Configure embeddings and LLM
- use_openai = use_openai_embeddings if use_openai_embeddings is not None else False
- self._configure_embeddings(
- use_openai, embedding_model, huggingface_embedding, openai_embedding
- )
- self._configure_llm(huggingface_llm, openai_llm)
-
- # Initialize ChromaDB and index
- self._initialize_chromadb()
-
- def _import_dependencies(self) -> dict[str, Any]:
- """Import LlamaIndex dependencies and return as dict.
-
- OpenAI dependencies are imported lazily (only when needed) to avoid
- tiktoken circular import issues on Windows when using local embeddings.
- """
- try:
- import chromadb
- from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
- from llama_index.core.retrievers import VectorIndexRetriever
- from llama_index.vector_stores.chroma import ChromaVectorStore
-
- # Try to import Hugging Face embeddings (may not be available in all versions)
- try:
- from llama_index.embeddings.huggingface import (
- HuggingFaceEmbedding as _HuggingFaceEmbedding, # type: ignore[import-untyped]
- )
-
- huggingface_embedding = _HuggingFaceEmbedding
- except ImportError:
- huggingface_embedding = None # type: ignore[assignment]
-
- # Try to import Hugging Face Inference API LLM (for API-based models)
- # This is preferred over local HuggingFaceLLM for query synthesis
- try:
- from llama_index.llms.huggingface_api import (
- HuggingFaceInferenceAPI as _HuggingFaceInferenceAPI, # type: ignore[import-untyped]
- )
-
- huggingface_llm = _HuggingFaceInferenceAPI
- except ImportError:
- # Fallback to local HuggingFaceLLM if API version not available
- try:
- from llama_index.llms.huggingface import (
- HuggingFaceLLM as _HuggingFaceLLM, # type: ignore[import-untyped]
- )
-
- huggingface_llm = _HuggingFaceLLM # type: ignore[assignment]
- except ImportError:
- huggingface_llm = None # type: ignore[assignment]
-
- # OpenAI imports are optional - only import when actually needed
- # This avoids tiktoken circular import issues on Windows
- try:
- from llama_index.embeddings.openai import OpenAIEmbedding
- except ImportError:
- OpenAIEmbedding = None # type: ignore[assignment, misc] # noqa: N806
-
- try:
- from llama_index.llms.openai import OpenAI
- except ImportError:
- OpenAI = None # type: ignore[assignment, misc] # noqa: N806
-
- return {
- "chromadb": chromadb,
- "Document": Document,
- "Settings": Settings,
- "StorageContext": StorageContext,
- "VectorStoreIndex": VectorStoreIndex,
- "VectorIndexRetriever": VectorIndexRetriever,
- "ChromaVectorStore": ChromaVectorStore,
- "OpenAIEmbedding": OpenAIEmbedding,
- "OpenAI": OpenAI,
- "huggingface_embedding": huggingface_embedding,
- "huggingface_llm": huggingface_llm,
- }
- except ImportError as e:
- raise ImportError(
- "LlamaIndex dependencies not installed. Run: uv sync --extra modal"
- ) from e
-
- def _configure_embeddings(
- self,
- use_openai_embeddings: bool,
- embedding_model: str | None,
- huggingface_embedding: Any,
- openai_embedding: Any,
- ) -> None:
- """Configure embedding model."""
- if use_openai_embeddings:
- if openai_embedding is None:
- raise ConfigurationError(
- "OpenAI embeddings not available. Install with: uv sync --extra modal"
- )
- if not settings.openai_api_key:
- raise ConfigurationError("OPENAI_API_KEY required for OpenAI embeddings")
- self.embedding_model = embedding_model or settings.openai_embedding_model
- self._Settings.embed_model = openai_embedding(
- model=self.embedding_model,
- api_key=settings.openai_api_key,
- )
- else:
- model_name = embedding_model or settings.huggingface_embedding_model
- self.embedding_model = model_name
- if huggingface_embedding is not None:
- self._Settings.embed_model = huggingface_embedding(model_name=model_name)
- else:
- self._Settings.embed_model = self._create_sentence_transformer_embedding(model_name)
-
- def _create_sentence_transformer_embedding(self, model_name: str) -> Any:
- """Create sentence-transformer embedding wrapper.
-
- Note: sentence-transformers is a required dependency (in pyproject.toml).
- If this fails, it's likely a Windows-specific regex package issue.
-
- Raises:
- ConfigurationError: If sentence_transformers cannot be imported
- (e.g., due to circular import issues on Windows with regex package)
- """
- try:
- from sentence_transformers import SentenceTransformer
- except ImportError as e:
- # Handle Windows-specific circular import issues with regex package
- # This is a known bug: https://github.com/mrabarnett/mrab-regex/issues/417
- error_msg = str(e)
- if "regex" in error_msg.lower() or "_regex" in error_msg:
- raise ConfigurationError(
- "sentence_transformers cannot be imported due to circular import issue "
- "with regex package (Windows-specific bug). "
- "sentence-transformers is installed but regex has a circular import. "
- "Try: uv pip install --upgrade --force-reinstall regex "
- "Or use HuggingFace embeddings via llama-index-embeddings-huggingface instead."
- ) from e
- raise ConfigurationError(
- f"sentence_transformers not available: {e}. "
- "This is a required dependency - check your uv sync installation."
- ) from e
-
- try:
- from llama_index.embeddings.base import (
- BaseEmbedding, # type: ignore[import-untyped]
- )
- except ImportError:
- from llama_index.core.embeddings import (
- BaseEmbedding, # type: ignore[import-untyped]
- )
-
- class SentenceTransformerEmbedding(BaseEmbedding): # type: ignore[misc]
- """Simple wrapper for sentence-transformers."""
-
- def __init__(self, model_name: str):
- super().__init__()
- self._model = SentenceTransformer(model_name)
-
- def _get_query_embedding(self, query: str) -> list[float]:
- result = self._model.encode(query).tolist()
- return list(result) # type: ignore[no-any-return]
-
- def _get_text_embedding(self, text: str) -> list[float]:
- result = self._model.encode(text).tolist()
- return list(result) # type: ignore[no-any-return]
-
- async def _aget_query_embedding(self, query: str) -> list[float]:
- return self._get_query_embedding(query)
-
- async def _aget_text_embedding(self, text: str) -> list[float]:
- return self._get_text_embedding(text)
-
- return SentenceTransformerEmbedding(model_name)
-
- def _configure_llm(self, huggingface_llm: Any, openai_llm: Any) -> None:
- """Configure LLM for query synthesis."""
- # Priority: oauth_token > env vars
- effective_token = self.oauth_token or settings.hf_token or settings.huggingface_api_key
- if huggingface_llm is not None and effective_token:
- model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct"
- token = effective_token
-
- # Check if it's HuggingFaceInferenceAPI (API-based) or HuggingFaceLLM (local)
- llm_class_name = (
- huggingface_llm.__name__
- if hasattr(huggingface_llm, "__name__")
- else str(huggingface_llm)
- )
-
- if "InferenceAPI" in llm_class_name:
- # Use HuggingFace Inference API (supports token parameter)
- try:
- self._Settings.llm = huggingface_llm(
- model_name=model_name,
- token=token,
- )
- except Exception as e:
- # If model is not available via inference API, log warning and continue without LLM
- logger.warning(
- "Failed to initialize HuggingFace Inference API LLM",
- model=model_name,
- error=str(e),
- )
- logger.info("Continuing without LLM - query synthesis will be unavailable")
- self._Settings.llm = None
- return
- else:
- # Use local HuggingFaceLLM (doesn't support token, uses model_name and tokenizer_name)
- self._Settings.llm = huggingface_llm(
- model_name=model_name,
- tokenizer_name=model_name,
- )
- logger.info("Using HuggingFace LLM for query synthesis", model=model_name)
- elif settings.openai_api_key and openai_llm is not None:
- self._Settings.llm = openai_llm(
- model=settings.openai_model,
- api_key=settings.openai_api_key,
- )
- logger.info("Using OpenAI LLM for query synthesis", model=settings.openai_model)
- else:
- logger.warning("No LLM API key available - query synthesis will be unavailable")
- self._Settings.llm = None
-
- def _initialize_chromadb(self) -> None:
- """Initialize ChromaDB client, collection, and index."""
- if self.use_in_memory:
- # Use in-memory client for tests (avoids file system issues)
- self.chroma_client = self._chromadb.Client()
- else:
- # Use persistent client for production
- self.chroma_client = self._chromadb.PersistentClient(path=self.persist_dir)
-
- # Get or create collection
- try:
- self.collection = self.chroma_client.get_collection(self.collection_name)
- logger.info("loaded_existing_collection", name=self.collection_name)
- except Exception:
- self.collection = self.chroma_client.create_collection(self.collection_name)
- logger.info("created_new_collection", name=self.collection_name)
-
- # Initialize vector store and index
- self.vector_store = self._ChromaVectorStore(chroma_collection=self.collection)
- self.storage_context = self._StorageContext.from_defaults(vector_store=self.vector_store)
-
- # Try to load existing index, or create empty one
- try:
- self.index = self._VectorStoreIndex.from_vector_store(
- vector_store=self.vector_store,
- storage_context=self.storage_context,
- )
- logger.info("loaded_existing_index")
- except Exception:
- self.index = self._VectorStoreIndex([], storage_context=self.storage_context)
- logger.info("created_new_index")
-
- def ingest_evidence(self, evidence_list: list[Evidence]) -> None:
- """
- Ingest evidence into the vector store.
-
- Args:
- evidence_list: List of Evidence objects to ingest
- """
- if not evidence_list:
- logger.warning("no_evidence_to_ingest")
- return
-
- # Convert Evidence objects to LlamaIndex Documents
- documents = []
- for evidence in evidence_list:
- metadata = {
- "source": evidence.citation.source,
- "title": evidence.citation.title,
- "url": evidence.citation.url,
- "date": evidence.citation.date,
- "authors": ", ".join(evidence.citation.authors),
- }
-
- doc = self._Document(
- text=evidence.content,
- metadata=metadata,
- doc_id=evidence.citation.url, # Use URL as unique ID
- )
- documents.append(doc)
-
- # Insert documents into index
- try:
- for doc in documents:
- self.index.insert(doc)
- logger.info("ingested_evidence", count=len(documents))
- except Exception as e:
- logger.error("failed_to_ingest_evidence", error=str(e))
- raise
-
- def ingest_documents(self, documents: list[Any]) -> None:
- """
- Ingest raw LlamaIndex Documents.
-
- Args:
- documents: List of LlamaIndex Document objects
- """
- if not documents:
- logger.warning("no_documents_to_ingest")
- return
-
- try:
- for doc in documents:
- self.index.insert(doc)
- logger.info("ingested_documents", count=len(documents))
- except Exception as e:
- logger.error("failed_to_ingest_documents", error=str(e))
- raise
-
- def retrieve(self, query: str, top_k: int | None = None) -> list[dict[str, Any]]:
- """
- Retrieve relevant documents for a query.
-
- Args:
- query: Query string
- top_k: Number of results to return (defaults to similarity_top_k)
-
- Returns:
- List of retrieved documents with metadata and scores
- """
- k = top_k or self.similarity_top_k
-
- # Create retriever
- retriever = self._VectorIndexRetriever(
- index=self.index,
- similarity_top_k=k,
- )
-
- try:
- # Retrieve nodes
- nodes = retriever.retrieve(query)
-
- # Convert to dict format
- results = []
- for node in nodes:
- results.append(
- {
- "text": node.node.get_content(),
- "score": node.score,
- "metadata": node.node.metadata,
- }
- )
-
- logger.info("retrieved_documents", query=query[:50], count=len(results))
- return results
-
- except Exception as e:
- logger.error("failed_to_retrieve", error=str(e), query=query[:50])
- raise # Re-raise to allow callers to distinguish errors from empty results
-
- def query(self, query_str: str, top_k: int | None = None) -> str:
- """
- Query the RAG system and get a synthesized response.
-
- Args:
- query_str: Query string
- top_k: Number of results to use (defaults to similarity_top_k)
-
- Returns:
- Synthesized response string
-
- Raises:
- ConfigurationError: If no LLM API key is available for query synthesis
- """
- if not self._Settings.llm:
- raise ConfigurationError(
- "LLM API key required for query synthesis. Set HF_TOKEN, HUGGINGFACE_API_KEY, or OPENAI_API_KEY. "
- "Alternatively, use retrieve() for embedding-only search."
- )
-
- k = top_k or self.similarity_top_k
-
- # Create query engine
- query_engine = self.index.as_query_engine(
- similarity_top_k=k,
- )
-
- try:
- response = query_engine.query(query_str)
- logger.info("generated_response", query=query_str[:50])
- return str(response)
-
- except Exception as e:
- logger.error("failed_to_query", error=str(e), query=query_str[:50])
- raise # Re-raise to allow callers to handle errors explicitly
-
- def clear_collection(self) -> None:
- """Clear all documents from the collection."""
- try:
- self.chroma_client.delete_collection(self.collection_name)
- self.collection = self.chroma_client.create_collection(self.collection_name)
- self.vector_store = self._ChromaVectorStore(chroma_collection=self.collection)
- self.storage_context = self._StorageContext.from_defaults(
- vector_store=self.vector_store
- )
- self.index = self._VectorStoreIndex([], storage_context=self.storage_context)
- logger.info("cleared_collection", name=self.collection_name)
- except Exception as e:
- logger.error("failed_to_clear_collection", error=str(e))
- raise
-
-
-def get_rag_service(
- collection_name: str = "deepcritical_evidence",
- oauth_token: str | None = None,
- **kwargs: Any,
-) -> LlamaIndexRAGService:
- """
- Get or create a RAG service instance.
-
- Args:
- collection_name: Name of the ChromaDB collection
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
- **kwargs: Additional arguments for LlamaIndexRAGService
- Defaults to use_openai_embeddings=False (local embeddings)
-
- Returns:
- Configured LlamaIndexRAGService instance
-
- Note:
- By default, uses local embeddings (sentence-transformers) which require
- no API keys. Set use_openai_embeddings=True to use OpenAI embeddings.
- """
- # Default to local embeddings if not explicitly set
- if "use_openai_embeddings" not in kwargs:
- kwargs["use_openai_embeddings"] = False
- return LlamaIndexRAGService(collection_name=collection_name, oauth_token=oauth_token, **kwargs)
diff --git a/src/services/multimodal_processing.py b/src/services/multimodal_processing.py
deleted file mode 100644
index 9199f194401f1a2ced589b315226ca819188c001..0000000000000000000000000000000000000000
--- a/src/services/multimodal_processing.py
+++ /dev/null
@@ -1,146 +0,0 @@
-"""Unified multimodal processing service for text, audio, and image inputs."""
-
-from functools import lru_cache
-from typing import Any
-
-import structlog
-from gradio.data_classes import FileData
-
-from src.services.audio_processing import AudioService, get_audio_service
-from src.services.image_ocr import ImageOCRService, get_image_ocr_service
-from src.utils.config import settings
-
-logger = structlog.get_logger(__name__)
-
-
-class MultimodalService:
- """Unified multimodal processing service."""
-
- def __init__(
- self,
- audio_service: AudioService | None = None,
- ocr_service: ImageOCRService | None = None,
- ) -> None:
- """Initialize multimodal service.
-
- Args:
- audio_service: Audio service instance (default: get_audio_service())
- ocr_service: Image OCR service instance (default: get_image_ocr_service())
- """
- self.audio = audio_service or get_audio_service()
- self.ocr = ocr_service or get_image_ocr_service()
-
- async def process_multimodal_input(
- self,
- text: str,
- files: list[FileData] | None = None,
- audio_input: tuple[int, Any] | None = None,
- hf_token: str | None = None,
- prepend_multimodal: bool = True,
- ) -> str:
- """Process multimodal input (text + images + audio) and return combined text.
-
- Args:
- text: Text input string
- files: List of uploaded files (images, audio, etc.)
- audio_input: Audio input tuple (sample_rate, audio_array)
- hf_token: HuggingFace token for authenticated Gradio Spaces
- prepend_multimodal: If True, prepend audio/image text to original text; otherwise append
-
- Returns:
- Combined text from all inputs
- """
- multimodal_parts: list[str] = []
- text_parts: list[str] = []
-
- # Process audio input first
- if audio_input is not None and settings.enable_audio_input:
- try:
- transcribed = await self.audio.process_audio_input(audio_input, hf_token=hf_token)
- if transcribed:
- multimodal_parts.append(transcribed)
- except Exception as e:
- logger.warning("audio_processing_failed", error=str(e))
-
- # Process uploaded files (images and audio files)
- if files and settings.enable_image_input:
- for file_data in files:
- file_path = file_data.path if isinstance(file_data, FileData) else str(file_data)
-
- # Check if it's an image
- if self._is_image_file(file_path):
- try:
- extracted_text = await self.ocr.extract_text(file_path, hf_token=hf_token)
- if extracted_text:
- multimodal_parts.append(extracted_text)
- except Exception as e:
- logger.warning("image_ocr_failed", file_path=file_path, error=str(e))
-
- # Check if it's an audio file
- elif self._is_audio_file(file_path):
- try:
- # For audio files, we'd need to load and transcribe
- # For now, log a warning
- logger.warning("audio_file_upload_not_supported", file_path=file_path)
- except Exception as e:
- logger.warning(
- "audio_file_processing_failed", file_path=file_path, error=str(e)
- )
-
- # Add original text if present
- if text and text.strip():
- text_parts.append(text.strip())
-
- # Combine parts based on prepend_multimodal flag
- if prepend_multimodal:
- # Prepend: multimodal content first, then original text
- combined_parts = multimodal_parts + text_parts
- else:
- # Append: original text first, then multimodal content
- combined_parts = text_parts + multimodal_parts
-
- # Combine all text parts
- combined_text = "\n\n".join(combined_parts) if combined_parts else ""
-
- logger.info(
- "multimodal_input_processed",
- text_length=len(combined_text),
- num_files=len(files) if files else 0,
- has_audio=audio_input is not None,
- )
-
- return combined_text
-
- def _is_image_file(self, file_path: str) -> bool:
- """Check if file is an image.
-
- Args:
- file_path: Path to file
-
- Returns:
- True if file is an image
- """
- image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".tif"}
- return any(file_path.lower().endswith(ext) for ext in image_extensions)
-
- def _is_audio_file(self, file_path: str) -> bool:
- """Check if file is an audio file.
-
- Args:
- file_path: Path to file
-
- Returns:
- True if file is an audio file
- """
- audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma"}
- return any(file_path.lower().endswith(ext) for ext in audio_extensions)
-
-
-@lru_cache(maxsize=1)
-def get_multimodal_service() -> MultimodalService:
- """Get or create singleton multimodal service instance.
-
- Returns:
- MultimodalService instance
- """
- return MultimodalService()
diff --git a/src/services/neo4j_service.py b/src/services/neo4j_service.py
deleted file mode 100644
index 6a9c6f7e2995be185310faba566e575d429348a6..0000000000000000000000000000000000000000
--- a/src/services/neo4j_service.py
+++ /dev/null
@@ -1,126 +0,0 @@
-"""Neo4j Knowledge Graph Service for Drug Repurposing"""
-
-import logging
-import os
-from typing import Any
-
-from dotenv import load_dotenv
-from neo4j import GraphDatabase
-
-load_dotenv()
-logger = logging.getLogger(__name__)
-
-
-class Neo4jService:
- def __init__(self) -> None:
- self.uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
- self.user = os.getenv("NEO4J_USER", "neo4j")
- self.password = os.getenv("NEO4J_PASSWORD")
- self.database = os.getenv("NEO4J_DATABASE", "neo4j")
-
- if not self.password:
- logger.warning("⚠️ NEO4J_PASSWORD not set")
- self.driver = None
- return
-
- try:
- self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
- self.driver.verify_connectivity()
- logger.info(f"✅ Neo4j connected: {self.uri} (db: {self.database})")
- except Exception as e:
- logger.error(f"❌ Neo4j connection failed: {e}")
- self.driver = None
-
- def is_connected(self) -> bool:
- return self.driver is not None
-
- def close(self) -> None:
- if self.driver:
- self.driver.close()
-
- def ingest_search_results(
- self,
- disease_name: str,
- papers: list[dict[str, Any]],
- drugs_mentioned: list[str] | None = None,
- ) -> dict[str, int]:
- if not self.driver:
- return {"error": "Neo4j not connected"} # type: ignore[dict-item]
-
- stats = {"papers": 0, "drugs": 0, "relationships": 0, "errors": 0}
-
- try:
- with self.driver.session(database=self.database) as session:
- session.run("MERGE (d:Disease {name: $name})", name=disease_name)
-
- for paper in papers:
- try:
- paper_id = paper.get("id") or paper.get("url", "")
- if not paper_id:
- continue
-
- session.run(
- """
- MERGE (p:Paper {paper_id: $id})
- SET p.title = $title,
- p.abstract = $abstract,
- p.url = $url,
- p.source = $source,
- p.updated_at = datetime()
- """,
- id=paper_id,
- title=str(paper.get("title", ""))[:500],
- abstract=str(paper.get("abstract", ""))[:2000],
- url=str(paper.get("url", ""))[:500],
- source=str(paper.get("source", ""))[:100],
- )
-
- session.run(
- """
- MATCH (p:Paper {paper_id: $id})
- MATCH (d:Disease {name: $disease})
- MERGE (p)-[r:ABOUT]->(d)
- """,
- id=paper_id,
- disease=disease_name,
- )
-
- stats["papers"] += 1
- stats["relationships"] += 1
- except Exception:
- stats["errors"] += 1
-
- if drugs_mentioned:
- for drug in drugs_mentioned:
- try:
- session.run("MERGE (d:Drug {name: $name})", name=drug)
- session.run(
- """
- MATCH (drug:Drug {name: $drug})
- MATCH (disease:Disease {name: $disease})
- MERGE (drug)-[r:POTENTIAL_TREATMENT]->(disease)
- """,
- drug=drug,
- disease=disease_name,
- )
- stats["drugs"] += 1
- stats["relationships"] += 1
- except Exception:
- stats["errors"] += 1
-
- logger.info(f"�� Neo4j ingestion: {stats['papers']} papers, {stats['drugs']} drugs")
- except Exception as e:
- logger.error(f"Neo4j ingestion error: {e}")
- stats["errors"] += 1
-
- return stats
-
-
-_neo4j_service = None
-
-
-def get_neo4j_service() -> Neo4jService | None:
- global _neo4j_service
- if _neo4j_service is None:
- _neo4j_service = Neo4jService()
- return _neo4j_service if _neo4j_service and _neo4j_service.is_connected() else None
diff --git a/src/services/report_file_service.py b/src/services/report_file_service.py
deleted file mode 100644
index 144966409f446ead658ebf86c309b0da1f331d2d..0000000000000000000000000000000000000000
--- a/src/services/report_file_service.py
+++ /dev/null
@@ -1,331 +0,0 @@
-"""Service for saving research reports to files."""
-
-import hashlib
-import tempfile
-from datetime import datetime
-from pathlib import Path
-from typing import Literal
-
-import structlog
-
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger()
-
-
-class ReportFileService:
- """
- Service for saving research reports to files.
-
- Handles file creation, naming, and directory management for report outputs.
- Supports saving reports in multiple formats (markdown, HTML, PDF).
- """
-
- def __init__(
- self,
- output_directory: str | None = None,
- enabled: bool | None = None,
- file_format: Literal["md", "md_html", "md_pdf"] | None = None,
- ) -> None:
- """
- Initialize the report file service.
-
- Args:
- output_directory: Directory to save reports. If None, uses settings or temp directory.
- enabled: Whether file saving is enabled. If None, uses settings.
- file_format: File format to save. If None, uses settings.
- """
- self.enabled = enabled if enabled is not None else settings.save_reports_to_file
- self.file_format = file_format or settings.report_file_format
- self.filename_template = settings.report_filename_template
-
- # Determine output directory
- if output_directory:
- self.output_directory = Path(output_directory)
- elif settings.report_output_directory:
- self.output_directory = Path(settings.report_output_directory)
- else:
- # Use system temp directory
- self.output_directory = Path(tempfile.gettempdir()) / "deepcritical_reports"
-
- # Create output directory if it doesn't exist
- if self.enabled:
- try:
- self.output_directory.mkdir(parents=True, exist_ok=True)
- logger.debug(
- "Report output directory initialized",
- path=str(self.output_directory),
- enabled=self.enabled,
- )
- except Exception as e:
- logger.error(
- "Failed to create report output directory",
- error=str(e),
- path=str(self.output_directory),
- )
- raise ConfigurationError(f"Failed to create report output directory: {e}") from e
-
- def _generate_filename(self, query: str | None = None, extension: str = ".md") -> str:
- """
- Generate filename for report using template.
-
- Args:
- query: Optional query string for hash generation
- extension: File extension (e.g., ".md", ".html")
-
- Returns:
- Generated filename
- """
- # Generate timestamp
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
-
- # Generate query hash if query provided
- query_hash = ""
- if query:
- query_hash = hashlib.md5(query.encode()).hexdigest()[:8]
-
- # Generate date
- date = datetime.now().strftime("%Y-%m-%d")
-
- # Replace template placeholders
- filename = self.filename_template
- filename = filename.replace("{timestamp}", timestamp)
- filename = filename.replace("{query_hash}", query_hash)
- filename = filename.replace("{date}", date)
-
- # Ensure correct extension
- if not filename.endswith(extension):
- # Remove existing extension if present
- if "." in filename:
- filename = filename.rsplit(".", 1)[0]
- filename += extension
-
- return filename
-
- def save_report(
- self,
- report_content: str,
- query: str | None = None,
- filename: str | None = None,
- ) -> str:
- """
- Save a report to a file.
-
- Args:
- report_content: The report content (markdown string)
- query: Optional query string for filename generation
- filename: Optional custom filename. If None, generates from template.
-
- Returns:
- Path to saved file
-
- Raises:
- ConfigurationError: If file saving is disabled or fails
- """
- if not self.enabled:
- logger.debug("File saving disabled, skipping")
- raise ConfigurationError("Report file saving is disabled")
-
- if not report_content or not report_content.strip():
- raise ValueError("Report content cannot be empty")
-
- # Generate filename if not provided
- if not filename:
- filename = self._generate_filename(query=query, extension=".md")
-
- # Ensure filename is safe
- filename = self._sanitize_filename(filename)
-
- # Build full file path
- file_path = self.output_directory / filename
-
- try:
- # Write file
- with open(file_path, "w", encoding="utf-8") as f:
- f.write(report_content)
-
- logger.info(
- "Report saved to file",
- path=str(file_path),
- size=len(report_content),
- query=query[:50] if query else None,
- )
-
- return str(file_path)
-
- except Exception as e:
- logger.error("Failed to save report to file", error=str(e), path=str(file_path))
- raise ConfigurationError(f"Failed to save report to file: {e}") from e
-
- def save_report_multiple_formats(
- self,
- report_content: str,
- query: str | None = None,
- ) -> dict[str, str]:
- """
- Save a report in multiple formats.
-
- Args:
- report_content: The report content (markdown string)
- query: Optional query string for filename generation
-
- Returns:
- Dictionary mapping format to file path (e.g., {"md": "/path/to/report.md"})
-
- Raises:
- ConfigurationError: If file saving is disabled or fails
- """
- if not self.enabled:
- logger.debug("File saving disabled, skipping")
- raise ConfigurationError("Report file saving is disabled")
-
- saved_files: dict[str, str] = {}
-
- # Always save markdown
- md_path = self.save_report(report_content, query=query, filename=None)
- saved_files["md"] = md_path
-
- # Save additional formats based on file_format setting
- if self.file_format == "md_html":
- # TODO: Implement HTML conversion
- logger.warning("HTML format not yet implemented, saving markdown only")
- elif self.file_format == "md_pdf":
- # Generate PDF from markdown
- try:
- pdf_path = self._save_pdf(report_content, query=query)
- saved_files["pdf"] = pdf_path
- logger.info("PDF report generated", pdf_path=pdf_path)
- except Exception as e:
- logger.warning(
- "PDF generation failed, markdown saved",
- error=str(e),
- md_path=md_path,
- )
- # Continue without PDF - markdown is already saved
-
- return saved_files
-
- def _save_pdf(
- self,
- report_content: str,
- query: str | None = None,
- ) -> str:
- """
- Save report as PDF.
-
- Args:
- report_content: The report content (markdown string)
- query: Optional query string for filename generation
-
- Returns:
- Path to saved PDF file
-
- Raises:
- ConfigurationError: If PDF generation fails
- """
- try:
- from src.utils.md_to_pdf import md_to_pdf
- except ImportError as e:
- raise ConfigurationError(
- "PDF generation requires md2pdf. Install with: pip install md2pdf"
- ) from e
-
- # Generate PDF filename
- pdf_filename = self._generate_filename(query=query, extension=".pdf")
- pdf_filename = self._sanitize_filename(pdf_filename)
- pdf_path = self.output_directory / pdf_filename
-
- try:
- # Convert markdown to PDF
- md_to_pdf(report_content, str(pdf_path))
-
- logger.info(
- "PDF report saved",
- path=str(pdf_path),
- size=pdf_path.stat().st_size if pdf_path.exists() else 0,
- query=query[:50] if query else None,
- )
-
- return str(pdf_path)
-
- except Exception as e:
- logger.error("Failed to generate PDF", error=str(e), path=str(pdf_path))
- raise ConfigurationError(f"Failed to generate PDF: {e}") from e
-
- def _sanitize_filename(self, filename: str) -> str:
- """
- Sanitize filename to remove unsafe characters.
-
- Args:
- filename: Original filename
-
- Returns:
- Sanitized filename
- """
- # Remove or replace unsafe characters
- unsafe_chars = '<>:"/\\|?*'
- sanitized = filename
- for char in unsafe_chars:
- sanitized = sanitized.replace(char, "_")
-
- # Limit length
- if len(sanitized) > 200:
- name, ext = sanitized.rsplit(".", 1) if "." in sanitized else (sanitized, "")
- sanitized = name[:190] + ext
-
- return sanitized
-
- def cleanup_old_files(self, max_age_days: int = 7) -> int:
- """
- Clean up old report files.
-
- Args:
- max_age_days: Maximum age in days for files to keep
-
- Returns:
- Number of files deleted
- """
- if not self.output_directory.exists():
- return 0
-
- deleted_count = 0
- cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60)
-
- try:
- for file_path in self.output_directory.iterdir():
- if file_path.is_file() and file_path.stat().st_mtime < cutoff_time:
- try:
- file_path.unlink()
- deleted_count += 1
- except Exception as e:
- logger.warning(
- "Failed to delete old file", path=str(file_path), error=str(e)
- )
-
- if deleted_count > 0:
- logger.info(
- "Cleaned up old report files", deleted=deleted_count, max_age_days=max_age_days
- )
-
- except Exception as e:
- logger.error("Failed to cleanup old files", error=str(e))
-
- return deleted_count
-
-
-def get_report_file_service() -> ReportFileService:
- """
- Get or create a ReportFileService instance (singleton pattern).
-
- Returns:
- ReportFileService instance
- """
- # Use lru_cache for singleton pattern
- from functools import lru_cache
-
- @lru_cache(maxsize=1)
- def _get_service() -> ReportFileService:
- return ReportFileService()
-
- return _get_service()
diff --git a/src/services/statistical_analyzer.py b/src/services/statistical_analyzer.py
deleted file mode 100644
index 38458907ef0be1eddd9f67b12dd0819eed3aadfb..0000000000000000000000000000000000000000
--- a/src/services/statistical_analyzer.py
+++ /dev/null
@@ -1,255 +0,0 @@
-"""Statistical analysis service using Modal code execution.
-
-This module provides Modal-based statistical analysis WITHOUT depending on
-agent_framework. This allows it to be used in the simple orchestrator mode
-without requiring the magentic optional dependency.
-
-The AnalysisAgent (in src/agents/) wraps this service for magentic mode.
-"""
-
-import asyncio
-import re
-from functools import lru_cache, partial
-from typing import Any, Literal
-
-# Type alias for verdict values
-VerdictType = Literal["SUPPORTED", "REFUTED", "INCONCLUSIVE"]
-
-from pydantic import BaseModel, Field
-from pydantic_ai import Agent
-
-from src.agent_factory.judges import get_model
-from src.tools.code_execution import (
- CodeExecutionError,
- get_code_executor,
- get_sandbox_library_prompt,
-)
-from src.utils.models import Evidence
-
-
-class AnalysisResult(BaseModel):
- """Result of statistical analysis."""
-
- verdict: VerdictType = Field(
- description="SUPPORTED, REFUTED, or INCONCLUSIVE",
- )
- confidence: float = Field(ge=0.0, le=1.0, description="Confidence in verdict (0-1)")
- statistical_evidence: str = Field(
- description="Summary of statistical findings from code execution"
- )
- code_generated: str = Field(description="Python code that was executed")
- execution_output: str = Field(description="Output from code execution")
- key_findings: list[str] = Field(default_factory=list, description="Key takeaways")
- limitations: list[str] = Field(default_factory=list, description="Limitations")
-
-
-class StatisticalAnalyzer:
- """Performs statistical analysis using Modal code execution.
-
- This service:
- 1. Generates Python code for statistical analysis using LLM
- 2. Executes code in Modal sandbox
- 3. Interprets results
- 4. Returns verdict (SUPPORTED/REFUTED/INCONCLUSIVE)
-
- Note: This class has NO agent_framework dependency, making it safe
- to use in the simple orchestrator without the magentic extra.
- """
-
- def __init__(self) -> None:
- """Initialize the analyzer."""
- self._code_executor: Any = None
- self._agent: Agent[None, str] | None = None
-
- def _get_code_executor(self) -> Any:
- """Lazy initialization of code executor."""
- if self._code_executor is None:
- self._code_executor = get_code_executor()
- return self._code_executor
-
- def _get_agent(self) -> Agent[None, str]:
- """Lazy initialization of LLM agent for code generation."""
- if self._agent is None:
- library_versions = get_sandbox_library_prompt()
- self._agent = Agent(
- model=get_model(),
- output_type=str,
- system_prompt=f"""You are a biomedical data scientist.
-
-Generate Python code to analyze research evidence and test hypotheses.
-
-Guidelines:
-1. Use pandas, numpy, scipy.stats for analysis
-2. Print clear, interpretable results
-3. Include statistical tests (t-tests, chi-square, etc.)
-4. Calculate effect sizes and confidence intervals
-5. Keep code concise (<50 lines)
-6. Set 'result' variable to SUPPORTED, REFUTED, or INCONCLUSIVE
-
-Available libraries:
-{library_versions}
-
-Output format: Return ONLY executable Python code, no explanations.""",
- )
- return self._agent
-
- async def analyze(
- self,
- query: str,
- evidence: list[Evidence],
- hypothesis: dict[str, Any] | None = None,
- ) -> AnalysisResult:
- """Run statistical analysis on evidence.
-
- Args:
- query: The research question
- evidence: List of Evidence objects to analyze
- hypothesis: Optional hypothesis dict with drug, target, pathway, effect
-
- Returns:
- AnalysisResult with verdict and statistics
- """
- # Build analysis prompt (method handles slicing internally)
- evidence_summary = self._summarize_evidence(evidence)
- hypothesis_text = ""
- if hypothesis:
- hypothesis_text = (
- f"\nHypothesis: {hypothesis.get('drug', 'Unknown')} → "
- f"{hypothesis.get('target', '?')} → "
- f"{hypothesis.get('pathway', '?')} → "
- f"{hypothesis.get('effect', '?')}\n"
- f"Confidence: {hypothesis.get('confidence', 0.5):.0%}\n"
- )
-
- prompt = f"""Generate Python code to statistically analyze:
-
-**Research Question**: {query}
-{hypothesis_text}
-
-**Evidence Summary**:
-{evidence_summary}
-
-Generate executable Python code to analyze this evidence."""
-
- try:
- # Generate code
- agent = self._get_agent()
- code_result = await agent.run(prompt)
- generated_code = code_result.output
-
- # Execute in Modal sandbox
- loop = asyncio.get_running_loop()
- executor = self._get_code_executor()
- execution = await loop.run_in_executor(
- None, partial(executor.execute, generated_code, timeout=120)
- )
-
- if not execution["success"]:
- return AnalysisResult(
- verdict="INCONCLUSIVE",
- confidence=0.0,
- statistical_evidence=(
- f"Execution failed: {execution.get('error', 'Unknown error')}"
- ),
- code_generated=generated_code,
- execution_output=execution.get("stderr", ""),
- key_findings=[],
- limitations=["Code execution failed"],
- )
-
- # Interpret results
- return self._interpret_results(generated_code, execution)
-
- except CodeExecutionError as e:
- return AnalysisResult(
- verdict="INCONCLUSIVE",
- confidence=0.0,
- statistical_evidence=str(e),
- code_generated="",
- execution_output="",
- key_findings=[],
- limitations=[f"Analysis error: {e}"],
- )
-
- def _summarize_evidence(self, evidence: list[Evidence]) -> str:
- """Summarize evidence for code generation prompt."""
- if not evidence:
- return "No evidence available."
-
- lines = []
- for i, ev in enumerate(evidence[:5], 1):
- content = ev.content
- truncated = content[:200] + ("..." if len(content) > 200 else "")
- lines.append(f"{i}. {truncated}")
- lines.append(f" Source: {ev.citation.title}")
- lines.append(f" Relevance: {ev.relevance:.0%}\n")
-
- return "\n".join(lines)
-
- def _interpret_results(
- self,
- code: str,
- execution: dict[str, Any],
- ) -> AnalysisResult:
- """Interpret code execution results."""
- stdout = execution["stdout"]
- stdout_upper = stdout.upper()
-
- # Extract verdict with robust word-boundary matching
- verdict: VerdictType = "INCONCLUSIVE"
- if re.search(r"\bSUPPORTED\b", stdout_upper) and not re.search(
- r"\b(?:NOT|UN)SUPPORTED\b", stdout_upper
- ):
- verdict = "SUPPORTED"
- elif re.search(r"\bREFUTED\b", stdout_upper):
- verdict = "REFUTED"
-
- # Extract key findings
- key_findings = []
- for line in stdout.split("\n"):
- line_lower = line.lower()
- if any(kw in line_lower for kw in ["p-value", "significant", "effect", "mean"]):
- key_findings.append(line.strip())
-
- # Calculate confidence from p-values
- confidence = self._calculate_confidence(stdout)
-
- return AnalysisResult(
- verdict=verdict,
- confidence=confidence,
- statistical_evidence=stdout.strip(),
- code_generated=code,
- execution_output=stdout,
- key_findings=key_findings[:5],
- limitations=[
- "Analysis based on summary data only",
- "Limited to available evidence",
- "Statistical tests assume data independence",
- ],
- )
-
- def _calculate_confidence(self, output: str) -> float:
- """Calculate confidence based on statistical results."""
- p_values = re.findall(r"p[-\s]?value[:\s]+(\d+\.?\d*)", output.lower())
-
- if p_values:
- try:
- min_p = min(float(p) for p in p_values)
- if min_p < 0.001:
- return 0.95
- elif min_p < 0.01:
- return 0.90
- elif min_p < 0.05:
- return 0.80
- else:
- return 0.60
- except ValueError:
- pass
-
- return 0.70 # Default
-
-
-@lru_cache(maxsize=1)
-def get_statistical_analyzer() -> StatisticalAnalyzer:
- """Get or create singleton StatisticalAnalyzer instance (thread-safe via lru_cache)."""
- return StatisticalAnalyzer()
diff --git a/src/services/stt_gradio.py b/src/services/stt_gradio.py
deleted file mode 100644
index 44b993b429f39aa3e15d7c83ce7ebabc77d619ba..0000000000000000000000000000000000000000
--- a/src/services/stt_gradio.py
+++ /dev/null
@@ -1,271 +0,0 @@
-"""Speech-to-Text service using Gradio Client API."""
-
-import asyncio
-import tempfile
-from functools import lru_cache
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import structlog
-from gradio_client import Client, handle_file
-
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger(__name__)
-
-
-class STTService:
- """STT service using nvidia/canary-1b-v2 Gradio Space."""
-
- def __init__(self, api_url: str | None = None, hf_token: str | None = None) -> None:
- """Initialize STT service.
-
- Args:
- api_url: Gradio Space URL (default: settings.stt_api_url or nvidia/canary-1b-v2)
- hf_token: HuggingFace token for authenticated Spaces (default: None)
-
- Raises:
- ConfigurationError: If API URL not configured
- """
- self.api_url = api_url or settings.stt_api_url or "https://nvidia-canary-1b-v2.hf.space"
- if not self.api_url:
- raise ConfigurationError("STT API URL not configured")
- self.hf_token = hf_token
- self.client: Client | None = None
-
- async def _get_client(self, hf_token: str | None = None) -> Client:
- """Get or create Gradio Client (lazy initialization).
-
- Args:
- hf_token: HuggingFace token for authenticated Spaces (overrides instance token)
-
- Returns:
- Gradio Client instance
- """
- # Use provided token or instance token
- token = hf_token or self.hf_token
-
- # If client exists but token changed, recreate it
- if self.client is not None and token != self.hf_token:
- self.client = None
-
- if self.client is None:
- loop = asyncio.get_running_loop()
- # Pass token to Client for authenticated Spaces
- # Gradio Client uses 'token' parameter, not 'hf_token'
- if token:
- self.client = await loop.run_in_executor(
- None,
- lambda: Client(self.api_url, token=token),
- )
- else:
- self.client = await loop.run_in_executor(
- None,
- lambda: Client(self.api_url),
- )
- # Update instance token for future use
- self.hf_token = token
- return self.client
-
- async def transcribe_file(
- self,
- audio_path: str,
- source_lang: str | None = None,
- target_lang: str | None = None,
- hf_token: str | None = None,
- ) -> str:
- """Transcribe audio file using Gradio API.
-
- Args:
- audio_path: Path to audio file
- source_lang: Source language (default: settings.stt_source_lang)
- target_lang: Target language (default: settings.stt_target_lang)
-
- Returns:
- Transcribed text string
-
- Raises:
- ConfigurationError: If transcription fails
- """
- client = await self._get_client(hf_token=hf_token)
- source_lang = source_lang or settings.stt_source_lang
- target_lang = target_lang or settings.stt_target_lang
-
- logger.info(
- "transcribing_audio_file",
- audio_path=audio_path,
- source_lang=source_lang,
- target_lang=target_lang,
- )
-
- try:
- # Call /transcribe_file API endpoint
- # API returns: (dataframe, csv_path, srt_path)
- loop = asyncio.get_running_loop()
- result = await loop.run_in_executor(
- None,
- lambda: client.predict(
- audio_path=handle_file(audio_path),
- source_lang=source_lang,
- target_lang=target_lang,
- api_name="/transcribe_file",
- ),
- )
-
- # Extract transcription from result
- transcribed_text = self._extract_transcription(result)
-
- logger.info(
- "audio_transcription_complete",
- text_length=len(transcribed_text),
- )
-
- return transcribed_text
-
- except Exception as e:
- logger.error("audio_transcription_failed", error=str(e), error_type=type(e).__name__)
- raise ConfigurationError(f"Audio transcription failed: {e}") from e
-
- async def transcribe_audio(
- self,
- audio_data: tuple[int, np.ndarray[Any, Any]], # type: ignore[type-arg]
- hf_token: str | None = None,
- ) -> str:
- """Transcribe audio numpy array to text.
-
- Args:
- audio_data: Tuple of (sample_rate, audio_array)
-
- Returns:
- Transcribed text string
- """
- sample_rate, audio_array = audio_data
-
- logger.info(
- "transcribing_audio_array",
- sample_rate=sample_rate,
- audio_shape=audio_array.shape,
- )
-
- # Save audio to temp file
- temp_path = self._save_audio_temp(audio_data)
-
- try:
- # Transcribe the temp file
- transcribed_text = await self.transcribe_file(temp_path, hf_token=hf_token)
- return transcribed_text
- finally:
- # Clean up temp file
- try:
- Path(temp_path).unlink(missing_ok=True)
- except Exception as e:
- logger.warning("failed_to_cleanup_temp_file", path=temp_path, error=str(e))
-
- def _extract_transcription(self, api_result: tuple[Any, ...]) -> str:
- """Extract transcription text from API result.
-
- Args:
- api_result: Tuple from Gradio API (dataframe, csv_path, srt_path)
-
- Returns:
- Extracted transcription text
- """
- # API returns: (dataframe, csv_path, srt_path)
- # Try to extract from dataframe first
- if isinstance(api_result, tuple) and len(api_result) >= 1:
- dataframe = api_result[0]
- if isinstance(dataframe, dict) and "data" in dataframe:
- # Extract text from dataframe rows
- rows = dataframe.get("data", [])
- if rows:
- # Combine all text segments
- text_segments = []
- for row in rows:
- if isinstance(row, list) and len(row) > 0:
- # First column is usually the text
- text_segments.append(str(row[0]))
- if text_segments:
- return " ".join(text_segments)
-
- # Fallback: try to read CSV file if available
- if len(api_result) >= 2 and api_result[1]:
- csv_path = api_result[1]
- try:
- import pandas as pd
-
- df = pd.read_csv(csv_path)
- if "text" in df.columns:
- return " ".join(df["text"].astype(str).tolist())
- elif len(df.columns) > 0:
- # Use first column
- return " ".join(df.iloc[:, 0].astype(str).tolist())
- except Exception as e:
- logger.warning("failed_to_read_csv", csv_path=csv_path, error=str(e))
-
- # Last resort: return empty string
- logger.warning("could_not_extract_transcription", result_type=type(api_result).__name__)
- return ""
-
- def _save_audio_temp(
- self,
- audio_data: tuple[int, np.ndarray[Any, Any]], # type: ignore[type-arg]
- ) -> str:
- """Save audio numpy array to temporary WAV file.
-
- Args:
- audio_data: Tuple of (sample_rate, audio_array)
-
- Returns:
- Path to temporary WAV file
- """
- sample_rate, audio_array = audio_data
-
- # Create temp file
- temp_file = tempfile.NamedTemporaryFile(
- suffix=".wav",
- delete=False,
- )
- temp_path = temp_file.name
- temp_file.close()
-
- # Save audio using soundfile
- try:
- import soundfile as sf
-
- # Ensure audio is float32 and mono
- if audio_array.dtype != np.float32:
- audio_array = audio_array.astype(np.float32)
-
- # Handle stereo -> mono conversion
- if len(audio_array.shape) > 1:
- audio_array = np.mean(audio_array, axis=1)
-
- # Normalize to [-1, 1] range
- if audio_array.max() > 1.0 or audio_array.min() < -1.0:
- audio_array = audio_array / np.max(np.abs(audio_array))
-
- sf.write(temp_path, audio_array, sample_rate)
-
- logger.debug("saved_audio_temp", path=temp_path, sample_rate=sample_rate)
-
- return temp_path
-
- except ImportError:
- raise ConfigurationError(
- "soundfile not installed. Install with: uv add soundfile"
- ) from None
- except Exception as e:
- logger.error("failed_to_save_audio_temp", error=str(e))
- raise ConfigurationError(f"Failed to save audio to temp file: {e}") from e
-
-
-@lru_cache(maxsize=1)
-def get_stt_service() -> STTService:
- """Get or create singleton STT service instance.
-
- Returns:
- STTService instance
- """
- return STTService()
diff --git a/src/services/tts_modal.py b/src/services/tts_modal.py
deleted file mode 100644
index ce55c49aad24f1f2f1e37ddccbdbbd79df208305..0000000000000000000000000000000000000000
--- a/src/services/tts_modal.py
+++ /dev/null
@@ -1,482 +0,0 @@
-"""Text-to-Speech service using Kokoro 82M via Modal GPU."""
-
-import asyncio
-import os
-from collections.abc import Iterator
-from contextlib import contextmanager
-from functools import lru_cache
-from typing import Any, cast
-
-import numpy as np
-from numpy.typing import NDArray
-import structlog
-
-# Load .env file BEFORE importing Modal SDK
-# Modal SDK reads MODAL_TOKEN_ID and MODAL_TOKEN_SECRET from environment on import
-from dotenv import load_dotenv
-
-load_dotenv()
-
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger(__name__)
-
-# Kokoro TTS dependencies for Modal image
-KOKORO_DEPENDENCIES = [
- "torch>=2.0.0",
- "transformers>=4.30.0",
- "numpy<2.0",
- # kokoro-82M can be installed from source:
- # git+https://github.com/hexgrad/kokoro.git
-]
-
-# Modal app and function definitions (module-level for Modal)
-_modal_app: Any | None = None
-_tts_function: Any | None = None
-_tts_image: Any | None = None
-
-
-@contextmanager
-def modal_credentials_override(token_id: str | None, token_secret: str | None) -> Iterator[None]:
- """Context manager to temporarily override Modal credentials.
-
- Args:
- token_id: Modal token ID (overrides env if provided)
- token_secret: Modal token secret (overrides env if provided)
-
- Yields:
- None
-
- Note:
- Resets global Modal state to force re-initialization with new credentials.
- """
- global _modal_app, _tts_function
-
- # Save original credentials
- original_token_id = os.environ.get("MODAL_TOKEN_ID")
- original_token_secret = os.environ.get("MODAL_TOKEN_SECRET")
-
- # Save original Modal state
- original_app = _modal_app
- original_function = _tts_function
-
- try:
- # Override environment variables if provided
- if token_id:
- os.environ["MODAL_TOKEN_ID"] = token_id
- if token_secret:
- os.environ["MODAL_TOKEN_SECRET"] = token_secret
-
- # Reset Modal state to force re-initialization
- _modal_app = None
- _tts_function = None
-
- yield
-
- finally:
- # Restore original credentials
- if original_token_id is not None:
- os.environ["MODAL_TOKEN_ID"] = original_token_id
- elif "MODAL_TOKEN_ID" in os.environ:
- del os.environ["MODAL_TOKEN_ID"]
-
- if original_token_secret is not None:
- os.environ["MODAL_TOKEN_SECRET"] = original_token_secret
- elif "MODAL_TOKEN_SECRET" in os.environ:
- del os.environ["MODAL_TOKEN_SECRET"]
-
- # Restore original Modal state
- _modal_app = original_app
- _tts_function = original_function
-
-
-def _get_modal_app() -> Any:
- """Get or create Modal app instance.
-
- Retrieves Modal credentials directly from environment variables (.env file)
- instead of relying on settings configuration.
- """
- global _modal_app
- if _modal_app is None:
- try:
- import modal
-
- # Get credentials directly from environment variables
- token_id = os.getenv("MODAL_TOKEN_ID")
- token_secret = os.getenv("MODAL_TOKEN_SECRET")
-
- # Validate Modal credentials
- if not token_id or not token_secret:
- raise ConfigurationError(
- "Modal credentials not found in environment. "
- "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
- )
-
- # Validate token ID format (Modal token IDs are typically UUIDs or specific formats)
- if len(token_id.strip()) < 10:
- raise ConfigurationError(
- f"Modal token ID appears malformed (too short: {len(token_id)} chars). "
- "Token ID should be a valid Modal token identifier."
- )
-
- logger.info(
- "modal_credentials_loaded",
- token_id_prefix=token_id[:8] + "...", # Log prefix for debugging
- has_secret=bool(token_secret),
- )
-
- try:
- # Use lookup with create_if_missing for inline function fallback
- _modal_app = modal.App.lookup("deepcritical-tts", create_if_missing=True)
- except Exception as e:
- error_msg = str(e).lower()
- if "token" in error_msg or "malformed" in error_msg or "invalid" in error_msg:
- raise ConfigurationError(
- f"Modal token validation failed: {e}. "
- "Please check that MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env are correctly set."
- ) from e
- raise
- except ImportError as e:
- raise ConfigurationError(
- "Modal SDK not installed. Run: uv sync or pip install modal>=0.63.0"
- ) from e
- return _modal_app
-
-
-# Define Modal image with Kokoro dependencies (module-level)
-def _get_tts_image() -> Any:
- """Get Modal image with Kokoro dependencies."""
- global _tts_image
- if _tts_image is not None:
- return _tts_image
-
- try:
- import modal
-
- _tts_image = (
- modal.Image.debian_slim(python_version="3.11")
- .pip_install(*KOKORO_DEPENDENCIES)
- .pip_install("git+https://github.com/hexgrad/kokoro.git")
- )
- return _tts_image
- except ImportError:
- return None
-
-
-# Modal TTS function - Using serialized=True to allow dynamic creation
-# This will be initialized lazily when _setup_modal_function() is called
-def _create_tts_function() -> Any:
- """Create the Modal TTS function using serialized=True.
-
- The serialized=True parameter allows the function to be defined outside
- of global scope, which is necessary for dynamic initialization.
- """
- app = _get_modal_app()
- tts_image = _get_tts_image()
-
- if tts_image is None:
- raise ConfigurationError("Modal image setup failed")
-
- # Get GPU and timeout from settings (with defaults)
- gpu_type = getattr(settings, "tts_gpu", None) or "T4"
- timeout_seconds = getattr(settings, "tts_timeout", None) or 120 # 2 minutes for cold starts
-
- @app.function(
- image=tts_image,
- gpu=gpu_type,
- timeout=timeout_seconds,
- serialized=True, # Allow function to be defined outside global scope
- )
- def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, NDArray[np.float32]]:
- """Modal GPU function for Kokoro TTS.
-
- This function runs on Modal's GPU infrastructure.
- Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS
- Reference: https://huggingface.co/spaces/hexgrad/Kokoro-TTS/raw/main/app.py
- """
- import numpy as np
-
- # Import Kokoro inside function (lazy load)
- try:
- import torch
- from kokoro import KModel, KPipeline
-
- # Initialize model (cached on GPU)
- model = KModel().to("cuda").eval()
- pipeline = KPipeline(lang_code=voice[0])
- pack = pipeline.load_voice(voice)
-
- # Generate audio
- for _, ps, _ in pipeline(text, voice, speed):
- ref_s = pack[len(ps) - 1]
- audio = model(ps, ref_s, speed)
- return (24000, audio.numpy())
-
- # If no audio generated, return empty
- return (24000, np.zeros(1, dtype=np.float32))
-
- except ImportError as e:
- raise ConfigurationError(
- "Kokoro not installed. Install with: pip install git+https://github.com/hexgrad/kokoro.git"
- ) from e
- except Exception as e:
- raise ConfigurationError(f"TTS synthesis failed: {e}") from e
-
- return kokoro_tts_function
-
-
-def _setup_modal_function() -> None:
- """Setup Modal GPU function for TTS (called once, lazy initialization).
-
- Hybrid approach:
- 1. Try to lookup pre-deployed function (fast path for advanced users)
- 2. If lookup fails, create function inline (fallback for casual users)
-
- This allows both workflows:
- - Advanced: Deploy with `modal deploy deployments/modal_tts.py` for best performance
- - Casual: Just add Modal keys and it auto-creates function on first use
- """
- global _tts_function
-
- if _tts_function is not None:
- return # Already set up
-
- try:
- import modal
-
- # Try path 1: Lookup pre-deployed function (fast path)
- try:
- _tts_function = modal.Function.from_name("deepcritical-tts", "kokoro_tts_function")
- logger.info(
- "modal_tts_function_lookup_success",
- app_name="deepcritical-tts",
- function_name="kokoro_tts_function",
- method="lookup",
- )
- return
- except Exception as lookup_error:
- logger.info(
- "modal_tts_function_lookup_failed",
- error=str(lookup_error),
- fallback="Creating function inline",
- )
-
- # Try path 2: Create function inline (fallback for casual users)
- logger.info("modal_tts_creating_inline_function")
- _tts_function = _create_tts_function()
- logger.info(
- "modal_tts_function_setup_complete",
- app_name="deepcritical-tts",
- function_name="kokoro_tts_function",
- method="inline",
- )
-
- except Exception as e:
- logger.error("modal_tts_function_setup_failed", error=str(e))
- raise ConfigurationError(
- f"Failed to setup Modal TTS function: {e}. "
- "Ensure Modal credentials (MODAL_TOKEN_ID, MODAL_TOKEN_SECRET) are valid."
- ) from e
-
-
-class ModalTTSExecutor:
- """Execute Kokoro TTS synthesis on Modal GPU.
-
- This class provides TTS synthesis using Kokoro 82M model on Modal's GPU infrastructure.
- Follows the same pattern as ModalCodeExecutor but uses GPU functions for TTS.
- """
-
- def __init__(self) -> None:
- """Initialize Modal TTS executor.
-
- Note:
- Logs a warning if Modal credentials are not configured in environment.
- Execution will fail at runtime without valid credentials in .env file.
- """
- # Check for Modal credentials directly from environment
- token_id = os.getenv("MODAL_TOKEN_ID")
- token_secret = os.getenv("MODAL_TOKEN_SECRET")
-
- if not token_id or not token_secret:
- logger.warning(
- "Modal credentials not found in environment. "
- "TTS will not be available. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
- )
-
- def synthesize(
- self,
- text: str,
- voice: str = "af_heart",
- speed: float = 1.0,
- timeout: int = 120,
- ) -> tuple[int, NDArray[np.float32]]:
- """Synthesize text to speech using Kokoro on Modal GPU.
-
- Args:
- text: Text to synthesize (max 5000 chars for free tier)
- voice: Voice ID from Kokoro (e.g., af_heart, af_bella, am_michael)
- speed: Speech speed multiplier (0.5-2.0)
- timeout: Maximum execution time (not used, Modal function has its own timeout)
-
- Returns:
- Tuple of (sample_rate, audio_array)
-
- Raises:
- ConfigurationError: If synthesis fails
- """
- # Setup Modal function if not already done
- _setup_modal_function()
-
- if _tts_function is None:
- raise ConfigurationError("Modal TTS function not initialized")
-
- logger.info("synthesizing_tts", text_length=len(text), voice=voice, speed=speed)
-
- try:
- # Call the GPU function remotely
- result = cast(tuple[int, NDArray[np.float32]], _tts_function.remote(text, voice, speed))
-
- logger.info(
- "tts_synthesis_complete", sample_rate=result[0], audio_shape=result[1].shape
- )
-
- return result
-
- except Exception as e:
- logger.error("tts_synthesis_failed", error=str(e), error_type=type(e).__name__)
- raise ConfigurationError(f"TTS synthesis failed: {e}") from e
-
-
-class TTSService:
- """TTS service wrapper for async usage."""
-
- def __init__(self) -> None:
- """Initialize TTS service.
-
- Validates Modal credentials from environment variables (.env file).
- """
- # Check credentials directly from environment
- token_id = os.getenv("MODAL_TOKEN_ID")
- token_secret = os.getenv("MODAL_TOKEN_SECRET")
-
- if not token_id or not token_secret:
- raise ConfigurationError(
- "Modal credentials required for TTS. "
- "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
- )
- self.executor = ModalTTSExecutor()
-
- async def synthesize_async(
- self,
- text: str,
- voice: str = "af_heart",
- speed: float = 1.0,
- ) -> tuple[int, NDArray[np.float32]] | None:
- """Async wrapper for TTS synthesis.
-
- Args:
- text: Text to synthesize
- voice: Voice ID (default: settings.tts_voice)
- speed: Speech speed (default: settings.tts_speed)
-
- Returns:
- Tuple of (sample_rate, audio_array) or None if error
- """
- voice = voice or settings.tts_voice
- speed = speed or settings.tts_speed
-
- loop = asyncio.get_running_loop()
-
- try:
- result = await loop.run_in_executor(
- None,
- lambda: self.executor.synthesize(text, voice, speed),
- )
- return result
- except Exception as e:
- logger.error("tts_synthesis_async_failed", error=str(e))
- return None
-
-
-@lru_cache(maxsize=1)
-def get_tts_service() -> TTSService:
- """Get or create singleton TTS service instance.
-
- Returns:
- TTSService instance
-
- Raises:
- ConfigurationError: If Modal credentials not configured
- """
- return TTSService()
-
-
-async def generate_audio_on_demand(
- text: str,
- modal_token_id: str | None = None,
- modal_token_secret: str | None = None,
- voice: str = "af_heart",
- speed: float = 1.0,
- use_llm_polish: bool = False,
-) -> tuple[tuple[int, NDArray[np.float32]] | None, str]:
- """Generate audio on-demand with optional runtime credentials.
-
- Args:
- text: Text to synthesize
- modal_token_id: Modal token ID (UI input, overrides .env)
- modal_token_secret: Modal token secret (UI input, overrides .env)
- voice: Voice ID (default: af_heart)
- speed: Speech speed (default: 1.0)
- use_llm_polish: Apply LLM polish to text (default: False)
-
- Returns:
- Tuple of (audio_output, status_message)
- - audio_output: (sample_rate, audio_array) or None if failed
- - status_message: Status/error message for user
-
- Priority: UI credentials > .env credentials
- """
- # Priority: UI keys > .env keys
- token_id = (modal_token_id or "").strip() or os.getenv("MODAL_TOKEN_ID")
- token_secret = (modal_token_secret or "").strip() or os.getenv("MODAL_TOKEN_SECRET")
-
- if not token_id or not token_secret:
- return (
- None,
- "❌ Modal credentials required. Enter keys above or set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env",
- )
-
- try:
- # Use credentials override context
- with modal_credentials_override(token_id, token_secret):
- # Import audio_processing here to avoid circular import
- from src.services.audio_processing import AudioService
-
- # Temporarily override LLM polish setting
- original_llm_polish = settings.tts_use_llm_polish
- try:
- settings.tts_use_llm_polish = use_llm_polish
-
- # Create fresh AudioService instance (bypass cache to pick up new credentials)
- audio_service = AudioService()
- audio_output = await audio_service.generate_audio_output(
- text=text,
- voice=voice,
- speed=speed,
- )
-
- if audio_output:
- return audio_output, "✅ Audio generated successfully"
- else:
- return None, "⚠️ Audio generation returned no output"
-
- finally:
- settings.tts_use_llm_polish = original_llm_polish
-
- except ConfigurationError as e:
- logger.error("audio_generation_config_error", error=str(e))
- return None, f"❌ Configuration error: {e}"
- except Exception as e:
- logger.error("audio_generation_failed", error=str(e), exc_info=True)
- return None, f"❌ Audio generation failed: {e}"
diff --git a/src/state/__init__.py b/src/state/__init__.py
deleted file mode 100644
index a2323db724ea3ee5902b1121cdb2a30406319fe5..0000000000000000000000000000000000000000
--- a/src/state/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-"""State package - re-exports from agents.state for compatibility."""
-
-from src.agents.state import (
- MagenticState,
- get_magentic_state,
- init_magentic_state,
-)
-
-__all__ = ["MagenticState", "get_magentic_state", "init_magentic_state"]
diff --git a/src/tools/__init__.py b/src/tools/__init__.py
deleted file mode 100644
index 19f5c3ecd4c646af301ab73eab0286698ce99822..0000000000000000000000000000000000000000
--- a/src/tools/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""Search tools package."""
-
-from src.tools.base import SearchTool
-from src.tools.pubmed import PubMedTool
-from src.tools.rag_tool import RAGTool, create_rag_tool
-from src.tools.search_handler import SearchHandler
-
-# Re-export
-__all__ = [
- "PubMedTool",
- "RAGTool",
- "SearchHandler",
- "SearchTool",
- "create_rag_tool",
-]
diff --git a/src/tools/base.py b/src/tools/base.py
deleted file mode 100644
index 0533c851973e3063d735d55f618c44c4ad8418ba..0000000000000000000000000000000000000000
--- a/src/tools/base.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""Base classes and protocols for search tools."""
-
-from typing import Protocol
-
-from src.utils.models import Evidence
-
-
-class SearchTool(Protocol):
- """Protocol defining the interface for all search tools."""
-
- @property
- def name(self) -> str:
- """Human-readable name of this tool."""
- ...
-
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """
- Execute a search and return evidence.
-
- Args:
- query: The search query string
- max_results: Maximum number of results to return
-
- Returns:
- List of Evidence objects
-
- Raises:
- SearchError: If the search fails
- RateLimitError: If we hit rate limits
- """
- ...
diff --git a/src/tools/clinicaltrials.py b/src/tools/clinicaltrials.py
deleted file mode 100644
index 06ef61158096e0ab2a5ff5a52f77e428de13bb4c..0000000000000000000000000000000000000000
--- a/src/tools/clinicaltrials.py
+++ /dev/null
@@ -1,143 +0,0 @@
-"""ClinicalTrials.gov search tool using API v2."""
-
-import asyncio
-from typing import Any, ClassVar
-
-import requests
-from tenacity import retry, stop_after_attempt, wait_exponential
-
-from src.utils.exceptions import SearchError
-from src.utils.models import Citation, Evidence
-
-
-class ClinicalTrialsTool:
- """Search tool for ClinicalTrials.gov.
-
- Note: Uses `requests` library instead of `httpx` because ClinicalTrials.gov's
- WAF blocks httpx's TLS fingerprint. The `requests` library is not blocked.
- See: https://clinicaltrials.gov/data-api/api
- """
-
- BASE_URL = "https://clinicaltrials.gov/api/v2/studies"
-
- # Fields to retrieve
- FIELDS: ClassVar[list[str]] = [
- "NCTId",
- "BriefTitle",
- "Phase",
- "OverallStatus",
- "Condition",
- "InterventionName",
- "StartDate",
- "BriefSummary",
- ]
-
- # Status filter: Only active/completed studies with potential data
- STATUS_FILTER = "COMPLETED,ACTIVE_NOT_RECRUITING,RECRUITING,ENROLLING_BY_INVITATION"
-
- # Study type filter: Only interventional (drug/treatment studies)
- STUDY_TYPE_FILTER = "INTERVENTIONAL"
-
- @property
- def name(self) -> str:
- return "clinicaltrials"
-
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=1, max=10),
- reraise=True,
- )
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """Search ClinicalTrials.gov for interventional studies.
-
- Args:
- query: Search query (e.g., "metformin alzheimer")
- max_results: Maximum results to return (max 100)
-
- Returns:
- List of Evidence objects from clinical trials
- """
- # Add study type filter to query string (parameter is not supported)
- # AREA[StudyType]INTERVENTIONAL restricts to interventional studies
- final_query = f"{query} AND AREA[StudyType]INTERVENTIONAL"
-
- params: dict[str, Any] = {
- "query.term": final_query,
- "pageSize": min(max_results, 100),
- "fields": ",".join(self.FIELDS),
- # FILTERS - Only active/completed studies
- "filter.overallStatus": self.STATUS_FILTER,
- }
-
- try:
- # Run blocking requests.get in a separate thread for async compatibility
- response = await asyncio.to_thread(
- requests.get,
- self.BASE_URL,
- params=params,
- headers={"User-Agent": "DETERMINATOR-Research-Agent/1.0"},
- timeout=30,
- )
- response.raise_for_status()
-
- data = response.json()
- studies = data.get("studies", [])
- return [self._study_to_evidence(study) for study in studies[:max_results]]
-
- except requests.HTTPError as e:
- raise SearchError(f"ClinicalTrials.gov API error: {e}") from e
- except requests.RequestException as e:
- raise SearchError(f"ClinicalTrials.gov request failed: {e}") from e
-
- def _study_to_evidence(self, study: dict[str, Any]) -> Evidence:
- """Convert a clinical trial study to Evidence."""
- # Navigate nested structure
- protocol = study.get("protocolSection", {})
- id_module = protocol.get("identificationModule", {})
- status_module = protocol.get("statusModule", {})
- desc_module = protocol.get("descriptionModule", {})
- design_module = protocol.get("designModule", {})
- conditions_module = protocol.get("conditionsModule", {})
- arms_module = protocol.get("armsInterventionsModule", {})
-
- nct_id = id_module.get("nctId", "Unknown")
- title = id_module.get("briefTitle", "Untitled Study")
- status = status_module.get("overallStatus", "Unknown")
- start_date = status_module.get("startDateStruct", {}).get("date", "Unknown")
-
- # Get phase (might be a list)
- phases = design_module.get("phases", [])
- phase = phases[0] if phases else "Not Applicable"
-
- # Get conditions
- conditions = conditions_module.get("conditions", [])
- conditions_str = ", ".join(conditions[:3]) if conditions else "Unknown"
-
- # Get interventions
- interventions = arms_module.get("interventions", [])
- intervention_names = [i.get("name", "") for i in interventions[:3]]
- interventions_str = ", ".join(intervention_names) if intervention_names else "Unknown"
-
- # Get summary
- summary = desc_module.get("briefSummary", "No summary available.")
-
- # Build content with key trial info
- content = (
- f"{summary[:500]}... "
- f"Trial Phase: {phase}. "
- f"Status: {status}. "
- f"Conditions: {conditions_str}. "
- f"Interventions: {interventions_str}."
- )
-
- return Evidence(
- content=content[:2000],
- citation=Citation(
- source="clinicaltrials",
- title=title[:500],
- url=f"https://clinicaltrials.gov/study/{nct_id}",
- date=start_date,
- authors=[], # Trials don't have traditional authors
- ),
- relevance=0.85, # Trials are highly relevant for repurposing
- )
diff --git a/src/tools/code_execution.py b/src/tools/code_execution.py
index da22fa9b1f864b0ab0e3778706c150ae60c714fd..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/src/tools/code_execution.py
+++ b/src/tools/code_execution.py
@@ -1,256 +0,0 @@
-"""Modal-based secure code execution tool for statistical analysis.
-
-This module provides sandboxed Python code execution using Modal's serverless infrastructure.
-It's designed for running LLM-generated statistical analysis code safely.
-"""
-
-import os
-from functools import lru_cache
-from typing import Any
-
-import structlog
-
-logger = structlog.get_logger(__name__)
-
-# Shared library versions for Modal sandbox - used by both executor and LLM prompts
-# Keep these in sync to avoid version mismatch between generated code and execution
-SANDBOX_LIBRARIES: dict[str, str] = {
- "pandas": "2.2.0",
- "numpy": "1.26.4",
- "scipy": "1.11.4",
- "matplotlib": "3.8.2",
- "scikit-learn": "1.4.0",
- "statsmodels": "0.14.1",
-}
-
-
-def get_sandbox_library_list() -> list[str]:
- """Get list of library==version strings for Modal image."""
- return [f"{lib}=={ver}" for lib, ver in SANDBOX_LIBRARIES.items()]
-
-
-def get_sandbox_library_prompt() -> str:
- """Get formatted library versions for LLM prompts."""
- return "\n".join(f"- {lib}=={ver}" for lib, ver in SANDBOX_LIBRARIES.items())
-
-
-class CodeExecutionError(Exception):
- """Raised when code execution fails."""
-
- pass
-
-
-class ModalCodeExecutor:
- """Execute Python code securely using Modal sandboxes.
-
- This class provides a safe environment for executing LLM-generated code,
- particularly for scientific computing and statistical analysis tasks.
-
- Features:
- - Sandboxed execution (isolated from host system)
- - Pre-installed scientific libraries (numpy, scipy, pandas, matplotlib)
- - Network isolation for security
- - Timeout protection
- - Stdout/stderr capture
-
- Example:
- >>> executor = ModalCodeExecutor()
- >>> result = executor.execute('''
- ... import pandas as pd
- ... df = pd.DataFrame({'a': [1, 2, 3]})
- ... result = df['a'].sum()
- ... ''')
- >>> print(result['stdout'])
- 6
- """
-
- def __init__(self) -> None:
- """Initialize Modal code executor.
-
- Note:
- Logs a warning if Modal credentials are not configured.
- Execution will fail at runtime without valid credentials.
- """
- # Check for Modal credentials
- self.modal_token_id = os.getenv("MODAL_TOKEN_ID")
- self.modal_token_secret = os.getenv("MODAL_TOKEN_SECRET")
-
- if not self.modal_token_id or not self.modal_token_secret:
- logger.warning(
- "Modal credentials not found. Code execution will fail unless modal setup is run."
- )
-
- def execute(self, code: str, timeout: int = 60, allow_network: bool = False) -> dict[str, Any]:
- """Execute Python code in a Modal sandbox.
-
- Args:
- code: Python code to execute
- timeout: Maximum execution time in seconds (default: 60)
- allow_network: Whether to allow network access (default: False for security)
-
- Returns:
- Dictionary containing:
- - stdout: Standard output from code execution
- - stderr: Standard error from code execution
- - success: Boolean indicating if execution succeeded
- - error: Error message if execution failed
-
- Raises:
- CodeExecutionError: If execution fails or times out
- """
- try:
- import modal
- except ImportError as e:
- raise CodeExecutionError(
- "Modal SDK not installed. Run: uv sync or pip install modal>=0.63.0"
- ) from e
-
- logger.info("executing_code", code_length=len(code), timeout=timeout)
-
- try:
- # Create or lookup Modal app
- app = modal.App.lookup("deepcritical-code-execution", create_if_missing=True)
-
- # Define scientific computing image with common libraries
- scientific_image = modal.Image.debian_slim(python_version="3.11").uv_pip_install(
- *get_sandbox_library_list()
- )
-
- # Create sandbox with security restrictions
- sandbox = modal.Sandbox.create(
- app=app,
- image=scientific_image,
- timeout=timeout,
- block_network=not allow_network, # Wire the network control
- )
-
- try:
- # Execute the code
- # Wrap code to capture result
- wrapped_code = f"""
-import sys
-import io
-from contextlib import redirect_stdout, redirect_stderr
-
-stdout_io = io.StringIO()
-stderr_io = io.StringIO()
-
-try:
- with redirect_stdout(stdout_io), redirect_stderr(stderr_io):
- {self._indent_code(code, 8)}
- print("__EXECUTION_SUCCESS__")
-except Exception as e:
- print(f"__EXECUTION_ERROR__: {{type(e).__name__}}: {{e}}", file=sys.stderr)
-
-print("__STDOUT_START__")
-print(stdout_io.getvalue())
-print("__STDOUT_END__")
-print("__STDERR_START__")
-print(stderr_io.getvalue(), file=sys.stderr)
-print("__STDERR_END__", file=sys.stderr)
-"""
-
- # Run the wrapped code
- process = sandbox.exec("python", "-c", wrapped_code, timeout=timeout)
-
- # Read output
- stdout_raw = process.stdout.read()
- stderr_raw = process.stderr.read()
- finally:
- # Always clean up sandbox to prevent resource leaks
- sandbox.terminate()
-
- # Parse output
- success = "__EXECUTION_SUCCESS__" in stdout_raw
-
- # Extract actual stdout/stderr
- stdout = self._extract_output(stdout_raw, "__STDOUT_START__", "__STDOUT_END__")
- stderr = self._extract_output(stderr_raw, "__STDERR_START__", "__STDERR_END__")
-
- result = {
- "stdout": stdout,
- "stderr": stderr,
- "success": success,
- "error": stderr if not success else None,
- }
-
- logger.info(
- "code_execution_completed",
- success=success,
- stdout_length=len(stdout),
- stderr_length=len(stderr),
- )
-
- return result
-
- except Exception as e:
- logger.error("code_execution_failed", error=str(e), error_type=type(e).__name__)
- raise CodeExecutionError(f"Code execution failed: {e}") from e
-
- def execute_with_return(self, code: str, timeout: int = 60) -> Any:
- """Execute code and return the value of the 'result' variable.
-
- Convenience method that executes code and extracts a return value.
- The code should assign its final result to a variable named 'result'.
-
- Args:
- code: Python code to execute (must set 'result' variable)
- timeout: Maximum execution time in seconds
-
- Returns:
- The value of the 'result' variable from the executed code
-
- Example:
- >>> executor.execute_with_return("result = 2 + 2")
- 4
- """
- # Modify code to print result as JSON
- wrapped = f"""
-import json
-{code}
-print(json.dumps({{"__RESULT__": result}}))
-"""
-
- execution_result = self.execute(wrapped, timeout=timeout)
-
- if not execution_result["success"]:
- raise CodeExecutionError(f"Execution failed: {execution_result['error']}")
-
- # Parse result from stdout
- import json
-
- try:
- output = execution_result["stdout"].strip()
- if "__RESULT__" in output:
- # Extract JSON line
- for line in output.split("\n"):
- if "__RESULT__" in line:
- data = json.loads(line)
- return data["__RESULT__"]
- raise ValueError("Result not found in output")
- except (json.JSONDecodeError, ValueError) as e:
- logger.warning(
- "failed_to_parse_result", error=str(e), stdout=execution_result["stdout"]
- )
- return execution_result["stdout"]
-
- def _indent_code(self, code: str, spaces: int) -> str:
- """Indent code by specified number of spaces."""
- indent = " " * spaces
- return "\n".join(indent + line if line.strip() else line for line in code.split("\n"))
-
- def _extract_output(self, text: str, start_marker: str, end_marker: str) -> str:
- """Extract content between markers."""
- try:
- start_idx = text.index(start_marker) + len(start_marker)
- end_idx = text.index(end_marker)
- return text[start_idx:end_idx].strip()
- except ValueError:
- # Markers not found, return original text
- return text.strip()
-
-
-@lru_cache(maxsize=1)
-def get_code_executor() -> ModalCodeExecutor:
- """Get or create singleton code executor instance (thread-safe via lru_cache)."""
- return ModalCodeExecutor()
diff --git a/src/tools/crawl_adapter.py b/src/tools/crawl_adapter.py
deleted file mode 100644
index 332569c52698e39a46085f81989e0c386d402347..0000000000000000000000000000000000000000
--- a/src/tools/crawl_adapter.py
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Website crawl tool adapter for Pydantic AI agents.
-
-Uses the vendored crawl_website implementation from src/tools/vendored/crawl_website.py.
-"""
-
-import structlog
-
-logger = structlog.get_logger()
-
-
-async def crawl_website(starting_url: str) -> str:
- """
- Crawl a website starting from the given URL and return formatted results.
-
- Use this tool to crawl a website for information relevant to the query.
- Provide a starting URL as input.
-
- Args:
- starting_url: The starting URL to crawl (e.g., "https://example.com")
-
- Returns:
- Formatted string with crawled content including titles, descriptions, and URLs
- """
- try:
- # Import vendored crawl tool
- from src.tools.vendored.crawl_website import crawl_website as crawl_tool
-
- # Call the tool function
- # The tool returns List[ScrapeResult] or str
- results = await crawl_tool(starting_url)
-
- if isinstance(results, str):
- # Error message returned
- logger.warning("Crawl returned error", error=results)
- return results
-
- if not results:
- return f"No content found when crawling: {starting_url}"
-
- # Format results for agent consumption
- formatted = [f"Found {len(results)} pages from {starting_url}:\n"]
- for i, result in enumerate(results[:10], 1): # Limit to 10 pages
- formatted.append(f"{i}. **{result.title or 'Untitled'}**")
- if result.description:
- formatted.append(f" {result.description[:200]}...")
- formatted.append(f" URL: {result.url}")
- if result.text:
- formatted.append(f" Content: {result.text[:500]}...")
- formatted.append("")
-
- return "\n".join(formatted)
-
- except ImportError as e:
- logger.error("Crawl tool not available", error=str(e))
- return f"Crawl tool not available: {e!s}"
- except Exception as e:
- logger.error("Crawl failed", error=str(e), url=starting_url)
- return f"Error crawling website: {e!s}"
diff --git a/src/tools/europepmc.py b/src/tools/europepmc.py
deleted file mode 100644
index 4dee1f58bc9e1e674578b9e4e3bcd7261f9895c4..0000000000000000000000000000000000000000
--- a/src/tools/europepmc.py
+++ /dev/null
@@ -1,107 +0,0 @@
-"""Europe PMC search tool - replaces BioRxiv."""
-
-from typing import Any
-
-import httpx
-from tenacity import retry, stop_after_attempt, wait_exponential
-
-from src.utils.exceptions import SearchError
-from src.utils.models import Citation, Evidence
-
-
-class EuropePMCTool:
- """
- Search Europe PMC for papers and preprints.
-
- Europe PMC indexes:
- - PubMed/MEDLINE articles
- - PMC full-text articles
- - Preprints from bioRxiv, medRxiv, ChemRxiv, etc.
- - Patents and clinical guidelines
-
- API Docs: https://europepmc.org/RestfulWebService
- """
-
- BASE_URL = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
-
- @property
- def name(self) -> str:
- return "europepmc"
-
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=1, max=10),
- reraise=True,
- )
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """
- Search Europe PMC for papers matching query.
-
- Args:
- query: Search keywords
- max_results: Maximum results to return
-
- Returns:
- List of Evidence objects
- """
- params: dict[str, str | int] = {
- "query": query,
- "resultType": "core",
- "pageSize": min(max_results, 100),
- "format": "json",
- }
-
- async with httpx.AsyncClient(timeout=30.0) as client:
- try:
- response = await client.get(self.BASE_URL, params=params)
- response.raise_for_status()
-
- data = response.json()
- results = data.get("resultList", {}).get("result", [])
-
- return [self._to_evidence(r) for r in results[:max_results]]
-
- except httpx.HTTPStatusError as e:
- raise SearchError(f"Europe PMC API error: {e}") from e
- except httpx.RequestError as e:
- raise SearchError(f"Europe PMC connection failed: {e}") from e
-
- def _to_evidence(self, result: dict[str, Any]) -> Evidence:
- """Convert Europe PMC result to Evidence."""
- title = result.get("title", "Untitled")
- abstract = result.get("abstractText", "No abstract available.")
- doi = result.get("doi", "")
- pub_year = result.get("pubYear", "Unknown")
-
- # Get authors
- author_list = result.get("authorList", {}).get("author", [])
- authors = [a.get("fullName", "") for a in author_list[:5] if a.get("fullName")]
-
- # Check if preprint
- pub_types = result.get("pubTypeList", {}).get("pubType", [])
- is_preprint = "Preprint" in pub_types
- source_db = result.get("source", "europepmc")
-
- # Build content
- preprint_marker = "[PREPRINT - Not peer-reviewed] " if is_preprint else ""
- content = f"{preprint_marker}{abstract[:1800]}"
-
- # Build URL
- if doi:
- url = f"https://doi.org/{doi}"
- elif result.get("pmid"):
- url = f"https://pubmed.ncbi.nlm.nih.gov/{result['pmid']}/"
- else:
- url = f"https://europepmc.org/article/{source_db}/{result.get('id', '')}"
-
- return Evidence(
- content=content[:2000],
- citation=Citation(
- source="preprint" if is_preprint else "europepmc",
- title=title[:500],
- url=url,
- date=str(pub_year),
- authors=authors,
- ),
- relevance=0.75 if is_preprint else 0.9,
- )
diff --git a/src/tools/fallback_web_search.py b/src/tools/fallback_web_search.py
deleted file mode 100644
index 62d2dfd0e7daf98c6a71f7f92825eb1a96d92503..0000000000000000000000000000000000000000
--- a/src/tools/fallback_web_search.py
+++ /dev/null
@@ -1,118 +0,0 @@
-"""Fallback web search tool that tries Serper first, then DuckDuckGo on errors."""
-
-import structlog
-
-from src.tools.serper_web_search import SerperWebSearchTool
-from src.tools.web_search import WebSearchTool
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError, RateLimitError, SearchError
-from src.utils.models import Evidence
-
-logger = structlog.get_logger()
-
-
-class FallbackWebSearchTool:
- """Web search tool that tries Serper first, falls back to DuckDuckGo on any error.
-
- This ensures search always works even if Serper fails due to:
- - Credit exhaustion (403 Forbidden)
- - Rate limiting (429)
- - Network errors
- - Invalid API key
- - Any other errors
- """
-
- def __init__(self) -> None:
- """Initialize fallback web search tool."""
- self._serper_tool: SerperWebSearchTool | None = None
- self._duckduckgo_tool: WebSearchTool | None = None
- self._serper_available = False
-
- # Try to initialize Serper if API key is available
- if settings.serper_api_key:
- try:
- self._serper_tool = SerperWebSearchTool()
- self._serper_available = True
- logger.info("Serper web search initialized for fallback tool")
- except Exception as e:
- logger.warning(
- "Failed to initialize Serper, will use DuckDuckGo only",
- error=str(e),
- )
- self._serper_available = False
-
- # DuckDuckGo is always available as fallback
- self._duckduckgo_tool = WebSearchTool()
- logger.info("DuckDuckGo web search initialized as fallback")
-
- @property
- def name(self) -> str:
- """Return the name of this search tool."""
- return "serper" if self._serper_available else "duckduckgo"
-
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """Execute web search with automatic fallback.
-
- Args:
- query: The search query string
- max_results: Maximum number of results to return
-
- Returns:
- List of Evidence objects from Serper (if successful) or DuckDuckGo (if fallback)
- """
- # Try Serper first if available
- if self._serper_available and self._serper_tool:
- try:
- logger.debug("Attempting Serper search", query=query)
- results = await self._serper_tool.search(query, max_results=max_results)
- logger.info(
- "Serper search successful",
- query=query,
- results_count=len(results),
- )
- return results
- except (ConfigurationError, RateLimitError, SearchError) as e:
- # Serper failed - log and fall back to DuckDuckGo
- logger.warning(
- "Serper search failed, falling back to DuckDuckGo",
- error=str(e),
- error_type=type(e).__name__,
- query=query,
- )
- # Mark Serper as unavailable for future requests (optional optimization)
- # self._serper_available = False
- except Exception as e:
- # Unexpected error from Serper - fall back
- logger.error(
- "Unexpected error in Serper search, falling back to DuckDuckGo",
- error=str(e),
- error_type=type(e).__name__,
- query=query,
- )
-
- # Fall back to DuckDuckGo
- if self._duckduckgo_tool:
- try:
- logger.info("Using DuckDuckGo search", query=query)
- results = await self._duckduckgo_tool.search(query, max_results=max_results)
- logger.info(
- "DuckDuckGo search successful",
- query=query,
- results_count=len(results),
- )
- return results
- except Exception as e:
- logger.error(
- "DuckDuckGo search also failed",
- error=str(e),
- query=query,
- )
- # If even DuckDuckGo fails, return empty list
- return []
-
- # Should never reach here, but just in case
- logger.error("No web search tools available")
- return []
-
-
-
diff --git a/src/tools/neo4j_search.py b/src/tools/neo4j_search.py
deleted file mode 100644
index 0198388961f9f0dba02c813944560b898ce584b7..0000000000000000000000000000000000000000
--- a/src/tools/neo4j_search.py
+++ /dev/null
@@ -1,72 +0,0 @@
-"""Neo4j knowledge graph search tool."""
-
-import structlog
-
-from src.services.neo4j_service import get_neo4j_service
-from src.utils.models import Citation, Evidence
-
-logger = structlog.get_logger()
-
-
-class Neo4jSearchTool:
- """Search Neo4j knowledge graph for papers."""
-
- def __init__(self) -> None:
- self.name = "neo4j" # ✅ Definir explícitamente
-
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """Search Neo4j for papers about diseases in the query."""
- try:
- service = get_neo4j_service()
- if not service:
- logger.warning("Neo4j service not available")
- return []
-
- # Extract disease name from query
- disease = query
- if "for" in query.lower():
- disease = query.split("for")[-1].strip().rstrip("?")
-
- # Query Neo4j
- if not service.driver:
- logger.warning("Neo4j driver not available")
- return []
- with service.driver.session(database=service.database) as session:
- result = session.run(
- """
- MATCH (p:Paper)-[:ABOUT]->(d:Disease)
- WHERE d.name CONTAINS $disease
- RETURN p.title as title, p.abstract as abstract,
- p.url as url, p.source as source
- ORDER BY p.updated_at DESC
- LIMIT $max_results
- """,
- disease=disease,
- max_results=max_results,
- )
-
- records = list(result)
-
- results = []
- for record in records:
- citation = Citation(
- source="neo4j",
- title=record["title"] or "Untitled",
- url=record["url"] or "",
- date="",
- authors=[],
- )
-
- evidence = Evidence(
- content=record["abstract"] or record["title"] or "",
- citation=citation,
- relevance=1.0,
- metadata={"from_kb": True, "original_source": record["source"]},
- )
- results.append(evidence)
-
- logger.info(f"📊 Neo4j returned {len(results)} results")
- return results
- except Exception as e:
- logger.error(f"Neo4j search failed: {e}")
- return []
diff --git a/src/tools/pubmed.py b/src/tools/pubmed.py
deleted file mode 100644
index fe6ed6fa422cea1d80998a52cfe790a2922ad2c0..0000000000000000000000000000000000000000
--- a/src/tools/pubmed.py
+++ /dev/null
@@ -1,207 +0,0 @@
-"""PubMed search tool using NCBI E-utilities."""
-
-from typing import Any
-
-import httpx
-import xmltodict
-from tenacity import retry, stop_after_attempt, wait_exponential
-
-from src.tools.query_utils import preprocess_query
-from src.tools.rate_limiter import get_pubmed_limiter
-from src.utils.config import settings
-from src.utils.exceptions import RateLimitError, SearchError
-from src.utils.models import Citation, Evidence
-
-
-class PubMedTool:
- """Search tool for PubMed/NCBI."""
-
- BASE_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
- HTTP_TOO_MANY_REQUESTS = 429
-
- def __init__(self, api_key: str | None = None) -> None:
- self.api_key = api_key or settings.ncbi_api_key
- # Ignore placeholder values from .env.example
- if self.api_key == "your-ncbi-key-here":
- self.api_key = None
-
- # Use shared rate limiter
- self._limiter = get_pubmed_limiter(self.api_key)
-
- @property
- def name(self) -> str:
- return "pubmed"
-
- async def _rate_limit(self) -> None:
- """Enforce NCBI rate limiting."""
- await self._limiter.acquire()
-
- def _build_params(self, **kwargs: Any) -> dict[str, Any]:
- """Build request params with optional API key."""
- params = {**kwargs, "retmode": "json"}
- if self.api_key:
- params["api_key"] = self.api_key
- return params
-
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=1, max=10),
- reraise=True,
- )
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """
- Search PubMed and return evidence.
-
- 1. ESearch: Get PMIDs matching query
- 2. EFetch: Get abstracts for those PMIDs
- 3. Parse and return Evidence objects
- """
- await self._rate_limit()
-
- # Preprocess query to remove noise and expand synonyms
- clean_query = preprocess_query(query)
- final_query = clean_query if clean_query else query
-
- async with httpx.AsyncClient(timeout=30.0) as client:
- # Step 1: Search for PMIDs
- search_params = self._build_params(
- db="pubmed",
- term=final_query,
- retmax=max_results,
- sort="relevance",
- )
-
- try:
- search_resp = await client.get(
- f"{self.BASE_URL}/esearch.fcgi",
- params=search_params,
- )
- search_resp.raise_for_status()
- except httpx.TimeoutException as e:
- raise SearchError(f"PubMed search timeout: {e}") from e
- except httpx.HTTPStatusError as e:
- if e.response.status_code == self.HTTP_TOO_MANY_REQUESTS:
- raise RateLimitError("PubMed rate limit exceeded") from e
- raise SearchError(f"PubMed search failed: {e}") from e
-
- search_data = search_resp.json()
- pmids = search_data.get("esearchresult", {}).get("idlist", [])
-
- if not pmids:
- return []
-
- # Step 2: Fetch abstracts
- await self._rate_limit()
- fetch_params = self._build_params(
- db="pubmed",
- id=",".join(pmids),
- rettype="abstract",
- )
- # Use XML for fetch (more reliable parsing)
- fetch_params["retmode"] = "xml"
-
- try:
- fetch_resp = await client.get(
- f"{self.BASE_URL}/efetch.fcgi",
- params=fetch_params,
- )
- fetch_resp.raise_for_status()
- except httpx.TimeoutException as e:
- raise SearchError(f"PubMed fetch timeout: {e}") from e
-
- # Step 3: Parse XML to Evidence
- return self._parse_pubmed_xml(fetch_resp.text)
-
- def _parse_pubmed_xml(self, xml_text: str) -> list[Evidence]:
- """Parse PubMed XML into Evidence objects."""
- try:
- data = xmltodict.parse(xml_text)
- except Exception as e:
- raise SearchError(f"Failed to parse PubMed XML: {e}") from e
-
- if data is None:
- return []
-
- # Handle case where PubmedArticleSet might not exist or be empty
- pubmed_set = data.get("PubmedArticleSet")
- if not pubmed_set:
- return []
-
- articles = pubmed_set.get("PubmedArticle", [])
-
- # Handle single article (xmltodict returns dict instead of list)
- if isinstance(articles, dict):
- articles = [articles]
-
- evidence_list = []
- for article in articles:
- try:
- evidence = self._article_to_evidence(article)
- if evidence:
- evidence_list.append(evidence)
- except Exception:
- continue # Skip malformed articles
-
- return evidence_list
-
- def _article_to_evidence(self, article: dict[str, Any]) -> Evidence | None:
- """Convert a single PubMed article to Evidence."""
- medline = article.get("MedlineCitation", {})
- article_data = medline.get("Article", {})
-
- # Extract PMID
- pmid = medline.get("PMID", {})
- if isinstance(pmid, dict):
- pmid = pmid.get("#text", "")
-
- # Extract title
- title = article_data.get("ArticleTitle", "")
- if isinstance(title, dict):
- title = title.get("#text", str(title))
-
- # Extract abstract
- abstract_data = article_data.get("Abstract", {}).get("AbstractText", "")
- if isinstance(abstract_data, list):
- abstract = " ".join(
- item.get("#text", str(item)) if isinstance(item, dict) else str(item)
- for item in abstract_data
- )
- elif isinstance(abstract_data, dict):
- abstract = abstract_data.get("#text", str(abstract_data))
- else:
- abstract = str(abstract_data)
-
- if not abstract or not title:
- return None
-
- # Extract date
- pub_date = article_data.get("Journal", {}).get("JournalIssue", {}).get("PubDate", {})
- year = pub_date.get("Year", "Unknown")
- month = pub_date.get("Month", "01")
- day = pub_date.get("Day", "01")
- date_str = f"{year}-{month}-{day}" if year != "Unknown" else "Unknown"
-
- # Extract authors
- author_list = article_data.get("AuthorList", {}).get("Author", [])
- if isinstance(author_list, dict):
- author_list = [author_list]
- authors = []
- for author in author_list[:5]: # Limit to 5 authors
- last = author.get("LastName", "")
- first = author.get("ForeName", "")
- if last:
- authors.append(f"{last} {first}".strip())
-
- # Truncation rationale: LLM context limits + cost optimization
- # - Abstract: 2000 chars (~500 tokens) captures key findings
- # - Title: 500 chars covers even verbose journal titles
- return Evidence(
- content=abstract[:2000],
- citation=Citation(
- source="pubmed",
- title=title[:500],
- url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/",
- date=date_str,
- authors=authors,
- ),
- )
diff --git a/src/tools/query_utils.py b/src/tools/query_utils.py
deleted file mode 100644
index c52507f226400cfe8adef3b4dce90d3f70a8f5a5..0000000000000000000000000000000000000000
--- a/src/tools/query_utils.py
+++ /dev/null
@@ -1,218 +0,0 @@
-"""Query preprocessing utilities for biomedical search."""
-
-import re
-
-# Question words and filler words to remove
-QUESTION_WORDS: set[str] = {
- # Question starters
- "what",
- "which",
- "how",
- "why",
- "when",
- "where",
- "who",
- "whom",
- # Auxiliary verbs in questions
- "is",
- "are",
- "was",
- "were",
- "do",
- "does",
- "did",
- "can",
- "could",
- "would",
- "should",
- "will",
- "shall",
- "may",
- "might",
- # Filler words in natural questions
- "show",
- "promise",
- "help",
- "believe",
- "think",
- "suggest",
- "possible",
- "potential",
- "effective",
- "useful",
- "good",
- # Articles (remove but less aggressively)
- "the",
- "a",
- "an",
-}
-
-# Medical synonym expansions
-SYNONYMS: dict[str, list[str]] = {
- "long covid": [
- "long COVID",
- "PASC",
- "post-acute sequelae of SARS-CoV-2",
- "post-COVID syndrome",
- "post-COVID-19 condition",
- ],
- "alzheimer": [
- "Alzheimer's disease",
- "Alzheimer disease",
- "AD",
- "Alzheimer dementia",
- ],
- "parkinson": [
- "Parkinson's disease",
- "Parkinson disease",
- "PD",
- ],
- "diabetes": [
- "diabetes mellitus",
- "type 2 diabetes",
- "T2DM",
- "diabetic",
- ],
- "cancer": [
- "cancer",
- "neoplasm",
- "tumor",
- "malignancy",
- "carcinoma",
- ],
- "heart disease": [
- "cardiovascular disease",
- "CVD",
- "coronary artery disease",
- "heart failure",
- ],
-}
-
-
-def strip_question_words(query: str) -> str:
- """
- Remove question words and filler terms from query.
-
- Args:
- query: Raw query string
-
- Returns:
- Query with question words removed
- """
- words = query.lower().split()
- filtered = [w for w in words if w not in QUESTION_WORDS]
- return " ".join(filtered)
-
-
-def expand_synonyms(query: str) -> str:
- """
- Expand medical terms to include synonyms.
-
- Args:
- query: Query string
-
- Returns:
- Query with synonym expansions in OR groups
- """
- result = query.lower()
-
- for term, expansions in SYNONYMS.items():
- if term in result:
- # Create OR group: ("term1" OR "term2" OR "term3")
- or_group = " OR ".join([f'"{exp}"' for exp in expansions])
- # Case insensitive replacement is tricky with simple replace
- # But we lowercased result already.
- # However, this replaces ALL instances.
- # Also, result is lowercased, so we lose original casing if any.
- # But search engines are usually case-insensitive.
- result = result.replace(term, f"({or_group})")
-
- return result
-
-
-def preprocess_query(raw_query: str) -> str:
- """
- Full preprocessing pipeline for PubMed queries.
-
- Pipeline:
- 1. Strip whitespace and punctuation
- 2. Remove question words
- 3. Expand medical synonyms
-
- Args:
- raw_query: Natural language query from user
-
- Returns:
- Optimized query for PubMed
- """
- if not raw_query or not raw_query.strip():
- return ""
-
- # Remove question marks and extra whitespace
- query = raw_query.replace("?", "").strip()
- query = re.sub(r"\s+", " ", query)
-
- # Strip question words
- query = strip_question_words(query)
-
- # Expand synonyms
- query = expand_synonyms(query)
-
- return query.strip()
-
-
-def preprocess_web_query(raw_query: str) -> str:
- """
- Simplified preprocessing pipeline for web search engines (Serper, DuckDuckGo, etc.).
-
- Web search engines work better with natural language queries rather than
- complex boolean syntax. This function:
- 1. Strips whitespace and punctuation
- 2. Removes question words (less aggressively)
- 3. Removes complex boolean syntax (OR groups, parentheses)
- 4. Uses primary synonym terms instead of expanding to OR groups
-
- Args:
- raw_query: Natural language query from user
-
- Returns:
- Simplified query optimized for web search engines
- """
- if not raw_query or not raw_query.strip():
- return ""
-
- # Remove question marks and extra whitespace
- query = raw_query.replace("?", "").strip()
- query = re.sub(r"\s+", " ", query)
-
- # Remove complex boolean syntax that might have been added
- # Remove OR groups like: ("term1" OR "term2" OR "term3")
- query = re.sub(r'\([^)]*OR[^)]*\)', '', query, flags=re.IGNORECASE)
- # Remove standalone OR statements
- query = re.sub(r'\s+OR\s+', ' ', query, flags=re.IGNORECASE)
- # Remove extra parentheses
- query = re.sub(r'[()]', '', query)
- # Remove extra quotes that might be left
- query = re.sub(r'"([^"]*)"', r'\1', query)
- # Clean up multiple spaces
- query = re.sub(r'\s+', ' ', query)
-
- # Strip question words (less aggressively - keep important context words)
- # Only remove very common question starters
- minimal_question_words = {"what", "which", "how", "why", "when", "where", "who"}
- words = query.split()
- filtered = [w for w in words if w.lower() not in minimal_question_words]
- query = " ".join(filtered)
-
- # Replace known medical terms with their primary/common form
- # Use the first synonym (most common) instead of OR groups
- query_lower = query.lower()
- for term, expansions in SYNONYMS.items():
- if term in query_lower:
- # Use the first expansion (usually the most common term)
- primary_term = expansions[0] if expansions else term
- # Replace case-insensitively, preserving original case where possible
- pattern = re.compile(re.escape(term), re.IGNORECASE)
- query = pattern.sub(primary_term, query)
-
- return query.strip()
\ No newline at end of file
diff --git a/src/tools/rag_tool.py b/src/tools/rag_tool.py
deleted file mode 100644
index 230405db4134ab99b557c6a2f839b9384b628bdf..0000000000000000000000000000000000000000
--- a/src/tools/rag_tool.py
+++ /dev/null
@@ -1,200 +0,0 @@
-"""RAG tool for semantic search within collected evidence.
-
-Implements SearchTool protocol to enable RAG as a search option in the research workflow.
-"""
-
-from typing import TYPE_CHECKING, Any
-
-import structlog
-
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import Citation, Evidence, SourceName
-
-if TYPE_CHECKING:
- from src.services.llamaindex_rag import LlamaIndexRAGService
-
-logger = structlog.get_logger()
-
-
-class RAGTool:
- """Search tool that uses LlamaIndex RAG for semantic search within collected evidence.
-
- Wraps LlamaIndexRAGService to implement the SearchTool protocol.
- Returns Evidence objects from RAG retrieval results.
- """
-
- def __init__(
- self,
- rag_service: "LlamaIndexRAGService | None" = None,
- oauth_token: str | None = None,
- ) -> None:
- """
- Initialize RAG tool.
-
- Args:
- rag_service: Optional RAG service instance. If None, will be lazy-initialized.
- oauth_token: Optional OAuth token from HuggingFace login (for RAG LLM)
- """
- self._rag_service = rag_service
- self.oauth_token = oauth_token
- self.logger = logger
-
- @property
- def name(self) -> str:
- """Return the tool name."""
- return "rag"
-
- def _get_rag_service(self) -> "LlamaIndexRAGService":
- """
- Get or create RAG service instance.
-
- Returns:
- LlamaIndexRAGService instance
-
- Raises:
- ConfigurationError: If RAG service cannot be initialized
- """
- if self._rag_service is None:
- try:
- from src.services.llamaindex_rag import get_rag_service
-
- # Use local embeddings by default (no API key required)
- # Use in-memory ChromaDB to avoid file system issues
- # Pass OAuth token for LLM query synthesis
- self._rag_service = get_rag_service(
- use_openai_embeddings=False,
- use_in_memory=True, # Use in-memory for better reliability
- oauth_token=self.oauth_token,
- )
- self.logger.info("RAG service initialized with local embeddings")
- except (ConfigurationError, ImportError) as e:
- self.logger.error("Failed to initialize RAG service", error=str(e))
- raise ConfigurationError(
- "RAG service unavailable. Check LlamaIndex dependencies are installed."
- ) from e
-
- return self._rag_service
-
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """
- Search RAG system and return evidence.
-
- Args:
- query: The search query string
- max_results: Maximum number of results to return
-
- Returns:
- List of Evidence objects from RAG retrieval
-
- Note:
- Returns empty list on error (does not raise exceptions).
- """
- try:
- rag_service = self._get_rag_service()
- except ConfigurationError:
- self.logger.warning("RAG service unavailable, returning empty results")
- return []
-
- try:
- # Retrieve documents from RAG
- retrieved_docs = rag_service.retrieve(query, top_k=max_results)
-
- if not retrieved_docs:
- self.logger.info("No RAG results found", query=query[:50])
- return []
-
- # Convert retrieved documents to Evidence objects
- evidence_list: list[Evidence] = []
- for doc in retrieved_docs:
- try:
- evidence = self._doc_to_evidence(doc)
- evidence_list.append(evidence)
- except Exception as e:
- self.logger.warning(
- "Failed to convert document to evidence",
- error=str(e),
- doc_text=doc.get("text", "")[:50],
- )
- continue
-
- self.logger.info(
- "RAG search completed",
- query=query[:50],
- results=len(evidence_list),
- )
- return evidence_list
-
- except Exception as e:
- self.logger.error("RAG search failed", error=str(e), query=query[:50])
- # Return empty list on error (graceful degradation)
- return []
-
- def _doc_to_evidence(self, doc: dict[str, Any]) -> Evidence:
- """
- Convert RAG document to Evidence object.
-
- Args:
- doc: Document dict with keys: text, score, metadata
-
- Returns:
- Evidence object
-
- Raises:
- ValueError: If document is missing required fields
- """
- text = doc.get("text", "")
- if not text:
- raise ValueError("Document missing text content")
-
- metadata = doc.get("metadata", {})
- score = doc.get("score", 0.0)
-
- # Extract citation information from metadata
- source: SourceName = "rag" # RAG is the source
- title = metadata.get("title", "Untitled")
- url = metadata.get("url", "")
- date = metadata.get("date", "Unknown")
- authors_str = metadata.get("authors", "")
- authors = [a.strip() for a in authors_str.split(",") if a.strip()] if authors_str else []
-
- # Create citation
- citation = Citation(
- source=source,
- title=title[:500], # Enforce max length
- url=url,
- date=date,
- authors=authors,
- )
-
- # Create evidence with relevance score (normalize score to 0-1 if needed)
- relevance = min(max(float(score), 0.0), 1.0) if score else 0.0
-
- return Evidence(
- content=text,
- citation=citation,
- relevance=relevance,
- )
-
-
-def create_rag_tool(
- rag_service: "LlamaIndexRAGService | None" = None,
- oauth_token: str | None = None,
-) -> RAGTool:
- """
- Factory function to create a RAG tool.
-
- Args:
- rag_service: Optional RAG service instance. If None, will be lazy-initialized.
- oauth_token: Optional OAuth token from HuggingFace login (for RAG LLM)
-
- Returns:
- Configured RAGTool instance
-
- Raises:
- ConfigurationError: If RAG service cannot be initialized and rag_service is None
- """
- try:
- return RAGTool(rag_service=rag_service, oauth_token=oauth_token)
- except Exception as e:
- logger.error("Failed to create RAG tool", error=str(e))
- raise ConfigurationError(f"Failed to create RAG tool: {e}") from e
diff --git a/src/tools/rate_limiter.py b/src/tools/rate_limiter.py
deleted file mode 100644
index 48cf5385cc04dcfc3d7d7e938813aa24887254ef..0000000000000000000000000000000000000000
--- a/src/tools/rate_limiter.py
+++ /dev/null
@@ -1,165 +0,0 @@
-"""Rate limiting utilities using the limits library."""
-
-import asyncio
-import random
-from typing import ClassVar
-
-from limits import RateLimitItem, parse
-from limits.storage import MemoryStorage
-from limits.strategies import MovingWindowRateLimiter
-
-
-class RateLimiter:
- """
- Async-compatible rate limiter using limits library.
-
- Uses moving window algorithm for smooth rate limiting.
- """
-
- def __init__(self, rate: str) -> None:
- """
- Initialize rate limiter.
-
- Args:
- rate: Rate string like "3/second" or "10/second"
- """
- self.rate = rate
- self._storage = MemoryStorage()
- self._limiter = MovingWindowRateLimiter(self._storage)
- self._rate_limit: RateLimitItem = parse(rate)
- self._identity = "default" # Single identity for shared limiting
-
- async def acquire(self, wait: bool = True, jitter: bool = False) -> bool:
- """
- Acquire permission to make a request.
-
- ASYNC-SAFE: Uses asyncio.sleep(), never time.sleep().
- The polling pattern allows other coroutines to run while waiting.
-
- Args:
- wait: If True, wait until allowed. If False, return immediately.
- jitter: If True, add random jitter (0-20% of wait time) to avoid thundering herd
-
- Returns:
- True if allowed, False if not (only when wait=False)
- """
- while True:
- # Check if we can proceed (synchronous, fast - ~microseconds)
- if self._limiter.hit(self._rate_limit, self._identity):
- # Add jitter after acquiring to spread out requests
- if jitter:
- # Add 0-1 second jitter to spread requests slightly
- # This prevents thundering herd without long delays
- jitter_seconds = random.uniform(0, 1.0)
- await asyncio.sleep(jitter_seconds)
- return True
-
- if not wait:
- return False
-
- # CRITICAL: Use asyncio.sleep(), NOT time.sleep()
- # This yields control to the event loop, allowing other
- # coroutines (UI, parallel searches) to run.
- # Using 0.01s for fine-grained responsiveness.
- await asyncio.sleep(0.01)
-
- def reset(self) -> None:
- """Reset the rate limiter (for testing)."""
- self._storage.reset()
-
-
-# Singleton limiter for PubMed/NCBI
-_pubmed_limiter: RateLimiter | None = None
-
-
-def get_pubmed_limiter(api_key: str | None = None) -> RateLimiter:
- """
- Get the shared PubMed rate limiter.
-
- Rate depends on whether API key is provided:
- - Without key: 3 requests/second
- - With key: 10 requests/second
-
- Args:
- api_key: NCBI API key (optional)
-
- Returns:
- Shared RateLimiter instance
- """
- global _pubmed_limiter
-
- if _pubmed_limiter is None:
- rate = "10/second" if api_key else "3/second"
- _pubmed_limiter = RateLimiter(rate)
-
- return _pubmed_limiter
-
-
-def reset_pubmed_limiter() -> None:
- """Reset the PubMed limiter (for testing)."""
- global _pubmed_limiter
- _pubmed_limiter = None
-
-
-def get_serper_limiter(api_key: str | None = None) -> RateLimiter:
- """
- Get the shared Serper API rate limiter.
-
- Rate: 100 requests/second (Serper free tier limit)
-
- Serper free tier provides:
- - 2,500 credits (one-time, expire after 6 months)
- - 100 requests/second rate limit
- - Credits only deduct for successful responses
-
- We use a slightly conservative rate (90/second) to stay safely under the limit
- while allowing high throughput when needed.
-
- Args:
- api_key: Serper API key (optional, for consistency with other limiters)
-
- Returns:
- Shared RateLimiter instance
- """
- # Use 90/second to stay safely under 100/second limit
- return RateLimiterFactory.get("serper", "90/second")
-
-
-def get_searchxng_limiter() -> RateLimiter:
- """
- Get the shared SearchXNG API rate limiter.
-
- Rate: 5 requests/second (conservative limit)
-
- Returns:
- Shared RateLimiter instance
- """
- return RateLimiterFactory.get("searchxng", "5/second")
-
-
-# Factory for other APIs
-class RateLimiterFactory:
- """Factory for creating/getting rate limiters for different APIs."""
-
- _limiters: ClassVar[dict[str, RateLimiter]] = {}
-
- @classmethod
- def get(cls, api_name: str, rate: str) -> RateLimiter:
- """
- Get or create a rate limiter for an API.
-
- Args:
- api_name: Unique identifier for the API
- rate: Rate limit string (e.g., "10/second")
-
- Returns:
- RateLimiter instance (shared for same api_name)
- """
- if api_name not in cls._limiters:
- cls._limiters[api_name] = RateLimiter(rate)
- return cls._limiters[api_name]
-
- @classmethod
- def reset_all(cls) -> None:
- """Reset all limiters (for testing)."""
- cls._limiters.clear()
diff --git a/src/tools/search_handler.py b/src/tools/search_handler.py
deleted file mode 100644
index 9d2b4adf7d00001af74dd2c4614ae92459c74b1d..0000000000000000000000000000000000000000
--- a/src/tools/search_handler.py
+++ /dev/null
@@ -1,222 +0,0 @@
-"""Search handler - orchestrates multiple search tools."""
-
-import asyncio
-from typing import TYPE_CHECKING, cast
-
-import structlog
-
-from src.services.neo4j_service import get_neo4j_service
-from src.tools.base import SearchTool
-from src.tools.rag_tool import create_rag_tool
-from src.utils.exceptions import ConfigurationError, SearchError
-from src.utils.models import Evidence, SearchResult, SourceName
-
-if TYPE_CHECKING:
- from src.services.llamaindex_rag import LlamaIndexRAGService
-else:
- LlamaIndexRAGService = object
-
-logger = structlog.get_logger()
-
-
-class SearchHandler:
- """Orchestrates parallel searches across multiple tools."""
-
- def __init__(
- self,
- tools: list[SearchTool],
- timeout: float = 30.0,
- include_rag: bool = False,
- auto_ingest_to_rag: bool = True,
- oauth_token: str | None = None,
- ) -> None:
- """
- Initialize the search handler.
-
- Args:
- tools: List of search tools to use
- timeout: Timeout for each search in seconds
- include_rag: Whether to include RAG tool in searches
- auto_ingest_to_rag: Whether to automatically ingest results into RAG
- oauth_token: Optional OAuth token from HuggingFace login (for RAG LLM)
- """
- self.tools = list(tools) # Make a copy
- self.timeout = timeout
- self.auto_ingest_to_rag = auto_ingest_to_rag
- self.oauth_token = oauth_token
- self._rag_service: LlamaIndexRAGService | None = None
-
- if include_rag:
- self.add_rag_tool()
-
- def add_rag_tool(self) -> None:
- """Add RAG tool to the tools list if available."""
- try:
- rag_tool = create_rag_tool(oauth_token=self.oauth_token)
- self.tools.append(rag_tool)
- logger.info("RAG tool added to search handler")
- except ConfigurationError:
- logger.warning(
- "RAG tool unavailable, not adding to search handler",
- hint="LlamaIndex dependencies required",
- )
- except Exception as e:
- logger.error("Failed to add RAG tool", error=str(e))
-
- def _get_rag_service(self) -> "LlamaIndexRAGService | None":
- """Get or create RAG service for ingestion."""
- if self._rag_service is None and self.auto_ingest_to_rag:
- try:
- from src.services.llamaindex_rag import get_rag_service
-
- # Use local embeddings by default (no API key required)
- # Use in-memory ChromaDB to avoid file system issues
- # Pass OAuth token for LLM query synthesis
- self._rag_service = get_rag_service(
- use_openai_embeddings=False,
- use_in_memory=True, # Use in-memory for better reliability
- oauth_token=self.oauth_token,
- )
- logger.info("RAG service initialized for ingestion with local embeddings")
- except (ConfigurationError, ImportError):
- logger.warning("RAG service unavailable for ingestion")
- return None
- return self._rag_service
-
- async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
- """
- Execute search across all tools in parallel.
-
- Args:
- query: The search query
- max_results_per_tool: Max results from each tool
-
- Returns:
- SearchResult containing all evidence and metadata
- """
- logger.info("Starting search", query=query, tools=[t.name for t in self.tools])
-
- # Create tasks for parallel execution
- tasks = [
- self._search_with_timeout(tool, query, max_results_per_tool) for tool in self.tools
- ]
-
- # Gather results (don't fail if one tool fails)
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- # Process results
- all_evidence: list[Evidence] = []
- sources_searched: list[SourceName] = []
- errors: list[str] = []
-
- # Map tool names to SourceName values
- # Some tools have internal names that differ from SourceName literals
- tool_name_to_source: dict[str, SourceName] = {
- "duckduckgo": "web",
- "serper": "web", # Serper uses Google search but maps to "web" source
- "searchxng": "web", # SearchXNG also maps to "web" source
- "pubmed": "pubmed",
- "clinicaltrials": "clinicaltrials",
- "europepmc": "europepmc",
- "neo4j": "neo4j",
- "rag": "rag",
- "web": "web", # In case tool already uses "web"
- }
-
- for tool, result in zip(self.tools, results, strict=True):
- if isinstance(result, Exception):
- errors.append(f"{tool.name}: {result!s}")
- logger.warning("Search tool failed", tool=tool.name, error=str(result))
- else:
- # Cast result to list[Evidence] as we know it succeeded
- success_result = cast(list[Evidence], result)
- all_evidence.extend(success_result)
-
- # Map tool.name to SourceName (handle tool names that don't match SourceName literals)
- tool_name = tool_name_to_source.get(tool.name, cast(SourceName, tool.name))
- if tool_name not in [
- "pubmed",
- "clinicaltrials",
- "biorxiv",
- "europepmc",
- "preprint",
- "rag",
- "web",
- "neo4j",
- ]:
- logger.warning(
- "Tool name not in SourceName literals, defaulting to 'web'",
- tool_name=tool.name,
- )
- tool_name = "web"
- sources_searched.append(tool_name)
- logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
-
- search_result = SearchResult(
- query=query,
- evidence=all_evidence,
- sources_searched=sources_searched,
- total_found=len(all_evidence),
- errors=errors,
- )
-
- # Ingest evidence into RAG if enabled and available
- if self.auto_ingest_to_rag and all_evidence:
- rag_service = self._get_rag_service()
- if rag_service:
- try:
- # Filter out RAG-sourced evidence (avoid circular ingestion)
- evidence_to_ingest = [e for e in all_evidence if e.citation.source != "rag"]
- if evidence_to_ingest:
- rag_service.ingest_evidence(evidence_to_ingest)
- logger.info(
- "Ingested evidence into RAG",
- count=len(evidence_to_ingest),
- )
- except Exception as e:
- logger.warning("Failed to ingest evidence into RAG", error=str(e))
-
- # 🔥 INGEST INTO NEO4J KNOWLEDGE GRAPH 🔥
- if all_evidence:
- try:
- neo4j_service = get_neo4j_service()
- if neo4j_service:
- # Extract disease from query
- disease = query
- if "for" in query.lower():
- disease = query.split("for")[-1].strip().rstrip("?")
-
- # Convert Evidence objects to dicts for Neo4j
- papers = []
- for ev in all_evidence:
- papers.append(
- {
- "id": ev.citation.url or "",
- "title": ev.citation.title or "",
- "abstract": ev.content,
- "url": ev.citation.url or "",
- "source": ev.citation.source,
- }
- )
-
- stats = neo4j_service.ingest_search_results(disease, papers)
- logger.info("💾 Saved to Neo4j", stats=stats)
- except Exception as e:
- logger.warning("Neo4j ingestion failed", error=str(e))
-
- return search_result
-
- async def _search_with_timeout(
- self,
- tool: SearchTool,
- query: str,
- max_results: int,
- ) -> list[Evidence]:
- """Execute a single tool search with timeout."""
- try:
- return await asyncio.wait_for(
- tool.search(query, max_results),
- timeout=self.timeout,
- )
- except TimeoutError as e:
- raise SearchError(f"{tool.name} search timed out after {self.timeout}s") from e
diff --git a/src/tools/searchxng_web_search.py b/src/tools/searchxng_web_search.py
deleted file mode 100644
index e120b61f72b4e28e780ad43a5673f49f5d27ef58..0000000000000000000000000000000000000000
--- a/src/tools/searchxng_web_search.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""SearchXNG web search tool using SearchXNG API for Google searches."""
-
-import structlog
-from tenacity import retry, stop_after_attempt, wait_exponential
-
-from src.tools.query_utils import preprocess_query
-from src.tools.rate_limiter import get_searchxng_limiter
-from src.tools.vendored.searchxng_client import SearchXNGClient
-from src.tools.vendored.web_search_core import scrape_urls
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError, RateLimitError, SearchError
-from src.utils.models import Citation, Evidence
-
-logger = structlog.get_logger()
-
-
-class SearchXNGWebSearchTool:
- """Tool for searching the web using SearchXNG API (Google search)."""
-
- def __init__(self, host: str | None = None) -> None:
- """Initialize SearchXNG web search tool.
-
- Args:
- host: SearchXNG host URL. If None, reads from settings.
-
- Raises:
- ConfigurationError: If no host is available.
- """
- self.host = host or settings.searchxng_host
- if not self.host:
- raise ConfigurationError(
- "SearchXNG host required. Set SEARCHXNG_HOST environment variable or searchxng_host in settings."
- )
-
- self._client = SearchXNGClient(host=self.host)
- self._limiter = get_searchxng_limiter()
-
- @property
- def name(self) -> str:
- """Return the name of this search tool."""
- return "searchxng"
-
- async def _rate_limit(self) -> None:
- """Enforce SearchXNG API rate limiting."""
- await self._limiter.acquire()
-
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=1, max=10),
- reraise=True,
- )
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """Execute a web search using SearchXNG API.
-
- Args:
- query: The search query string
- max_results: Maximum number of results to return
-
- Returns:
- List of Evidence objects
-
- Raises:
- SearchError: If the search fails
- RateLimitError: If rate limit is exceeded
- """
- await self._rate_limit()
-
- # Preprocess query to remove noise
- clean_query = preprocess_query(query)
- final_query = clean_query if clean_query else query
-
- try:
- # Get search results (snippets)
- search_results = await self._client.search(
- final_query, filter_for_relevance=False, max_results=max_results
- )
-
- if not search_results:
- logger.info("No search results found", query=final_query)
- return []
-
- # Scrape URLs to get full content
- scraped = await scrape_urls(search_results)
-
- # Convert ScrapeResult to Evidence objects
- evidence = []
- for result in scraped:
- # Truncate title to max 500 characters to match Citation model validation
- title = result.title
- if len(title) > 500:
- title = title[:497] + "..."
-
- ev = Evidence(
- content=result.text,
- citation=Citation(
- title=title,
- url=result.url,
- source="web", # Use "web" to match SourceName literal, not "searchxng"
- date="Unknown",
- authors=[],
- ),
- relevance=0.0,
- )
- evidence.append(ev)
-
- logger.info(
- "SearchXNG search complete",
- query=final_query,
- results_found=len(evidence),
- )
-
- return evidence
-
- except RateLimitError:
- raise
- except SearchError:
- raise
- except Exception as e:
- logger.error("Unexpected error in SearchXNG search", error=str(e), query=final_query)
- raise SearchError(f"SearchXNG search failed: {e}") from e
diff --git a/src/tools/serper_web_search.py b/src/tools/serper_web_search.py
deleted file mode 100644
index d5b6bdb350ca7d6cae834b77b0cb57f043833689..0000000000000000000000000000000000000000
--- a/src/tools/serper_web_search.py
+++ /dev/null
@@ -1,144 +0,0 @@
-"""Serper web search tool using Serper API for Google searches."""
-
-import structlog
-from tenacity import (
- retry,
- retry_if_not_exception_type,
- stop_after_attempt,
- wait_random_exponential,
-)
-
-from src.tools.query_utils import preprocess_web_query
-from src.tools.rate_limiter import get_serper_limiter
-from src.tools.vendored.serper_client import SerperClient
-from src.tools.vendored.web_search_core import scrape_urls
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError, RateLimitError, SearchError
-from src.utils.models import Citation, Evidence
-
-logger = structlog.get_logger()
-
-
-class SerperWebSearchTool:
- """Tool for searching the web using Serper API (Google search)."""
-
- def __init__(self, api_key: str | None = None) -> None:
- """Initialize Serper web search tool.
-
- Args:
- api_key: Serper API key. If None, reads from settings.
-
- Raises:
- ConfigurationError: If no API key is available.
- """
- self.api_key = api_key or settings.serper_api_key
- if not self.api_key:
- raise ConfigurationError(
- "Serper API key required. Set SERPER_API_KEY environment variable or serper_api_key in settings."
- )
-
- self._client = SerperClient(api_key=self.api_key)
- self._limiter = get_serper_limiter(self.api_key)
-
- # Validate API key format (basic check)
- if self.api_key and len(self.api_key.strip()) < 10:
- logger.warning(
- "Serper API key appears to be too short",
- key_length=len(self.api_key),
- hint="Verify SERPER_API_KEY is correct",
- )
-
- @property
- def name(self) -> str:
- """Return the name of this search tool."""
- return "serper"
-
- async def _rate_limit(self) -> None:
- """Enforce Serper API rate limiting with jitter.
-
- Uses jitter to spread out requests and avoid thundering herd problems.
- Rate limit is 100 requests/second for free tier, we use 90/second to stay safe.
- """
- await self._limiter.acquire(jitter=True)
-
- @retry(
- stop=stop_after_attempt(3), # Reduced retries for faster fallback
- wait=wait_random_exponential(
- multiplier=1, min=2, max=10, exp_base=2
- ), # 2s to 10s backoff with jitter (faster for fallback)
- reraise=True,
- retry=retry_if_not_exception_type(ConfigurationError), # Don't retry on config errors
- )
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """Execute a web search using Serper API.
-
- Args:
- query: The search query string
- max_results: Maximum number of results to return
-
- Returns:
- List of Evidence objects
-
- Raises:
- SearchError: If the search fails
- RateLimitError: If rate limit is exceeded
- ConfigurationError: If API key is invalid (403 Forbidden)
- """
- await self._rate_limit()
-
- # Preprocess query for web search (simplified, no boolean syntax)
- clean_query = preprocess_web_query(query)
- final_query = clean_query if clean_query else query
-
- try:
- # Get search results (snippets)
- search_results = await self._client.search(
- final_query, filter_for_relevance=False, max_results=max_results
- )
-
- if not search_results:
- logger.info("No search results found", query=final_query)
- return []
-
- # Scrape URLs to get full content
- scraped = await scrape_urls(search_results)
-
- # Convert ScrapeResult to Evidence objects
- evidence = []
- for result in scraped:
- # Truncate title to max 500 characters to match Citation model validation
- title = result.title
- if len(title) > 500:
- title = title[:497] + "..."
-
- ev = Evidence(
- content=result.text,
- citation=Citation(
- title=title,
- url=result.url,
- source="web", # Use "web" to match SourceName literal, not "serper"
- date="Unknown",
- authors=[],
- ),
- relevance=0.0,
- )
- evidence.append(ev)
-
- logger.info(
- "Serper search complete",
- query=final_query,
- results_found=len(evidence),
- )
-
- return evidence
-
- except ConfigurationError:
- # Don't retry configuration errors (e.g., 403 Forbidden = invalid API key)
- raise
- except RateLimitError:
- raise
- except SearchError:
- raise
- except Exception as e:
- logger.error("Unexpected error in Serper search", error=str(e), query=final_query)
- raise SearchError(f"Serper search failed: {e}") from e
diff --git a/src/tools/tool_executor.py b/src/tools/tool_executor.py
deleted file mode 100644
index e499960726ae4f0044e863fbc0f30a4a4afd0e9e..0000000000000000000000000000000000000000
--- a/src/tools/tool_executor.py
+++ /dev/null
@@ -1,193 +0,0 @@
-"""Tool executor for running AgentTask objects.
-
-Executes tool tasks selected by the tool selector agent and returns ToolAgentOutput.
-"""
-
-import structlog
-
-from src.tools.crawl_adapter import crawl_website
-from src.tools.rag_tool import RAGTool, create_rag_tool
-from src.tools.web_search_adapter import web_search
-from src.utils.exceptions import ConfigurationError
-from src.utils.models import AgentTask, Evidence, ToolAgentOutput
-
-logger = structlog.get_logger()
-
-# Module-level RAG tool instance (lazy initialization)
-_rag_tool: RAGTool | None = None
-
-
-def _get_rag_tool() -> RAGTool | None:
- """
- Get or create RAG tool instance.
-
- Returns:
- RAGTool instance, or None if unavailable
- """
- global _rag_tool
- if _rag_tool is None:
- try:
- _rag_tool = create_rag_tool()
- logger.info("RAG tool initialized")
- except ConfigurationError:
- logger.warning("RAG tool unavailable (OPENAI_API_KEY required)")
- return None
- except Exception as e:
- logger.error("Failed to initialize RAG tool", error=str(e))
- return None
- return _rag_tool
-
-
-def _evidence_to_text(evidence_list: list[Evidence]) -> str:
- """
- Convert Evidence objects to formatted text.
-
- Args:
- evidence_list: List of Evidence objects
-
- Returns:
- Formatted text string with citations and content
- """
- if not evidence_list:
- return "No evidence found."
-
- formatted_parts = []
- for i, evidence in enumerate(evidence_list, 1):
- citation = evidence.citation
- citation_str = f"{citation.formatted}"
- if citation.url:
- citation_str += f" [{citation.url}]"
-
- formatted_parts.append(f"[{i}] {citation_str}\n\n{evidence.content}\n\n---\n")
-
- return "\n".join(formatted_parts)
-
-
-async def execute_agent_task(task: AgentTask) -> ToolAgentOutput:
- """
- Execute a single agent task and return ToolAgentOutput.
-
- Args:
- task: AgentTask specifying which tool to use and what query to run
-
- Returns:
- ToolAgentOutput with results and source URLs
- """
- logger.info(
- "Executing agent task",
- agent=task.agent,
- query=task.query[:100] if task.query else "",
- gap=task.gap[:100] if task.gap else "",
- )
-
- try:
- if task.agent == "WebSearchAgent":
- # Use web search adapter
- result_text = await web_search(task.query)
- # Extract URLs from result (simple heuristic - look for http/https)
- import re
-
- urls = re.findall(r"https?://[^\s\)]+", result_text)
- sources = list(set(urls)) # Deduplicate
-
- return ToolAgentOutput(output=result_text, sources=sources)
-
- elif task.agent == "SiteCrawlerAgent":
- # Use crawl adapter
- if task.entity_website:
- starting_url = task.entity_website
- elif task.query.startswith(("http://", "https://")):
- starting_url = task.query
- else:
- # Try to construct URL from query
- starting_url = f"https://{task.query}"
-
- result_text = await crawl_website(starting_url)
- # Extract URLs from result
- import re
-
- urls = re.findall(r"https?://[^\s\)]+", result_text)
- sources = list(set(urls)) # Deduplicate
-
- return ToolAgentOutput(output=result_text, sources=sources)
-
- elif task.agent == "RAGAgent":
- # Use RAG tool for semantic search
- rag_tool = _get_rag_tool()
- if rag_tool is None:
- return ToolAgentOutput(
- output="RAG service unavailable. OPENAI_API_KEY required.",
- sources=[],
- )
-
- # Search RAG and get Evidence objects
- evidence_list = await rag_tool.search(task.query, max_results=10)
-
- if not evidence_list:
- return ToolAgentOutput(
- output="No relevant evidence found in collected research.",
- sources=[],
- )
-
- # Convert Evidence to formatted text
- result_text = _evidence_to_text(evidence_list)
-
- # Extract URLs from evidence citations
- sources = [evidence.citation.url for evidence in evidence_list if evidence.citation.url]
-
- return ToolAgentOutput(output=result_text, sources=sources)
-
- else:
- logger.warning("Unknown agent type", agent=task.agent)
- return ToolAgentOutput(
- output=f"Unknown agent type: {task.agent}. Available: WebSearchAgent, SiteCrawlerAgent, RAGAgent",
- sources=[],
- )
-
- except Exception as e:
- logger.error("Tool execution failed", error=str(e), agent=task.agent)
- return ToolAgentOutput(
- output=f"Error executing {task.agent} for gap '{task.gap}': {e!s}",
- sources=[],
- )
-
-
-async def execute_tool_tasks(
- tasks: list[AgentTask],
-) -> dict[str, ToolAgentOutput]:
- """
- Execute multiple agent tasks concurrently.
-
- Args:
- tasks: List of AgentTask objects to execute
-
- Returns:
- Dictionary mapping task keys to ToolAgentOutput results
- """
- import asyncio
-
- logger.info("Executing tool tasks", count=len(tasks))
-
- # Create async tasks
- async_tasks = [execute_agent_task(task) for task in tasks]
-
- # Run concurrently
- results_list = await asyncio.gather(*async_tasks, return_exceptions=True)
-
- # Build results dictionary
- results: dict[str, ToolAgentOutput] = {}
- for i, (task, result) in enumerate(zip(tasks, results_list, strict=False)):
- if isinstance(result, Exception):
- logger.error("Task execution failed", error=str(result), task_index=i)
- results[f"{task.agent}_{i}"] = ToolAgentOutput(output=f"Error: {result!s}", sources=[])
- else:
- # Type narrowing: result is ToolAgentOutput after Exception check
- assert isinstance(result, ToolAgentOutput), (
- "Expected ToolAgentOutput after Exception check"
- )
- key = f"{task.agent}_{task.gap or i}" if task.gap else f"{task.agent}_{i}"
- results[key] = result
-
- logger.info("Tool tasks completed", completed=len(results))
-
- return results
diff --git a/src/tools/vendored/__init__.py b/src/tools/vendored/__init__.py
deleted file mode 100644
index e96e825253d1d80cb65b5f33c22878fd4ddc225a..0000000000000000000000000000000000000000
--- a/src/tools/vendored/__init__.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""Vendored web search components from folder/tools/web_search.py."""
-
-from src.tools.vendored.crawl_website import crawl_website
-from src.tools.vendored.searchxng_client import SearchXNGClient
-from src.tools.vendored.serper_client import SerperClient
-from src.tools.vendored.web_search_core import (
- CONTENT_LENGTH_LIMIT,
- ScrapeResult,
- WebpageSnippet,
- fetch_and_process_url,
- html_to_text,
- is_valid_url,
- scrape_urls,
-)
-
-__all__ = [
- "CONTENT_LENGTH_LIMIT",
- "ScrapeResult",
- "SearchXNGClient",
- "SerperClient",
- "WebpageSnippet",
- "crawl_website",
- "fetch_and_process_url",
- "html_to_text",
- "is_valid_url",
- "scrape_urls",
-]
diff --git a/src/tools/vendored/crawl_website.py b/src/tools/vendored/crawl_website.py
deleted file mode 100644
index cd75bb509fde305a5ba26d51877b4c06d87351ba..0000000000000000000000000000000000000000
--- a/src/tools/vendored/crawl_website.py
+++ /dev/null
@@ -1,153 +0,0 @@
-"""Website crawl tool vendored from folder/tools/crawl_website.py.
-
-This module provides website crawling functionality that starts from a given URL
-and crawls linked pages in a breadth-first manner, prioritizing navigation links.
-"""
-
-from urllib.parse import urljoin, urlparse
-
-import aiohttp
-import structlog
-from bs4 import BeautifulSoup
-
-from src.tools.vendored.web_search_core import (
- ScrapeResult,
- WebpageSnippet,
- scrape_urls,
- ssl_context,
-)
-
-logger = structlog.get_logger()
-
-
-async def _extract_links(
- html: str, current_url: str, base_domain: str
-) -> tuple[list[str], list[str]]:
- """Extract prioritized links from HTML content."""
- soup = BeautifulSoup(html, "html.parser")
- nav_links = set()
- body_links = set()
-
- # Find navigation/header links
- for nav_element in soup.find_all(["nav", "header"]):
- for a in nav_element.find_all("a", href=True):
- href = str(a["href"])
- link = urljoin(current_url, href)
- if urlparse(link).netloc == base_domain:
- nav_links.add(link)
-
- # Find remaining body links
- for a in soup.find_all("a", href=True):
- href = str(a["href"])
- link = urljoin(current_url, href)
- if urlparse(link).netloc == base_domain and link not in nav_links:
- body_links.add(link)
-
- return list(nav_links), list(body_links)
-
-
-async def _fetch_page(url: str) -> str:
- """Fetch HTML content from a URL."""
- connector = aiohttp.TCPConnector(ssl=ssl_context)
- async with aiohttp.ClientSession(connector=connector) as session:
- try:
- timeout = aiohttp.ClientTimeout(total=30)
- async with session.get(url, timeout=timeout) as response:
- if response.status == 200:
- return await response.text()
- return ""
- except Exception as e:
- logger.warning("Error fetching URL", url=url, error=str(e))
- return ""
-
-
-def _add_links_to_queue(
- links: list[str],
- queue: list[str],
- all_pages_to_scrape: set[str],
- remaining_slots: int,
-) -> int:
- """Add normalized links to queue if not already visited."""
- for link in links:
- normalized_link = link.rstrip("/")
- if normalized_link not in all_pages_to_scrape and remaining_slots > 0:
- queue.append(normalized_link)
- all_pages_to_scrape.add(normalized_link)
- remaining_slots -= 1
- return remaining_slots
-
-
-async def crawl_website(starting_url: str) -> list[ScrapeResult] | str:
- """Crawl the pages of a website starting with the starting_url and then descending into the pages linked from there.
-
- Prioritizes links found in headers/navigation, then body links, then subsequent pages.
-
- Args:
- starting_url: Starting URL to scrape
-
- Returns:
- List of ScrapeResult objects which have the following fields:
- - url: The URL of the web page
- - title: The title of the web page
- - description: The description of the web page
- - text: The text content of the web page
- """
- if not starting_url:
- return "Empty URL provided"
-
- # Ensure URL has a protocol
- if not starting_url.startswith(("http://", "https://")):
- starting_url = "http://" + starting_url
-
- max_pages = 10
- base_domain = urlparse(starting_url).netloc
-
- # Initialize with starting URL
- queue: list[str] = [starting_url]
- next_level_queue: list[str] = []
- all_pages_to_scrape: set[str] = set([starting_url])
-
- # Breadth-first crawl
- while queue and len(all_pages_to_scrape) < max_pages:
- current_url = queue.pop(0)
-
- # Fetch and process the page
- html_content = await _fetch_page(current_url)
- if html_content:
- nav_links, body_links = await _extract_links(html_content, current_url, base_domain)
-
- # Add unvisited nav links to current queue (higher priority)
- remaining_slots = max_pages - len(all_pages_to_scrape)
- remaining_slots = _add_links_to_queue(
- nav_links, queue, all_pages_to_scrape, remaining_slots
- )
-
- # Add unvisited body links to next level queue (lower priority)
- remaining_slots = _add_links_to_queue(
- body_links, next_level_queue, all_pages_to_scrape, remaining_slots
- )
-
- # If current queue is empty, add next level links
- if not queue:
- queue = next_level_queue
- next_level_queue = []
-
- # Convert set to list for final processing
- pages_to_scrape = list(all_pages_to_scrape)[:max_pages]
- pages_to_scrape_snippets: list[WebpageSnippet] = [
- WebpageSnippet(url=page, title="", description="") for page in pages_to_scrape
- ]
-
- # Use scrape_urls to get the content for all discovered pages
- result = await scrape_urls(pages_to_scrape_snippets)
- return result
-
-
-
-
-
-
-
-
-
-
diff --git a/src/tools/vendored/searchxng_client.py b/src/tools/vendored/searchxng_client.py
deleted file mode 100644
index 2cf4ce83b2f58e47bfb0f543790a6bd7d540bab6..0000000000000000000000000000000000000000
--- a/src/tools/vendored/searchxng_client.py
+++ /dev/null
@@ -1,106 +0,0 @@
-"""SearchXNG API client for Google searches.
-
-Vendored and adapted from folder/tools/web_search.py.
-"""
-
-import os
-
-import aiohttp
-import structlog
-
-from src.tools.vendored.web_search_core import WebpageSnippet, ssl_context
-from src.utils.exceptions import RateLimitError, SearchError
-
-logger = structlog.get_logger()
-
-
-class SearchXNGClient:
- """A client for the SearchXNG API to perform Google searches."""
-
- def __init__(self, host: str | None = None) -> None:
- """Initialize SearchXNG client.
-
- Args:
- host: SearchXNG host URL. If None, reads from SEARCHXNG_HOST env var.
-
- Raises:
- ConfigurationError: If no host is provided.
- """
- host = host or os.getenv("SEARCHXNG_HOST")
- if not host:
- from src.utils.exceptions import ConfigurationError
-
- raise ConfigurationError("SEARCHXNG_HOST environment variable is not set")
-
- # Ensure host ends with /search
- if not host.endswith("/search"):
- host = f"{host}/search" if not host.endswith("/") else f"{host}search"
-
- self.host: str = host
-
- async def search(
- self, query: str, filter_for_relevance: bool = False, max_results: int = 5
- ) -> list[WebpageSnippet]:
- """Perform a search using SearchXNG API.
-
- Args:
- query: The search query
- filter_for_relevance: Whether to filter results (currently not implemented)
- max_results: Maximum number of results to return
-
- Returns:
- List of WebpageSnippet objects with search results
-
- Raises:
- SearchError: If the search fails
- RateLimitError: If rate limit is exceeded
- """
- connector = aiohttp.TCPConnector(ssl=ssl_context)
- try:
- async with aiohttp.ClientSession(connector=connector) as session:
- params = {
- "q": query,
- "format": "json",
- }
-
- async with session.get(self.host, params=params) as response:
- if response.status == 429:
- raise RateLimitError("SearchXNG API rate limit exceeded")
-
- response.raise_for_status()
- results = await response.json()
-
- results_list = [
- WebpageSnippet(
- url=result.get("url", ""),
- title=result.get("title", ""),
- description=result.get("content", ""),
- )
- for result in results.get("results", [])
- ]
-
- if not results_list:
- logger.info("No search results found", query=query)
- return []
-
- # Return results up to max_results
- return results_list[:max_results]
-
- except aiohttp.ClientError as e:
- logger.error("SearchXNG API request failed", error=str(e), query=query)
- raise SearchError(f"SearchXNG API request failed: {e}") from e
- except RateLimitError:
- raise
- except Exception as e:
- logger.error("Unexpected error in SearchXNG search", error=str(e), query=query)
- raise SearchError(f"SearchXNG search failed: {e}") from e
-
-
-
-
-
-
-
-
-
-
diff --git a/src/tools/vendored/serper_client.py b/src/tools/vendored/serper_client.py
deleted file mode 100644
index 17d30b6a8b86f0ec9d36e4107e571a397bf9c050..0000000000000000000000000000000000000000
--- a/src/tools/vendored/serper_client.py
+++ /dev/null
@@ -1,139 +0,0 @@
-"""Serper API client for Google searches.
-
-Vendored and adapted from folder/tools/web_search.py.
-"""
-
-import os
-
-import aiohttp
-import structlog
-
-from src.tools.vendored.web_search_core import WebpageSnippet, ssl_context
-from src.utils.exceptions import ConfigurationError, RateLimitError, SearchError
-
-logger = structlog.get_logger()
-
-
-class SerperClient:
- """A client for the Serper API to perform Google searches."""
-
- def __init__(self, api_key: str | None = None) -> None:
- """Initialize Serper client.
-
- Args:
- api_key: Serper API key. If None, reads from SERPER_API_KEY env var.
-
- Raises:
- ConfigurationError: If no API key is provided.
- """
- self.api_key = api_key or os.getenv("SERPER_API_KEY")
- if not self.api_key:
- from src.utils.exceptions import ConfigurationError
-
- raise ConfigurationError(
- "No API key provided. Set SERPER_API_KEY environment variable."
- )
-
- # Serper API endpoint and headers
- # Documentation: https://serper.dev/api
- # Format: POST https://google.serper.dev/search
- # Headers: X-API-KEY (required), Content-Type: application/json
- # Body: {"q": "search query", "autocorrect": false}
- self.url = "https://google.serper.dev/search"
- self.headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
-
- async def search(
- self, query: str, filter_for_relevance: bool = False, max_results: int = 5
- ) -> list[WebpageSnippet]:
- """Perform a Google search using Serper API.
-
- Args:
- query: The search query
- filter_for_relevance: Whether to filter results (currently not implemented)
- max_results: Maximum number of results to return
-
- Returns:
- List of WebpageSnippet objects with search results
-
- Raises:
- SearchError: If the search fails
- RateLimitError: If rate limit is exceeded
- """
- connector = aiohttp.TCPConnector(ssl=ssl_context)
- try:
- async with aiohttp.ClientSession(connector=connector) as session:
- # Verify API call format matches Serper API documentation:
- # POST https://google.serper.dev/search
- # Headers: X-API-KEY, Content-Type: application/json
- # Body: {"q": query, "autocorrect": false}
- async with session.post(
- self.url,
- headers=self.headers,
- json={"q": query, "autocorrect": False},
- timeout=aiohttp.ClientTimeout(total=30), # 30 second timeout
- ) as response:
- if response.status == 429:
- raise RateLimitError("Serper API rate limit exceeded")
-
- if response.status == 403:
- # 403 can mean either invalid key OR credits exhausted
- # For free tier (2,500 credits), it's often credit exhaustion
- # Read response body to get more details
- try:
- error_body = await response.text()
- logger.warning(
- "Serper API returned 403 Forbidden",
- status=403,
- body=error_body[:200], # Truncate for logging
- hint="May be credit exhaustion (free tier: 2,500 credits) or invalid key",
- )
- except Exception:
- pass
-
- # Raise RateLimitError instead of ConfigurationError
- # This allows retry logic to handle credit exhaustion
- # The retry decorator will use exponential backoff with jitter
- raise RateLimitError(
- "Serper API credits may be exhausted (403 Forbidden). "
- "Free tier provides 2,500 credits (one-time, expire after 6 months). "
- "Check your dashboard at https://serper.dev/dashboard. "
- "Retrying with backoff..."
- )
-
- response.raise_for_status()
- results = await response.json()
-
- results_list = [
- WebpageSnippet(
- url=result.get("link", ""),
- title=result.get("title", ""),
- description=result.get("snippet", ""),
- )
- for result in results.get("organic", [])
- ]
-
- if not results_list:
- logger.info("No search results found", query=query)
- return []
-
- # Return results up to max_results
- return results_list[:max_results]
-
- except aiohttp.ClientError as e:
- logger.error("Serper API request failed", error=str(e), query=query)
- raise SearchError(f"Serper API request failed: {e}") from e
- except RateLimitError:
- raise
- except Exception as e:
- logger.error("Unexpected error in Serper search", error=str(e), query=query)
- raise SearchError(f"Serper search failed: {e}") from e
-
-
-
-
-
-
-
-
-
-
diff --git a/src/tools/vendored/web_search_core.py b/src/tools/vendored/web_search_core.py
deleted file mode 100644
index a465f3982df18728c2f4d4eebb21477bae3dbe65..0000000000000000000000000000000000000000
--- a/src/tools/vendored/web_search_core.py
+++ /dev/null
@@ -1,201 +0,0 @@
-"""Core web search utilities vendored from folder/tools/web_search.py.
-
-This module contains shared utilities for web scraping, URL processing,
-and HTML text extraction used by web search tools.
-"""
-
-import asyncio
-import ssl
-
-import aiohttp
-import structlog
-from bs4 import BeautifulSoup
-from pydantic import BaseModel, Field
-
-logger = structlog.get_logger()
-
-# Content length limit to avoid exceeding token limits
-CONTENT_LENGTH_LIMIT = 10000
-
-# Create a shared SSL context for web requests
-ssl_context = ssl.create_default_context()
-ssl_context.check_hostname = False
-ssl_context.verify_mode = ssl.CERT_NONE
-ssl_context.set_ciphers("DEFAULT:@SECLEVEL=1") # Allow older cipher suites
-
-
-class ScrapeResult(BaseModel):
- """Result of scraping a single webpage."""
-
- url: str = Field(description="The URL of the webpage")
- text: str = Field(description="The full text content of the webpage")
- title: str = Field(description="The title of the webpage")
- description: str = Field(description="A short description of the webpage")
-
-
-class WebpageSnippet(BaseModel):
- """Snippet information for a webpage (before scraping)."""
-
- url: str = Field(description="The URL of the webpage")
- title: str = Field(description="The title of the webpage")
- description: str | None = Field(default=None, description="A short description of the webpage")
-
-
-async def scrape_urls(items: list[WebpageSnippet]) -> list[ScrapeResult]:
- """Fetch text content from provided URLs.
-
- Args:
- items: List of WebpageSnippet items to extract content from
-
- Returns:
- List of ScrapeResult objects with scraped content
- """
- connector = aiohttp.TCPConnector(ssl=ssl_context)
- async with aiohttp.ClientSession(connector=connector) as session:
- # Create list of tasks for concurrent execution
- tasks = []
- for item in items:
- if item.url: # Skip empty URLs
- tasks.append(fetch_and_process_url(session, item))
-
- # Execute all tasks concurrently and gather results
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- # Filter out errors and return successful results
- successful_results: list[ScrapeResult] = []
- for result in results:
- if isinstance(result, ScrapeResult):
- successful_results.append(result)
- elif isinstance(result, Exception):
- logger.warning("Failed to scrape URL", error=str(result))
-
- return successful_results
-
-
-async def fetch_and_process_url(
- session: aiohttp.ClientSession, item: WebpageSnippet
-) -> ScrapeResult:
- """Helper function to fetch and process a single URL.
-
- Args:
- session: aiohttp ClientSession
- item: WebpageSnippet with URL to fetch
-
- Returns:
- ScrapeResult with fetched content
- """
- if not is_valid_url(item.url):
- return ScrapeResult(
- url=item.url,
- title=item.title,
- description=item.description or "",
- text="Error fetching content: URL contains restricted file extension",
- )
-
- try:
- timeout = aiohttp.ClientTimeout(total=8)
- async with session.get(item.url, timeout=timeout) as response:
- if response.status == 200:
- content = await response.text()
- # Run html_to_text in a thread pool to avoid blocking
- loop = asyncio.get_event_loop()
- text_content = await loop.run_in_executor(None, html_to_text, content)
- text_content = text_content[
- :CONTENT_LENGTH_LIMIT
- ] # Trim content to avoid exceeding token limit
- return ScrapeResult(
- url=item.url,
- title=item.title,
- description=item.description or "",
- text=text_content,
- )
- else:
- # Return a ScrapeResult with an error message
- return ScrapeResult(
- url=item.url,
- title=item.title,
- description=item.description or "",
- text=f"Error fetching content: HTTP {response.status}",
- )
- except Exception as e:
- logger.warning("Error fetching URL", url=item.url, error=str(e))
- # Return a ScrapeResult with an error message
- return ScrapeResult(
- url=item.url,
- title=item.title,
- description=item.description or "",
- text=f"Error fetching content: {e!s}",
- )
-
-
-def html_to_text(html_content: str) -> str:
- """Strip out unnecessary elements from HTML to prepare for text extraction.
-
- Args:
- html_content: Raw HTML content
-
- Returns:
- Extracted text from relevant HTML tags
- """
- # Parse the HTML using lxml for speed
- soup = BeautifulSoup(html_content, "lxml")
-
- # Extract text from relevant tags
- tags_to_extract = ("h1", "h2", "h3", "h4", "h5", "h6", "p", "li", "blockquote")
-
- # Use a generator expression for efficiency
- extracted_text = "\n".join(
- element.get_text(strip=True)
- for element in soup.find_all(tags_to_extract)
- if element.get_text(strip=True)
- )
-
- return extracted_text
-
-
-def is_valid_url(url: str) -> bool:
- """Check that a URL does not contain restricted file extensions.
-
- Args:
- url: URL to validate
-
- Returns:
- True if URL is valid, False if it contains restricted extensions
- """
- restricted_extensions = [
- ".pdf",
- ".doc",
- ".xls",
- ".ppt",
- ".zip",
- ".rar",
- ".7z",
- ".txt",
- ".js",
- ".xml",
- ".css",
- ".png",
- ".jpg",
- ".jpeg",
- ".gif",
- ".ico",
- ".svg",
- ".webp",
- ".mp3",
- ".mp4",
- ".avi",
- ".mov",
- ".wmv",
- ".flv",
- ".wma",
- ".wav",
- ".m4a",
- ".m4v",
- ".m4b",
- ".m4p",
- ".m4u",
- ]
-
- if any(ext in url for ext in restricted_extensions):
- return False
- return True
diff --git a/src/tools/web_search.py b/src/tools/web_search.py
deleted file mode 100644
index f6350e820f2edffa346fe2392828a38518ce3a00..0000000000000000000000000000000000000000
--- a/src/tools/web_search.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""Web search tool using DuckDuckGo."""
-
-import asyncio
-
-import structlog
-
-try:
- from ddgs import DDGS # New package name
-except ImportError:
- # Fallback to old package name for backward compatibility
- from duckduckgo_search import DDGS # type: ignore[no-redef]
-
-from src.tools.query_utils import preprocess_query
-from src.utils.exceptions import SearchError
-from src.utils.models import Citation, Evidence
-
-logger = structlog.get_logger()
-
-
-class WebSearchTool:
- """Tool for searching the web using DuckDuckGo."""
-
- def __init__(self) -> None:
- self._ddgs = DDGS()
-
- @property
- def name(self) -> str:
- """Return the name of this search tool."""
- return "duckduckgo"
-
- async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
- """Execute a web search and return evidence.
-
- Args:
- query: The search query string
- max_results: Maximum number of results to return
-
- Returns:
- List of Evidence objects
-
- Raises:
- SearchError: If the search fails
- """
- try:
- # Preprocess query to remove noise
- clean_query = preprocess_query(query)
- final_query = clean_query if clean_query else query
-
- loop = asyncio.get_running_loop()
-
- def _do_search() -> list[dict[str, str]]:
- # text() returns an iterator, need to list() it or iterate
- return list(self._ddgs.text(final_query, max_results=max_results))
-
- raw_results = await loop.run_in_executor(None, _do_search)
-
- evidence = []
- for r in raw_results:
- # Truncate title to max 500 characters to match Citation model validation
- title = r.get("title", "No Title")
- if len(title) > 500:
- title = title[:497] + "..."
-
- ev = Evidence(
- content=r.get("body", ""),
- citation=Citation(
- title=title,
- url=r.get("href", ""),
- source="web",
- date="Unknown",
- authors=[],
- ),
- relevance=0.0,
- )
- evidence.append(ev)
-
- return evidence
-
- except Exception as e:
- logger.error("Web search failed", error=str(e), query=query)
- raise SearchError(f"DuckDuckGo search failed: {e}") from e
diff --git a/src/tools/web_search_adapter.py b/src/tools/web_search_adapter.py
deleted file mode 100644
index 6ee833f2fdda510bd99d3311b923a248b55fa963..0000000000000000000000000000000000000000
--- a/src/tools/web_search_adapter.py
+++ /dev/null
@@ -1,55 +0,0 @@
-"""Web search tool adapter for Pydantic AI agents.
-
-Uses the new web search factory to provide web search functionality.
-"""
-
-import structlog
-
-from src.tools.web_search_factory import create_web_search_tool
-
-logger = structlog.get_logger()
-
-
-async def web_search(query: str) -> str:
- """
- Perform a web search for a given query and return formatted results.
-
- Use this tool to search the web for information relevant to the query.
- Provide a query with 3-6 words as input.
-
- Args:
- query: The search query (3-6 words recommended)
-
- Returns:
- Formatted string with search results including titles, descriptions, and URLs
- """
- try:
- # Get web search tool from factory
- tool = create_web_search_tool()
-
- if tool is None:
- logger.warning("Web search tool not available", hint="Check configuration")
- return "Web search tool not available. Please configure a web search provider."
-
- # Call the tool - it returns list[Evidence]
- evidence = await tool.search(query, max_results=5)
-
- if not evidence:
- return f"No web search results found for: {query}"
-
- # Format results for agent consumption
- formatted = [f"Found {len(evidence)} web search results:\n"]
- for i, ev in enumerate(evidence, 1):
- citation = ev.citation
- formatted.append(f"{i}. **{citation.title}**")
- if citation.url:
- formatted.append(f" URL: {citation.url}")
- if ev.content:
- formatted.append(f" Content: {ev.content[:300]}...")
- formatted.append("")
-
- return "\n".join(formatted)
-
- except Exception as e:
- logger.error("Web search failed", error=str(e), query=query)
- return f"Error performing web search: {e!s}"
diff --git a/src/tools/web_search_factory.py b/src/tools/web_search_factory.py
deleted file mode 100644
index ae79811639a141e49053aff27f37336dbf146755..0000000000000000000000000000000000000000
--- a/src/tools/web_search_factory.py
+++ /dev/null
@@ -1,117 +0,0 @@
-"""Factory for creating web search tools based on configuration."""
-
-import structlog
-
-from src.tools.base import SearchTool
-from src.tools.fallback_web_search import FallbackWebSearchTool
-from src.tools.searchxng_web_search import SearchXNGWebSearchTool
-from src.tools.serper_web_search import SerperWebSearchTool
-from src.tools.web_search import WebSearchTool
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger()
-
-
-def create_web_search_tool(provider: str | None = None) -> SearchTool | None:
- """Create a web search tool based on configuration.
-
- Args:
- provider: Override provider selection. If None, uses settings.web_search_provider.
-
- Returns:
- SearchTool instance, or None if not available/configured
-
- The tool is selected based on provider (or settings.web_search_provider if None):
- - "serper": SerperWebSearchTool (requires SERPER_API_KEY)
- - "searchxng": SearchXNGWebSearchTool (requires SEARCHXNG_HOST)
- - "duckduckgo": WebSearchTool (always available, no API key)
- - "brave" or "tavily": Not yet implemented, returns None
- - "auto": Auto-detect best available provider (prefers Serper > SearchXNG > DuckDuckGo)
-
- Auto-detection logic (when provider is "auto" or not explicitly set):
- 1. Try Serper if SERPER_API_KEY is available (best quality - Google search + full content scraping)
- 2. Try SearchXNG if SEARCHXNG_HOST is available
- 3. Fall back to DuckDuckGo (always available, but lower quality - snippets only)
- """
- provider = provider or settings.web_search_provider
-
- # Auto-detect best available provider if "auto" or if provider is duckduckgo but better options exist
- if provider == "auto" or (provider == "duckduckgo" and settings.serper_api_key):
- # Use fallback tool if Serper API key is available
- # This automatically falls back to DuckDuckGo on any Serper error
- if settings.serper_api_key:
- try:
- logger.info(
- "Auto-detected Serper with DuckDuckGo fallback (SERPER_API_KEY found)",
- provider="serper+duckduckgo",
- )
- return FallbackWebSearchTool()
- except Exception as e:
- logger.warning(
- "Failed to initialize fallback web search, trying alternatives",
- error=str(e),
- )
-
- # Try SearchXNG as second choice
- if settings.searchxng_host:
- try:
- logger.info(
- "Auto-detected SearchXNG web search (SEARCHXNG_HOST found)",
- provider="searchxng",
- )
- return SearchXNGWebSearchTool()
- except Exception as e:
- logger.warning(
- "Failed to initialize SearchXNG, falling back",
- error=str(e),
- )
-
- # Fall back to DuckDuckGo only
- if provider == "auto":
- logger.info(
- "Auto-detected DuckDuckGo web search (no API keys found)",
- provider="duckduckgo",
- )
- return WebSearchTool()
-
- try:
- if provider == "serper":
- if not settings.serper_api_key:
- logger.warning(
- "Serper provider selected but no API key found",
- hint="Set SERPER_API_KEY environment variable",
- )
- return None
- return SerperWebSearchTool()
-
- elif provider == "searchxng":
- if not settings.searchxng_host:
- logger.warning(
- "SearchXNG provider selected but no host found",
- hint="Set SEARCHXNG_HOST environment variable",
- )
- return None
- return SearchXNGWebSearchTool()
-
- elif provider == "duckduckgo":
- # DuckDuckGo is always available (no API key required)
- return WebSearchTool()
-
- elif provider in ("brave", "tavily"):
- logger.warning(
- f"Web search provider '{provider}' not yet implemented",
- hint="Use 'serper', 'searchxng', or 'duckduckgo'",
- )
- return None
-
- else:
- logger.warning(f"Unknown web search provider '{provider}', falling back to DuckDuckGo")
- return WebSearchTool()
-
- except ConfigurationError as e:
- logger.error("Failed to create web search tool", error=str(e), provider=provider)
- return None
- except Exception as e:
- logger.error("Unexpected error creating web search tool", error=str(e), provider=provider)
- return None
diff --git a/src/agent_factory/__init__.py b/src/tools/websearch.py
similarity index 100%
rename from src/agent_factory/__init__.py
rename to src/tools/websearch.py
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/src/utils/citation_validator.py b/src/utils/citation_validator.py
deleted file mode 100644
index c124c1a529ad46449c5182576a0e58a54e22381c..0000000000000000000000000000000000000000
--- a/src/utils/citation_validator.py
+++ /dev/null
@@ -1,178 +0,0 @@
-"""Citation validation to prevent LLM hallucination.
-
-CRITICAL: Medical research requires accurate citations.
-This module validates that all references exist in collected evidence.
-"""
-
-import logging
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from src.utils.models import Evidence, ResearchReport
-
-logger = logging.getLogger(__name__)
-
-# Max characters to display for URLs in log messages
-_MAX_URL_DISPLAY_LENGTH = 80
-
-
-def validate_references(report: "ResearchReport", evidence: list["Evidence"]) -> "ResearchReport":
- """Ensure all references actually exist in collected evidence.
-
- CRITICAL: Prevents LLM hallucination of citations.
-
- Note:
- This function MUTATES report.references in-place and returns the same
- report object. This is intentional for efficiency.
-
- Args:
- report: The generated research report (will be mutated)
- evidence: All evidence collected during research
-
- Returns:
- The same report object with references updated in-place
- """
- # Build set of valid URLs from evidence
- valid_urls = {e.citation.url for e in evidence}
- # Also check titles (case-insensitive, exact match) as fallback
- valid_titles = {e.citation.title.lower() for e in evidence}
-
- validated_refs = []
- removed_count = 0
-
- for ref in report.references:
- ref_url = ref.get("url", "")
- ref_title = ref.get("title", "").lower()
-
- # Check if URL matches collected evidence
- if ref_url in valid_urls:
- validated_refs.append(ref)
- # Fallback: exact title match (case-insensitive)
- elif ref_title and ref_title in valid_titles:
- validated_refs.append(ref)
- else:
- removed_count += 1
- # Truncate URL for display
- if len(ref_url) > _MAX_URL_DISPLAY_LENGTH:
- url_display = ref_url[:_MAX_URL_DISPLAY_LENGTH] + "..."
- else:
- url_display = ref_url
- logger.warning(
- f"Removed hallucinated reference: '{ref.get('title', 'Unknown')}' "
- f"(URL: {url_display})"
- )
-
- if removed_count > 0:
- logger.info(
- f"Citation validation removed {removed_count} hallucinated references. "
- f"{len(validated_refs)} valid references remain."
- )
-
- # Update report with validated references
- report.references = validated_refs
- return report
-
-
-def build_reference_from_evidence(evidence: "Evidence") -> dict[str, str]:
- """Build a properly formatted reference from evidence.
-
- Use this to ensure references match the original evidence exactly.
- """
- return {
- "title": evidence.citation.title,
- "authors": ", ".join(evidence.citation.authors or ["Unknown"]),
- "source": evidence.citation.source,
- "date": evidence.citation.date or "n.d.",
- "url": evidence.citation.url,
- }
-
-
-def validate_markdown_citations(
- markdown_report: str, evidence: list["Evidence"]
-) -> tuple[str, int]:
- """Validate citations in a markdown report against collected evidence.
-
- This function validates citations in markdown format (e.g., [1], [2]) by:
- 1. Extracting URLs from the references section
- 2. Matching them against Evidence objects
- 3. Removing invalid citations from the report
-
- Note:
- This is a basic validation. For full validation, use ResearchReport
- objects with validate_references().
-
- Args:
- markdown_report: The markdown report string with citations
- evidence: List of Evidence objects collected during research
-
- Returns:
- Tuple of (validated_markdown, removed_count)
- """
- import re
-
- # Build set of valid URLs from evidence
- valid_urls = {e.citation.url for e in evidence}
- valid_urls_lower = {url.lower() for url in valid_urls}
-
- # Extract references section (everything after "## References" or "References:")
- ref_section_pattern = r"(?i)(?:##\s*)?References:?\s*\n(.*?)(?=\n##|\Z)"
- ref_match = re.search(ref_section_pattern, markdown_report, re.DOTALL)
-
- if not ref_match:
- # No references section found, return as-is
- return markdown_report, 0
-
- ref_section = ref_match.group(1)
- ref_lines = ref_section.strip().split("\n")
-
- # Parse references: [1] https://example.com or [1] https://example.com Title
- valid_refs = []
- removed_count = 0
-
- for ref_line in ref_lines:
- stripped_line = ref_line.strip()
- if not stripped_line:
- continue
-
- # Extract URL from reference line
- # Pattern: [N] URL or [N] URL Title
- url_match = re.search(r"https?://[^\s\)]+", stripped_line)
- if url_match:
- url = url_match.group(0).rstrip(".,;")
- url_lower = url.lower()
-
- # Check if URL is valid
- if url in valid_urls or url_lower in valid_urls_lower:
- valid_refs.append(stripped_line)
- else:
- removed_count += 1
- logger.warning(
- f"Removed invalid citation from markdown: {url[:80]}"
- + ("..." if len(url) > 80 else "")
- )
- else:
- # No URL found, keep the line (might be formatted differently)
- valid_refs.append(stripped_line)
-
- # Rebuild references section
- if valid_refs:
- new_ref_section = "\n".join(valid_refs)
- # Replace the old references section
- validated_markdown = (
- markdown_report[: ref_match.start(1)]
- + new_ref_section
- + markdown_report[ref_match.end(1) :]
- )
- else:
- # No valid references, remove the entire section
- validated_markdown = (
- markdown_report[: ref_match.start()] + markdown_report[ref_match.end() :]
- )
-
- if removed_count > 0:
- logger.info(
- f"Citation validation removed {removed_count} invalid citations from markdown report. "
- f"{len(valid_refs)} valid citations remain."
- )
-
- return validated_markdown, removed_count
diff --git a/src/utils/config.py b/src/utils/config.py
deleted file mode 100644
index 9d4356d7b5afd35c6e782c870708cf12bdaae657..0000000000000000000000000000000000000000
--- a/src/utils/config.py
+++ /dev/null
@@ -1,329 +0,0 @@
-"""Application configuration using Pydantic Settings."""
-
-import logging
-from typing import Literal
-
-import structlog
-from pydantic import Field
-from pydantic_settings import BaseSettings, SettingsConfigDict
-
-from src.utils.exceptions import ConfigurationError
-
-
-class Settings(BaseSettings):
- """Strongly-typed application settings."""
-
- model_config = SettingsConfigDict(
- env_file=".env",
- env_file_encoding="utf-8",
- case_sensitive=False,
- extra="ignore",
- )
-
- # LLM Configuration
- openai_api_key: str | None = Field(default=None, description="OpenAI API key")
- anthropic_api_key: str | None = Field(default=None, description="Anthropic API key")
- llm_provider: Literal["openai", "anthropic", "huggingface"] = Field(
- default="huggingface", description="Which LLM provider to use"
- )
- openai_model: str = Field(default="gpt-5.1", description="OpenAI model name")
- anthropic_model: str = Field(
- default="claude-sonnet-4-5-20250929", description="Anthropic model"
- )
- hf_token: str | None = Field(
- default=None, alias="HF_TOKEN", description="HuggingFace API token"
- )
-
- # Embedding Configuration
- # Note: OpenAI embeddings require OPENAI_API_KEY (Anthropic has no embeddings API)
- openai_embedding_model: str = Field(
- default="text-embedding-3-small",
- description="OpenAI embedding model (used by LlamaIndex RAG)",
- )
- local_embedding_model: str = Field(
- default="all-MiniLM-L6-v2",
- description="Local sentence-transformers model (used by EmbeddingService)",
- )
- embedding_provider: Literal["openai", "local", "huggingface"] = Field(
- default="local",
- description="Embedding provider to use",
- )
- huggingface_embedding_model: str = Field(
- default="sentence-transformers/all-MiniLM-L6-v2",
- description="HuggingFace embedding model ID",
- )
-
- # HuggingFace Configuration
- huggingface_api_key: str | None = Field(
- default=None, description="HuggingFace API token (HF_TOKEN or HUGGINGFACE_API_KEY)"
- )
- huggingface_model: str = Field(
- default="meta-llama/Llama-3.1-8B-Instruct",
- description="Default HuggingFace model ID for inference",
- )
- hf_fallback_models: str = Field(
- default="Qwen/Qwen3-Next-80B-A3B-Thinking,Qwen/Qwen3-Next-80B-A3B-Instruct,meta-llama/Llama-3.3-70B-Instruct,meta-llama/Llama-3.1-8B-Instruct,HuggingFaceH4/zephyr-7b-beta,Qwen/Qwen2-7B-Instruct",
- alias="HF_FALLBACK_MODELS",
- description=(
- "Comma-separated list of fallback models for provider discovery and error recovery. "
- "Reads from HF_FALLBACK_MODELS environment variable. "
- "Default value is used only if the environment variable is not set."
- ),
- )
-
- # PubMed Configuration
- ncbi_api_key: str | None = Field(
- default=None, description="NCBI API key for higher rate limits"
- )
-
- # Web Search Configuration
- web_search_provider: Literal["serper", "searchxng", "brave", "tavily", "duckduckgo", "auto"] = (
- Field(
- default="auto",
- description="Web search provider to use. 'auto' will auto-detect best available (prefers Serper > SearchXNG > DuckDuckGo)",
- )
- )
- serper_api_key: str | None = Field(default=None, description="Serper API key for Google search")
- searchxng_host: str | None = Field(default=None, description="SearchXNG host URL")
- brave_api_key: str | None = Field(default=None, description="Brave Search API key")
- tavily_api_key: str | None = Field(default=None, description="Tavily API key")
-
- # Agent Configuration
- max_iterations: int = Field(default=10, ge=1, le=50)
- search_timeout: int = Field(default=30, description="Seconds to wait for search")
- use_graph_execution: bool = Field(
- default=False, description="Use graph-based execution for research flows"
- )
-
- # Budget & Rate Limiting Configuration
- default_token_limit: int = Field(
- default=100000,
- ge=1000,
- le=1000000,
- description="Default token budget per research loop",
- )
- default_time_limit_minutes: int = Field(
- default=10,
- ge=1,
- le=120,
- description="Default time limit per research loop (minutes)",
- )
- default_iterations_limit: int = Field(
- default=10,
- ge=1,
- le=50,
- description="Default iterations limit per research loop",
- )
-
- # Logging
- log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
-
- # External Services
- modal_token_id: str | None = Field(default=None, description="Modal token ID")
- modal_token_secret: str | None = Field(default=None, description="Modal token secret")
- chroma_db_path: str = Field(default="./chroma_db", description="ChromaDB storage path")
- chroma_db_persist: bool = Field(
- default=True,
- description="Whether to persist ChromaDB to disk",
- )
- chroma_db_host: str | None = Field(
- default=None,
- description="ChromaDB server host (for remote ChromaDB)",
- )
- chroma_db_port: int | None = Field(
- default=None,
- description="ChromaDB server port (for remote ChromaDB)",
- )
-
- # RAG Service Configuration
- rag_collection_name: str = Field(
- default="deepcritical_evidence",
- description="ChromaDB collection name for RAG",
- )
- rag_similarity_top_k: int = Field(
- default=5,
- ge=1,
- le=50,
- description="Number of top results to retrieve from RAG",
- )
- rag_auto_ingest: bool = Field(
- default=True,
- description="Automatically ingest evidence into RAG",
- )
-
- # Audio/TTS Configuration
- enable_audio_input: bool = Field(
- default=True,
- description="Enable audio input (speech-to-text) in multimodal interface",
- )
- enable_audio_output: bool = Field(
- default=True,
- description="Enable audio output (text-to-speech) for responses",
- )
- enable_image_input: bool = Field(
- default=True,
- description="Enable image input (OCR) in multimodal interface",
- )
- tts_voice: str = Field(
- default="af_heart",
- description="TTS voice ID for Kokoro TTS (e.g., af_heart, am_michael)",
- )
- tts_speed: float = Field(
- default=1.0,
- ge=0.5,
- le=2.0,
- description="TTS speech speed multiplier (0.5x to 2.0x)",
- )
- tts_use_llm_polish: bool = Field(
- default=False,
- description="Use LLM for final text polish before TTS (optional, costs API calls)",
- )
- tts_gpu: str | None = Field(
- default=None,
- description="Modal GPU type for TTS (T4, A10, A100, L4, L40S). None uses default T4.",
- )
-
- # STT (Speech-to-Text) Configuration
- stt_api_url: str | None = Field(
- default="https://nvidia-canary-1b-v2.hf.space",
- description="Gradio Space URL for STT service (default: nvidia/canary-1b-v2)",
- )
- stt_source_lang: str = Field(
- default="English",
- description="Source language for STT (full name like 'English', 'Spanish', etc.)",
- )
- stt_target_lang: str = Field(
- default="English",
- description="Target language for STT (full name like 'English', 'Spanish', etc.)",
- )
-
- # Image OCR Configuration
- ocr_api_url: str | None = Field(
- default="https://prithivmlmods-multimodal-ocr3.hf.space",
- description="Gradio Space URL for OCR service (default: prithivMLmods/Multimodal-OCR3)",
- )
-
- # Report File Output Configuration
- save_reports_to_file: bool = Field(
- default=True,
- description="Save generated reports to files (enables file downloads in Gradio)",
- )
- report_output_directory: str | None = Field(
- default=None,
- description="Directory to save report files. If None, uses system temp directory.",
- )
- report_file_format: Literal["md", "md_html", "md_pdf"] = Field(
- default="md",
- description="File format(s) to save reports in. 'md' saves only markdown, others save multiple formats.",
- )
- report_filename_template: str = Field(
- default="report_{timestamp}_{query_hash}.md",
- description="Template for report filenames. Supports {timestamp}, {query_hash}, {date} placeholders.",
- )
-
- @property
- def modal_available(self) -> bool:
- """Check if Modal credentials are configured."""
- return bool(self.modal_token_id and self.modal_token_secret)
-
- def get_api_key(self) -> str:
- """Get the API key for the configured provider."""
- if self.llm_provider == "openai":
- if not self.openai_api_key:
- raise ConfigurationError("OPENAI_API_KEY not set")
- return self.openai_api_key
-
- if self.llm_provider == "anthropic":
- if not self.anthropic_api_key:
- raise ConfigurationError("ANTHROPIC_API_KEY not set")
- return self.anthropic_api_key
-
- raise ConfigurationError(f"Unknown LLM provider: {self.llm_provider}")
-
- def get_openai_api_key(self) -> str:
- """Get OpenAI API key (required for Magentic function calling)."""
- if not self.openai_api_key:
- raise ConfigurationError(
- "OPENAI_API_KEY not set. Magentic mode requires OpenAI for function calling. "
- "Use mode='simple' for other providers."
- )
- return self.openai_api_key
-
- @property
- def has_openai_key(self) -> bool:
- """Check if OpenAI API key is available."""
- return bool(self.openai_api_key)
-
- @property
- def has_anthropic_key(self) -> bool:
- """Check if Anthropic API key is available."""
- return bool(self.anthropic_api_key)
-
- @property
- def has_huggingface_key(self) -> bool:
- """Check if HuggingFace API key is available."""
- return bool(self.huggingface_api_key or self.hf_token)
-
- @property
- def has_any_llm_key(self) -> bool:
- """Check if any LLM API key is available."""
- return self.has_openai_key or self.has_anthropic_key or self.has_huggingface_key
-
- @property
- def web_search_available(self) -> bool:
- """Check if web search is available (either no-key provider or API key present)."""
- if self.web_search_provider == "duckduckgo":
- return True # No API key required
- if self.web_search_provider == "serper":
- return bool(self.serper_api_key)
- if self.web_search_provider == "searchxng":
- return bool(self.searchxng_host)
- if self.web_search_provider == "brave":
- return bool(self.brave_api_key)
- if self.web_search_provider == "tavily":
- return bool(self.tavily_api_key)
- return False
-
- def get_hf_fallback_models_list(self) -> list[str]:
- """Get the list of fallback models as a list.
-
- Parses the comma-separated HF_FALLBACK_MODELS string into a list,
- stripping whitespace from each model ID.
-
- Returns:
- List of model IDs
- """
- if not self.hf_fallback_models:
- return []
- return [model.strip() for model in self.hf_fallback_models.split(",") if model.strip()]
-
-
-def get_settings() -> Settings:
- """Factory function to get settings (allows mocking in tests)."""
- return Settings()
-
-
-def configure_logging(settings: Settings) -> None:
- """Configure structured logging with the configured log level."""
- # Set stdlib logging level from settings
- logging.basicConfig(
- level=getattr(logging, settings.log_level),
- format="%(message)s",
- )
-
- structlog.configure(
- processors=[
- structlog.stdlib.filter_by_level,
- structlog.stdlib.add_logger_name,
- structlog.stdlib.add_log_level,
- structlog.processors.TimeStamper(fmt="iso"),
- structlog.processors.JSONRenderer(),
- ],
- wrapper_class=structlog.stdlib.BoundLogger,
- context_class=dict,
- logger_factory=structlog.stdlib.LoggerFactory(),
- )
-
-
-# Singleton for easy import
-settings = get_settings()
diff --git a/src/utils/exceptions.py b/src/utils/exceptions.py
deleted file mode 100644
index abedf1baf49f3f88f01b2451a81c71c221b0caac..0000000000000000000000000000000000000000
--- a/src/utils/exceptions.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Custom exceptions for The DETERMINATOR."""
-
-
-class DeepCriticalError(Exception):
- """Base exception for all DETERMINATOR errors.
-
- Note: Class name kept for backward compatibility.
- """
-
- pass
-
-
-class SearchError(DeepCriticalError):
- """Raised when a search operation fails."""
-
- pass
-
-
-class JudgeError(DeepCriticalError):
- """Raised when the judge fails to assess evidence."""
-
- pass
-
-
-class ConfigurationError(DeepCriticalError):
- """Raised when configuration is invalid."""
-
- pass
-
-
-class RateLimitError(SearchError):
- """Raised when we hit API rate limits."""
-
- pass
diff --git a/src/utils/hf_error_handler.py b/src/utils/hf_error_handler.py
deleted file mode 100644
index becc6862275260ef63249fee69fc7225468dec30..0000000000000000000000000000000000000000
--- a/src/utils/hf_error_handler.py
+++ /dev/null
@@ -1,199 +0,0 @@
-"""Utility functions for handling HuggingFace API errors and token validation."""
-
-import re
-from typing import Any
-
-import structlog
-
-logger = structlog.get_logger()
-
-
-def extract_error_details(error: Exception) -> dict[str, Any]:
- """Extract error details from HuggingFace API errors.
-
- Pydantic AI and HuggingFace Inference API errors often contain
- information in the error message string like:
- "status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden"
-
- Args:
- error: The exception object
-
- Returns:
- Dictionary with extracted error details:
- - status_code: HTTP status code (if found)
- - model_name: Model name (if found)
- - body: Error body/message (if found)
- - error_type: Type of error (403, 422, etc.)
- - is_auth_error: Whether this is an authentication/authorization error
- - is_model_error: Whether this is a model-specific error
- """
- error_str = str(error)
- details: dict[str, Any] = {
- "status_code": None,
- "model_name": None,
- "body": None,
- "error_type": "unknown",
- "is_auth_error": False,
- "is_model_error": False,
- }
-
- # Try to extract status_code
- status_match = re.search(r"status_code:\s*(\d+)", error_str)
- if status_match:
- details["status_code"] = int(status_match.group(1))
- details["error_type"] = f"http_{details['status_code']}"
-
- # Determine error category
- if details["status_code"] == 403:
- details["is_auth_error"] = True
- elif details["status_code"] == 422:
- details["is_model_error"] = True
-
- # Try to extract model_name
- model_match = re.search(r"model_name:\s*([^\s,]+)", error_str)
- if model_match:
- details["model_name"] = model_match.group(1)
-
- # Try to extract body
- body_match = re.search(r"body:\s*(.+)", error_str)
- if body_match:
- details["body"] = body_match.group(1).strip()
-
- return details
-
-
-def get_user_friendly_error_message(error: Exception, model_name: str | None = None) -> str:
- """Generate a user-friendly error message from an exception.
-
- Args:
- error: The exception object
- model_name: Optional model name for context
-
- Returns:
- User-friendly error message
- """
- details = extract_error_details(error)
-
- if details["is_auth_error"]:
- return (
- "🔐 **Authentication Error**\n\n"
- "Your HuggingFace token doesn't have permission to access this model or API.\n\n"
- "**Possible solutions:**\n"
- "1. **Re-authenticate**: Log out and log back in to ensure your token has the `inference-api` scope\n"
- "2. **Check model access**: Visit the model page on HuggingFace and request access if it's gated\n"
- "3. **Use alternative model**: Try a different model that's publicly available\n\n"
- f"**Model attempted**: {details['model_name'] or model_name or 'Unknown'}\n"
- f"**Error**: {details['body'] or str(error)}"
- )
-
- if details["is_model_error"]:
- return (
- "⚠️ **Model Compatibility Error**\n\n"
- "The selected model is not compatible with the current provider or has specific requirements.\n\n"
- "**Possible solutions:**\n"
- "1. **Try a different model**: Use a model that's compatible with the current provider\n"
- "2. **Check provider status**: The provider may be in staging mode or unavailable\n"
- "3. **Wait and retry**: If the model is in staging, it may become available later\n\n"
- f"**Model attempted**: {details['model_name'] or model_name or 'Unknown'}\n"
- f"**Error**: {details['body'] or str(error)}"
- )
-
- # Generic error
- return (
- "❌ **API Error**\n\n"
- f"An error occurred while calling the HuggingFace API:\n\n"
- f"**Error**: {error!s}\n\n"
- "Please try again or contact support if the issue persists."
- )
-
-
-def validate_hf_token(token: str | None) -> tuple[bool, str | None]:
- """Validate HuggingFace token format.
-
- Args:
- token: The token to validate
-
- Returns:
- Tuple of (is_valid, error_message)
- - is_valid: True if token appears valid
- - error_message: Error message if invalid, None if valid
- """
- if not token:
- return False, "Token is None or empty"
-
- if not isinstance(token, str):
- return False, f"Token is not a string (type: {type(token).__name__})"
-
- if len(token) < 10:
- return False, "Token appears too short (minimum 10 characters expected)"
-
- # HuggingFace tokens typically start with "hf_" for user tokens
- # OAuth tokens may have different formats, so we're lenient
- # Just check it's not obviously invalid
-
- return True, None
-
-
-def log_token_info(token: str | None, context: str = "") -> None:
- """Log token information for debugging (without exposing the actual token).
-
- Args:
- token: The token to log info about
- context: Additional context for the log message
- """
- if token:
- is_valid, error_msg = validate_hf_token(token)
- logger.debug(
- "Token validation",
- context=context,
- has_token=True,
- is_valid=is_valid,
- token_length=len(token),
- token_prefix=token[:4] + "..." if len(token) > 4 else "***",
- validation_error=error_msg,
- )
- else:
- logger.debug("Token validation", context=context, has_token=False)
-
-
-def should_retry_with_fallback(error: Exception) -> bool:
- """Determine if an error should trigger a fallback to alternative models.
-
- Args:
- error: The exception object
-
- Returns:
- True if the error suggests we should try a fallback model
- """
- details = extract_error_details(error)
-
- # Retry with fallback for:
- # - 403 errors (authentication/permission issues - might work with different model)
- # - 422 errors (model/provider compatibility - definitely try different model)
- # - Model-specific errors
- return (
- details["is_auth_error"] or details["is_model_error"] or details["model_name"] is not None
- )
-
-
-def get_fallback_models(original_model: str | None = None) -> list[str]:
- """Get a list of fallback models to try.
-
- Args:
- original_model: The original model that failed
-
- Returns:
- List of fallback model names to try in order
- """
- # Publicly available models that should work with most tokens
- fallbacks = [
- "meta-llama/Llama-3.1-8B-Instruct", # Common, often available
- "mistralai/Mistral-7B-Instruct-v0.3", # Alternative
- "HuggingFaceH4/zephyr-7b-beta", # Ungated fallback
- ]
-
- # If original model is in the list, remove it
- if original_model and original_model in fallbacks:
- fallbacks.remove(original_model)
-
- return fallbacks
diff --git a/src/utils/hf_model_validator.py b/src/utils/hf_model_validator.py
deleted file mode 100644
index e7dd75eea282b66fc655d4fb33fe1a884af397f7..0000000000000000000000000000000000000000
--- a/src/utils/hf_model_validator.py
+++ /dev/null
@@ -1,477 +0,0 @@
-"""Validator for querying available HuggingFace models and providers using OAuth token.
-
-This module provides functions to:
-1. Query available models from HuggingFace Hub
-2. Query available inference providers (with dynamic discovery)
-3. Validate model/provider combinations
-4. Return formatted lists for Gradio dropdowns
-
-Uses Hugging Face Hub API to discover providers dynamically by querying model
-information. Falls back to known providers list if discovery fails.
-"""
-
-import asyncio
-from time import time
-from typing import Any
-
-import structlog
-from huggingface_hub import HfApi
-
-from src.utils.config import settings
-
-logger = structlog.get_logger()
-
-
-def extract_oauth_token(oauth_token: Any) -> str | None:
- """Extract OAuth token value from Gradio OAuthToken object.
-
- Handles both gr.OAuthToken objects (with .token attribute) and plain strings.
- This is a convenience function for Gradio apps that use OAuth authentication.
-
- Args:
- oauth_token: Gradio OAuthToken object or string token
-
- Returns:
- Token string if available, None otherwise
- """
- if oauth_token is None:
- return None
-
- if hasattr(oauth_token, "token"):
- return oauth_token.token # type: ignore[no-any-return]
- elif isinstance(oauth_token, str):
- return oauth_token
-
- logger.warning(
- "Could not extract token from OAuthToken object",
- oauth_token_type=type(oauth_token).__name__,
- )
- return None
-
-
-# Known providers as fallback (updated from Hugging Face documentation)
-# These are used when dynamic discovery fails or times out
-KNOWN_PROVIDERS = [
- "auto", # Auto-select (always available)
- "hf-inference", # HuggingFace's own Inference API
- "nebius",
- "together",
- "scaleway",
- "hyperbolic",
- "novita",
- "nscale",
- "sambanova",
- "ovh",
- "fireworks-ai", # Note: API uses "fireworks-ai", not "fireworks"
- "cerebras",
- "fal-ai",
- "cohere",
-]
-
-
-def get_provider_discovery_models() -> list[str]:
- """Get list of models to use for provider discovery.
-
- Reads from HF_FALLBACK_MODELS environment variable via settings.
- The environment variable should be a comma-separated list of model IDs.
-
- Returns:
- List of model IDs to query for provider discovery
- """
- # Get models from HF_FALLBACK_MODELS environment variable
- # This is automatically read by Pydantic Settings from the env var
- fallback_models = settings.get_hf_fallback_models_list()
-
- logger.debug(
- "Using HF_FALLBACK_MODELS for provider discovery",
- count=len(fallback_models),
- models=fallback_models,
- )
-
- return fallback_models
-
-
-# Simple in-memory cache for provider lists (TTL: 1 hour)
-_provider_cache: dict[str, tuple[list[str], float]] = {}
-PROVIDER_CACHE_TTL = 3600 # 1 hour in seconds
-
-
-async def get_available_providers(token: str | None = None) -> list[str]:
- """Get list of available inference providers.
-
- Discovers providers dynamically by querying model information from HuggingFace Hub.
- Uses caching to avoid repeated API calls. Falls back to known providers if discovery fails.
-
- Strategy:
- 1. Check cache (if valid, return cached list)
- 2. Query popular models to extract unique providers from their inferenceProviderMapping
- 3. Fall back to known providers list if discovery fails
- 4. Cache results for future use
-
- Args:
- token: Optional HuggingFace API token for authenticated requests
- Can be extracted from gr.OAuthToken.token in Gradio apps
-
- Returns:
- List of provider names sorted alphabetically, with "auto" first
- (e.g., ["auto", "fireworks-ai", "hf-inference", "nebius", ...])
- """
- # Check cache first
- cache_key = "providers" + (f"_{token[:8]}" if token else "_no_token")
- if cache_key in _provider_cache:
- cached_providers, cache_time = _provider_cache[cache_key]
- if time() - cache_time < PROVIDER_CACHE_TTL:
- logger.debug("Returning cached providers", count=len(cached_providers))
- return cached_providers
-
- try:
- providers = set(["auto"]) # Always include "auto"
-
- # Try dynamic discovery by querying popular models
- loop = asyncio.get_running_loop()
- api = HfApi(token=token)
-
- # Get models to query from HF_FALLBACK_MODELS environment variable via settings
- discovery_models = get_provider_discovery_models()
-
- # Query a sample of popular models to discover providers
- # This is more efficient than querying all models
- discovery_count = 0
- for model_id in discovery_models:
- try:
-
- def _get_model_info(m: str) -> Any:
- """Get model info synchronously."""
- return api.model_info(m, expand=["inferenceProviderMapping"]) # type: ignore[arg-type]
-
- info = await loop.run_in_executor(None, _get_model_info, model_id)
-
- # Extract providers from inference_provider_mapping
- if hasattr(info, "inference_provider_mapping") and info.inference_provider_mapping:
- mapping = info.inference_provider_mapping
- # mapping is a dict like {'hf-inference': InferenceProviderMapping(...), ...}
- providers.update(mapping.keys())
- discovery_count += 1
- logger.debug(
- "Discovered providers from model",
- model=model_id,
- providers=list(mapping.keys()),
- )
- except Exception as e:
- logger.debug(
- "Could not get provider info for model",
- model=model_id,
- error=str(e),
- )
- continue
-
- # If we discovered providers, use them; otherwise fall back to known providers
- if len(providers) > 1: # More than just "auto"
- provider_list = sorted(list(providers))
- logger.info(
- "Discovered providers dynamically",
- count=len(provider_list),
- models_queried=discovery_count,
- has_token=bool(token),
- )
- else:
- # Fallback to known providers
- provider_list = KNOWN_PROVIDERS.copy()
- logger.info(
- "Using known providers list (discovery failed or incomplete)",
- count=len(provider_list),
- models_queried=discovery_count,
- )
-
- # Cache the results
- _provider_cache[cache_key] = (provider_list, time())
-
- return provider_list
-
- except Exception as e:
- logger.warning("Failed to get providers", error=str(e))
- # Return known providers as fallback
- return KNOWN_PROVIDERS.copy()
-
-
-async def get_available_models(
- token: str | None = None,
- task: str = "text-generation",
- limit: int = 100,
- inference_provider: str | None = None,
-) -> list[str]:
- """Get list of available models for text generation.
-
- Queries HuggingFace Hub API to get models that support text generation.
- Optionally filters by inference provider to show only models available via that provider.
-
- Args:
- token: Optional HuggingFace API token for authenticated requests
- Can be extracted from gr.OAuthToken.token in Gradio apps
- task: Task type to filter models (default: "text-generation")
- limit: Maximum number of models to return
- inference_provider: Optional provider name to filter models (e.g., "fireworks-ai", "nebius")
- If None, returns all models for the task
-
- Returns:
- List of model IDs (e.g., ["meta-llama/Llama-3.1-8B-Instruct", ...])
- """
- try:
- loop = asyncio.get_running_loop()
-
- def _fetch_models() -> list[str]:
- """Fetch models synchronously in executor."""
- api = HfApi(token=token)
-
- # Build query parameters
- query_params: dict[str, Any] = {
- "task": task,
- "sort": "downloads",
- "direction": -1,
- "limit": limit,
- }
-
- # Filter by inference provider if specified
- if inference_provider and inference_provider != "auto":
- query_params["inference_provider"] = inference_provider
-
- # Search for models
- models = api.list_models(**query_params)
-
- # Extract model IDs
- model_ids = [model.id for model in models]
- return model_ids
-
- model_ids = await loop.run_in_executor(None, _fetch_models)
-
- logger.info(
- "Fetched available models",
- count=len(model_ids),
- task=task,
- provider=inference_provider or "all",
- has_token=bool(token),
- )
-
- return model_ids
-
- except Exception as e:
- logger.warning("Failed to get models from Hub API", error=str(e))
- # Return popular fallback models
- return [
- "meta-llama/Llama-3.1-8B-Instruct",
- "mistralai/Mistral-7B-Instruct-v0.3",
- "HuggingFaceH4/zephyr-7b-beta",
- "google/gemma-2-9b-it",
- ]
-
-
-async def validate_model_provider_combination(
- model_id: str,
- provider: str | None,
- token: str | None = None,
-) -> tuple[bool, str | None]:
- """Validate that a model is available with a specific provider.
-
- Uses HuggingFace Hub API to check if the provider is listed in the model's
- inferenceProviderMapping. This is faster and more reliable than making test API calls.
-
- Args:
- model_id: HuggingFace model ID
- provider: Provider name (or None/empty for auto)
- token: Optional HuggingFace API token (from gr.OAuthToken.token)
-
- Returns:
- Tuple of (is_valid, error_message)
- - is_valid: True if combination is valid or provider is "auto"
- - error_message: Error message if invalid, None if valid
- """
- # "auto" is always valid - let HuggingFace select the provider
- if not provider or provider == "auto":
- return True, None
-
- try:
- loop = asyncio.get_running_loop()
- api = HfApi(token=token)
-
- def _get_model_info() -> Any:
- """Get model info with provider mapping synchronously."""
- return api.model_info(model_id, expand=["inferenceProviderMapping"]) # type: ignore[arg-type]
-
- info = await loop.run_in_executor(None, _get_model_info)
-
- # Check if provider is in the model's inference provider mapping
- if hasattr(info, "inference_provider_mapping") and info.inference_provider_mapping:
- mapping = info.inference_provider_mapping
- available_providers = set(mapping.keys())
-
- # Normalize provider name (some APIs use "fireworks-ai", others use "fireworks")
- normalized_provider = provider.lower()
- provider_variants = {normalized_provider}
-
- # Handle common provider name variations
- if normalized_provider == "fireworks":
- provider_variants.add("fireworks-ai")
- elif normalized_provider == "fireworks-ai":
- provider_variants.add("fireworks")
-
- # Check if any variant matches
- if any(p in available_providers for p in provider_variants):
- logger.debug(
- "Model/provider combination validated via API",
- model=model_id,
- provider=provider,
- available_providers=list(available_providers),
- )
- return True, None
- else:
- error_msg = (
- f"Model {model_id} is not available with provider '{provider}'. "
- f"Available providers: {', '.join(sorted(available_providers))}"
- )
- logger.debug(
- "Model/provider combination invalid",
- model=model_id,
- provider=provider,
- available_providers=list(available_providers),
- )
- return False, error_msg
- else:
- # Model doesn't have provider mapping - assume valid and let actual usage determine
- logger.debug(
- "Model has no provider mapping, assuming valid",
- model=model_id,
- provider=provider,
- )
- return True, None
-
- except Exception as e:
- logger.warning(
- "Model/provider validation failed",
- model=model_id,
- provider=provider,
- error=str(e),
- )
- # Don't fail validation on error - let the actual request fail
- # This is more user-friendly than blocking on validation errors
- return True, None
-
-
-async def get_models_for_provider(
- provider: str,
- token: str | None = None,
- limit: int = 50,
-) -> list[str]:
- """Get models available for a specific provider.
-
- This is a convenience wrapper around get_available_models() with provider filtering.
-
- Args:
- provider: Provider name (e.g., "nebius", "together", "fireworks-ai")
- Note: Use "fireworks-ai" not "fireworks" for the API
- token: Optional HuggingFace API token (from gr.OAuthToken.token)
- limit: Maximum number of models to return
-
- Returns:
- List of model IDs available for the provider
- """
- # Normalize provider name for API
- normalized_provider = provider
- if provider.lower() == "fireworks":
- normalized_provider = "fireworks-ai"
- logger.debug("Normalized provider name", original=provider, normalized=normalized_provider)
-
- return await get_available_models(
- token=token,
- task="text-generation",
- limit=limit,
- inference_provider=normalized_provider,
- )
-
-
-async def validate_oauth_token(token: str | None) -> dict[str, Any]:
- """Validate OAuth token and return available resources.
-
- Args:
- token: OAuth token to validate
-
- Returns:
- Dictionary with:
- - is_valid: Whether token is valid
- - has_inference_api_scope: Whether token has inference-api scope
- - available_models: List of available model IDs
- - available_providers: List of available provider names
- - username: HuggingFace username (if available)
- - error: Error message if validation failed
- """
- result: dict[str, Any] = {
- "is_valid": False,
- "has_inference_api_scope": False,
- "available_models": [],
- "available_providers": [],
- "username": None,
- "error": None,
- }
-
- if not token:
- result["error"] = "No token provided"
- return result
-
- try:
- # Validate token format
- from src.utils.hf_error_handler import validate_hf_token
-
- is_valid_format, format_error = validate_hf_token(token)
- if not is_valid_format:
- result["error"] = f"Invalid token format: {format_error}"
- return result
-
- # Try to get user info to validate token
- loop = asyncio.get_running_loop()
-
- def _get_user_info() -> dict[str, Any] | None:
- """Get user info from HuggingFace API."""
- try:
- api = HfApi(token=token)
- user_info = api.whoami()
- return user_info
- except Exception:
- return None
-
- user_info = await loop.run_in_executor(None, _get_user_info)
-
- if user_info:
- result["is_valid"] = True
- result["username"] = user_info.get("name") or user_info.get("fullname")
- logger.info("Token validated", username=result["username"])
- else:
- result["error"] = "Token validation failed - could not authenticate"
- return result
-
- # Try to query models to check inference-api scope
- try:
- models = await get_available_models(token=token, limit=10)
- if models:
- result["has_inference_api_scope"] = True
- result["available_models"] = models
- logger.info("Inference API scope confirmed", model_count=len(models))
- except Exception as e:
- logger.warning("Could not verify inference-api scope", error=str(e))
- # Token might be valid but without inference-api scope
- result["has_inference_api_scope"] = False
- result["error"] = f"Token may not have inference-api scope: {e}"
-
- # Get available providers
- try:
- providers = await get_available_providers(token=token)
- result["available_providers"] = providers
- except Exception as e:
- logger.warning("Could not get providers", error=str(e))
- # Use fallback providers
- result["available_providers"] = ["auto"]
-
- return result
-
- except Exception as e:
- logger.error("Token validation failed", error=str(e))
- result["error"] = str(e)
- return result
diff --git a/src/utils/huggingface_chat_client.py b/src/utils/huggingface_chat_client.py
deleted file mode 100644
index 5b0d1f67a96e95b971f110d571486c816d9fb7c3..0000000000000000000000000000000000000000
--- a/src/utils/huggingface_chat_client.py
+++ /dev/null
@@ -1,129 +0,0 @@
-"""Custom ChatClient implementation using HuggingFace InferenceClient.
-
-Uses HuggingFace InferenceClient which natively supports function calling,
-making this a thin async wrapper rather than a complex implementation.
-
-Reference: https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
-"""
-
-import asyncio
-from typing import Any
-
-import structlog
-from huggingface_hub import InferenceClient
-
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger()
-
-
-class HuggingFaceChatClient:
- """ChatClient implementation using HuggingFace InferenceClient.
-
- HuggingFace InferenceClient natively supports function calling via
- the 'tools' parameter, making this a simple async wrapper.
-
- This client is compatible with agent-framework's ChatAgent interface.
- """
-
- def __init__(
- self,
- model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
- api_key: str | None = None,
- provider: str = "auto",
- ) -> None:
- """Initialize HuggingFace chat client.
-
- Args:
- model_name: HuggingFace model identifier (e.g., "meta-llama/Llama-3.1-8B-Instruct")
- api_key: Optional HF_TOKEN for gated models. If None, uses environment token.
- provider: Provider name or "auto" for automatic selection.
- Options: "auto", "cerebras", "together", "sambanova", etc.
-
- Raises:
- ConfigurationError: If initialization fails
- """
- try:
- # Type ignore: provider can be str but InferenceClient expects Literal
- # We validate it's a valid provider at runtime
- self.client = InferenceClient(
- model=model_name,
- api_key=api_key,
- provider=provider, # type: ignore[arg-type]
- )
- self.model_name = model_name
- self.provider = provider
- logger.info(
- "Initialized HuggingFace chat client",
- model=model_name,
- provider=provider,
- )
- except Exception as e:
- raise ConfigurationError(
- f"Failed to initialize HuggingFace InferenceClient: {e}"
- ) from e
-
- async def chat_completion(
- self,
- messages: list[dict[str, Any]],
- tools: list[dict[str, Any]] | None = None,
- tool_choice: str | dict[str, Any] | None = None,
- temperature: float | None = None,
- max_tokens: int | None = None,
- ) -> Any:
- """Send chat completion with optional tools.
-
- HuggingFace InferenceClient natively supports tools parameter!
- This is just an async wrapper around the synchronous API.
-
- Args:
- messages: List of message dicts with 'role' and 'content' keys.
- Format: [{"role": "user", "content": "Hello"}]
- tools: Optional list of tool definitions in OpenAI format.
- Format: [{"type": "function", "function": {...}}]
- tool_choice: Tool selection strategy.
- Options: "auto", "none", or {"type": "function", "function": {"name": "tool_name"}}
- temperature: Sampling temperature (0.0 to 2.0). Defaults to 1.0.
- max_tokens: Maximum tokens in response. Defaults to 100.
-
- Returns:
- ChatCompletionOutput compatible with agent-framework.
- Has .choices attribute with message and tool_calls.
-
- Raises:
- ConfigurationError: If chat completion fails
- """
- try:
- loop = asyncio.get_running_loop()
- response = await loop.run_in_executor(
- None,
- lambda: self.client.chat_completion(
- messages=messages,
- tools=tools, # type: ignore[arg-type] # ✅ Native support!
- tool_choice=tool_choice, # type: ignore[arg-type] # ✅ Native support!
- temperature=temperature,
- max_tokens=max_tokens,
- ),
- )
-
- logger.debug(
- "Chat completion successful",
- model=self.model_name,
- has_tools=bool(tools),
- has_tool_calls=bool(
- response.choices[0].message.tool_calls
- if response.choices and response.choices[0].message.tool_calls
- else None
- ),
- )
-
- return response
-
- except Exception as e:
- logger.error(
- "Chat completion failed",
- model=self.model_name,
- error=str(e),
- error_type=type(e).__name__,
- )
- raise ConfigurationError(f"HuggingFace chat completion failed: {e}") from e
diff --git a/src/utils/llm_factory.py b/src/utils/llm_factory.py
deleted file mode 100644
index 2009568298056aca40e6fdddb93f7e7b48fc9f54..0000000000000000000000000000000000000000
--- a/src/utils/llm_factory.py
+++ /dev/null
@@ -1,206 +0,0 @@
-"""Centralized LLM client factory.
-
-This module provides factory functions for creating LLM clients,
-ensuring consistent configuration and clear error messages.
-
-Agent-Framework Chat Clients:
-- HuggingFace InferenceClient: Native function calling support via 'tools' parameter
-- OpenAI ChatClient: Native function calling support (original implementation)
-- Both can be used with agent-framework's ChatAgent
-
-Pydantic AI Models:
-- Default provider is HuggingFace (free tier, no API key required for public models)
-- OpenAI and Anthropic are available as fallback options
-- All providers use Pydantic AI's unified interface
-"""
-
-from typing import TYPE_CHECKING, Any
-
-import structlog
-
-from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
-
-logger = structlog.get_logger()
-
-if TYPE_CHECKING:
- from agent_framework.openai import OpenAIChatClient
-
- from src.utils.huggingface_chat_client import HuggingFaceChatClient
-
-
-def get_magentic_client() -> "OpenAIChatClient":
- """
- Get the OpenAI client for Magentic agents (legacy function).
-
- Note: This function is kept for backward compatibility.
- For new code, use get_chat_client_for_agent() which supports
- both OpenAI and HuggingFace.
-
- Raises:
- ConfigurationError: If OPENAI_API_KEY is not set
-
- Returns:
- Configured OpenAIChatClient for Magentic agents
- """
- # Import here to avoid requiring agent-framework for simple mode
- from agent_framework.openai import OpenAIChatClient
-
- api_key = settings.get_openai_api_key()
-
- return OpenAIChatClient(
- model_id=settings.openai_model,
- api_key=api_key,
- )
-
-
-def get_huggingface_chat_client(oauth_token: str | None = None) -> "HuggingFaceChatClient":
- """
- Get HuggingFace chat client for agent-framework.
-
- HuggingFace InferenceClient natively supports function calling,
- making it compatible with agent-framework's ChatAgent.
-
- Args:
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured HuggingFaceChatClient
-
- Raises:
- ConfigurationError: If initialization fails
- """
- from src.utils.huggingface_chat_client import HuggingFaceChatClient
-
- model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct"
- # Priority: oauth_token > env vars
- api_key = oauth_token or settings.hf_token or settings.huggingface_api_key
-
- return HuggingFaceChatClient(
- model_name=model_name,
- api_key=api_key,
- provider="auto", # Auto-select best provider
- )
-
-
-def get_chat_client_for_agent(oauth_token: str | None = None) -> Any:
- """
- Get appropriate chat client for agent-framework based on configuration.
-
- Supports:
- - HuggingFace InferenceClient (if HF_TOKEN available, preferred for free tier)
- - OpenAI ChatClient (if OPENAI_API_KEY available, fallback)
-
- Args:
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- ChatClient compatible with agent-framework (HuggingFaceChatClient or OpenAIChatClient)
-
- Raises:
- ConfigurationError: If no suitable client can be created
- """
- # Check if we have OAuth token or env vars
- has_hf_key = bool(oauth_token or settings.has_huggingface_key)
-
- # Prefer HuggingFace if available (free tier)
- if has_hf_key:
- return get_huggingface_chat_client(oauth_token=oauth_token)
-
- # Fallback to OpenAI if available
- if settings.has_openai_key:
- return get_magentic_client()
-
- # If neither available, try HuggingFace without key (public models)
- try:
- return get_huggingface_chat_client(oauth_token=oauth_token)
- except Exception:
- pass
-
- raise ConfigurationError(
- "No chat client available. Set HF_TOKEN or OPENAI_API_KEY for agent-framework mode."
- )
-
-
-def get_pydantic_ai_model(oauth_token: str | None = None) -> Any:
- """
- Get the appropriate model for pydantic-ai based on configuration.
-
- Uses the configured LLM_PROVIDER to select between HuggingFace, OpenAI, and Anthropic.
- Defaults to HuggingFace if provider is not specified or unknown.
- This is used by simple mode components (JudgeHandler, etc.)
-
- Args:
- oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
-
- Returns:
- Configured pydantic-ai model
- """
- from pydantic_ai.models.huggingface import HuggingFaceModel
- from pydantic_ai.providers.huggingface import HuggingFaceProvider
-
- # Priority: oauth_token > settings.hf_token > settings.huggingface_api_key
- effective_hf_token = oauth_token or settings.hf_token or settings.huggingface_api_key
-
- # HuggingFaceProvider requires a token - cannot use None
- if not effective_hf_token:
- raise ConfigurationError(
- "HuggingFace token required. Please either:\n"
- "1. Log in via HuggingFace OAuth (recommended for Spaces)\n"
- "2. Set HF_TOKEN environment variable\n"
- "3. Set huggingface_api_key in settings"
- )
-
- # Validate and log token information
- from src.utils.hf_error_handler import log_token_info, validate_hf_token
-
- log_token_info(effective_hf_token, context="get_pydantic_ai_model")
- is_valid, error_msg = validate_hf_token(effective_hf_token)
- if not is_valid:
- logger.warning(
- "Token validation failed in get_pydantic_ai_model",
- error=error_msg,
- has_oauth=bool(oauth_token),
- )
- # Continue anyway - let the API call fail with a clear error
-
- # Always use HuggingFace with available token
- model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct"
- hf_provider = HuggingFaceProvider(api_key=effective_hf_token)
- return HuggingFaceModel(model_name, provider=hf_provider)
-
-
-def check_magentic_requirements() -> None:
- """
- Check if Magentic/agent-framework mode requirements are met.
-
- Note: HuggingFace InferenceClient now supports function calling natively,
- so this check is relaxed. We prefer HuggingFace if available, fallback to OpenAI.
-
- Raises:
- ConfigurationError: If no suitable client can be created
- """
- # Try to get a chat client - will raise if none available
- try:
- get_chat_client_for_agent()
- except ConfigurationError as e:
- raise ConfigurationError(
- "Agent-framework mode requires HF_TOKEN or OPENAI_API_KEY. "
- "HuggingFace is preferred (free tier with function calling support). "
- "Use mode='simple' for other LLM providers."
- ) from e
-
-
-def check_simple_mode_requirements() -> None:
- """
- Check if simple mode requirements are met.
-
- Simple mode supports HuggingFace (default), OpenAI, and Anthropic.
- HuggingFace can work without an API key for public models.
-
- Raises:
- ConfigurationError: If no LLM is available (only if explicitly required)
- """
- # HuggingFace can work without API key for public models, so we don't require it
- # This allows simple mode to work out of the box
- pass
diff --git a/src/utils/markdown.css b/src/utils/markdown.css
deleted file mode 100644
index 1854a4bb7d4cb4329e7c3b720f62d01a1b521473..0000000000000000000000000000000000000000
--- a/src/utils/markdown.css
+++ /dev/null
@@ -1,7 +0,0 @@
-body {
- font-family: Arial, sans-serif;
- font-size: 14px;
- line-height: 1.8;
- color: #000;
-}
-
diff --git a/src/utils/md_to_pdf.py b/src/utils/md_to_pdf.py
deleted file mode 100644
index 08e7f71d003032dc202f897563f99cf1b600f07c..0000000000000000000000000000000000000000
--- a/src/utils/md_to_pdf.py
+++ /dev/null
@@ -1,63 +0,0 @@
-"""Utility for converting markdown to PDF."""
-
-from pathlib import Path
-from typing import TYPE_CHECKING
-
-import structlog
-
-if TYPE_CHECKING:
- pass
-
-logger = structlog.get_logger()
-
-# Try to import md2pdf
-try:
- from md2pdf import md2pdf
-
- _MD2PDF_AVAILABLE = True
-except ImportError:
- md2pdf = None # type: ignore[assignment, misc]
- _MD2PDF_AVAILABLE = False
- logger.warning("md2pdf not available - PDF generation will be disabled")
-
-
-def get_css_path() -> Path:
- """Get the path to the markdown.css file."""
- curdir = Path(__file__).parent
- css_path = curdir / "markdown.css"
- return css_path
-
-
-def md_to_pdf(md_text: str, pdf_file_path: str) -> None:
- """
- Convert markdown text to PDF.
-
- Args:
- md_text: Markdown text content
- pdf_file_path: Path where PDF should be saved
-
- Raises:
- ImportError: If md2pdf is not installed
- ValueError: If markdown text is empty
- OSError: If PDF file cannot be written
- """
- if not _MD2PDF_AVAILABLE:
- raise ImportError("md2pdf is not installed. Install it with: pip install md2pdf")
-
- if not md_text or not md_text.strip():
- raise ValueError("Markdown text cannot be empty")
-
- css_path = get_css_path()
-
- if not css_path.exists():
- logger.warning(
- "CSS file not found, PDF will be generated without custom styling",
- css_path=str(css_path),
- )
- # Generate PDF without CSS
- md2pdf(pdf_file_path, md_text)
- else:
- # Generate PDF with CSS
- md2pdf(pdf_file_path, md_text, css_file_path=str(css_path))
-
- logger.debug("PDF generated successfully", pdf_path=pdf_file_path)
diff --git a/src/utils/message_history.py b/src/utils/message_history.py
deleted file mode 100644
index 5ddbbe30af20de6ab9163494267a0b5626a480be..0000000000000000000000000000000000000000
--- a/src/utils/message_history.py
+++ /dev/null
@@ -1,169 +0,0 @@
-"""Message history utilities for Pydantic AI integration."""
-
-from typing import Any
-
-import structlog
-
-try:
- from pydantic_ai import ModelMessage, ModelRequest, ModelResponse
- from pydantic_ai.messages import TextPart, UserPromptPart
-
- _PYDANTIC_AI_AVAILABLE = True
-except ImportError:
- # Fallback for older pydantic-ai versions
- ModelMessage = Any # type: ignore[assignment, misc]
- ModelRequest = Any # type: ignore[assignment, misc]
- ModelResponse = Any # type: ignore[assignment, misc]
- TextPart = Any # type: ignore[assignment, misc]
- UserPromptPart = Any # type: ignore[assignment, misc]
- _PYDANTIC_AI_AVAILABLE = False
-
-logger = structlog.get_logger()
-
-
-def convert_gradio_to_message_history(
- history: list[dict[str, Any]],
- max_messages: int = 20,
-) -> list[ModelMessage]:
- """
- Convert Gradio chat history to Pydantic AI message history.
-
- Args:
- history: Gradio chat history format [{"role": "user", "content": "..."}, ...]
- max_messages: Maximum messages to include (most recent)
-
- Returns:
- List of ModelMessage objects for Pydantic AI
- """
- if not history:
- return []
-
- if not _PYDANTIC_AI_AVAILABLE:
- logger.warning(
- "Pydantic AI message history not available, returning empty list",
- )
- return []
-
- messages: list[ModelMessage] = []
-
- # Take most recent messages
- recent = history[-max_messages:] if len(history) > max_messages else history
-
- for msg in recent:
- role = msg.get("role", "")
- content = msg.get("content", "")
-
- if not content or role not in ("user", "assistant"):
- continue
-
- # Convert content to string if needed
- content_str = str(content)
-
- if role == "user":
- messages.append(
- ModelRequest(parts=[UserPromptPart(content=content_str)]),
- )
- elif role == "assistant":
- messages.append(
- ModelResponse(parts=[TextPart(content=content_str)]),
- )
-
- logger.debug(
- "Converted Gradio history to message history",
- input_turns=len(history),
- output_messages=len(messages),
- )
-
- return messages
-
-
-def message_history_to_string(
- messages: list[ModelMessage],
- max_messages: int = 5,
- include_metadata: bool = False,
-) -> str:
- """
- Convert message history to string format for backward compatibility.
-
- Used during transition period when some agents still expect strings.
-
- Args:
- messages: List of ModelMessage objects
- max_messages: Maximum messages to include
- include_metadata: Whether to include metadata
-
- Returns:
- Formatted string representation
- """
- if not messages:
- return ""
-
- recent = messages[-max_messages:] if len(messages) > max_messages else messages
-
- parts = ["PREVIOUS CONVERSATION:", "---"]
- turn_num = 1
-
- for msg in recent:
- # Extract text content
- text = ""
- if isinstance(msg, ModelRequest):
- for part in msg.parts:
- if hasattr(part, "content"):
- text += str(part.content)
- parts.append(f"[Turn {turn_num}]")
- parts.append(f"User: {text}")
- turn_num += 1
- elif isinstance(msg, ModelResponse):
- for part in msg.parts: # type: ignore[assignment]
- if hasattr(part, "content"):
- text += str(part.content)
- parts.append(f"Assistant: {text}")
-
- parts.append("---")
- return "\n".join(parts)
-
-
-def create_truncation_processor(max_messages: int = 10) -> Any:
- """Create a history processor that keeps only the most recent N messages.
-
- Args:
- max_messages: Maximum number of messages to keep
-
- Returns:
- Processor function that takes a list of messages and returns truncated list
- """
-
- def processor(messages: list[ModelMessage]) -> list[ModelMessage]:
- return messages[-max_messages:] if len(messages) > max_messages else messages
-
- return processor
-
-
-def create_relevance_processor(min_length: int = 10) -> Any:
- """Create a history processor that filters out very short messages.
-
- Args:
- min_length: Minimum message length to keep
-
- Returns:
- Processor function that filters messages by length
- """
-
- def processor(messages: list[ModelMessage]) -> list[ModelMessage]:
- filtered = []
- for msg in messages:
- text = ""
- if isinstance(msg, ModelRequest):
- for part in msg.parts:
- if hasattr(part, "content"):
- text += str(part.content)
- elif isinstance(msg, ModelResponse):
- for part in msg.parts: # type: ignore[assignment]
- if hasattr(part, "content"):
- text += str(part.content)
-
- if len(text.strip()) >= min_length:
- filtered.append(msg)
- return filtered
-
- return processor
diff --git a/src/utils/models.py b/src/utils/models.py
deleted file mode 100644
index e43efac0e0477910c98bd50cee94f75bf321a39d..0000000000000000000000000000000000000000
--- a/src/utils/models.py
+++ /dev/null
@@ -1,574 +0,0 @@
-"""Data models for the Search feature."""
-
-from datetime import UTC, datetime
-from typing import Any, ClassVar, Literal
-
-from pydantic import BaseModel, Field
-
-# Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11)
-SourceName = Literal[
- "pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web", "neo4j"
-]
-
-
-class Citation(BaseModel):
- """A citation to a source document."""
-
- source: SourceName = Field(description="Where this came from")
-
- title: str = Field(min_length=1, max_length=500)
- url: str = Field(description="URL to the source")
- date: str = Field(description="Publication date (YYYY-MM-DD or 'Unknown')")
- authors: list[str] = Field(default_factory=list)
-
- MAX_AUTHORS_IN_CITATION: ClassVar[int] = 3
-
- @property
- def formatted(self) -> str:
- """Format as a citation string."""
- author_str = ", ".join(self.authors[: self.MAX_AUTHORS_IN_CITATION])
- if len(self.authors) > self.MAX_AUTHORS_IN_CITATION:
- author_str += " et al."
- return f"{author_str} ({self.date}). {self.title}. {self.source.upper()}"
-
-
-class Evidence(BaseModel):
- """A piece of evidence retrieved from search."""
-
- content: str = Field(min_length=1, description="The actual text content")
- citation: Citation
- relevance: float = Field(default=0.0, ge=0.0, le=1.0, description="Relevance score 0-1")
- metadata: dict[str, Any] = Field(
- default_factory=dict,
- description="Additional metadata (e.g., cited_by_count, concepts, is_open_access)",
- )
-
- model_config = {"frozen": True}
-
-
-class SearchResult(BaseModel):
- """Result of a search operation."""
-
- query: str
- evidence: list[Evidence]
- sources_searched: list[SourceName]
- total_found: int
- errors: list[str] = Field(default_factory=list)
-
-
-class AssessmentDetails(BaseModel):
- """Detailed assessment of evidence quality."""
-
- mechanism_score: int = Field(
- ...,
- ge=0,
- le=10,
- description="How well does the evidence explain the mechanism? 0-10",
- )
- mechanism_reasoning: str = Field(
- ..., min_length=10, description="Explanation of mechanism score"
- )
- clinical_evidence_score: int = Field(
- ...,
- ge=0,
- le=10,
- description="Strength of clinical/preclinical evidence. 0-10",
- )
- clinical_reasoning: str = Field(
- ..., min_length=10, description="Explanation of clinical evidence score"
- )
- drug_candidates: list[str] = Field(
- default_factory=list, description="List of specific drug candidates mentioned"
- )
- key_findings: list[str] = Field(
- default_factory=list, description="Key findings from the evidence"
- )
-
-
-class JudgeAssessment(BaseModel):
- """Complete assessment from the Judge."""
-
- details: AssessmentDetails
- sufficient: bool = Field(..., description="Is evidence sufficient to provide a recommendation?")
- confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence in the assessment (0-1)")
- recommendation: Literal["continue", "synthesize"] = Field(
- ...,
- description="continue = need more evidence, synthesize = ready to answer",
- )
- next_search_queries: list[str] = Field(
- default_factory=list, description="If continue, what queries to search next"
- )
- reasoning: str = Field(
- ..., min_length=20, description="Overall reasoning for the recommendation"
- )
-
-
-class AgentEvent(BaseModel):
- """Event emitted by the orchestrator for UI streaming."""
-
- type: Literal[
- "started",
- "searching",
- "search_complete",
- "judging",
- "judge_complete",
- "looping",
- "synthesizing",
- "complete",
- "error",
- "streaming",
- "hypothesizing",
- "analyzing", # NEW for Phase 13
- "analysis_complete", # NEW for Phase 13
- ]
- message: str
- data: Any = None
- timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC))
- iteration: int = 0
-
- def to_markdown(self) -> str:
- """Format event as markdown for chat display."""
- icons = {
- "started": "🚀",
- "searching": "🔍",
- "search_complete": "📚",
- "judging": "🧠",
- "judge_complete": "✅",
- "looping": "🔄",
- "synthesizing": "📝",
- "complete": "🎉",
- "error": "❌",
- "streaming": "📡",
- "hypothesizing": "🔬", # NEW
- "analyzing": "📊", # NEW
- "analysis_complete": "📈", # NEW
- }
- icon = icons.get(self.type, "•")
- return f"{icon} **{self.type.upper()}**: {self.message}"
-
-
-class MechanismHypothesis(BaseModel):
- """A scientific hypothesis about drug mechanism."""
-
- drug: str = Field(description="The drug being studied")
- target: str = Field(description="Molecular target (e.g., AMPK, mTOR)")
- pathway: str = Field(description="Biological pathway affected")
- effect: str = Field(description="Downstream effect on disease")
- confidence: float = Field(ge=0, le=1, description="Confidence in hypothesis")
- supporting_evidence: list[str] = Field(
- default_factory=list, description="PMIDs or URLs supporting this hypothesis"
- )
- contradicting_evidence: list[str] = Field(
- default_factory=list, description="PMIDs or URLs contradicting this hypothesis"
- )
- search_suggestions: list[str] = Field(
- default_factory=list, description="Suggested searches to test this hypothesis"
- )
-
- def to_search_queries(self) -> list[str]:
- """Generate search queries to test this hypothesis."""
- return [
- f"{self.drug} {self.target}",
- f"{self.target} {self.pathway}",
- f"{self.pathway} {self.effect}",
- *self.search_suggestions,
- ]
-
-
-class HypothesisAssessment(BaseModel):
- """Assessment of evidence against hypotheses."""
-
- hypotheses: list[MechanismHypothesis]
- primary_hypothesis: MechanismHypothesis | None = Field(
- default=None, description="Most promising hypothesis based on current evidence"
- )
- knowledge_gaps: list[str] = Field(description="What we don't know yet")
- recommended_searches: list[str] = Field(description="Searches to fill knowledge gaps")
-
-
-class ReportSection(BaseModel):
- """A section of the research report."""
-
- title: str
- content: str
- # Reserved for future inline citation tracking within sections
- citations: list[str] = Field(default_factory=list)
-
-
-class ResearchReport(BaseModel):
- """Structured scientific report."""
-
- title: str = Field(description="Report title")
- executive_summary: str = Field(
- description="One-paragraph summary for quick reading", min_length=100, max_length=1000
- )
- research_question: str = Field(description="Clear statement of what was investigated")
-
- methodology: ReportSection = Field(description="How the research was conducted")
- hypotheses_tested: list[dict[str, Any]] = Field(
- description="Hypotheses with supporting/contradicting evidence counts"
- )
-
- mechanistic_findings: ReportSection = Field(description="Findings about drug mechanisms")
- clinical_findings: ReportSection = Field(
- description="Findings from clinical/preclinical studies"
- )
-
- drug_candidates: list[str] = Field(description="Identified drug candidates")
- limitations: list[str] = Field(description="Study limitations")
- conclusion: str = Field(description="Overall conclusion")
-
- references: list[dict[str, str]] = Field(
- default_factory=list,
- description="Formatted references with title, authors, source, URL",
- )
-
- # Metadata
- sources_searched: list[str] = Field(default_factory=list)
- total_papers_reviewed: int = 0
- search_iterations: int = 0
- confidence_score: float = Field(ge=0, le=1)
-
- def to_markdown(self) -> str:
- """Render report as markdown."""
- sections = [
- f"# {self.title}\n",
- f"## Executive Summary\n{self.executive_summary}\n",
- f"## Research Question\n{self.research_question}\n",
- f"## Methodology\n{self.methodology.content}\n",
- ]
-
- # Hypotheses
- sections.append("## Hypotheses Tested\n")
- if not self.hypotheses_tested:
- sections.append("*No hypotheses tested yet.*\n")
- for h in self.hypotheses_tested:
- supported = h.get("supported", 0)
- contradicted = h.get("contradicted", 0)
- if supported == 0 and contradicted == 0:
- status = "❓ Untested"
- elif supported > contradicted:
- status = "✅ Supported"
- else:
- status = "⚠️ Mixed"
- sections.append(
- f"- **{h.get('mechanism', 'Unknown')}** ({status}): "
- f"{supported} supporting, {contradicted} contradicting\n"
- )
-
- # Findings
- sections.append(f"## Mechanistic Findings\n{self.mechanistic_findings.content}\n")
- sections.append(f"## Clinical Findings\n{self.clinical_findings.content}\n")
-
- # Drug candidates
- sections.append("## Drug Candidates\n")
- if self.drug_candidates:
- for drug in self.drug_candidates:
- sections.append(f"- **{drug}**\n")
- else:
- sections.append("*No drug candidates identified.*\n")
-
- # Limitations
- sections.append("## Limitations\n")
- if self.limitations:
- for lim in self.limitations:
- sections.append(f"- {lim}\n")
- else:
- sections.append("*No limitations documented.*\n")
-
- # Conclusion
- sections.append(f"## Conclusion\n{self.conclusion}\n")
-
- # References
- sections.append("## References\n")
- if self.references:
- for i, ref in enumerate(self.references, 1):
- sections.append(
- f"{i}. {ref.get('authors', 'Unknown')}. "
- f"*{ref.get('title', 'Untitled')}*. "
- f"{ref.get('source', '')} ({ref.get('date', '')}). "
- f"[Link]({ref.get('url', '#')})\n"
- )
- else:
- sections.append("*No references available.*\n")
-
- # Metadata footer
- sections.append("\n---\n")
- sections.append(
- f"*Report generated from {self.total_papers_reviewed} papers "
- f"across {self.search_iterations} search iterations. "
- f"Confidence: {self.confidence_score:.0%}*"
- )
-
- return "\n".join(sections)
-
-
-class OrchestratorConfig(BaseModel):
- """Configuration for the orchestrator."""
-
- max_iterations: int = Field(default=10, ge=1, le=20)
- max_results_per_tool: int = Field(default=10, ge=1, le=50)
- search_timeout: float = Field(default=30.0, ge=5.0, le=120.0)
-
-
-# Models for iterative/deep research patterns
-
-
-class IterationData(BaseModel):
- """Data for a single iteration of the research loop."""
-
- gap: str = Field(description="The gap addressed in the iteration", default="")
- tool_calls: list[str] = Field(description="The tool calls made", default_factory=list)
- findings: list[str] = Field(
- description="The findings collected from tool calls", default_factory=list
- )
- thought: str = Field(
- description="The thinking done to reflect on the success of the iteration and next steps",
- default="",
- )
-
- model_config = {"frozen": True}
-
-
-class Conversation(BaseModel):
- """A conversation between the user and the iterative researcher."""
-
- history: list[IterationData] = Field(
- description="The data for each iteration of the research loop",
- default_factory=list,
- )
-
- def add_iteration(self, iteration_data: IterationData | None = None) -> None:
- """Add a new iteration to the conversation history."""
- if iteration_data is None:
- iteration_data = IterationData()
- self.history.append(iteration_data)
-
- def set_latest_gap(self, gap: str) -> None:
- """Set the gap for the latest iteration."""
- if not self.history:
- self.add_iteration()
- # Use model_copy() since IterationData is frozen
- self.history[-1] = self.history[-1].model_copy(update={"gap": gap})
-
- def set_latest_tool_calls(self, tool_calls: list[str]) -> None:
- """Set the tool calls for the latest iteration."""
- if not self.history:
- self.add_iteration()
- # Use model_copy() since IterationData is frozen
- self.history[-1] = self.history[-1].model_copy(update={"tool_calls": tool_calls})
-
- def set_latest_findings(self, findings: list[str]) -> None:
- """Set the findings for the latest iteration."""
- if not self.history:
- self.add_iteration()
- # Use model_copy() since IterationData is frozen
- self.history[-1] = self.history[-1].model_copy(update={"findings": findings})
-
- def set_latest_thought(self, thought: str) -> None:
- """Set the thought for the latest iteration."""
- if not self.history:
- self.add_iteration()
- # Use model_copy() since IterationData is frozen
- self.history[-1] = self.history[-1].model_copy(update={"thought": thought})
-
- def get_latest_gap(self) -> str:
- """Get the gap from the latest iteration."""
- if not self.history:
- return ""
- return self.history[-1].gap
-
- def get_latest_tool_calls(self) -> list[str]:
- """Get the tool calls from the latest iteration."""
- if not self.history:
- return []
- return self.history[-1].tool_calls
-
- def get_latest_findings(self) -> list[str]:
- """Get the findings from the latest iteration."""
- if not self.history:
- return []
- return self.history[-1].findings
-
- def get_latest_thought(self) -> str:
- """Get the thought from the latest iteration."""
- if not self.history:
- return ""
- return self.history[-1].thought
-
- def get_all_findings(self) -> list[str]:
- """Get all findings from all iterations."""
- return [finding for iteration_data in self.history for finding in iteration_data.findings]
-
- def compile_conversation_history(self) -> str:
- """Compile the conversation history into a string."""
- conversation = ""
- for iteration_num, iteration_data in enumerate(self.history):
- conversation += f"[ITERATION {iteration_num + 1}]\n\n"
- if iteration_data.thought:
- conversation += f"{self.get_thought_string(iteration_num)}\n\n"
- if iteration_data.gap:
- conversation += f"{self.get_task_string(iteration_num)}\n\n"
- if iteration_data.tool_calls:
- conversation += f"{self.get_action_string(iteration_num)}\n\n"
- if iteration_data.findings:
- conversation += f"{self.get_findings_string(iteration_num)}\n\n"
-
- return conversation
-
- def get_task_string(self, iteration_num: int) -> str:
- """Get the task for the specified iteration."""
- if iteration_num < len(self.history) and self.history[iteration_num].gap:
- return f"\nAddress this knowledge gap: {self.history[iteration_num].gap}\n"
- return ""
-
- def get_action_string(self, iteration_num: int) -> str:
- """Get the action for the specified iteration."""
- if iteration_num < len(self.history) and self.history[iteration_num].tool_calls:
- joined_calls = "\n".join(self.history[iteration_num].tool_calls)
- return (
- "\nCalling the following tools to address the knowledge gap:\n"
- f"{joined_calls}\n"
- )
- return ""
-
- def get_findings_string(self, iteration_num: int) -> str:
- """Get the findings for the specified iteration."""
- if iteration_num < len(self.history) and self.history[iteration_num].findings:
- joined_findings = "\n\n".join(self.history[iteration_num].findings)
- return f"\n{joined_findings}\n"
- return ""
-
- def get_thought_string(self, iteration_num: int) -> str:
- """Get the thought for the specified iteration."""
- if iteration_num < len(self.history) and self.history[iteration_num].thought:
- return f"\n{self.history[iteration_num].thought}\n"
- return ""
-
- def latest_task_string(self) -> str:
- """Get the latest task."""
- if not self.history:
- return ""
- return self.get_task_string(len(self.history) - 1)
-
- def latest_action_string(self) -> str:
- """Get the latest action."""
- if not self.history:
- return ""
- return self.get_action_string(len(self.history) - 1)
-
- def latest_findings_string(self) -> str:
- """Get the latest findings."""
- if not self.history:
- return ""
- return self.get_findings_string(len(self.history) - 1)
-
- def latest_thought_string(self) -> str:
- """Get the latest thought."""
- if not self.history:
- return ""
- return self.get_thought_string(len(self.history) - 1)
-
-
-class ReportPlanSection(BaseModel):
- """A section of the report that needs to be written."""
-
- title: str = Field(description="The title of the section")
- key_question: str = Field(description="The key question to be addressed in the section")
-
- model_config = {"frozen": True}
-
-
-class ReportPlan(BaseModel):
- """Output from the Report Planner Agent."""
-
- background_context: str = Field(
- description="A summary of supporting context that can be passed onto the research agents"
- )
- report_outline: list[ReportPlanSection] = Field(
- description="List of sections that need to be written in the report"
- )
- report_title: str = Field(description="The title of the report")
-
- model_config = {"frozen": True}
-
-
-class KnowledgeGapOutput(BaseModel):
- """Output from the Knowledge Gap Agent."""
-
- research_complete: bool = Field(
- description="Whether the research and findings are complete enough to end the research loop"
- )
- outstanding_gaps: list[str] = Field(
- description="List of knowledge gaps that still need to be addressed"
- )
-
- model_config = {"frozen": True}
-
-
-class AgentTask(BaseModel):
- """A task for a specific agent to address knowledge gaps."""
-
- gap: str | None = Field(description="The knowledge gap being addressed", default=None)
- agent: str = Field(description="The name of the agent to use")
- query: str = Field(description="The specific query for the agent")
- entity_website: str | None = Field(
- description="The website of the entity being researched, if known",
- default=None,
- )
-
- model_config = {"frozen": True}
-
-
-class AgentSelectionPlan(BaseModel):
- """Plan for which agents to use for knowledge gaps."""
-
- tasks: list[AgentTask] = Field(description="List of agent tasks to address knowledge gaps")
-
- model_config = {"frozen": True}
-
-
-class ReportDraftSection(BaseModel):
- """A section of the report that needs to be written."""
-
- section_title: str = Field(description="The title of the section")
- section_content: str = Field(description="The content of the section")
-
- model_config = {"frozen": True}
-
-
-class ReportDraft(BaseModel):
- """Output from the Report Planner Agent."""
-
- sections: list[ReportDraftSection] = Field(
- description="List of sections that are in the report"
- )
-
- model_config = {"frozen": True}
-
-
-class ToolAgentOutput(BaseModel):
- """Standard output for all tool agents."""
-
- output: str = Field(description="The output from the tool agent")
- sources: list[str] = Field(description="List of source URLs", default_factory=list)
-
- model_config = {"frozen": True}
-
-
-class ParsedQuery(BaseModel):
- """Parsed and improved user query with research mode detection."""
-
- original_query: str = Field(description="The original user query")
- improved_query: str = Field(description="Improved/refined query")
- research_mode: Literal["iterative", "deep"] = Field(description="Detected research mode")
- key_entities: list[str] = Field(
- default_factory=list,
- description="Key entities extracted from query",
- )
- research_questions: list[str] = Field(
- default_factory=list,
- description="Specific research questions extracted",
- )
-
- model_config = {"frozen": True}
diff --git a/src/utils/report_generator.py b/src/utils/report_generator.py
deleted file mode 100644
index 330f2ad281867acc08f21b9a04fc1617a0ad7029..0000000000000000000000000000000000000000
--- a/src/utils/report_generator.py
+++ /dev/null
@@ -1,179 +0,0 @@
-"""Utility functions for generating reports from evidence when LLM fails."""
-
-from typing import TYPE_CHECKING
-
-import structlog
-
-if TYPE_CHECKING:
- from src.utils.models import Citation, Evidence
-
-logger = structlog.get_logger()
-
-
-def _format_authors(citation: "Citation") -> str:
- """Format authors string from citation."""
- authors = ", ".join(citation.authors[:3])
- if len(citation.authors) > 3:
- authors += " et al."
- elif not authors:
- authors = "Unknown"
- return authors
-
-
-def _add_evidence_section(report_parts: list[str], evidence: list["Evidence"]) -> None:
- """Add evidence summary section to report."""
- from src.utils.models import SourceName
-
- report_parts.append("## Evidence Summary\n")
- report_parts.append(f"**Total Sources Found:** {len(evidence)}\n\n")
-
- # Group evidence by source
- by_source: dict[SourceName, list[Evidence]] = {}
- for ev in evidence:
- source = ev.citation.source
- if source not in by_source:
- by_source[source] = []
- by_source[source].append(ev)
-
- # Organize by source
- for source in sorted(by_source.keys()): # type: ignore[assignment]
- source_evidence = by_source[source]
- report_parts.append(f"### {source.upper()} Sources ({len(source_evidence)})\n\n")
-
- for i, ev in enumerate(source_evidence, 1):
- authors = _format_authors(ev.citation)
- report_parts.append(f"#### {i}. {ev.citation.title}\n")
- if authors and authors != "Unknown":
- report_parts.append(f"**Authors:** {authors} \n")
- report_parts.append(f"**Date:** {ev.citation.date} \n")
- report_parts.append(f"**Source:** {ev.citation.source.upper()} \n")
- report_parts.append(f"**URL:** {ev.citation.url} \n\n")
-
- # Content (truncated if too long)
- content = ev.content
- if len(content) > 500:
- content = content[:500] + "... [truncated]"
- report_parts.append(f"{content}\n\n")
-
-
-def _add_key_findings(report_parts: list[str], evidence: list["Evidence"]) -> None:
- """Add key findings section to report."""
- report_parts.append("## Key Findings\n\n")
- report_parts.append(
- "Based on the evidence collected, the following key points were identified:\n\n"
- )
-
- # Extract key points from evidence (first sentence or summary)
- key_points: list[str] = []
- for ev in evidence[:10]: # Limit to top 10
- # Try to extract first meaningful sentence
- content = ev.content.strip()
- if content:
- # Find first sentence
- first_period = content.find(".")
- if first_period > 0 and first_period < 200:
- key_point = content[: first_period + 1].strip()
- else:
- # Fallback: first 150 chars
- key_point = content[:150].strip()
- if len(content) > 150:
- key_point += "..."
- key_points.append(f"- {key_point} [[{len(key_points) + 1}]](#references)")
-
- if key_points:
- report_parts.append("\n".join(key_points))
- report_parts.append("\n\n")
- else:
- report_parts.append("*No specific key findings could be extracted from the evidence.*\n\n")
-
-
-def _add_references(report_parts: list[str], evidence: list["Evidence"]) -> None:
- """Add references section to report."""
- report_parts.append("## References\n\n")
- for i, ev in enumerate(evidence, 1):
- authors = _format_authors(ev.citation)
- report_parts.append(
- f"[{i}] {authors} ({ev.citation.date}). "
- f"*{ev.citation.title}*. "
- f"{ev.citation.source.upper()}. "
- f"Available at: {ev.citation.url}\n\n"
- )
-
-
-def generate_report_from_evidence(
- query: str,
- evidence: list["Evidence"] | None = None,
- findings: str | None = None,
-) -> str:
- """
- Generate a structured markdown report from evidence or findings when LLM fails.
-
- This function creates a proper report structure even without LLM assistance,
- formatting the collected evidence into a readable, well-structured document.
-
- Args:
- query: The original research query
- evidence: List of Evidence objects (preferred if available)
- findings: Pre-formatted findings string (fallback if evidence not available)
-
- Returns:
- Markdown formatted report string
- """
- report_parts: list[str] = []
-
- # Title
- report_parts.append(f"# Research Report: {query}\n")
-
- # Introduction
- report_parts.append("## Introduction\n")
- report_parts.append(f"This report addresses the following research query: **{query}**\n")
- report_parts.append(
- "*Note: This report was generated from collected evidence. "
- "LLM-based synthesis was unavailable due to API limitations.*\n\n"
- )
-
- # Evidence Summary
- if evidence and len(evidence) > 0:
- _add_evidence_section(report_parts, evidence)
- _add_key_findings(report_parts, evidence)
-
- elif findings:
- # Fallback: use findings string if evidence not available
- report_parts.append("## Research Findings\n\n")
- # Truncate if too long
- if len(findings) > 10000:
- findings = findings[:10000] + "\n\n[Content truncated due to length]"
- report_parts.append(f"{findings}\n\n")
- else:
- report_parts.append("## Research Findings\n\n")
- report_parts.append(
- "*No evidence or findings were collected during the research process.*\n\n"
- )
-
- # References Section
- if evidence and len(evidence) > 0:
- _add_references(report_parts, evidence)
-
- # Conclusion
- report_parts.append("## Conclusion\n\n")
- if evidence and len(evidence) > 0:
- report_parts.append(
- f"This report synthesized information from {len(evidence)} sources "
- f"to address the research query: **{query}**\n\n"
- )
- report_parts.append(
- "*Note: Due to API limitations, this report was generated directly from "
- "collected evidence without LLM-based synthesis. For a more comprehensive "
- "analysis, please retry when API access is available.*\n"
- )
- else:
- report_parts.append(
- "This report could not be fully generated due to limited evidence collection "
- "and API access issues.\n\n"
- )
- report_parts.append(
- "*Please retry your query when API access is available for a more "
- "comprehensive research report.*\n"
- )
-
- return "".join(report_parts)
diff --git a/src/utils/text_utils.py b/src/utils/text_utils.py
deleted file mode 100644
index da37ce8eeece0575b2009e483f75a3cf4c2ab9fc..0000000000000000000000000000000000000000
--- a/src/utils/text_utils.py
+++ /dev/null
@@ -1,132 +0,0 @@
-"""Text processing utilities for evidence handling."""
-
-from typing import TYPE_CHECKING
-
-import numpy as np
-
-if TYPE_CHECKING:
- from src.services.embeddings import EmbeddingService
- from src.utils.models import Evidence
-
-
-def truncate_at_sentence(text: str, max_chars: int = 300) -> str:
- """Truncate text at sentence boundary, preserving meaning.
-
- Args:
- text: The text to truncate
- max_chars: Maximum characters (default 300)
-
- Returns:
- Text truncated at last complete sentence within limit
- """
- if len(text) <= max_chars:
- return text
-
- # Find truncation point
- truncated = text[:max_chars]
-
- # Look for sentence endings: . ! ? followed by space or end
- # We check for sep at the END of the truncated string
- for sep in [". ", "! ", "? ", ".\n", "!\n", "?\n"]:
- last_sep = truncated.rfind(sep)
- if last_sep > max_chars // 2: # Don't truncate too aggressively (less than half)
- return text[: last_sep + 1].strip()
-
- # Fallback: find last period (even if not followed by space, e.g. end of string)
- last_period = truncated.rfind(".")
- if last_period > max_chars // 2:
- return text[: last_period + 1].strip()
-
- # Last resort: truncate at word boundary
- last_space = truncated.rfind(" ")
- if last_space > 0:
- return text[:last_space].strip() + "..."
-
- return truncated + "..."
-
-
-async def select_diverse_evidence(
- evidence: list["Evidence"], n: int, query: str, embeddings: "EmbeddingService | None" = None
-) -> list["Evidence"]:
- """Select n most diverse and relevant evidence items.
-
- Uses Maximal Marginal Relevance (MMR) when embeddings available,
- falls back to relevance_score sorting otherwise.
-
- Args:
- evidence: All available evidence
- n: Number of items to select
- query: Original query for relevance scoring
- embeddings: Optional EmbeddingService for semantic diversity
-
- Returns:
- Selected evidence items, diverse and relevant
- """
- if not evidence:
- return []
-
- if n >= len(evidence):
- return evidence
-
- # Fallback: sort by relevance score if no embeddings
- if embeddings is None:
- return sorted(
- evidence,
- key=lambda e: e.relevance, # Use .relevance (from Pydantic model)
- reverse=True,
- )[:n]
-
- # MMR: Maximal Marginal Relevance for diverse selection
- # Score = λ * relevance - (1-λ) * max_similarity_to_selected
- lambda_param = 0.7 # Balance relevance vs diversity
-
- # Get query embedding
- query_emb = await embeddings.embed(query)
-
- # Get all evidence embeddings
- evidence_embs = await embeddings.embed_batch([e.content for e in evidence])
-
- # Cosine similarity helper
- def cosine(a: list[float], b: list[float]) -> float:
- arr_a, arr_b = np.array(a), np.array(b)
- denominator = float(np.linalg.norm(arr_a) * np.linalg.norm(arr_b))
- if denominator == 0:
- return 0.0
- return float(np.dot(arr_a, arr_b) / denominator)
-
- # Compute relevance scores (cosine similarity to query)
- # Note: We use semantic relevance to query, not the keyword search 'relevance' score
- relevance_scores = [cosine(query_emb, emb) for emb in evidence_embs]
-
- # Greedy MMR selection
- selected_indices: list[int] = []
- remaining = set(range(len(evidence)))
-
- for _ in range(n):
- best_score = float("-inf")
- best_idx = -1
-
- for idx in remaining:
- # Relevance component
- relevance = relevance_scores[idx]
-
- # Diversity component: max similarity to already selected
- if selected_indices:
- max_sim = max(
- cosine(evidence_embs[idx], evidence_embs[sel]) for sel in selected_indices
- )
- else:
- max_sim = 0
-
- # MMR score
- mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim
-
- if mmr_score > best_score:
- best_score = mmr_score
- best_idx = idx
-
- if best_idx >= 0:
- selected_indices.append(best_idx)
- remaining.remove(best_idx)
-
- return [evidence[i] for i in selected_indices]