diff --git a/.cursorrules b/.cursorrules deleted file mode 100644 index 8fbe6def025d95d15c47f657eafbbbf0643a5ca5..0000000000000000000000000000000000000000 --- a/.cursorrules +++ /dev/null @@ -1,240 +0,0 @@ -# DeepCritical Project - Cursor Rules - -## Project-Wide Rules - -**Architecture**: Multi-agent research system using Pydantic AI for agent orchestration, supporting iterative and deep research patterns. Uses middleware for state management, budget tracking, and workflow coordination. - -**Type Safety**: ALWAYS use complete type hints. All functions must have parameter and return type annotations. Use `mypy --strict` compliance. Use `TYPE_CHECKING` imports for circular dependencies: `from typing import TYPE_CHECKING; if TYPE_CHECKING: from src.services.embeddings import EmbeddingService` - -**Async Patterns**: ALL I/O operations must be async (`async def`, `await`). Use `asyncio.gather()` for parallel operations. CPU-bound work must use `run_in_executor()`: `loop = asyncio.get_running_loop(); result = await loop.run_in_executor(None, cpu_bound_function, args)`. Never block the event loop. - -**Error Handling**: Use custom exceptions from `src/utils/exceptions.py`: `DeepCriticalError`, `SearchError`, `RateLimitError`, `JudgeError`, `ConfigurationError`. Always chain exceptions: `raise SearchError(...) from e`. Log with structlog: `logger.error("Operation failed", error=str(e), context=value)`. - -**Logging**: Use `structlog` for ALL logging (NOT `print` or `logging`). Import: `import structlog; logger = structlog.get_logger()`. Log with structured data: `logger.info("event", key=value)`. Use appropriate levels: DEBUG, INFO, WARNING, ERROR. - -**Pydantic Models**: All data exchange uses Pydantic models from `src/utils/models.py`. Models are frozen (`model_config = {"frozen": True}`) for immutability. Use `Field()` with descriptions. Validate with `ge=`, `le=`, `min_length=`, `max_length=` constraints. - -**Code Style**: Ruff with 100-char line length. Ignore rules: `PLR0913` (too many arguments), `PLR0912` (too many branches), `PLR0911` (too many returns), `PLR2004` (magic values), `PLW0603` (global statement), `PLC0415` (lazy imports). - -**Docstrings**: Google-style docstrings for all public functions. Include Args, Returns, Raises sections. Use type hints in docstrings only if needed for clarity. - -**Testing**: Unit tests in `tests/unit/` (mocked, fast). Integration tests in `tests/integration/` (real APIs, marked `@pytest.mark.integration`). Use `respx` for httpx mocking, `pytest-mock` for general mocking. - -**State Management**: Use `ContextVar` in middleware for thread-safe isolation. Never use global mutable state (except singletons via `@lru_cache`). Use `WorkflowState` from `src/middleware/state_machine.py` for workflow state. - -**Citation Validation**: ALWAYS validate references before returning reports. Use `validate_references()` from `src/utils/citation_validator.py`. Remove hallucinated citations. Log warnings for removed citations. - ---- - -## src/agents/ - Agent Implementation Rules - -**Pattern**: All agents use Pydantic AI `Agent` class. Agents have structured output types (Pydantic models) or return strings. Use factory functions in `src/agent_factory/agents.py` for creation. - -**Agent Structure**: -- System prompt as module-level constant (with date injection: `datetime.now().strftime("%Y-%m-%d")`) -- Agent class with `__init__(model: Any | None = None)` -- Main method (e.g., `async def evaluate()`, `async def write_report()`) -- Factory function: `def create_agent_name(model: Any | None = None) -> AgentName` - -**Model Initialization**: Use `get_model()` from `src/agent_factory/judges.py` if no model provided. Support OpenAI/Anthropic/HF Inference via settings. - -**Error Handling**: Return fallback values (e.g., `KnowledgeGapOutput(research_complete=False, outstanding_gaps=[...])`) on failure. Log errors with context. Use retry logic (3 retries) in Pydantic AI Agent initialization. - -**Input Validation**: Validate query/inputs are not empty. Truncate very long inputs with warnings. Handle None values gracefully. - -**Output Types**: Use structured output types from `src/utils/models.py` (e.g., `KnowledgeGapOutput`, `AgentSelectionPlan`, `ReportDraft`). For text output (writer agents), return `str` directly. - -**Agent-Specific Rules**: -- `knowledge_gap.py`: Outputs `KnowledgeGapOutput`. Evaluates research completeness. -- `tool_selector.py`: Outputs `AgentSelectionPlan`. Selects tools (RAG/web/database). -- `writer.py`: Returns markdown string. Includes citations in numbered format. -- `long_writer.py`: Uses `ReportDraft` input/output. Handles section-by-section writing. -- `proofreader.py`: Takes `ReportDraft`, returns polished markdown. -- `thinking.py`: Returns observation string from conversation history. -- `input_parser.py`: Outputs `ParsedQuery` with research mode detection. - ---- - -## src/tools/ - Search Tool Rules - -**Protocol**: All tools implement `SearchTool` protocol from `src/tools/base.py`: `name` property and `async def search(query, max_results) -> list[Evidence]`. - -**Rate Limiting**: Use `@retry` decorator from tenacity: `@retry(stop=stop_after_attempt(3), wait=wait_exponential(...))`. Implement `_rate_limit()` method for APIs with limits. Use shared rate limiters from `src/tools/rate_limiter.py`. - -**Error Handling**: Raise `SearchError` or `RateLimitError` on failures. Handle HTTP errors (429, 500, timeout). Return empty list on non-critical errors (log warning). - -**Query Preprocessing**: Use `preprocess_query()` from `src/tools/query_utils.py` to remove noise and expand synonyms. - -**Evidence Conversion**: Convert API responses to `Evidence` objects with `Citation`. Extract metadata (title, url, date, authors). Set relevance scores (0.0-1.0). Handle missing fields gracefully. - -**Tool-Specific Rules**: -- `pubmed.py`: Use NCBI E-utilities (ESearch → EFetch). Rate limit: 0.34s between requests. Parse XML with `xmltodict`. Handle single vs. multiple articles. -- `clinicaltrials.py`: Use `requests` library (NOT httpx - WAF blocks httpx). Run in thread pool: `await asyncio.to_thread(requests.get, ...)`. Filter: Only interventional studies, active/completed. -- `europepmc.py`: Handle preprint markers: `[PREPRINT - Not peer-reviewed]`. Build URLs from DOI or PMID. -- `rag_tool.py`: Wraps `LlamaIndexRAGService`. Returns Evidence from RAG results. Handles ingestion. -- `search_handler.py`: Orchestrates parallel searches across multiple tools. Uses `asyncio.gather()` with `return_exceptions=True`. Aggregates results into `SearchResult`. - ---- - -## src/middleware/ - Middleware Rules - -**State Management**: Use `ContextVar` for thread-safe isolation. `WorkflowState` uses `ContextVar[WorkflowState | None]`. Initialize with `init_workflow_state(embedding_service)`. Access with `get_workflow_state()` (auto-initializes if missing). - -**WorkflowState**: Tracks `evidence: list[Evidence]`, `conversation: Conversation`, `embedding_service: Any`. Methods: `add_evidence()` (deduplicates by URL), `async search_related()` (semantic search). - -**WorkflowManager**: Manages parallel research loops. Methods: `add_loop()`, `run_loops_parallel()`, `update_loop_status()`, `sync_loop_evidence_to_state()`. Uses `asyncio.gather()` for parallel execution. Handles errors per loop (don't fail all if one fails). - -**BudgetTracker**: Tracks tokens, time, iterations per loop and globally. Methods: `create_budget()`, `add_tokens()`, `start_timer()`, `update_timer()`, `increment_iteration()`, `check_budget()`, `can_continue()`. Token estimation: `estimate_tokens(text)` (~4 chars per token), `estimate_llm_call_tokens(prompt, response)`. - -**Models**: All middleware models in `src/utils/models.py`. `IterationData`, `Conversation`, `ResearchLoop`, `BudgetStatus` are used by middleware. - ---- - -## src/orchestrator/ - Orchestration Rules - -**Research Flows**: Two patterns: `IterativeResearchFlow` (single loop) and `DeepResearchFlow` (plan → parallel loops → synthesis). Both support agent chains (`use_graph=False`) and graph execution (`use_graph=True`). - -**IterativeResearchFlow**: Pattern: Generate observations → Evaluate gaps → Select tools → Execute → Judge → Continue/Complete. Uses `KnowledgeGapAgent`, `ToolSelectorAgent`, `ThinkingAgent`, `WriterAgent`, `JudgeHandler`. Tracks iterations, time, budget. - -**DeepResearchFlow**: Pattern: Planner → Parallel iterative loops per section → Synthesizer. Uses `PlannerAgent`, `IterativeResearchFlow` (per section), `LongWriterAgent` or `ProofreaderAgent`. Uses `WorkflowManager` for parallel execution. - -**Graph Orchestrator**: Uses Pydantic AI Graphs (when available) or agent chains (fallback). Routes based on research mode (iterative/deep/auto). Streams `AgentEvent` objects for UI. - -**State Initialization**: Always call `init_workflow_state()` before running flows. Initialize `BudgetTracker` per loop. Use `WorkflowManager` for parallel coordination. - -**Event Streaming**: Yield `AgentEvent` objects during execution. Event types: "started", "search_complete", "judge_complete", "hypothesizing", "synthesizing", "complete", "error". Include iteration numbers and data payloads. - ---- - -## src/services/ - Service Rules - -**EmbeddingService**: Local sentence-transformers (NO API key required). All operations async-safe via `run_in_executor()`. ChromaDB for vector storage. Deduplication threshold: 0.85 (85% similarity = duplicate). - -**LlamaIndexRAGService**: Uses OpenAI embeddings (requires `OPENAI_API_KEY`). Methods: `ingest_evidence()`, `retrieve()`, `query()`. Returns documents with metadata (source, title, url, date, authors). Lazy initialization with graceful fallback. - -**StatisticalAnalyzer**: Generates Python code via LLM. Executes in Modal sandbox (secure, isolated). Library versions pinned in `SANDBOX_LIBRARIES` dict. Returns `AnalysisResult` with verdict (SUPPORTED/REFUTED/INCONCLUSIVE). - -**Singleton Pattern**: Use `@lru_cache(maxsize=1)` for singletons: `@lru_cache(maxsize=1); def get_service() -> Service: return Service()`. Lazy initialization to avoid requiring dependencies at import time. - ---- - -## src/utils/ - Utility Rules - -**Models**: All Pydantic models in `src/utils/models.py`. Use frozen models (`model_config = {"frozen": True}`) except where mutation needed. Use `Field()` with descriptions. Validate with constraints. - -**Config**: Settings via Pydantic Settings (`src/utils/config.py`). Load from `.env` automatically. Use `settings` singleton: `from src.utils.config import settings`. Validate API keys with properties: `has_openai_key`, `has_anthropic_key`. - -**Exceptions**: Custom exception hierarchy in `src/utils/exceptions.py`. Base: `DeepCriticalError`. Specific: `SearchError`, `RateLimitError`, `JudgeError`, `ConfigurationError`. Always chain exceptions. - -**LLM Factory**: Centralized LLM model creation in `src/utils/llm_factory.py`. Supports OpenAI, Anthropic, HF Inference. Use `get_model()` or factory functions. Check requirements before initialization. - -**Citation Validator**: Use `validate_references()` from `src/utils/citation_validator.py`. Removes hallucinated citations (URLs not in evidence). Logs warnings. Returns validated report string. - ---- - -## src/orchestrator_factory.py Rules - -**Purpose**: Factory for creating orchestrators. Supports "simple" (legacy) and "advanced" (magentic) modes. Auto-detects mode based on API key availability. - -**Pattern**: Lazy import for optional dependencies (`_get_magentic_orchestrator_class()`). Handles `ImportError` gracefully with clear error messages. - -**Mode Detection**: `_determine_mode()` checks explicit mode or auto-detects: "advanced" if `settings.has_openai_key`, else "simple". Maps "magentic" → "advanced". - -**Function Signature**: `create_orchestrator(search_handler, judge_handler, config, mode) -> Any`. Simple mode requires handlers. Advanced mode uses MagenticOrchestrator. - -**Error Handling**: Raise `ValueError` with clear messages if requirements not met. Log mode selection with structlog. - ---- - -## src/orchestrator_hierarchical.py Rules - -**Purpose**: Hierarchical orchestrator using middleware and sub-teams. Adapts Magentic ChatAgent to SubIterationTeam protocol. - -**Pattern**: Uses `SubIterationMiddleware` with `ResearchTeam` and `LLMSubIterationJudge`. Event-driven via callback queue. - -**State Initialization**: Initialize embedding service with graceful fallback. Use `init_magentic_state()` (deprecated, but kept for compatibility). - -**Event Streaming**: Uses `asyncio.Queue` for event coordination. Yields `AgentEvent` objects. Handles event callback pattern with `asyncio.wait()`. - -**Error Handling**: Log errors with context. Yield error events. Process remaining events after task completion. - ---- - -## src/orchestrator_magentic.py Rules - -**Purpose**: Magentic-based orchestrator using ChatAgent pattern. Each agent has internal LLM. Manager orchestrates agents. - -**Pattern**: Uses `MagenticBuilder` with participants (searcher, hypothesizer, judge, reporter). Manager uses `OpenAIChatClient`. Workflow built in `_build_workflow()`. - -**Event Processing**: `_process_event()` converts Magentic events to `AgentEvent`. Handles: `MagenticOrchestratorMessageEvent`, `MagenticAgentMessageEvent`, `MagenticFinalResultEvent`, `MagenticAgentDeltaEvent`, `WorkflowOutputEvent`. - -**Text Extraction**: `_extract_text()` defensively extracts text from messages. Priority: `.content` → `.text` → `str(message)`. Handles buggy message objects. - -**State Initialization**: Initialize embedding service with graceful fallback. Use `init_magentic_state()` (deprecated). - -**Requirements**: Must call `check_magentic_requirements()` in `__init__`. Requires `agent-framework-core` and OpenAI API key. - -**Event Types**: Maps agent names to event types: "search" → "search_complete", "judge" → "judge_complete", "hypothes" → "hypothesizing", "report" → "synthesizing". - ---- - -## src/agent_factory/ - Factory Rules - -**Pattern**: Factory functions for creating agents and handlers. Lazy initialization for optional dependencies. Support OpenAI/Anthropic/HF Inference. - -**Judges**: `create_judge_handler()` creates `JudgeHandler` with structured output (`JudgeAssessment`). Supports `MockJudgeHandler`, `HFInferenceJudgeHandler` as fallbacks. - -**Agents**: Factory functions in `agents.py` for all Pydantic AI agents. Pattern: `create_agent_name(model: Any | None = None) -> AgentName`. Use `get_model()` if model not provided. - -**Graph Builder**: `graph_builder.py` contains utilities for building research graphs. Supports iterative and deep research graph construction. - -**Error Handling**: Raise `ConfigurationError` if required API keys missing. Log agent creation. Handle import errors gracefully. - ---- - -## src/prompts/ - Prompt Rules - -**Pattern**: System prompts stored as module-level constants. Include date injection: `datetime.now().strftime("%Y-%m-%d")`. Format evidence with truncation (1500 chars per item). - -**Judge Prompts**: In `judge.py`. Handle empty evidence case separately. Always request structured JSON output. - -**Hypothesis Prompts**: In `hypothesis.py`. Use diverse evidence selection (MMR algorithm). Sentence-aware truncation. - -**Report Prompts**: In `report.py`. Include full citation details. Use diverse evidence selection (n=20). Emphasize citation validation rules. - ---- - -## Testing Rules - -**Structure**: Unit tests in `tests/unit/` (mocked, fast). Integration tests in `tests/integration/` (real APIs, marked `@pytest.mark.integration`). - -**Mocking**: Use `respx` for httpx mocking. Use `pytest-mock` for general mocking. Mock LLM calls in unit tests (use `MockJudgeHandler`). - -**Fixtures**: Common fixtures in `tests/conftest.py`: `mock_httpx_client`, `mock_llm_response`. - -**Coverage**: Aim for >80% coverage. Test error handling, edge cases, and integration paths. - ---- - -## File-Specific Agent Rules - -**knowledge_gap.py**: Outputs `KnowledgeGapOutput`. System prompt evaluates research completeness. Handles conversation history. Returns fallback on error. - -**writer.py**: Returns markdown string. System prompt includes citation format examples. Validates inputs. Truncates long findings. Retry logic for transient failures. - -**long_writer.py**: Uses `ReportDraft` input/output. Writes sections iteratively. Reformats references (deduplicates, renumbers). Reformats section headings. - -**proofreader.py**: Takes `ReportDraft`, returns polished markdown. Removes duplicates. Adds summary. Preserves references. - -**tool_selector.py**: Outputs `AgentSelectionPlan`. System prompt lists available agents (WebSearchAgent, SiteCrawlerAgent, RAGAgent). Guidelines for when to use each. - -**thinking.py**: Returns observation string. Generates observations from conversation history. Uses query and background context. - -**input_parser.py**: Outputs `ParsedQuery`. Detects research mode (iterative/deep). Extracts entities and research questions. Improves/refines query. - - - - - - - diff --git a/.env.example b/.env.example deleted file mode 100644 index 442ff75d33f92422e78850b3c9d6d49af6f1d6e3..0000000000000000000000000000000000000000 --- a/.env.example +++ /dev/null @@ -1,107 +0,0 @@ -# HuggingFace -HF_TOKEN=your_huggingface_token_here - -# OpenAI (optional) -OPENAI_API_KEY=your_openai_key_here - -# Anthropic (optional) -ANTHROPIC_API_KEY=your_anthropic_key_here - -# Model names (optional - sensible defaults set in config.py) -# ANTHROPIC_MODEL=claude-sonnet-4-5-20250929 -# OPENAI_MODEL=gpt-5.1 - - -# ============================================ -# Audio Processing Configuration (TTS) -# ============================================ -# Kokoro TTS Model Configuration -TTS_MODEL=hexgrad/Kokoro-82M -TTS_VOICE=af_heart -TTS_SPEED=1.0 -TTS_GPU=T4 -TTS_TIMEOUT=60 - -# Available TTS Voices: -# American English Female: af_heart, af_bella, af_nicole, af_aoede, af_kore, af_sarah, af_nova, af_sky, af_alloy, af_jessica, af_river -# American English Male: am_michael, am_fenrir, am_puck, am_echo, am_eric, am_liam, am_onyx, am_santa, am_adam - -# Available GPU Types (Modal): -# T4 - Cheapest, good for testing (default) -# A10 - Good balance of cost/performance -# A100 - Fastest, most expensive -# L4 - NVIDIA L4 GPU -# L40S - NVIDIA L40S GPU -# Note: GPU type is set at function definition time. Changes require app restart. - -# ============================================ -# Audio Processing Configuration (STT) -# ============================================ -# Speech-to-Text API Configuration -STT_API_URL=nvidia/canary-1b-v2 -STT_SOURCE_LANG=English -STT_TARGET_LANG=English - -# Available STT Languages: -# English, Bulgarian, Croatian, Czech, Danish, Dutch, Estonian, Finnish, French, German, Greek, Hungarian, Italian, Latvian, Lithuanian, Maltese, Polish, Portuguese, Romanian, Slovak, Slovenian, Spanish, Swedish, Russian, Ukrainian - -# ============================================ -# Audio Feature Flags -# ============================================ -ENABLE_AUDIO_INPUT=true -ENABLE_AUDIO_OUTPUT=true - -# ============================================ -# Image OCR Configuration -# ============================================ -OCR_API_URL=prithivMLmods/Multimodal-OCR3 -ENABLE_IMAGE_INPUT=true - -# ============== EMBEDDINGS ============== - -# OpenAI Embedding Model (used if LLM_PROVIDER is openai and performing RAG/Embeddings) -OPENAI_EMBEDDING_MODEL=text-embedding-3-small - -# Local Embedding Model (used for local/offline embeddings) -LOCAL_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 - -# ============== HUGGINGFACE (FREE TIER) ============== - -# HuggingFace Token - enables Llama 3.1 (best quality free model) -# Get yours at: https://huggingface.co/settings/tokens -# -# WITHOUT HF_TOKEN: Falls back to ungated models (zephyr-7b-beta) -# WITH HF_TOKEN: Uses Llama 3.1 8B Instruct (requires accepting license) -# -# For HuggingFace Spaces deployment: -# Set this as a "Secret" in Space Settings -> Variables and secrets -# Users/judges don't need their own token - the Space secret is used -# -HF_TOKEN=hf_your-token-here - -# ============== AGENT CONFIGURATION ============== - -MAX_ITERATIONS=10 -SEARCH_TIMEOUT=30 -LOG_LEVEL=INFO - -# ============================================ -# Modal Configuration (Required for TTS) -# ============================================ -# Modal credentials are required for TTS (Text-to-Speech) functionality -# Get your credentials from: https://modal.com/ -MODAL_TOKEN_ID=your_modal_token_id_here -MODAL_TOKEN_SECRET=your_modal_token_secret_here - -# ============== EXTERNAL SERVICES ============== - -# PubMed (optional - higher rate limits) -NCBI_API_KEY=your-ncbi-key-here - -# Vector Database (optional - for LlamaIndex RAG) -CHROMA_DB_PATH=./chroma_db -# Neo4j Knowledge Graph -NEO4J_URI=bolt://localhost:7687 -NEO4J_USER=neo4j -NEO4J_PASSWORD=your_neo4j_password_here -NEO4J_DATABASE=your_database_name diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 9a982ecba57a76ced6b098ad431436049f052514..0000000000000000000000000000000000000000 --- a/.gitignore +++ /dev/null @@ -1,84 +0,0 @@ -folder/ -site/ -.cursor/ -.ruff_cache/ -# Python -__pycache__/ -*.py[cod] -*$py.class -*.so -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg - -# Virtual environments -.venv/ -venv/ -ENV/ -env/ - -# IDE -.vscode/ -.idea/ -*.swp -*.swo - -# Environment -.env -.env.local -*.local - -# Claude -.claude/ - -# Burner docs (working drafts, not for commit) -burner_docs/ - -# Reference repos (clone locally, don't commit) -reference_repos/autogen-microsoft/ -reference_repos/claude-agent-sdk/ -reference_repos/pydanticai-research-agent/ -reference_repos/pubmed-mcp-server/ -reference_repos/DeepCritical/ - -# Keep the README in reference_repos -!reference_repos/README.md - -# Development directory -dev/ - -# OS -.DS_Store -Thumbs.db - -# Logs -*.log -logs/ - -# Testing -.pytest_cache/ -.mypy_cache/ -.coverage -htmlcov/ -test_output*.txt - -# Database files -chroma_db/ -*.sqlite3 - - -# Trigger rebuild Wed Nov 26 17:51:41 EST 2025 -.env diff --git a/.pre-commit-hooks/run_pytest.ps1 b/.pre-commit-hooks/run_pytest.ps1 deleted file mode 100644 index 3df4f371b845a48ce3a1ea32e307218abbd5a033..0000000000000000000000000000000000000000 --- a/.pre-commit-hooks/run_pytest.ps1 +++ /dev/null @@ -1,19 +0,0 @@ -# PowerShell pytest runner for pre-commit (Windows) -# Uses uv if available, otherwise falls back to pytest - -if (Get-Command uv -ErrorAction SilentlyContinue) { - # Sync dependencies before running tests - uv sync - uv run pytest $args -} else { - Write-Warning "uv not found, using system pytest (may have missing dependencies)" - pytest $args -} - - - - - - - - diff --git a/.pre-commit-hooks/run_pytest.sh b/.pre-commit-hooks/run_pytest.sh deleted file mode 100644 index b2a4be920113fd340631f64602c24042e8c81086..0000000000000000000000000000000000000000 --- a/.pre-commit-hooks/run_pytest.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Cross-platform pytest runner for pre-commit -# Uses uv if available, otherwise falls back to pytest - -if command -v uv >/dev/null 2>&1; then - # Sync dependencies before running tests - uv sync - uv run pytest "$@" -else - echo "Warning: uv not found, using system pytest (may have missing dependencies)" - pytest "$@" -fi - - - - - - - - diff --git a/.pre-commit-hooks/run_pytest_embeddings.ps1 b/.pre-commit-hooks/run_pytest_embeddings.ps1 deleted file mode 100644 index 47a3e32a202240c42e5a205d2afd778a23292db7..0000000000000000000000000000000000000000 --- a/.pre-commit-hooks/run_pytest_embeddings.ps1 +++ /dev/null @@ -1,14 +0,0 @@ -# PowerShell wrapper to sync embeddings dependencies and run embeddings tests - -$ErrorActionPreference = "Stop" - -if (Get-Command uv -ErrorAction SilentlyContinue) { - Write-Host "Syncing embeddings dependencies..." - uv sync --extra embeddings - Write-Host "Running embeddings tests..." - uv run pytest tests/ -v -m local_embeddings --tb=short -p no:logfire -} else { - Write-Error "uv not found" - exit 1 -} - diff --git a/.pre-commit-hooks/run_pytest_embeddings.sh b/.pre-commit-hooks/run_pytest_embeddings.sh deleted file mode 100644 index 6f1b80746217244367ee86fcd7d69837df648b40..0000000000000000000000000000000000000000 --- a/.pre-commit-hooks/run_pytest_embeddings.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -# Wrapper script to sync embeddings dependencies and run embeddings tests - -set -e - -if command -v uv >/dev/null 2>&1; then - echo "Syncing embeddings dependencies..." - uv sync --extra embeddings - echo "Running embeddings tests..." - uv run pytest tests/ -v -m local_embeddings --tb=short -p no:logfire -else - echo "Error: uv not found" - exit 1 -fi - diff --git a/.pre-commit-hooks/run_pytest_unit.ps1 b/.pre-commit-hooks/run_pytest_unit.ps1 deleted file mode 100644 index c1196d22e86fe66a56d12f673c003ac88aa6b09f..0000000000000000000000000000000000000000 --- a/.pre-commit-hooks/run_pytest_unit.ps1 +++ /dev/null @@ -1,14 +0,0 @@ -# PowerShell wrapper to sync dependencies and run unit tests - -$ErrorActionPreference = "Stop" - -if (Get-Command uv -ErrorAction SilentlyContinue) { - Write-Host "Syncing dependencies..." - uv sync - Write-Host "Running unit tests..." - uv run pytest tests/unit/ -v -m "not openai and not embedding_provider" --tb=short -p no:logfire -} else { - Write-Error "uv not found" - exit 1 -} - diff --git a/.pre-commit-hooks/run_pytest_unit.sh b/.pre-commit-hooks/run_pytest_unit.sh deleted file mode 100644 index 173ab1b607647ecf4b4a1de6b75abd47fc0130ec..0000000000000000000000000000000000000000 --- a/.pre-commit-hooks/run_pytest_unit.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -# Wrapper script to sync dependencies and run unit tests - -set -e - -if command -v uv >/dev/null 2>&1; then - echo "Syncing dependencies..." - uv sync - echo "Running unit tests..." - uv run pytest tests/unit/ -v -m "not openai and not embedding_provider" --tb=short -p no:logfire -else - echo "Error: uv not found" - exit 1 -fi - diff --git a/.pre-commit-hooks/run_pytest_with_sync.ps1 b/.pre-commit-hooks/run_pytest_with_sync.ps1 deleted file mode 100644 index 546a5096bc6e4b9a46d039f5761234022b8658dd..0000000000000000000000000000000000000000 --- a/.pre-commit-hooks/run_pytest_with_sync.ps1 +++ /dev/null @@ -1,25 +0,0 @@ -# PowerShell wrapper for pytest runner -# Ensures uv is available and runs the Python script - -param( - [Parameter(Position=0)] - [string]$TestType = "unit" -) - -$ErrorActionPreference = "Stop" - -# Check if uv is available -if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { - Write-Error "uv not found. Please install uv: https://github.com/astral-sh/uv" - exit 1 -} - -# Get the script directory -$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path -$PythonScript = Join-Path $ScriptDir "run_pytest_with_sync.py" - -# Run the Python script using uv -uv run python $PythonScript $TestType - -exit $LASTEXITCODE - diff --git a/.pre-commit-hooks/run_pytest_with_sync.py b/.pre-commit-hooks/run_pytest_with_sync.py deleted file mode 100644 index 70c4eb52f8239d167868ba47b2d4bb80d9ac3173..0000000000000000000000000000000000000000 --- a/.pre-commit-hooks/run_pytest_with_sync.py +++ /dev/null @@ -1,235 +0,0 @@ -#!/usr/bin/env python3 -"""Cross-platform pytest runner that syncs dependencies before running tests.""" - -import shutil -import subprocess -import sys -from pathlib import Path - - -def clean_caches(project_root: Path) -> None: - """Remove pytest and Python cache directories and files. - - Comprehensively removes all cache files and directories to ensure - clean test runs. Only scans specific directories to avoid resource - exhaustion from scanning large directories like .venv on Windows. - """ - # Directories to scan for caches (only project code, not dependencies) - scan_dirs = ["src", "tests", ".pre-commit-hooks"] - - # Directories to exclude (to avoid resource issues) - exclude_dirs = { - ".venv", - "venv", - "ENV", - "env", - ".git", - "node_modules", - "dist", - "build", - ".eggs", - "reference_repos", - "folder", - } - - # Comprehensive list of cache patterns to remove - cache_patterns = [ - ".pytest_cache", - "__pycache__", - "*.pyc", - "*.pyo", - "*.pyd", - ".mypy_cache", - ".ruff_cache", - ".coverage", - "coverage.xml", - "htmlcov", - ".hypothesis", # Hypothesis testing framework cache - ".tox", # Tox cache (if used) - ".cache", # General Python cache - ] - - def should_exclude(path: Path) -> bool: - """Check if a path should be excluded from cache cleanup.""" - # Check if any parent directory is in exclude list - for parent in path.parents: - if parent.name in exclude_dirs: - return True - # Check if the path itself is excluded - if path.name in exclude_dirs: - return True - return False - - cleaned = [] - - # Only scan specific directories to avoid resource exhaustion - for scan_dir in scan_dirs: - scan_path = project_root / scan_dir - if not scan_path.exists(): - continue - - for pattern in cache_patterns: - if "*" in pattern: - # Handle glob patterns for files - try: - for cache_file in scan_path.rglob(pattern): - if should_exclude(cache_file): - continue - try: - if cache_file.is_file(): - cache_file.unlink() - cleaned.append(str(cache_file.relative_to(project_root))) - except OSError: - pass # Ignore errors (file might be locked or already deleted) - except OSError: - pass # Ignore errors during directory traversal - else: - # Handle directory patterns - try: - for cache_dir in scan_path.rglob(pattern): - if should_exclude(cache_dir): - continue - try: - if cache_dir.is_dir(): - shutil.rmtree(cache_dir, ignore_errors=True) - cleaned.append(str(cache_dir.relative_to(project_root))) - except OSError: - pass # Ignore errors (directory might be locked) - except OSError: - pass # Ignore errors during directory traversal - - # Also clean root-level caches (like .pytest_cache in project root) - root_cache_patterns = [ - ".pytest_cache", - ".mypy_cache", - ".ruff_cache", - ".coverage", - "coverage.xml", - "htmlcov", - ".hypothesis", - ".tox", - ".cache", - ".pytest", - ] - for pattern in root_cache_patterns: - cache_path = project_root / pattern - if cache_path.exists(): - try: - if cache_path.is_dir(): - shutil.rmtree(cache_path, ignore_errors=True) - elif cache_path.is_file(): - cache_path.unlink() - cleaned.append(pattern) - except OSError: - pass - - # Also remove any .pyc files in root directory - try: - for pyc_file in project_root.glob("*.pyc"): - try: - pyc_file.unlink() - cleaned.append(pyc_file.name) - except OSError: - pass - except OSError: - pass - - if cleaned: - print( - f"Cleaned {len(cleaned)} cache items: {', '.join(cleaned[:10])}{'...' if len(cleaned) > 10 else ''}" - ) - else: - print("No cache files found to clean") - - -def run_command( - cmd: list[str], check: bool = True, shell: bool = False, cwd: str | None = None -) -> int: - """Run a command and return exit code.""" - try: - result = subprocess.run( - cmd, - check=check, - shell=shell, - cwd=cwd, - env=None, # Use current environment, uv will handle venv - ) - return result.returncode - except subprocess.CalledProcessError as e: - return e.returncode - except FileNotFoundError: - print(f"Error: Command not found: {cmd[0]}") - return 1 - - -def main() -> int: - """Main entry point.""" - import os - - # Get the project root (where pyproject.toml is) - script_dir = Path(__file__).parent - project_root = script_dir.parent - - # Change to project root to ensure uv works correctly - os.chdir(project_root) - - # Clean caches before running tests - print("Cleaning pytest and Python caches...") - clean_caches(project_root) - - # Check if uv is available - if run_command(["uv", "--version"], check=False) != 0: - print("Error: uv not found. Please install uv: https://github.com/astral-sh/uv") - return 1 - - # Parse arguments - test_type = sys.argv[1] if len(sys.argv) > 1 else "unit" - extra_args = sys.argv[2:] if len(sys.argv) > 2 else [] - - # Sync dependencies - always include dev - # Note: embeddings dependencies are now in main dependencies, not optional - # Use --extra dev for [project.optional-dependencies].dev (not --dev which is for [dependency-groups]) - sync_cmd = ["uv", "sync", "--extra", "dev"] - - print(f"Syncing dependencies for {test_type} tests...") - if run_command(sync_cmd, cwd=project_root) != 0: - return 1 - - # Build pytest command - use uv run to ensure correct environment - if test_type == "unit": - pytest_args = [ - "tests/unit/", - "-v", - "-m", - "not openai and not embedding_provider", - "--tb=short", - "-p", - "no:logfire", - "--cache-clear", # Clear pytest cache before running - ] - elif test_type == "embeddings": - pytest_args = [ - "tests/", - "-v", - "-m", - "local_embeddings", - "--tb=short", - "-p", - "no:logfire", - "--cache-clear", # Clear pytest cache before running - ] - else: - pytest_args = [] - - pytest_args.extend(extra_args) - - # Use uv run python -m pytest to ensure we use the venv's pytest - # This is more reliable than uv run pytest which might find system pytest - pytest_cmd = ["uv", "run", "python", "-m", "pytest", *pytest_args] - - print(f"Running {test_type} tests...") - return run_command(pytest_cmd, cwd=project_root) - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/.python-version b/.python-version deleted file mode 100644 index 2c0733315e415bfb5e5b353f9996ecd964d395b2..0000000000000000000000000000000000000000 --- a/.python-version +++ /dev/null @@ -1 +0,0 @@ -3.11 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 9a93fd1f812752141d49e4e27efee17405ed9563..0000000000000000000000000000000000000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,494 +0,0 @@ -# Contributing to The DETERMINATOR - -Thank you for your interest in contributing to The DETERMINATOR! This guide will help you get started. - -## Table of Contents - -- [Git Workflow](#git-workflow) -- [Getting Started](#getting-started) -- [Development Commands](#development-commands) -- [MCP Integration](#mcp-integration) -- [Common Pitfalls](#common-pitfalls) -- [Key Principles](#key-principles) -- [Pull Request Process](#pull-request-process) - -> **Note**: Additional sections (Code Style, Error Handling, Testing, Implementation Patterns, Code Quality, and Prompt Engineering) are available as separate pages in the [documentation](https://deepcritical.github.io/GradioDemo/contributing/). -> **Note on Project Names**: "The DETERMINATOR" is the product name, "DeepCritical" is the organization/project name, and "determinator" is the Python package name. - -## Repository Information - -- **GitHub Repository**: [`DeepCritical/GradioDemo`](https://github.com/DeepCritical/GradioDemo) (source of truth, PRs, code review) -- **HuggingFace Space**: [`DataQuests/DeepCritical`](https://huggingface.co/spaces/DataQuests/DeepCritical) (deployment/demo) -- **Package Name**: `determinator` (Python package name in `pyproject.toml`) - -## Git Workflow - -- `main`: Production-ready (GitHub) -- `dev`: Development integration (GitHub) -- Use feature branches: `yourname-dev` -- **NEVER** push directly to `main` or `dev` on HuggingFace -- GitHub is source of truth; HuggingFace is for deployment - -### Dual Repository Setup - -This project uses a dual repository setup: - -- **GitHub (`DeepCritical/GradioDemo`)**: Source of truth for code, PRs, and code review -- **HuggingFace (`DataQuests/DeepCritical`)**: Deployment target for the Gradio demo - -#### Remote Configuration - -When cloning, set up remotes as follows: - -```bash -# Clone from GitHub -git clone https://github.com/DeepCritical/GradioDemo.git -cd GradioDemo - -# Add HuggingFace remote (optional, for deployment) -git remote add huggingface-upstream https://huggingface.co/spaces/DataQuests/DeepCritical -``` - -**Important**: Never push directly to `main` or `dev` on HuggingFace. Always work through GitHub PRs. GitHub is the source of truth; HuggingFace is for deployment/demo only. - -## Getting Started - -1. **Fork the repository** on GitHub: [`DeepCritical/GradioDemo`](https://github.com/DeepCritical/GradioDemo) -2. **Clone your fork**: - - ```bash - git clone https://github.com/yourusername/GradioDemo.git - cd GradioDemo - ``` - -3. **Install dependencies**: - - ```bash - uv sync --all-extras - uv run pre-commit install - ``` - -4. **Create a feature branch**: - - ```bash - git checkout -b yourname-feature-name - ``` - -5. **Make your changes** following the guidelines below -6. **Run checks**: - - ```bash - uv run ruff check src tests - uv run mypy src - uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire - ``` - -7. **Commit and push**: - - ```bash - git commit -m "Description of changes" - git push origin yourname-feature-name - ``` - -8. **Create a pull request** on GitHub - -## Package Manager - -This project uses [`uv`](https://github.com/astral-sh/uv) as the package manager. All commands should be prefixed with `uv run` to ensure they run in the correct environment. - -### Installation - -```bash -# Install uv if you haven't already (recommended: standalone installer) -# Unix/macOS/Linux: -curl -LsSf https://astral.sh/uv/install.sh | sh - -# Windows (PowerShell): -powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" - -# Alternative: pipx install uv -# Or: pip install uv - -# Sync all dependencies including dev extras -uv sync --all-extras - -# Install pre-commit hooks -uv run pre-commit install -``` - -## Development Commands - -```bash -# Installation -uv sync --all-extras # Install all dependencies including dev -uv run pre-commit install # Install pre-commit hooks - -# Code Quality Checks (run all before committing) -uv run ruff check src tests # Lint with ruff -uv run ruff format src tests # Format with ruff -uv run mypy src # Type checking -uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire # Tests with coverage - -# Testing Commands -uv run pytest tests/unit/ -v -m "not openai" -p no:logfire # Run unit tests (excludes OpenAI tests) -uv run pytest tests/ -v -m "huggingface" -p no:logfire # Run HuggingFace tests -uv run pytest tests/ -v -p no:logfire # Run all tests -uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire # Tests with terminal coverage -uv run pytest --cov=src --cov-report=html -p no:logfire # Generate HTML coverage report (opens htmlcov/index.html) - -# Documentation Commands -uv run mkdocs build # Build documentation -uv run mkdocs serve # Serve documentation locally (http://127.0.0.1:8000) -``` - -### Test Markers - -The project uses pytest markers to categorize tests. See [Testing Guidelines](docs/contributing/testing.md) for details: - -- `unit`: Unit tests (mocked, fast) -- `integration`: Integration tests (real APIs) -- `slow`: Slow tests -- `openai`: Tests requiring OpenAI API key -- `huggingface`: Tests requiring HuggingFace API key -- `embedding_provider`: Tests requiring API-based embedding providers -- `local_embeddings`: Tests using local embeddings - -**Note**: The `-p no:logfire` flag disables the logfire plugin to avoid conflicts during testing. - -## Code Style & Conventions - -### Type Safety - -- **ALWAYS** use type hints for all function parameters and return types -- Use `mypy --strict` compliance (no `Any` unless absolutely necessary) -- Use `TYPE_CHECKING` imports for circular dependencies: - - -[TYPE_CHECKING Import Pattern](../src/utils/citation_validator.py) start_line:8 end_line:11 - - -### Pydantic Models - -- All data exchange uses Pydantic models (`src/utils/models.py`) -- Models are frozen (`model_config = {"frozen": True}`) for immutability -- Use `Field()` with descriptions for all model fields -- Validate with `ge=`, `le=`, `min_length=`, `max_length=` constraints - -### Async Patterns - -- **ALL** I/O operations must be async (`async def`, `await`) -- Use `asyncio.gather()` for parallel operations -- CPU-bound work (embeddings, parsing) must use `run_in_executor()`: - -```python -loop = asyncio.get_running_loop() -result = await loop.run_in_executor(None, cpu_bound_function, args) -``` - -- Never block the event loop with synchronous I/O - -### Linting - -- Ruff with 100-char line length -- Ignore rules documented in `pyproject.toml`: - - `PLR0913`: Too many arguments (agents need many params) - - `PLR0912`: Too many branches (complex orchestrator logic) - - `PLR0911`: Too many return statements (complex agent logic) - - `PLR2004`: Magic values (statistical constants) - - `PLW0603`: Global statement (singleton pattern) - - `PLC0415`: Lazy imports for optional dependencies - -### Pre-commit - -- Pre-commit hooks run automatically on commit -- Must pass: lint + typecheck + test-cov -- Install hooks with: `uv run pre-commit install` -- Note: `uv sync --all-extras` installs the pre-commit package, but you must run `uv run pre-commit install` separately to set up the git hooks - -## Error Handling & Logging - -### Exception Hierarchy - -Use custom exception hierarchy (`src/utils/exceptions.py`): - - -[Exception Hierarchy](../src/utils/exceptions.py) start_line:4 end_line:31 - - -### Error Handling Rules - -- Always chain exceptions: `raise SearchError(...) from e` -- Log errors with context using `structlog`: - -```python -logger.error("Operation failed", error=str(e), context=value) -``` - -- Never silently swallow exceptions -- Provide actionable error messages - -### Logging - -- Use `structlog` for all logging (NOT `print` or `logging`) -- Import: `import structlog; logger = structlog.get_logger()` -- Log with structured data: `logger.info("event", key=value)` -- Use appropriate levels: DEBUG, INFO, WARNING, ERROR - -### Logging Examples - -```python -logger.info("Starting search", query=query, tools=[t.name for t in tools]) -logger.warning("Search tool failed", tool=tool.name, error=str(result)) -logger.error("Assessment failed", error=str(e)) -``` - -### Error Chaining - -Always preserve exception context: - -```python -try: - result = await api_call() -except httpx.HTTPError as e: - raise SearchError(f"API call failed: {e}") from e -``` - -## Testing Requirements - -### Test Structure - -- Unit tests in `tests/unit/` (mocked, fast) -- Integration tests in `tests/integration/` (real APIs, marked `@pytest.mark.integration`) -- Use markers: `unit`, `integration`, `slow` - -### Mocking - -- Use `respx` for httpx mocking -- Use `pytest-mock` for general mocking -- Mock LLM calls in unit tests (use `MockJudgeHandler`) -- Fixtures in `tests/conftest.py`: `mock_httpx_client`, `mock_llm_response` - -### TDD Workflow - -1. Write failing test in `tests/unit/` -2. Implement in `src/` -3. Ensure test passes -4. Run checks: `uv run ruff check src tests && uv run mypy src && uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire` - -### Test Examples - -```python -@pytest.mark.unit -async def test_pubmed_search(mock_httpx_client): - tool = PubMedTool() - results = await tool.search("metformin", max_results=5) - assert len(results) > 0 - assert all(isinstance(r, Evidence) for r in results) - -@pytest.mark.integration -async def test_real_pubmed_search(): - tool = PubMedTool() - results = await tool.search("metformin", max_results=3) - assert len(results) <= 3 -``` - -### Test Coverage - -- Run `uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire` for coverage report -- Run `uv run pytest --cov=src --cov-report=html -p no:logfire` for HTML coverage report (opens `htmlcov/index.html`) -- Aim for >80% coverage on critical paths -- Exclude: `__init__.py`, `TYPE_CHECKING` blocks - -## Implementation Patterns - -### Search Tools - -All tools implement `SearchTool` protocol (`src/tools/base.py`): - -- Must have `name` property -- Must implement `async def search(query, max_results) -> list[Evidence]` -- Use `@retry` decorator from tenacity for resilience -- Rate limiting: Implement `_rate_limit()` for APIs with limits (e.g., PubMed) -- Error handling: Raise `SearchError` or `RateLimitError` on failures - -Example pattern: - -```python -class MySearchTool: - @property - def name(self) -> str: - return "mytool" - - @retry(stop=stop_after_attempt(3), wait=wait_exponential(...)) - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - # Implementation - return evidence_list -``` - -### Judge Handlers - -- Implement `JudgeHandlerProtocol` (`async def assess(question, evidence) -> JudgeAssessment`) -- Use pydantic-ai `Agent` with `output_type=JudgeAssessment` -- System prompts in `src/prompts/judge.py` -- Support fallback handlers: `MockJudgeHandler`, `HFInferenceJudgeHandler` -- Always return valid `JudgeAssessment` (never raise exceptions) - -### Agent Factory Pattern - -- Use factory functions for creating agents (`src/agent_factory/`) -- Lazy initialization for optional dependencies (e.g., embeddings, Modal) -- Check requirements before initialization: - - -[Check Magentic Requirements](../src/utils/llm_factory.py) start_line:152 end_line:170 - - -### State Management - -- **Magentic Mode**: Use `ContextVar` for thread-safe state (`src/agents/state.py`) -- **Simple Mode**: Pass state via function parameters -- Never use global mutable state (except singletons via `@lru_cache`) - -### Singleton Pattern - -Use `@lru_cache(maxsize=1)` for singletons: - - -[Singleton Pattern Example](../src/services/statistical_analyzer.py) start_line:252 end_line:255 - - -- Lazy initialization to avoid requiring dependencies at import time - -## Code Quality & Documentation - -### Docstrings - -- Google-style docstrings for all public functions -- Include Args, Returns, Raises sections -- Use type hints in docstrings only if needed for clarity - -Example: - - -[Search Method Docstring Example](../src/tools/pubmed.py) start_line:51 end_line:58 - - -### Code Comments - -- Explain WHY, not WHAT -- Document non-obvious patterns (e.g., why `requests` not `httpx` for ClinicalTrials) -- Mark critical sections: `# CRITICAL: ...` -- Document rate limiting rationale -- Explain async patterns when non-obvious - -## Prompt Engineering & Citation Validation - -### Judge Prompts - -- System prompt in `src/prompts/judge.py` -- Format evidence with truncation (1500 chars per item) -- Handle empty evidence case separately -- Always request structured JSON output -- Use `format_user_prompt()` and `format_empty_evidence_prompt()` helpers - -### Hypothesis Prompts - -- Use diverse evidence selection (MMR algorithm) -- Sentence-aware truncation (`truncate_at_sentence()`) -- Format: Drug → Target → Pathway → Effect -- System prompt emphasizes mechanistic reasoning -- Use `format_hypothesis_prompt()` with embeddings for diversity - -### Report Prompts - -- Include full citation details for validation -- Use diverse evidence selection (n=20) -- **CRITICAL**: Emphasize citation validation rules -- Format hypotheses with support/contradiction counts -- System prompt includes explicit JSON structure requirements - -### Citation Validation - -- **ALWAYS** validate references before returning reports -- Use `validate_references()` from `src/utils/citation_validator.py` -- Remove hallucinated citations (URLs not in evidence) -- Log warnings for removed citations -- Never trust LLM-generated citations without validation - -### Citation Validation Rules - -1. Every reference URL must EXACTLY match a provided evidence URL -2. Do NOT invent, fabricate, or hallucinate any references -3. Do NOT modify paper titles, authors, dates, or URLs -4. If unsure about a citation, OMIT it rather than guess -5. Copy URLs exactly as provided - do not create similar-looking URLs - -### Evidence Selection - -- Use `select_diverse_evidence()` for MMR-based selection -- Balance relevance vs diversity (lambda=0.7 default) -- Sentence-aware truncation preserves meaning -- Limit evidence per prompt to avoid context overflow - -## MCP Integration - -### MCP Tools - -- Functions in `src/mcp_tools.py` for Claude Desktop -- Full type hints required -- Google-style docstrings with Args/Returns sections -- Formatted string returns (markdown) - -### Gradio MCP Server - -- Enable with `mcp_server=True` in `demo.launch()` -- Endpoint: `/gradio_api/mcp/` -- Use `ssr_mode=False` to fix hydration issues in HF Spaces - -## Common Pitfalls - -1. **Blocking the event loop**: Never use sync I/O in async functions -2. **Missing type hints**: All functions must have complete type annotations -3. **Hallucinated citations**: Always validate references -4. **Global mutable state**: Use ContextVar or pass via parameters -5. **Import errors**: Lazy-load optional dependencies (magentic, modal, embeddings) -6. **Rate limiting**: Always implement for external APIs -7. **Error chaining**: Always use `from e` when raising exceptions - -## Key Principles - -1. **Type Safety First**: All code must pass `mypy --strict` -2. **Async Everything**: All I/O must be async -3. **Test-Driven**: Write tests before implementation -4. **No Hallucinations**: Validate all citations -5. **Graceful Degradation**: Support free tier (HF Inference) when no API keys -6. **Lazy Loading**: Don't require optional dependencies at import time -7. **Structured Logging**: Use structlog, never print() -8. **Error Chaining**: Always preserve exception context - -## Pull Request Process - -1. Ensure all checks pass: `uv run ruff check src tests && uv run mypy src && uv run pytest --cov=src --cov-report=term-missing tests/unit/ -v -m "not openai" -p no:logfire` -2. Update documentation if needed -3. Add tests for new features -4. Update CHANGELOG if applicable -5. Request review from maintainers -6. Address review feedback -7. Wait for approval before merging - -## Project Structure - -- `src/`: Main source code -- `tests/`: Test files (`unit/` and `integration/`) -- `docs/`: Documentation source files (MkDocs) -- `examples/`: Example usage scripts -- `pyproject.toml`: Project configuration and dependencies -- `.pre-commit-config.yaml`: Pre-commit hook configuration - -## Questions? - -- Open an issue on [GitHub](https://github.com/DeepCritical/GradioDemo) -- Check existing [documentation](https://deepcritical.github.io/GradioDemo/) -- Review code examples in the codebase - -Thank you for contributing to The DETERMINATOR! diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 9d6fc14dce9d1bdbc102a1479304490324313167..0000000000000000000000000000000000000000 --- a/Dockerfile +++ /dev/null @@ -1,52 +0,0 @@ -# Dockerfile for DeepCritical -FROM python:3.11-slim - -# Set working directory -WORKDIR /app - -# Install system dependencies (curl needed for HEALTHCHECK) -RUN apt-get update && apt-get install -y \ - git \ - curl \ - && rm -rf /var/lib/apt/lists/* - -# Install uv -RUN pip install uv==0.5.4 - -# Copy project files -COPY pyproject.toml . -COPY uv.lock . -COPY src/ src/ -COPY README.md . - -# Install runtime dependencies only (no dev/test tools) -RUN uv sync --frozen --no-dev --extra embeddings --extra magentic - -# Create non-root user BEFORE downloading models -RUN useradd --create-home --shell /bin/bash appuser - -# Set cache directory for HuggingFace models (must be writable by appuser) -ENV HF_HOME=/app/.cache -ENV TRANSFORMERS_CACHE=/app/.cache - -# Create cache dir with correct ownership -RUN mkdir -p /app/.cache && chown -R appuser:appuser /app/.cache - -# Pre-download the embedding model during build (as appuser to set correct ownership) -USER appuser -RUN uv run python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')" - -# Expose port -EXPOSE 7860 - -# Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ - CMD curl -f http://localhost:7860/ || exit 1 - -# Set environment variables -ENV GRADIO_SERVER_NAME=0.0.0.0 -ENV GRADIO_SERVER_PORT=7860 -ENV PYTHONPATH=/app - -# Run the app -CMD ["uv", "run", "python", "-m", "src.app"] diff --git a/LICENSE.md b/LICENSE.md deleted file mode 100644 index a1f9be9c2733fb22fc43dfc4e5f23c62dbfb02ad..0000000000000000000000000000000000000000 --- a/LICENSE.md +++ /dev/null @@ -1,25 +0,0 @@ -# License - -DeepCritical is licensed under the MIT License. - -## MIT License - -Copyright (c) 2024 DeepCritical Team - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/README.md b/README.md index a4c0e90eae994942af1b1ea0f5cdfe967a29b8f5..e4c9a9698223d89be48e808d1f511cc4a2182141 100644 --- a/README.md +++ b/README.md @@ -1,63 +1,15 @@ --- -title: The DETERMINATOR -emoji: 🐉 -colorFrom: red -colorTo: yellow +title: DeepCritical +emoji: 📈 +colorFrom: blue +colorTo: purple sdk: gradio -sdk_version: "6.0.1" -python_version: "3.11" +sdk_version: 6.0.0 app_file: src/app.py -hf_oauth: true -hf_oauth_expiration_minutes: 480 -hf_oauth_scopes: - # Required for HuggingFace Inference API (includes all third-party providers) - # This scope grants access to: - # - HuggingFace's own Inference API - # - Third-party inference providers (nebius, together, scaleway, hyperbolic, novita, nscale, sambanova, ovh, fireworks, etc.) - # - All models available through the Inference Providers API - - inference-api - # Optional: Uncomment if you need to access user's billing information - # - read-billing -pinned: true +pinned: false license: mit -tags: - - mcp-in-action-track-enterprise - - mcp-hackathon - - deep-research - - biomedical-ai - - pydantic-ai - - llamaindex - - modal - - building-mcp-track-enterprise - - building-mcp-track-consumer - - mcp-in-action-track-enterprise - - mcp-in-action-track-consumer - - building-mcp-track-modal - - building-mcp-track-blaxel - - building-mcp-track-llama-index - - building-mcp-track-HUGGINGFACE +short_description: Deep Search for Critical Research [BigData] -> [Actionable] --- -> [!IMPORTANT] -> **You are reading the Gradio Demo README!** -> -> - 📚 **Documentation**: See our [technical documentation](https://deepcritical.github.io/GradioDemo/) for detailed information -> - 📖 **Complete README**: Check out the [Github README](.github/README.md) for setup, configuration, and contribution guidelines -> - ⚠️**This README is for our Gradio Demo Only !** +### DeepCritical -
- -[![GitHub](https://img.shields.io/github/stars/DeepCritical/GradioDemo?style=for-the-badge&logo=github&logoColor=white&label=GitHub&labelColor=181717&color=181717)](https://github.com/DeepCritical/GradioDemo) -[![Documentation](https://img.shields.io/badge/Docs-0080FF?style=for-the-badge&logo=readthedocs&logoColor=white&labelColor=0080FF&color=0080FF)](deepcritical.github.io/GradioDemo/) -[![Demo](https://img.shields.io/badge/Demo-FFD21E?style=for-the-badge&logo=huggingface&logoColor=white&labelColor=FFD21E&color=FFD21E)](https://huggingface.co/spaces/DataQuests/DeepCritical) -[![codecov](https://codecov.io/gh/DeepCritical/GradioDemo/graph/badge.svg?token=B1f05RCGpz)](https://codecov.io/gh/DeepCritical/GradioDemo) -[![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) - - -
- -# The DETERMINATOR - -## About - -The DETERMINATOR is a powerful generalist deep research agent system that stops at nothing until finding precise answers to complex questions. It uses iterative search-and-judge loops to comprehensively investigate any research question from any domain. diff --git a/deployments/README.md b/deployments/README.md deleted file mode 100644 index 3a4f4a4d8d7a3cccf6beacd56ce135117f3ad07a..0000000000000000000000000000000000000000 --- a/deployments/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# Deployments - -This directory contains infrastructure deployment scripts for DeepCritical services. - -## Modal Deployments - -### TTS Service (`modal_tts.py`) - -Deploys the Kokoro TTS (Text-to-Speech) function to Modal's GPU infrastructure. - -**Deploy:** -```bash -modal deploy deployments/modal_tts.py -``` - -**Features:** -- Kokoro 82M TTS model -- GPU-accelerated (T4) -- Voice options: af_heart, af_bella, am_michael, etc. -- Configurable speech speed - -**Requirements:** -- Modal account and credentials (`MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET` in `.env`) -- GPU quota on Modal - -**After Deployment:** -The function will be available at: -- App: `deepcritical-tts` -- Function: `kokoro_tts_function` - -The main application (`src/services/tts_modal.py`) will call this deployed function. - ---- - -## Adding New Deployments - -When adding new deployment scripts: - -1. Create a new file: `deployments/.py` -2. Use Modal's app pattern: - ```python - import modal - app = modal.App("deepcritical-") - ``` -3. Document in this README -4. Test deployment: `modal deploy deployments/.py` diff --git a/deployments/modal_tts.py b/deployments/modal_tts.py deleted file mode 100644 index 9987a339f6b89eb63cd512eb594dd6a6d488f42a..0000000000000000000000000000000000000000 --- a/deployments/modal_tts.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Deploy Kokoro TTS function to Modal. - -This script deploys the TTS function to Modal so it can be called -from the main DeepCritical application. - -Usage: - modal deploy deploy_modal_tts.py - -After deployment, the function will be available at: - App: deepcritical-tts - Function: kokoro_tts_function -""" - -import modal -import numpy as np - -# Create Modal app -app = modal.App("deepcritical-tts") - -# Define Kokoro TTS dependencies -KOKORO_DEPENDENCIES = [ - "torch>=2.0.0", - "transformers>=4.30.0", - "numpy<2.0", -] - -# Create Modal image with Kokoro -tts_image = ( - modal.Image.debian_slim(python_version="3.11") - .apt_install("git") # Install git first for pip install from github - .pip_install(*KOKORO_DEPENDENCIES) - .pip_install("git+https://github.com/hexgrad/kokoro.git") -) - - -@app.function( - image=tts_image, - gpu="T4", - timeout=60, -) -def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, np.ndarray]: - """Modal GPU function for Kokoro TTS. - - This function runs on Modal's GPU infrastructure. - Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS - - Args: - text: Text to synthesize - voice: Voice ID (e.g., af_heart, af_bella, am_michael) - speed: Speech speed multiplier (0.5-2.0) - - Returns: - Tuple of (sample_rate, audio_array) - """ - import numpy as np - - try: - import torch - from kokoro import KModel, KPipeline - - # Initialize model (cached on GPU) - model = KModel().to("cuda").eval() - pipeline = KPipeline(lang_code=voice[0]) - pack = pipeline.load_voice(voice) - - # Generate audio - accumulate all chunks - audio_chunks = [] - for _, ps, _ in pipeline(text, voice, speed): - ref_s = pack[len(ps) - 1] - audio = model(ps, ref_s, speed) - audio_chunks.append(audio.numpy()) - - # Concatenate all audio chunks - if audio_chunks: - full_audio = np.concatenate(audio_chunks) - return (24000, full_audio) - - # If no audio generated, return empty - return (24000, np.zeros(1, dtype=np.float32)) - - except ImportError as e: - raise RuntimeError( - f"Kokoro not installed: {e}. " - "Install with: pip install git+https://github.com/hexgrad/kokoro.git" - ) from e - except Exception as e: - raise RuntimeError(f"TTS synthesis failed: {e}") from e - - -# Optional: Add a test entrypoint -@app.local_entrypoint() -def test(): - """Test the TTS function.""" - print("Testing Modal TTS function...") - sample_rate, audio = kokoro_tts_function.remote("Hello, this is a test.", "af_heart", 1.0) - print(f"Generated audio: {sample_rate}Hz, shape={audio.shape}") - print("✓ TTS function works!") diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 6f59cac5edf5e957ee88a6fec36ad45db3040bd0..0000000000000000000000000000000000000000 --- a/pyproject.toml +++ /dev/null @@ -1,205 +0,0 @@ -[project] -name = "determinator" -version = "0.1.0" -description = "The DETERMINATOR - the Deep Research Agent that Stops at Nothing" -readme = "README.md" -requires-python = ">=3.11" -dependencies = [ - "pydantic>=2.7", - "pydantic-settings>=2.2", # For BaseSettings (config) - "pydantic-ai>=0.0.16", # Agent framework - "openai>=1.0.0", - "anthropic>=0.18.0", - "httpx>=0.27", # Async HTTP client (PubMed) - "beautifulsoup4>=4.12", # HTML parsing - "xmltodict>=0.13", # PubMed XML -> dict - "huggingface-hub>=0.20.0", # Hugging Face Inference API - "gradio[mcp,oauth]>=6.0.0", # Chat interface with MCP server support (6.0 required for css in launch()) - "python-dotenv>=1.0", # .env loading - "tenacity>=8.2", # Retry logic - "structlog>=24.1", # Structured logging - "requests>=2.32.5", # ClinicalTrials.gov (httpx blocked by WAF) - "pydantic-graph>=1.22.0", - "limits>=3.0", # Web search - "llama-index-llms-huggingface>=0.6.1", - "llama-index-llms-huggingface-api>=0.6.1", - "llama-index-vector-stores-chroma>=0.5.3", - "llama-index>=0.14.8", - "gradio-client>=1.0.0", # For STT/OCR API calls - "soundfile>=0.12.0", # For audio file I/O - "pillow>=10.0.0", # For image processing - "torch>=2.0.0", # Required by Kokoro TTS - "transformers>=4.57.2", # Required by Kokoro TTS - "modal>=0.63.0", # Required for TTS GPU execution - "tokenizers>=0.22.0,<=0.23.0", - "rpds-py>=0.29.0", - "pydantic-ai-slim[huggingface]>=0.0.18", - "agent-framework-core>=1.0.0b251120,<2.0.0", - "chromadb>=0.4.0", - "sentence-transformers>=2.2.0", - "numpy<2.0", - "llama-index-llms-openai>=0.6.9", - "llama-index-embeddings-openai>=0.5.1", - "ddgs>=9.9.2", - "aiohttp>=3.13.2", - "lxml>=6.0.2", - "fake-useragent==2.2.0", - "socksio==1.0.0", - "neo4j>=6.0.3", - "md2pdf>=1.0.1", -] - -[project.optional-dependencies] -dev = [ - # Testing - "pytest>=8.0", - "pytest-asyncio>=0.23", - "pytest-sugar>=1.0", - "pytest-cov>=5.0", - "pytest-mock>=3.12", - "respx>=0.21", # Mock httpx requests - "typer>=0.9.0", # Gradio CLI dependency for smoke tests - - # Quality - "ruff>=0.4.0", - "mypy>=1.10", - "pre-commit>=3.7", - - # Documentation - "mkdocs>=1.6.0", - "mkdocs-material>=9.0.0", - "mkdocs-mermaid2-plugin>=1.1.0", - "mkdocs-codeinclude-plugin>=0.2.0", - "mkdocs-git-revision-date-localized-plugin>=1.2.0", - "mkdocs-minify-plugin>=0.8.0", - "pymdown-extensions>=10.17.2", -] -magentic = [ - "agent-framework-core>=1.0.0b251120,<2.0.0", # Microsoft Agent Framework (PyPI) -] -embeddings = [ - "chromadb>=0.4.0", - "sentence-transformers>=2.2.0", - "numpy<2.0", # chromadb compatibility: uses np.float_ removed in NumPy 2.0 -] -modal = [ - # Mario's Modal code execution + LlamaIndex RAG - # Note: modal>=0.63.0 is now in main dependencies for TTS support - "llama-index>=0.11.0", - "llama-index-llms-openai>=0.6.9", - "llama-index-embeddings-openai>=0.5.1", - "llama-index-vector-stores-chroma", - "chromadb>=0.4.0", - "numpy<2.0", # chromadb compatibility: uses np.float_ removed in NumPy 2.0 -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["src"] - -# ============== RUFF CONFIG ============== -[tool.ruff] -line-length = 100 -target-version = "py311" -src = ["src"] -exclude = [ - "tests/", - "examples/", - "reference_repos/", - "folder/", -] - -[tool.ruff.lint] -select = [ - "E", # pycodestyle errors - "F", # pyflakes - "B", # flake8-bugbear - "I", # isort - "N", # pep8-naming - "UP", # pyupgrade - "PL", # pylint - "RUF", # ruff-specific -] -ignore = [ - "PLR0913", # Too many arguments (agents need many params) - "PLR0912", # Too many branches (complex orchestrator logic) - "PLR0911", # Too many return statements (complex agent logic) - "PLR0915", # Too many statements (Gradio UI setup functions) - "PLR2004", # Magic values (statistical constants like p-values) - "PLW0603", # Global statement (singleton pattern for Modal) - "PLC0415", # Lazy imports for optional dependencies - "E402", # Module level import not at top (needed for pytest.importorskip) - "E501", # Line too long (ignore line length violations) - "RUF100", # Unused noqa (version differences between local/CI) -] - -[tool.ruff.lint.isort] -known-first-party = ["src"] - -# ============== MYPY CONFIG ============== -[tool.mypy] -python_version = "3.11" -strict = true -ignore_missing_imports = true -disallow_untyped_defs = true -warn_return_any = true -warn_unused_ignores = false -explicit_package_bases = true -mypy_path = "." -exclude = [ - "^reference_repos/", - "^examples/", - "^folder/", - "^src/app.py", -] - -# ============== PYTEST CONFIG ============== -[tool.pytest.ini_options] -testpaths = ["tests"] -asyncio_mode = "auto" -addopts = [ - "-v", - "--tb=short", - "--strict-markers", - "-p", - "no:logfire", -] -markers = [ - "unit: Unit tests (mocked)", - "integration: Integration tests (real APIs)", - "slow: Slow tests", - "openai: Tests that require OpenAI API key", - "huggingface: Tests that require HuggingFace API key or use HuggingFace models", - "embedding_provider: Tests that require API-based embedding providers (OpenAI, etc.)", - "local_embeddings: Tests that use local embeddings (sentence-transformers, ChromaDB)", -] - -# ============== COVERAGE CONFIG ============== -[tool.coverage.run] -source = ["src"] -omit = ["*/__init__.py"] - -[tool.coverage.report] -exclude_lines = [ - "pragma: no cover", - "if TYPE_CHECKING:", - "raise NotImplementedError", -] - -[dependency-groups] -dev = [ - "mkdocs>=1.6.1", - "mkdocs-codeinclude-plugin>=0.2.1", - "mkdocs-material>=9.7.0", - "mkdocs-mermaid2-plugin>=1.2.3", - "mkdocs-git-revision-date-localized-plugin>=1.2.0", - "mkdocs-minify-plugin>=0.8.0", - "structlog>=25.5.0", - "ty>=0.0.1a28", -] - -# Note: agent-framework-core is optional for magentic mode (multi-agent orchestration) -# Version pinned to 1.0.0b* to avoid breaking changes. CI skips tests via pytest.importorskip diff --git a/requirements.txt b/requirements.txt index cf1b66bb960423122615ba3340799ef6e9d92435..3dfd393b3eb853eff9cee0d4e8ad62f73ba27ff0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,92 +1 @@ -########################## -# DO NOT USE THIS FILE -# FOR GRADIO DEMO ONLY -########################## - - -#Core dependencies for HuggingFace Spaces -pydantic>=2.7 -pydantic-settings>=2.2 -pydantic-ai>=0.0.16 - -# OPTIONAL AI Providers -openai>=1.0.0 -anthropic>=0.18.0 - -# HTTP & Parsing -httpx>=0.27 -aiohttp>=3.13.2 # Required for website crawling -beautifulsoup4>=4.12 -lxml>=6.0.2 # Required for BeautifulSoup lxml parser (faster than html.parser) -xmltodict>=0.13 - -# HuggingFace Hub -huggingface-hub>=0.20.0 - -# UI (Gradio with MCP server support) -gradio[mcp,oauth]>=6.0.0 - -# Utils -python-dotenv>=1.0 -tenacity>=8.2 -structlog>=24.1 -requests>=2.32.5 -limits>=3.0 # Rate limiting -pydantic-graph>=1.22.0 - -# Web search -ddgs>=9.9.2 # duckduckgo-search has been renamed to ddgs -fake-useragent==2.2.0 -socksio==1.0.0 -# LlamaIndex RAG -llama-index-llms-huggingface>=0.6.1 -llama-index-llms-huggingface-api>=0.6.1 -llama-index-vector-stores-chroma>=0.5.3 -llama-index>=0.14.8 - -# Audio/Image processing -gradio-client>=1.0.0 # For STT/OCR API calls -soundfile>=0.12.0 # For audio file I/O -pillow>=10.0.0 # For image processing - -# TTS dependencies (for Modal GPU TTS) -torch>=2.0.0 # Required by Kokoro TTS -transformers>=4.57.2 # Required by Kokoro TTS -modal>=0.63.0 # Required for TTS GPU execution -# Note: Kokoro is installed in Modal image from: git+https://github.com/hexgrad/kokoro.git - -# Embeddings & Vector Store -tokenizers>=0.22.0,<=0.23.0 -rpds-py>=0.29.0 # Python implementation of rpds (required by chromadb on Windows) -chromadb>=0.4.0 -sentence-transformers>=2.2.0 -numpy<2.0 # chromadb compatibility: uses np.float_ removed in NumPy 2.0 -neo4j>=6.0.3 - -### DOCUMENT STUFF - -cssselect2==0.8.0 -docopt==0.6.2 -fonttools==4.61.0 -markdown2==2.5.4 -md2pdf==1.0.1 -pydyf==0.11.0 -pyphen==0.17.2 -tinycss2==1.5.1 -tinyhtml5==2.0.0 -weasyprint==66.0 -webencodings==0.5.1 -zopfli==0.4.0 - -# Optional: Modal for code execution -modal>=0.63.0 - -# Pydantic AI with HuggingFace support -pydantic-ai-slim[huggingface]>=0.0.18 - -# Multi-agent orchestration (Advanced mode) -agent-framework-core>=1.0.0b251120,<2.0.0 - -# LlamaIndex RAG - OpenAI -llama-index-llms-openai>=0.6.9 -llama-index-embeddings-openai>=0.5.1 +deepcritical \ No newline at end of file diff --git a/src/agent_factory/agents.py b/src/agent_factory/agents.py index c676140e509d833431b9b280b5f2695de82ec181..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/agent_factory/agents.py +++ b/src/agent_factory/agents.py @@ -1,361 +0,0 @@ -"""Agent factory functions for creating research agents. - -Provides factory functions for creating all Pydantic AI agents used in -the research workflows, following the pattern from judges.py. -""" - -from typing import TYPE_CHECKING, Any - -import structlog - -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError - -if TYPE_CHECKING: - from src.agent_factory.graph_builder import GraphBuilder - from src.agents.input_parser import InputParserAgent - from src.agents.knowledge_gap import KnowledgeGapAgent - from src.agents.long_writer import LongWriterAgent - from src.agents.proofreader import ProofreaderAgent - from src.agents.thinking import ThinkingAgent - from src.agents.tool_selector import ToolSelectorAgent - from src.agents.writer import WriterAgent - from src.orchestrator.graph_orchestrator import GraphOrchestrator - from src.orchestrator.planner_agent import PlannerAgent - from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow - -logger = structlog.get_logger() - - -def create_input_parser_agent( - model: Any | None = None, oauth_token: str | None = None -) -> "InputParserAgent": - """ - Create input parser agent for query analysis and research mode detection. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured InputParserAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - from src.agents.input_parser import create_input_parser_agent as _create_agent - - try: - logger.debug("Creating input parser agent") - return _create_agent(model=model, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create input parser agent", error=str(e)) - raise ConfigurationError(f"Failed to create input parser agent: {e}") from e - - -def create_planner_agent( - model: Any | None = None, oauth_token: str | None = None -) -> "PlannerAgent": - """ - Create planner agent with web search and crawl tools. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured PlannerAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - # Lazy import to avoid circular dependencies - from src.orchestrator.planner_agent import create_planner_agent as _create_planner_agent - - try: - logger.debug("Creating planner agent") - return _create_planner_agent(model=model, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create planner agent", error=str(e)) - raise ConfigurationError(f"Failed to create planner agent: {e}") from e - - -def create_knowledge_gap_agent( - model: Any | None = None, oauth_token: str | None = None -) -> "KnowledgeGapAgent": - """ - Create knowledge gap agent for evaluating research completeness. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured KnowledgeGapAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - from src.agents.knowledge_gap import create_knowledge_gap_agent as _create_agent - - try: - logger.debug("Creating knowledge gap agent") - return _create_agent(model=model, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create knowledge gap agent", error=str(e)) - raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e - - -def create_tool_selector_agent( - model: Any | None = None, oauth_token: str | None = None -) -> "ToolSelectorAgent": - """ - Create tool selector agent for choosing tools to address gaps. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured ToolSelectorAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - from src.agents.tool_selector import create_tool_selector_agent as _create_agent - - try: - logger.debug("Creating tool selector agent") - return _create_agent(model=model, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create tool selector agent", error=str(e)) - raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e - - -def create_thinking_agent( - model: Any | None = None, oauth_token: str | None = None -) -> "ThinkingAgent": - """ - Create thinking agent for generating observations. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured ThinkingAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - from src.agents.thinking import create_thinking_agent as _create_agent - - try: - logger.debug("Creating thinking agent") - return _create_agent(model=model, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create thinking agent", error=str(e)) - raise ConfigurationError(f"Failed to create thinking agent: {e}") from e - - -def create_writer_agent(model: Any | None = None, oauth_token: str | None = None) -> "WriterAgent": - """ - Create writer agent for generating final reports. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured WriterAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - from src.agents.writer import create_writer_agent as _create_agent - - try: - logger.debug("Creating writer agent") - return _create_agent(model=model, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create writer agent", error=str(e)) - raise ConfigurationError(f"Failed to create writer agent: {e}") from e - - -def create_long_writer_agent( - model: Any | None = None, oauth_token: str | None = None -) -> "LongWriterAgent": - """ - Create long writer agent for iteratively writing report sections. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured LongWriterAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - from src.agents.long_writer import create_long_writer_agent as _create_agent - - try: - logger.debug("Creating long writer agent") - return _create_agent(model=model, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create long writer agent", error=str(e)) - raise ConfigurationError(f"Failed to create long writer agent: {e}") from e - - -def create_proofreader_agent( - model: Any | None = None, oauth_token: str | None = None -) -> "ProofreaderAgent": - """ - Create proofreader agent for finalizing report drafts. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured ProofreaderAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - from src.agents.proofreader import create_proofreader_agent as _create_agent - - try: - logger.debug("Creating proofreader agent") - return _create_agent(model=model, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create proofreader agent", error=str(e)) - raise ConfigurationError(f"Failed to create proofreader agent: {e}") from e - - -def create_iterative_flow( - max_iterations: int = 5, - max_time_minutes: int = 10, - verbose: bool = True, - use_graph: bool | None = None, -) -> "IterativeResearchFlow": - """ - Create iterative research flow. - - Args: - max_iterations: Maximum number of iterations - max_time_minutes: Maximum time in minutes - verbose: Whether to log progress - use_graph: Whether to use graph execution. If None, reads from settings.use_graph_execution - - Returns: - Configured IterativeResearchFlow instance - """ - from src.orchestrator.research_flow import IterativeResearchFlow - - try: - # Use settings default if not explicitly provided - if use_graph is None: - use_graph = settings.use_graph_execution - - logger.debug("Creating iterative research flow", use_graph=use_graph) - return IterativeResearchFlow( - max_iterations=max_iterations, - max_time_minutes=max_time_minutes, - verbose=verbose, - use_graph=use_graph, - ) - except Exception as e: - logger.error("Failed to create iterative flow", error=str(e)) - raise ConfigurationError(f"Failed to create iterative flow: {e}") from e - - -def create_deep_flow( - max_iterations: int = 5, - max_time_minutes: int = 10, - verbose: bool = True, - use_long_writer: bool = True, - use_graph: bool | None = None, -) -> "DeepResearchFlow": - """ - Create deep research flow. - - Args: - max_iterations: Maximum iterations per section - max_time_minutes: Maximum time per section - verbose: Whether to log progress - use_long_writer: Whether to use long writer (True) or proofreader (False) - use_graph: Whether to use graph execution. If None, reads from settings.use_graph_execution - - Returns: - Configured DeepResearchFlow instance - """ - from src.orchestrator.research_flow import DeepResearchFlow - - try: - # Use settings default if not explicitly provided - if use_graph is None: - use_graph = settings.use_graph_execution - - logger.debug("Creating deep research flow", use_graph=use_graph) - return DeepResearchFlow( - max_iterations=max_iterations, - max_time_minutes=max_time_minutes, - verbose=verbose, - use_long_writer=use_long_writer, - use_graph=use_graph, - ) - except Exception as e: - logger.error("Failed to create deep flow", error=str(e)) - raise ConfigurationError(f"Failed to create deep flow: {e}") from e - - -def create_graph_orchestrator( - mode: str = "auto", - max_iterations: int = 5, - max_time_minutes: int = 10, - use_graph: bool = True, -) -> "GraphOrchestrator": - """ - Create graph orchestrator. - - Args: - mode: Research mode ("iterative", "deep", or "auto") - max_iterations: Maximum iterations per loop - max_time_minutes: Maximum time per loop - use_graph: Whether to use graph execution (True) or agent chains (False) - - Returns: - Configured GraphOrchestrator instance - """ - from src.orchestrator.graph_orchestrator import create_graph_orchestrator as _create - - try: - logger.debug("Creating graph orchestrator", mode=mode, use_graph=use_graph) - return _create( - mode=mode, # type: ignore[arg-type] - max_iterations=max_iterations, - max_time_minutes=max_time_minutes, - use_graph=use_graph, - ) - except Exception as e: - logger.error("Failed to create graph orchestrator", error=str(e)) - raise ConfigurationError(f"Failed to create graph orchestrator: {e}") from e - - -def create_graph_builder() -> "GraphBuilder": - """ - Create a graph builder instance. - - Returns: - GraphBuilder instance - """ - from src.agent_factory.graph_builder import GraphBuilder - - try: - logger.debug("Creating graph builder") - return GraphBuilder() - except Exception as e: - logger.error("Failed to create graph builder", error=str(e)) - raise ConfigurationError(f"Failed to create graph builder: {e}") from e diff --git a/src/agent_factory/graph_builder.py b/src/agent_factory/graph_builder.py deleted file mode 100644 index c06c1b0ac0d38a943f957bc100bef6cdfc09f390..0000000000000000000000000000000000000000 --- a/src/agent_factory/graph_builder.py +++ /dev/null @@ -1,635 +0,0 @@ -"""Graph builder utilities for constructing research workflow graphs. - -Provides classes and utilities for building graph-based orchestration systems -using Pydantic AI agents as nodes. -""" - -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Literal - -import structlog -from pydantic import BaseModel, Field - -if TYPE_CHECKING: - from pydantic_ai import Agent - - from src.middleware.state_machine import WorkflowState - -logger = structlog.get_logger() - - -# ============================================================================ -# Graph Node Models -# ============================================================================ - - -class GraphNode(BaseModel): - """Base class for graph nodes.""" - - node_id: str = Field(description="Unique identifier for the node") - node_type: Literal["agent", "state", "decision", "parallel"] = Field(description="Type of node") - description: str = Field(default="", description="Human-readable description of the node") - - model_config = {"frozen": True} - - -class AgentNode(GraphNode): - """Node that executes a Pydantic AI agent.""" - - node_type: Literal["agent"] = "agent" - agent: Any = Field(description="Pydantic AI agent to execute") - input_transformer: Callable[[Any], Any] | None = Field( - default=None, description="Transform input before passing to agent" - ) - output_transformer: Callable[[Any], Any] | None = Field( - default=None, description="Transform output after agent execution" - ) - - model_config = {"arbitrary_types_allowed": True} - - -class StateNode(GraphNode): - """Node that updates or reads workflow state.""" - - node_type: Literal["state"] = "state" - state_updater: Callable[[Any, Any], Any] = Field( - description="Function to update workflow state" - ) - state_reader: Callable[[Any], Any] | None = Field( - default=None, description="Function to read state (optional)" - ) - - model_config = {"arbitrary_types_allowed": True} - - -class DecisionNode(GraphNode): - """Node that makes routing decisions based on conditions.""" - - node_type: Literal["decision"] = "decision" - decision_function: Callable[[Any], str] = Field( - description="Function that returns next node ID based on input" - ) - options: list[str] = Field(description="List of possible next node IDs", min_length=1) - - model_config = {"arbitrary_types_allowed": True} - - -class ParallelNode(GraphNode): - """Node that executes multiple nodes in parallel.""" - - node_type: Literal["parallel"] = "parallel" - parallel_nodes: list[str] = Field( - description="List of node IDs to run in parallel", min_length=0 - ) - aggregator: Callable[[list[Any]], Any] | None = Field( - default=None, description="Function to aggregate parallel results" - ) - - model_config = {"arbitrary_types_allowed": True} - - -# ============================================================================ -# Graph Edge Models -# ============================================================================ - - -class GraphEdge(BaseModel): - """Base class for graph edges.""" - - from_node: str = Field(description="Source node ID") - to_node: str = Field(description="Target node ID") - condition: Callable[[Any], bool] | None = Field( - default=None, description="Optional condition function" - ) - weight: float = Field(default=1.0, description="Edge weight for routing decisions") - - model_config = {"arbitrary_types_allowed": True} - - -class SequentialEdge(GraphEdge): - """Edge that is always traversed (no condition).""" - - condition: None = None - - -class ConditionalEdge(GraphEdge): - """Edge that is traversed based on a condition.""" - - condition: Callable[[Any], bool] = Field(description="Required condition function") - condition_description: str = Field( - default="", description="Human-readable description of condition" - ) - - -class ParallelEdge(GraphEdge): - """Edge used for parallel execution branches.""" - - condition: None = None - - -# ============================================================================ -# Research Graph Class -# ============================================================================ - - -class ResearchGraph(BaseModel): - """Represents a research workflow graph with nodes and edges.""" - - nodes: dict[str, GraphNode] = Field(default_factory=dict, description="All nodes in the graph") - edges: dict[str, list[GraphEdge]] = Field( - default_factory=dict, description="Edges by source node ID" - ) - entry_node: str = Field(description="Starting node ID") - exit_nodes: list[str] = Field(default_factory=list, description="Terminal node IDs") - - model_config = {"arbitrary_types_allowed": True} - - def add_node(self, node: GraphNode) -> None: - """Add a node to the graph. - - Args: - node: The node to add - - Raises: - ValueError: If node ID already exists - """ - if node.node_id in self.nodes: - raise ValueError(f"Node {node.node_id} already exists in graph") - self.nodes[node.node_id] = node - logger.debug("Node added to graph", node_id=node.node_id, type=node.node_type) - - def add_edge(self, edge: GraphEdge) -> None: - """Add an edge to the graph. - - Args: - edge: The edge to add - - Raises: - ValueError: If source or target node doesn't exist - """ - if edge.from_node not in self.nodes: - raise ValueError(f"Source node {edge.from_node} not found in graph") - if edge.to_node not in self.nodes: - raise ValueError(f"Target node {edge.to_node} not found in graph") - - if edge.from_node not in self.edges: - self.edges[edge.from_node] = [] - self.edges[edge.from_node].append(edge) - logger.debug( - "Edge added to graph", - from_node=edge.from_node, - to_node=edge.to_node, - ) - - def get_node(self, node_id: str) -> GraphNode | None: - """Get a node by ID. - - Args: - node_id: The node ID - - Returns: - The node, or None if not found - """ - return self.nodes.get(node_id) - - def get_next_nodes(self, node_id: str, context: Any = None) -> list[tuple[str, GraphEdge]]: - """Get all possible next nodes from a given node. - - Args: - node_id: The current node ID - context: Optional context for evaluating conditions - - Returns: - List of (node_id, edge) tuples for valid next nodes - """ - if node_id not in self.edges: - return [] - - next_nodes = [] - for edge in self.edges[node_id]: - # Evaluate condition if present - if edge.condition is None or edge.condition(context): - next_nodes.append((edge.to_node, edge)) - - return next_nodes - - def validate_structure(self) -> list[str]: - """Validate the graph structure. - - Returns: - List of validation error messages (empty if valid) - """ - errors = [] - - # Check entry node exists - if self.entry_node not in self.nodes: - errors.append(f"Entry node {self.entry_node} not found in graph") - - # Check exit nodes exist and at least one is defined - if not self.exit_nodes: - errors.append("At least one exit node must be defined") - for exit_node in self.exit_nodes: - if exit_node not in self.nodes: - errors.append(f"Exit node {exit_node} not found in graph") - - # Check all edges reference valid nodes - for from_node, edge_list in self.edges.items(): - if from_node not in self.nodes: - errors.append(f"Edge source node {from_node} not found") - for edge in edge_list: - if edge.to_node not in self.nodes: - errors.append(f"Edge target node {edge.to_node} not found") - - # Check all nodes are reachable from entry node (basic check) - if self.entry_node in self.nodes: - reachable = {self.entry_node} - queue = [self.entry_node] - while queue: - current = queue.pop(0) - for next_node, _ in self.get_next_nodes(current): - if next_node not in reachable: - reachable.add(next_node) - queue.append(next_node) - - unreachable = set(self.nodes.keys()) - reachable - if unreachable: - errors.append(f"Unreachable nodes from entry node: {', '.join(unreachable)}") - - return errors - - -# ============================================================================ -# Graph Builder Class -# ============================================================================ - - -class GraphBuilder: - """Builder for constructing research workflow graphs.""" - - def __init__(self) -> None: - """Initialize the graph builder.""" - self.graph = ResearchGraph(entry_node="", exit_nodes=[]) - - def add_agent_node( - self, - node_id: str, - agent: "Agent[Any, Any]", - description: str = "", - input_transformer: Callable[[Any], Any] | None = None, - output_transformer: Callable[[Any], Any] | None = None, - ) -> "GraphBuilder": - """Add an agent node to the graph. - - Args: - node_id: Unique identifier for the node - agent: Pydantic AI agent to execute - description: Human-readable description - input_transformer: Optional input transformation function - output_transformer: Optional output transformation function - - Returns: - Self for method chaining - """ - node = AgentNode( - node_id=node_id, - agent=agent, - description=description, - input_transformer=input_transformer, - output_transformer=output_transformer, - ) - self.graph.add_node(node) - return self - - def add_state_node( - self, - node_id: str, - state_updater: Callable[["WorkflowState", Any], "WorkflowState"], - description: str = "", - state_reader: Callable[["WorkflowState"], Any] | None = None, - ) -> "GraphBuilder": - """Add a state node to the graph. - - Args: - node_id: Unique identifier for the node - state_updater: Function to update workflow state - description: Human-readable description - state_reader: Optional function to read state - - Returns: - Self for method chaining - """ - node = StateNode( - node_id=node_id, - state_updater=state_updater, - description=description, - state_reader=state_reader, - ) - self.graph.add_node(node) - return self - - def add_decision_node( - self, - node_id: str, - decision_function: Callable[[Any], str], - options: list[str], - description: str = "", - ) -> "GraphBuilder": - """Add a decision node to the graph. - - Args: - node_id: Unique identifier for the node - decision_function: Function that returns next node ID - options: List of possible next node IDs - description: Human-readable description - - Returns: - Self for method chaining - """ - node = DecisionNode( - node_id=node_id, - decision_function=decision_function, - options=options, - description=description, - ) - self.graph.add_node(node) - return self - - def add_parallel_node( - self, - node_id: str, - parallel_nodes: list[str], - description: str = "", - aggregator: Callable[[list[Any]], Any] | None = None, - ) -> "GraphBuilder": - """Add a parallel node to the graph. - - Args: - node_id: Unique identifier for the node - parallel_nodes: List of node IDs to run in parallel - description: Human-readable description - aggregator: Optional function to aggregate results - - Returns: - Self for method chaining - """ - node = ParallelNode( - node_id=node_id, - parallel_nodes=parallel_nodes, - description=description, - aggregator=aggregator, - ) - self.graph.add_node(node) - return self - - def connect_nodes( - self, - from_node: str, - to_node: str, - condition: Callable[[Any], bool] | None = None, - condition_description: str = "", - ) -> "GraphBuilder": - """Connect two nodes with an edge. - - Args: - from_node: Source node ID - to_node: Target node ID - condition: Optional condition function - condition_description: Description of condition (if conditional) - - Returns: - Self for method chaining - """ - if condition is None: - edge: GraphEdge = SequentialEdge(from_node=from_node, to_node=to_node) - else: - edge = ConditionalEdge( - from_node=from_node, - to_node=to_node, - condition=condition, - condition_description=condition_description, - ) - self.graph.add_edge(edge) - return self - - def set_entry_node(self, node_id: str) -> "GraphBuilder": - """Set the entry node for the graph. - - Args: - node_id: The entry node ID - - Returns: - Self for method chaining - """ - self.graph.entry_node = node_id - return self - - def set_exit_nodes(self, node_ids: list[str]) -> "GraphBuilder": - """Set the exit nodes for the graph. - - Args: - node_ids: List of exit node IDs - - Returns: - Self for method chaining - """ - self.graph.exit_nodes = node_ids - return self - - def build(self) -> ResearchGraph: - """Finalize graph construction and validate. - - Returns: - The constructed ResearchGraph - - Raises: - ValueError: If graph validation fails - """ - errors = self.graph.validate_structure() - if errors: - error_msg = "Graph validation failed:\n" + "\n".join(f" - {e}" for e in errors) - logger.error("Graph validation failed", errors=errors) - raise ValueError(error_msg) - - logger.info( - "Graph built successfully", - nodes=len(self.graph.nodes), - edges=sum(len(edges) for edges in self.graph.edges.values()), - entry_node=self.graph.entry_node, - exit_nodes=self.graph.exit_nodes, - ) - return self.graph - - -# ============================================================================ -# Factory Functions -# ============================================================================ - - -def create_iterative_graph( - knowledge_gap_agent: "Agent[Any, Any]", - tool_selector_agent: "Agent[Any, Any]", - thinking_agent: "Agent[Any, Any]", - writer_agent: "Agent[Any, Any]", -) -> ResearchGraph: - """Create a graph for iterative research flow. - - Args: - knowledge_gap_agent: Agent for evaluating knowledge gaps - tool_selector_agent: Agent for selecting tools - thinking_agent: Agent for generating observations - writer_agent: Agent for writing final report - - Returns: - Constructed ResearchGraph for iterative research - """ - builder = GraphBuilder() - - # Add nodes - builder.add_agent_node("thinking", thinking_agent, "Generate observations") - builder.add_agent_node("knowledge_gap", knowledge_gap_agent, "Evaluate knowledge gaps") - - def _decision_function(result: Any) -> str: - """Decision function for continue_decision node. - - Args: - result: Result from knowledge_gap node (KnowledgeGapOutput or tuple) - - Returns: - Next node ID: "writer" if research complete, "tool_selector" otherwise - """ - # Handle case where result might be a tuple (validation error) - if isinstance(result, tuple): - # Try to extract research_complete from tuple - if len(result) == 2 and isinstance(result[0], str) and result[0] == "research_complete": - # Format: ('research_complete', False) - return "writer" if result[1] else "tool_selector" - # Try to find boolean value in tuple - for item in result: - if isinstance(item, bool): - return "writer" if item else "tool_selector" - elif isinstance(item, dict) and "research_complete" in item: - return "writer" if item["research_complete"] else "tool_selector" - # Default to continuing research if we can't determine - return "tool_selector" - - # Normal case: result is KnowledgeGapOutput object - research_complete = getattr(result, "research_complete", False) - return "writer" if research_complete else "tool_selector" - - builder.add_decision_node( - "continue_decision", - decision_function=_decision_function, - options=["tool_selector", "writer"], - description="Decide whether to continue research or write report", - ) - builder.add_agent_node("tool_selector", tool_selector_agent, "Select tools to address gap") - builder.add_state_node( - "execute_tools", - state_updater=lambda state, - tasks: state, # Placeholder - actual execution handled separately - description="Execute selected tools", - ) - builder.add_agent_node("writer", writer_agent, "Write final report") - - # Add edges - builder.connect_nodes("thinking", "knowledge_gap") - builder.connect_nodes("knowledge_gap", "continue_decision") - builder.connect_nodes("continue_decision", "tool_selector") - builder.connect_nodes("continue_decision", "writer") - builder.connect_nodes("tool_selector", "execute_tools") - builder.connect_nodes("execute_tools", "thinking") # Loop back - - # Set entry and exit - builder.set_entry_node("thinking") - builder.set_exit_nodes(["writer"]) - - return builder.build() - - -def create_deep_graph( - planner_agent: "Agent[Any, Any]", - knowledge_gap_agent: "Agent[Any, Any]", - tool_selector_agent: "Agent[Any, Any]", - thinking_agent: "Agent[Any, Any]", - writer_agent: "Agent[Any, Any]", - long_writer_agent: "Agent[Any, Any]", -) -> ResearchGraph: - """Create a graph for deep research flow. - - The graph structure: planner → store_plan → parallel_loops → collect_drafts → synthesizer - - Args: - planner_agent: Agent for creating report plan - knowledge_gap_agent: Agent for evaluating knowledge gaps (not used directly, but needed for iterative flows) - tool_selector_agent: Agent for selecting tools (not used directly, but needed for iterative flows) - thinking_agent: Agent for generating observations (not used directly, but needed for iterative flows) - writer_agent: Agent for writing section reports (not used directly, but needed for iterative flows) - long_writer_agent: Agent for synthesizing final report - - Returns: - Constructed ResearchGraph for deep research - """ - from src.utils.models import ReportPlan - - builder = GraphBuilder() - - # Add nodes - # 1. Planner agent - creates report plan - builder.add_agent_node("planner", planner_agent, "Create report plan with sections") - - # 2. State node - store report plan in workflow state - def store_plan(state: "WorkflowState", plan: ReportPlan) -> "WorkflowState": - """Store report plan in state for parallel loops to access.""" - # Store plan in a custom attribute (we'll need to extend WorkflowState or use a dict) - # For now, we'll store it in the context's node_results - # The actual storage will happen in the graph execution - return state - - builder.add_state_node( - "store_plan", - state_updater=store_plan, - description="Store report plan in state", - ) - - # 3. Parallel node - will execute iterative research flows for each section - # The actual execution will be handled dynamically in _execute_parallel_node() - # We use a special node ID that the executor will recognize - builder.add_parallel_node( - "parallel_loops", - parallel_nodes=[], # Will be populated dynamically based on report plan - description="Execute parallel iterative research loops for each section", - aggregator=lambda results: results, # Collect all section drafts - ) - - # 4. State node - collect section drafts into ReportDraft - def collect_drafts(state: "WorkflowState", section_drafts: list[str]) -> "WorkflowState": - """Collect section drafts into state for synthesizer.""" - # Store drafts in state (will be accessed by synthesizer) - return state - - builder.add_state_node( - "collect_drafts", - state_updater=collect_drafts, - description="Collect section drafts for synthesis", - ) - - # 5. Synthesizer agent - creates final report from drafts - builder.add_agent_node( - "synthesizer", long_writer_agent, "Synthesize final report from section drafts" - ) - - # Add edges - builder.connect_nodes("planner", "store_plan") - builder.connect_nodes("store_plan", "parallel_loops") - builder.connect_nodes("parallel_loops", "collect_drafts") - builder.connect_nodes("collect_drafts", "synthesizer") - - # Set entry and exit - builder.set_entry_node("planner") - builder.set_exit_nodes(["synthesizer"]) - - return builder.build() - - -# No need to rebuild models since we're using Any types -# The models will work correctly with arbitrary_types_allowed=True diff --git a/src/agent_factory/judges.py b/src/agent_factory/judges.py index 59ccdc08f3d8292be1161908a47b9451672121d2..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/agent_factory/judges.py +++ b/src/agent_factory/judges.py @@ -1,537 +0,0 @@ -"""Judge handler for evidence assessment using PydanticAI.""" - -import asyncio -import json -from typing import Any, ClassVar - -import structlog -from huggingface_hub import InferenceClient -from pydantic_ai import Agent -from pydantic_ai.models.anthropic import AnthropicModel -from pydantic_ai.models.huggingface import HuggingFaceModel -from pydantic_ai.models.openai import OpenAIChatModel as OpenAIModel -from pydantic_ai.providers.anthropic import AnthropicProvider -from pydantic_ai.providers.huggingface import HuggingFaceProvider -from pydantic_ai.providers.openai import OpenAIProvider -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential - -from src.prompts.judge import ( - SYSTEM_PROMPT, - format_empty_evidence_prompt, - format_user_prompt, -) -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError -from src.utils.models import AssessmentDetails, Evidence, JudgeAssessment - -logger = structlog.get_logger() - - -def get_model(oauth_token: str | None = None) -> Any: - """Get the LLM model based on configuration. - - Explicitly passes API keys from settings to avoid requiring - users to export environment variables manually. - - Priority order: - 1. HuggingFace (if OAuth token or API key available - preferred for free tier) - 2. OpenAI (if API key available) - 3. Anthropic (if API key available) - - If OAuth token is available, prefer HuggingFace (even if provider is set to OpenAI). - This ensures users logged in via HuggingFace Spaces get the free tier. - - Args: - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured Pydantic AI model - - Raises: - ConfigurationError: If no LLM provider is available - """ - from src.utils.hf_error_handler import log_token_info, validate_hf_token - - # Priority: oauth_token > settings.hf_token > settings.huggingface_api_key - effective_hf_token = oauth_token or settings.hf_token or settings.huggingface_api_key - - # Validate and log token information - if effective_hf_token: - log_token_info(effective_hf_token, context="get_model") - is_valid, error_msg = validate_hf_token(effective_hf_token) - if not is_valid: - logger.warning( - "Token validation failed", - error=error_msg, - has_oauth=bool(oauth_token), - ) - # Continue anyway - let the API call fail with a clear error - - # Try HuggingFace first (preferred for free tier) - if effective_hf_token: - model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct" - hf_provider = HuggingFaceProvider(api_key=effective_hf_token) - logger.info( - "using_huggingface_with_token", - has_oauth=bool(oauth_token), - has_settings_token=bool(settings.hf_token or settings.huggingface_api_key), - model=model_name, - ) - return HuggingFaceModel(model_name, provider=hf_provider) - - # Fallback to OpenAI if available - if settings.has_openai_key: - assert settings.openai_api_key is not None # Type narrowing - model_name = settings.openai_model - openai_provider = OpenAIProvider(api_key=settings.openai_api_key) - logger.info("using_openai", model=model_name) - return OpenAIModel(model_name, provider=openai_provider) - - # Fallback to Anthropic if available - if settings.has_anthropic_key: - assert settings.anthropic_api_key is not None # Type narrowing - model_name = settings.anthropic_model - anthropic_provider = AnthropicProvider(api_key=settings.anthropic_api_key) - logger.info("using_anthropic", model=model_name) - return AnthropicModel(model_name, provider=anthropic_provider) - - # No provider available - raise ConfigurationError( - "No LLM provider available. Please configure one of:\n" - "1. HuggingFace: Log in via OAuth (recommended for Spaces) or set HF_TOKEN\n" - "2. OpenAI: Set OPENAI_API_KEY environment variable\n" - "3. Anthropic: Set ANTHROPIC_API_KEY environment variable" - ) - - -class JudgeHandler: - """ - Handles evidence assessment using an LLM with structured output. - - Uses PydanticAI to ensure responses match the JudgeAssessment schema. - """ - - def __init__(self, model: Any = None) -> None: - """ - Initialize the JudgeHandler. - - Args: - model: Optional PydanticAI model. If None, uses config default. - """ - self.model = model or get_model() - self.agent = Agent( - model=self.model, - output_type=JudgeAssessment, - system_prompt=SYSTEM_PROMPT, - retries=3, - ) - - async def assess( - self, - question: str, - evidence: list[Evidence], - ) -> JudgeAssessment: - """ - Assess evidence and determine if it's sufficient. - - Args: - question: The user's research question - evidence: List of Evidence objects from search - - Returns: - JudgeAssessment with evaluation results - - Raises: - JudgeError: If assessment fails after retries - """ - logger.info( - "Starting evidence assessment", - question=question[:100], - evidence_count=len(evidence), - ) - - # Format the prompt based on whether we have evidence - if evidence: - user_prompt = format_user_prompt(question, evidence) - else: - user_prompt = format_empty_evidence_prompt(question) - - try: - # Run the agent with structured output - result = await self.agent.run(user_prompt) - assessment = result.output - - logger.info( - "Assessment complete", - sufficient=assessment.sufficient, - recommendation=assessment.recommendation, - confidence=assessment.confidence, - ) - - return assessment - - except Exception as e: - # Extract error details for better logging and handling - from src.utils.hf_error_handler import ( - extract_error_details, - get_user_friendly_error_message, - ) - - error_details = extract_error_details(e) - logger.error( - "Assessment failed", - error=str(e), - status_code=error_details.get("status_code"), - model_name=error_details.get("model_name"), - is_auth_error=error_details.get("is_auth_error"), - is_model_error=error_details.get("is_model_error"), - ) - - # Log user-friendly message for debugging - if error_details.get("is_auth_error") or error_details.get("is_model_error"): - user_msg = get_user_friendly_error_message(e, error_details.get("model_name")) - logger.warning("API error details", user_message=user_msg[:200]) - - # Return a safe default assessment on failure - return self._create_fallback_assessment(question, str(e)) - - def _create_fallback_assessment( - self, - question: str, - error: str, - ) -> JudgeAssessment: - """ - Create a fallback assessment when LLM fails. - - Args: - question: The original question - error: The error message - - Returns: - Safe fallback JudgeAssessment - """ - return JudgeAssessment( - details=AssessmentDetails( - mechanism_score=0, - mechanism_reasoning="Assessment failed due to LLM error", - clinical_evidence_score=0, - clinical_reasoning="Assessment failed due to LLM error", - drug_candidates=[], - key_findings=[], - ), - sufficient=False, - confidence=0.0, - recommendation="continue", - next_search_queries=[ - f"{question} mechanism", - f"{question} clinical trials", - f"{question} drug candidates", - ], - reasoning=f"Assessment failed: {error}. Recommend retrying with refined queries.", - ) - - -class HFInferenceJudgeHandler: - """ - JudgeHandler using HuggingFace Inference API for FREE LLM calls. - Defaults to Llama-3.1-8B-Instruct (requires HF_TOKEN) or falls back to public models. - """ - - FALLBACK_MODELS: ClassVar[list[str]] = [ - "meta-llama/Llama-3.1-8B-Instruct", # Primary (Gated) - "mistralai/Mistral-7B-Instruct-v0.3", # Secondary - "HuggingFaceH4/zephyr-7b-beta", # Fallback (Ungated) - ] - - def __init__(self, model_id: str | None = None, api_key: str | None = None) -> None: - """ - Initialize with HF Inference client. - - Args: - model_id: Optional specific model ID. If None, uses FALLBACK_MODELS chain. - api_key: Optional HuggingFace API key/token. If None, uses HF_TOKEN from env. - """ - self.model_id = model_id - # Pass api_key to InferenceClient if provided, otherwise it will use HF_TOKEN from env - self.client = InferenceClient(api_key=api_key) if api_key else InferenceClient() - self.call_count = 0 - self.last_question: str | None = None - self.last_evidence: list[Evidence] | None = None - - async def assess( - self, - question: str, - evidence: list[Evidence], - ) -> JudgeAssessment: - """ - Assess evidence using HuggingFace Inference API. - Attempts models in order until one succeeds. - """ - self.call_count += 1 - self.last_question = question - self.last_evidence = evidence - - # Format the user prompt - if evidence: - user_prompt = format_user_prompt(question, evidence) - else: - user_prompt = format_empty_evidence_prompt(question) - - models_to_try: list[str] = [self.model_id] if self.model_id else self.FALLBACK_MODELS - last_error: Exception | None = None - - for model in models_to_try: - try: - return await self._call_with_retry(model, user_prompt, question) - except Exception as e: - logger.warning("Model failed", model=model, error=str(e)) - last_error = e - continue - - # All models failed - logger.error("All HF models failed", error=str(last_error)) - return self._create_fallback_assessment(question, str(last_error)) - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=1, max=4), - retry=retry_if_exception_type(Exception), - reraise=True, - ) - async def _call_with_retry(self, model: str, prompt: str, question: str) -> JudgeAssessment: - """Make API call with retry logic using chat_completion.""" - loop = asyncio.get_running_loop() - - # Build messages for chat_completion (model-agnostic) - messages = [ - { - "role": "system", - "content": f"""{SYSTEM_PROMPT} - -IMPORTANT: Respond with ONLY valid JSON matching this schema: -{{ - "details": {{ - "mechanism_score": , - "mechanism_reasoning": "", - "clinical_evidence_score": , - "clinical_reasoning": "", - "drug_candidates": ["", ...], - "key_findings": ["", ...] - }}, - "sufficient": , - "confidence": , - "recommendation": "continue" | "synthesize", - "next_search_queries": ["", ...], - "reasoning": "" -}}""", - }, - {"role": "user", "content": prompt}, - ] - - # Use chat_completion (conversational task - supported by all models) - response = await loop.run_in_executor( - None, - lambda: self.client.chat_completion( - messages=messages, - model=model, - max_tokens=1024, - temperature=0.1, - ), - ) - - # Extract content from response - content = response.choices[0].message.content - if not content: - raise ValueError("Empty response from model") - - # Extract and parse JSON - json_data = self._extract_json(content) - if not json_data: - raise ValueError("No valid JSON found in response") - - return JudgeAssessment(**json_data) - - def _extract_json(self, text: str) -> dict[str, Any] | None: - """ - Robust JSON extraction that handles markdown blocks and nested braces. - """ - text = text.strip() - - # Remove markdown code blocks if present (with bounds checking) - if "```json" in text: - parts = text.split("```json", 1) - if len(parts) > 1: - inner_parts = parts[1].split("```", 1) - text = inner_parts[0] - elif "```" in text: - parts = text.split("```", 1) - if len(parts) > 1: - inner_parts = parts[1].split("```", 1) - text = inner_parts[0] - - text = text.strip() - - # Find first '{' - start_idx = text.find("{") - if start_idx == -1: - return None - - # Stack-based parsing ignoring chars in strings - count = 0 - in_string = False - escape = False - - for i, char in enumerate(text[start_idx:], start=start_idx): - if in_string: - if escape: - escape = False - elif char == "\\": - escape = True - elif char == '"': - in_string = False - elif char == '"': - in_string = True - elif char == "{": - count += 1 - elif char == "}": - count -= 1 - if count == 0: - try: - result = json.loads(text[start_idx : i + 1]) - if isinstance(result, dict): - return result - return None - except json.JSONDecodeError: - return None - - return None - - def _create_fallback_assessment( - self, - question: str, - error: str, - ) -> JudgeAssessment: - """Create a fallback assessment when inference fails.""" - return JudgeAssessment( - details=AssessmentDetails( - mechanism_score=0, - mechanism_reasoning=f"Assessment failed: {error}", - clinical_evidence_score=0, - clinical_reasoning=f"Assessment failed: {error}", - drug_candidates=[], - key_findings=[], - ), - sufficient=False, - confidence=0.0, - recommendation="continue", - next_search_queries=[ - f"{question} mechanism", - f"{question} clinical trials", - f"{question} drug candidates", - ], - reasoning=f"HF Inference failed: {error}. Recommend configuring OpenAI/Anthropic key.", - ) - - -def create_judge_handler() -> JudgeHandler: - """Create a judge handler based on configuration. - - Returns: - Configured JudgeHandler instance - """ - return JudgeHandler() - - -class MockJudgeHandler: - """ - Mock JudgeHandler for demo mode without LLM calls. - - Extracts meaningful information from real search results - to provide a useful demo experience without requiring API keys. - """ - - def __init__(self, mock_response: JudgeAssessment | None = None) -> None: - """ - Initialize with optional mock response. - - Args: - mock_response: The assessment to return. If None, extracts from evidence. - """ - self.mock_response = mock_response - self.call_count = 0 - self.last_question: str | None = None - self.last_evidence: list[Evidence] | None = None - - def _extract_key_findings(self, evidence: list[Evidence], max_findings: int = 5) -> list[str]: - """Extract key findings from evidence titles.""" - findings = [] - for e in evidence[:max_findings]: - # Use first 150 chars of title as a finding - title = e.citation.title - if len(title) > 150: - title = title[:147] + "..." - findings.append(title) - return findings if findings else ["No specific findings extracted (demo mode)"] - - def _extract_drug_candidates(self, question: str, evidence: list[Evidence]) -> list[str]: - """Extract drug candidates - demo mode returns honest message.""" - # Don't attempt heuristic extraction - it produces garbage like "Oral", "Kidney" - # Real drug extraction requires LLM analysis - return [ - "Drug identification requires AI analysis", - "Enter API key above for full results", - ] - - async def assess( - self, - question: str, - evidence: list[Evidence], - ) -> JudgeAssessment: - """Return assessment based on actual evidence (demo mode).""" - self.call_count += 1 - self.last_question = question - self.last_evidence = evidence - - if self.mock_response: - return self.mock_response - - min_evidence = 3 - evidence_count = len(evidence) - - # Extract meaningful data from actual evidence - drug_candidates = self._extract_drug_candidates(question, evidence) - key_findings = self._extract_key_findings(evidence) - - # Calculate scores based on evidence quantity - mechanism_score = min(10, evidence_count * 2) if evidence_count > 0 else 0 - clinical_score = min(10, evidence_count) if evidence_count > 0 else 0 - - return JudgeAssessment( - details=AssessmentDetails( - mechanism_score=mechanism_score, - mechanism_reasoning=( - f"Demo mode: Found {evidence_count} sources. " - "Configure LLM API key for detailed mechanism analysis." - ), - clinical_evidence_score=clinical_score, - clinical_reasoning=( - f"Demo mode: {evidence_count} sources retrieved from PubMed, " - "ClinicalTrials.gov, and Europe PMC. Full analysis requires LLM API key." - ), - drug_candidates=drug_candidates, - key_findings=key_findings, - ), - sufficient=evidence_count >= min_evidence, - confidence=min(0.5, evidence_count * 0.1) if evidence_count > 0 else 0.0, - recommendation="synthesize" if evidence_count >= min_evidence else "continue", - next_search_queries=( - [f"{question} mechanism", f"{question} clinical trials"] - if evidence_count < min_evidence - else [] - ), - reasoning=( - f"Demo mode assessment based on {evidence_count} real search results. " - "For AI-powered analysis with drug candidate identification and " - "evidence synthesis, configure OPENAI_API_KEY or ANTHROPIC_API_KEY." - ), - ) diff --git a/src/agents/__init__.py b/src/agents/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/agents/analysis_agent.py b/src/agents/analysis_agent.py deleted file mode 100644 index e8bbc7fec996050d96af1f9edd3c8605744fc830..0000000000000000000000000000000000000000 --- a/src/agents/analysis_agent.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Analysis agent for statistical analysis using Modal code execution. - -This agent wraps StatisticalAnalyzer for use in magentic multi-agent mode. -The core logic is in src/services/statistical_analyzer.py to avoid -coupling agent_framework to the simple orchestrator. -""" - -from collections.abc import AsyncIterable -from typing import TYPE_CHECKING, Any - -from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - AgentThread, - BaseAgent, - ChatMessage, - Role, -) - -from src.services.statistical_analyzer import ( - AnalysisResult, - get_statistical_analyzer, -) - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - - -class AnalysisAgent(BaseAgent): # type: ignore[misc] - """Wraps StatisticalAnalyzer for magentic multi-agent mode.""" - - def __init__( - self, - evidence_store: dict[str, Any], - embedding_service: "EmbeddingService | None" = None, - ) -> None: - super().__init__( - name="AnalysisAgent", - description="Performs statistical analysis using Modal sandbox", - ) - self._evidence_store = evidence_store - self._embeddings = embedding_service - self._analyzer = get_statistical_analyzer() - - async def run( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AgentRunResponse: - """Analyze evidence and return verdict.""" - query = self._extract_query(messages) - hypotheses = self._evidence_store.get("hypotheses", []) - evidence = self._evidence_store.get("current", []) - - if not evidence: - return self._error_response("No evidence available.") - - # Get primary hypothesis if available - hypothesis_dict = None - if hypotheses: - h = hypotheses[0] - hypothesis_dict = { - "drug": getattr(h, "drug", "Unknown"), - "target": getattr(h, "target", "?"), - "pathway": getattr(h, "pathway", "?"), - "effect": getattr(h, "effect", "?"), - "confidence": getattr(h, "confidence", 0.5), - } - - # Delegate to StatisticalAnalyzer - result = await self._analyzer.analyze( - query=query, - evidence=evidence, - hypothesis=hypothesis_dict, - ) - - # Store in shared context - self._evidence_store["analysis"] = result.model_dump() - - # Format response - response_text = self._format_response(result) - - return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)], - response_id=f"analysis-{result.verdict.lower()}", - additional_properties={"analysis": result.model_dump()}, - ) - - def _format_response(self, result: AnalysisResult) -> str: - """Format analysis result as markdown.""" - lines = [ - "## Statistical Analysis Complete\n", - f"### Verdict: **{result.verdict}**", - f"**Confidence**: {result.confidence:.0%}\n", - "### Key Findings", - ] - for finding in result.key_findings: - lines.append(f"- {finding}") - - lines.extend( - [ - "\n### Statistical Evidence", - "```", - result.statistical_evidence, - "```", - ] - ) - return "\n".join(lines) - - def _error_response(self, message: str) -> AgentRunResponse: - """Create error response.""" - return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text=f"**Error**: {message}")], - response_id="analysis-error", - ) - - def _extract_query( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None, - ) -> str: - """Extract query from messages.""" - if isinstance(messages, str): - return messages - elif isinstance(messages, ChatMessage): - return messages.text or "" - elif isinstance(messages, list): - for msg in reversed(messages): - if isinstance(msg, ChatMessage) and msg.role == Role.USER: - return msg.text or "" - elif isinstance(msg, str): - return msg - return "" - - async def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - """Streaming wrapper.""" - result = await self.run(messages, thread=thread, **kwargs) - yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id) diff --git a/src/agents/audio_refiner.py b/src/agents/audio_refiner.py deleted file mode 100644 index 257c6f5e6df42b16687f181a98ab36178dc9b26a..0000000000000000000000000000000000000000 --- a/src/agents/audio_refiner.py +++ /dev/null @@ -1,402 +0,0 @@ -"""Audio Refiner Agent - Cleans markdown reports for TTS audio clarity. - -This agent transforms markdown-formatted research reports into clean, -audio-friendly plain text suitable for text-to-speech synthesis. -""" - -import re - -import structlog -from pydantic_ai import Agent - -from src.utils.llm_factory import get_pydantic_ai_model - -logger = structlog.get_logger(__name__) - - -class AudioRefiner: - """Refines markdown reports for optimal TTS audio output. - - Handles common formatting issues that make text difficult to listen to: - - Markdown syntax (headers, bold, italic, links) - - Citations and reference markers - - Roman numerals in medical contexts - - Multiple References sections - - Special characters and formatting artifacts - """ - - # Roman numeral to integer mapping - ROMAN_VALUES = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000} - - # Number to word mapping (1-20, common in medical literature) - NUMBER_TO_WORD = { - 1: "One", - 2: "Two", - 3: "Three", - 4: "Four", - 5: "Five", - 6: "Six", - 7: "Seven", - 8: "Eight", - 9: "Nine", - 10: "Ten", - 11: "Eleven", - 12: "Twelve", - 13: "Thirteen", - 14: "Fourteen", - 15: "Fifteen", - 16: "Sixteen", - 17: "Seventeen", - 18: "Eighteen", - 19: "Nineteen", - 20: "Twenty", - } - - async def refine_for_audio(self, markdown_text: str, use_llm_polish: bool = False) -> str: - """Transform markdown report into audio-friendly plain text. - - Args: - markdown_text: Markdown-formatted research report - use_llm_polish: If True, apply LLM-based final polish (optional) - - Returns: - Clean plain text optimized for TTS audio - """ - logger.info("Refining report for audio output", use_llm_polish=use_llm_polish) - - text = markdown_text - - # Step 1: Remove References sections first (before other processing) - text = self._remove_references_sections(text) - - # Step 2: Remove markdown formatting - text = self._remove_markdown_syntax(text) - - # Step 3: Convert roman numerals to words - text = self._convert_roman_numerals(text) - - # Step 4: Remove citations - text = self._remove_citations(text) - - # Step 5: Clean up special characters and artifacts - text = self._clean_special_characters(text) - - # Step 6: Normalize whitespace - text = self._normalize_whitespace(text) - - # Step 7 (Optional): LLM polish for edge cases - if use_llm_polish: - text = await self._llm_polish(text) - - logger.info( - "Audio refinement complete", - original_length=len(markdown_text), - refined_length=len(text), - llm_polish_applied=use_llm_polish, - ) - - return text.strip() - - def _remove_references_sections(self, text: str) -> str: - """Remove References sections while preserving other content. - - Removes the References section and its content until the next section - heading or end of document. Handles multiple References sections. - - Matches various References heading formats: - - # References - - ## References - - **References:** - - **Additional References:** - - References: (plain text) - """ - # Pattern to match References section heading (case-insensitive) - # Matches: markdown headers (# References), bold (**References:**), or plain text (References:) - references_pattern = r"\n(?:#+\s*References?:?\s*\n|\*\*\s*(?:Additional\s+)?References?:?\s*\*\*\s*\n|References?:?\s*\n)" - - # Find all References sections - while True: - match = re.search(references_pattern, text, re.IGNORECASE) - if not match: - break - - # Find the start of the References section - section_start = match.start() - - # Find the next section (markdown header or bold heading) or end of document - # Match: "# Header", "## Header", or "**Header**" - next_section_patterns = [ - r"\n#+\s+\w+", # Markdown headers (# Section, ## Section) - r"\n\*\*[A-Z][^*]+\*\*", # Bold headings (**Section Name**) - ] - - remaining_text = text[match.end() :] - next_section_match = None - - # Try all patterns and find the earliest match - earliest_match = None - for pattern in next_section_patterns: - m = re.search(pattern, remaining_text) - if m and (earliest_match is None or m.start() < earliest_match.start()): - earliest_match = m - - next_section_match = earliest_match - - if next_section_match: - # Remove from References heading to next section - section_end = match.end() + next_section_match.start() - else: - # No next section - remove to end of document - section_end = len(text) - - # Remove the References section - text = text[:section_start] + text[section_end:] - logger.debug("Removed References section", removed_chars=section_end - section_start) - - return text - - def _remove_markdown_syntax(self, text: str) -> str: - """Remove markdown formatting syntax.""" - - # Headers (# ## ###) - text = re.sub(r"^\s*#+\s+", "", text, flags=re.MULTILINE) - - # Bold (**text** or __text__) - text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) - text = re.sub(r"__([^_]+)__", r"\1", text) - - # Italic (*text* or _text_) - text = re.sub(r"\*([^*]+)\*", r"\1", text) - text = re.sub(r"_([^_]+)_", r"\1", text) - - # Links [text](url) → text - text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text) - - # Inline code `code` → code - text = re.sub(r"`([^`]+)`", r"\1", text) - - # Strikethrough ~~text~~ - text = re.sub(r"~~([^~]+)~~", r"\1", text) - - # Blockquotes (> text) - text = re.sub(r"^\s*>\s+", "", text, flags=re.MULTILINE) - - # Horizontal rules (---, ***, ___) - text = re.sub(r"^\s*[-*_]{3,}\s*$", "", text, flags=re.MULTILINE) - - # List markers (-, *, 1., 2.) - text = re.sub(r"^\s*[-*]\s+", "", text, flags=re.MULTILINE) - text = re.sub(r"^\s*\d+\.\s+", "", text, flags=re.MULTILINE) - - return text - - def _roman_to_int(self, roman: str) -> int | None: - """Convert roman numeral string to integer. - - Args: - roman: Roman numeral string (e.g., 'IV', 'XII') - - Returns: - Integer value, or None if invalid roman numeral - """ - roman = roman.upper() - result = 0 - prev_value = 0 - - for char in reversed(roman): - if char not in self.ROMAN_VALUES: - return None - - value = self.ROMAN_VALUES[char] - - # Subtractive notation (IV = 4, IX = 9) - if value < prev_value: - result -= value - else: - result += value - - prev_value = value - - return result - - def _int_to_word(self, num: int) -> str: - """Convert integer to word representation. - - Args: - num: Integer to convert (1-20 supported) - - Returns: - Word representation (e.g., 'One', 'Twelve') - """ - if num in self.NUMBER_TO_WORD: - return self.NUMBER_TO_WORD[num] - else: - # For numbers > 20, just return the digit - return str(num) - - def _convert_roman_numerals(self, text: str) -> str: - """Convert roman numerals to words for better TTS pronunciation. - - Handles patterns like: - - Phase I, Phase II, Phase III - - Trial I, Trial II - - Type I, Type II - - Stage I, Stage II - - Standalone I, II, III (with word boundaries) - """ - - def replace_roman(match: re.Match[str]) -> str: - """Callback to replace matched roman numeral.""" - prefix = match.group(1) # Word before roman numeral (if any) - roman = match.group(2) # The roman numeral - - # Convert to integer - num = self._roman_to_int(roman) - if num is None: - return match.group(0) # Return original if invalid - - # Convert to word - word = self._int_to_word(num) - - # Return with prefix if present - if prefix: - return f"{prefix} {word}" - else: - return word - - # Pattern: Optional word + space + roman numeral - # Matches: "Phase I", "Trial II", standalone "I", "II" - # Uses word boundaries to avoid matching "I" in "INVALID" - pattern = r"\b(Phase|Trial|Type|Stage|Class|Group|Arm|Cohort)?\s*([IVXLCDM]+)\b" - - text = re.sub(pattern, replace_roman, text) - - return text - - def _remove_citations(self, text: str) -> str: - """Remove citation markers and references.""" - - # Numbered citations [1], [2], [1,2], [1-3] - text = re.sub(r"\[\d+(?:[-,]\d+)*\]", "", text) - - # Author citations (Smith et al., 2023) or (Smith et al. 2023) - text = re.sub(r"\([A-Z][a-z]+\s+et\s+al\.?,?\s+\d{4}\)", "", text) - - # Simple year citations (2023) - text = re.sub(r"\(\d{4}\)", "", text) - - # Author-year (Smith, 2023) - text = re.sub(r"\([A-Z][a-z]+,?\s+\d{4}\)", "", text) - - # Footnote markers (¹, ², ³) - text = re.sub(r"[¹²³⁴⁵⁶⁷⁸⁹⁰]+", "", text) - - return text - - def _clean_special_characters(self, text: str) -> str: - """Clean up special characters and formatting artifacts.""" - - # Replace em dashes with regular dashes - text = text.replace("\u2014", "-") # em dash - text = text.replace("\u2013", "-") # en dash - - # Replace smart quotes with regular quotes - text = text.replace("\u201c", '"') # left double quote - text = text.replace("\u201d", '"') # right double quote - text = text.replace("\u2018", "'") # left single quote - text = text.replace("\u2019", "'") # right single quote - - # Remove excessive punctuation (!!!, ???) - text = re.sub(r"([!?]){2,}", r"\1", text) - - # Remove asterisks used for footnotes - text = re.sub(r"\*+", "", text) - - # Remove hash symbols (from headers) - text = text.replace("#", "") - - # Remove excessive dots (...) - text = re.sub(r"\.{4,}", "...", text) - - return text - - def _normalize_whitespace(self, text: str) -> str: - """Normalize whitespace for clean audio output.""" - - # Replace multiple spaces with single space - text = re.sub(r" {2,}", " ", text) - - # Replace multiple newlines with double newline (paragraph break) - text = re.sub(r"\n{3,}", "\n\n", text) - - # Remove trailing/leading whitespace from lines - text = "\n".join(line.strip() for line in text.split("\n")) - - # Remove empty lines at start/end - text = text.strip() - - return text - - async def _llm_polish(self, text: str) -> str: - """Apply LLM-based final polish to catch edge cases. - - This is a lightweight pass that removes any remaining formatting - artifacts the rule-based methods might have missed. - - Args: - text: Pre-cleaned text from rule-based methods - - Returns: - Final polished text ready for TTS - """ - try: - # Create a simple agent for text cleanup - model = get_pydantic_ai_model() - polish_agent = Agent( - model=model, - system_prompt=( - "You are a text cleanup assistant. Your ONLY job is to remove " - "any remaining formatting artifacts (markdown, citations, special " - "characters) that make text unsuitable for text-to-speech audio. " - "DO NOT rewrite, improve, or change the content. " - "DO NOT add explanations. " - "ONLY output the cleaned text." - ), - ) - - # Run asynchronously - result = await polish_agent.run( - f"Clean this text for audio (remove any formatting artifacts):\n\n{text}" - ) - - polished_text = result.output.strip() - - logger.info( - "llm_polish_applied", original_length=len(text), polished_length=len(polished_text) - ) - - return polished_text - - except Exception as e: - logger.warning( - "llm_polish_failed", error=str(e), message="Falling back to rule-based output" - ) - # Graceful fallback: return original text if LLM fails - return text - - -# Singleton instance for easy import -audio_refiner = AudioRefiner() - - -async def refine_text_for_audio(markdown_text: str, use_llm_polish: bool = False) -> str: - """Convenience function to refine markdown text for audio. - - Args: - markdown_text: Markdown-formatted text - use_llm_polish: If True, apply LLM-based final polish (optional) - - Returns: - Audio-friendly plain text - """ - return await audio_refiner.refine_for_audio(markdown_text, use_llm_polish=use_llm_polish) diff --git a/src/agents/code_executor_agent.py b/src/agents/code_executor_agent.py deleted file mode 100644 index e86eab2225b9071e3421ed4f62173867d43ae1ea..0000000000000000000000000000000000000000 --- a/src/agents/code_executor_agent.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Code execution agent using Modal.""" - -import asyncio -from typing import Any - -import structlog -from agent_framework import ChatAgent, ai_function - -from src.tools.code_execution import get_code_executor -from src.utils.llm_factory import get_chat_client_for_agent - -logger = structlog.get_logger() - - -@ai_function # type: ignore[arg-type, misc] -async def execute_python_code(code: str) -> str: - """Execute Python code in a secure sandbox. - - Args: - code: The Python code to execute. - - Returns: - The standard output and standard error of the execution. - """ - logger.info("Code execution starting", code_length=len(code)) - executor = get_code_executor() - loop = asyncio.get_running_loop() - - # Run in executor to avoid blocking - try: - result = await loop.run_in_executor(None, lambda: executor.execute(code)) - if result["success"]: - logger.info("Code execution succeeded") - return f"Stdout:\n{result['stdout']}" - else: - logger.warning("Code execution failed", error=result.get("error")) - return f"Error:\n{result['error']}\nStderr:\n{result['stderr']}" - except Exception as e: - logger.error("Code execution exception", error=str(e)) - return f"Execution failed: {e}" - - -def create_code_executor_agent(chat_client: Any | None = None) -> ChatAgent: - """Create a code executor agent. - - Args: - chat_client: Optional custom chat client. If None, uses factory default - (HuggingFace preferred, OpenAI fallback). - - Returns: - ChatAgent configured for code execution. - """ - client = chat_client or get_chat_client_for_agent() - - return ChatAgent( - name="CodeExecutorAgent", - description="Executes Python code for data analysis, calculation, and simulation.", - instructions="""You are a code execution expert. -When asked to analyze data or perform calculations, write Python code and execute it. -Use libraries like pandas, numpy, scipy, matplotlib. - -Always output the code you want to execute using the `execute_python_code` tool. -Check the output and interpret the results.""", - chat_client=client, - tools=[execute_python_code], - temperature=0.0, # Strict code generation - ) diff --git a/src/agents/hypothesis_agent.py b/src/agents/hypothesis_agent.py deleted file mode 100644 index 6619fae1da264e80f296d7aa56528a17c1aa6815..0000000000000000000000000000000000000000 --- a/src/agents/hypothesis_agent.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Hypothesis agent for mechanistic reasoning.""" - -from collections.abc import AsyncIterable -from typing import TYPE_CHECKING, Any - -from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - AgentThread, - BaseAgent, - ChatMessage, - Role, -) -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.prompts.hypothesis import SYSTEM_PROMPT, format_hypothesis_prompt -from src.utils.models import HypothesisAssessment - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - - -class HypothesisAgent(BaseAgent): # type: ignore[misc] - """Generates mechanistic hypotheses based on evidence.""" - - def __init__( - self, - evidence_store: dict[str, Any], - embedding_service: "EmbeddingService | None" = None, # NEW: for diverse selection - ) -> None: - super().__init__( - name="HypothesisAgent", - description="Generates scientific hypotheses about drug mechanisms to guide research", - ) - self._evidence_store = evidence_store - self._embeddings = embedding_service # Used for MMR evidence selection - self._agent: Agent[None, HypothesisAssessment] | None = None # Lazy init - - def _get_agent(self) -> Agent[None, HypothesisAssessment]: - """Lazy initialization of LLM agent to avoid requiring API keys at import.""" - if self._agent is None: - self._agent = Agent( - model=get_model(), # Uses configured LLM (OpenAI/Anthropic) - output_type=HypothesisAssessment, - system_prompt=SYSTEM_PROMPT, - ) - return self._agent - - async def run( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AgentRunResponse: - """Generate hypotheses based on current evidence.""" - # Extract query - query = self._extract_query(messages) - - # Get current evidence - evidence = self._evidence_store.get("current", []) - - if not evidence: - return AgentRunResponse( - messages=[ - ChatMessage( - role=Role.ASSISTANT, - text="No evidence available yet. Search for evidence first.", - ) - ], - response_id="hypothesis-no-evidence", - ) - - # Generate hypotheses with diverse evidence selection - prompt = await format_hypothesis_prompt(query, evidence, embeddings=self._embeddings) - result = await self._get_agent().run(prompt) - assessment = result.output # pydantic-ai returns .output for structured output - - # Store hypotheses in shared context - existing = self._evidence_store.get("hypotheses", []) - self._evidence_store["hypotheses"] = existing + assessment.hypotheses - - # Format response - response_text = self._format_response(assessment) - - return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)], - response_id=f"hypothesis-{len(assessment.hypotheses)}", - additional_properties={"assessment": assessment.model_dump()}, - ) - - def _format_response(self, assessment: HypothesisAssessment) -> str: - """Format hypothesis assessment as markdown.""" - lines = ["## Generated Hypotheses\n"] - - for i, h in enumerate(assessment.hypotheses, 1): - lines.append(f"### Hypothesis {i} (Confidence: {h.confidence:.0%})") - lines.append(f"**Mechanism**: {h.drug} -> {h.target} -> {h.pathway} -> {h.effect}") - lines.append(f"**Suggested searches**: {', '.join(h.search_suggestions)}\n") - - if assessment.primary_hypothesis: - lines.append("### Primary Hypothesis") - h = assessment.primary_hypothesis - lines.append(f"{h.drug} -> {h.target} -> {h.pathway} -> {h.effect}\n") - - if assessment.knowledge_gaps: - lines.append("### Knowledge Gaps") - for gap in assessment.knowledge_gaps: - lines.append(f"- {gap}") - - if assessment.recommended_searches: - lines.append("\n### Recommended Next Searches") - for search in assessment.recommended_searches: - lines.append(f"- `{search}`") - - return "\n".join(lines) - - def _extract_query( - self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None - ) -> str: - """Extract query from messages.""" - if isinstance(messages, str): - return messages - elif isinstance(messages, ChatMessage): - return messages.text or "" - elif isinstance(messages, list): - for msg in reversed(messages): - if isinstance(msg, ChatMessage) and msg.role == Role.USER: - return msg.text or "" - elif isinstance(msg, str): - return msg - return "" - - async def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - """Streaming wrapper.""" - result = await self.run(messages, thread=thread, **kwargs) - yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id) diff --git a/src/agents/input_parser.py b/src/agents/input_parser.py deleted file mode 100644 index bb6b920d415287e35d8a471a1c7aaf64f371ec6c..0000000000000000000000000000000000000000 --- a/src/agents/input_parser.py +++ /dev/null @@ -1,189 +0,0 @@ -"""Input parser agent for analyzing and improving user queries. - -Determines research mode (iterative vs deep) and extracts key information -from user queries to improve research quality. -""" - -from typing import TYPE_CHECKING, Any, Literal - -import structlog -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.utils.exceptions import ConfigurationError, JudgeError -from src.utils.models import ParsedQuery - -if TYPE_CHECKING: - pass - -logger = structlog.get_logger() - -# System prompt for the input parser agent -SYSTEM_PROMPT = """ -You are an expert research query analyzer for a generalist deep research agent. Your job is to analyze user queries and determine: -1. Whether the query requires iterative research (single focused question) or deep research (multiple sections/topics) -2. Whether the query requires medical/biomedical knowledge sources (PubMed, ClinicalTrials.gov) or general knowledge sources (web search) -3. Improve and refine the query for better research results -4. Extract key entities (drugs, diseases, companies, technologies, concepts, etc.) -5. Extract specific research questions - -Guidelines for determining research mode: -- **Iterative mode**: Single focused question, straightforward research goal, can be answered with a focused search loop - Examples: "What is the mechanism of metformin?", "How does quantum computing work?", "What are the latest AI models?" - -- **Deep mode**: Complex query requiring multiple sections, comprehensive report, multiple related topics - Examples: "Write a comprehensive report on diabetes treatment", "Analyze the market for quantum computing", "Review the state of AI in healthcare" - Indicators: words like "comprehensive", "report", "sections", "analyze", "market analysis", "overview" - -Guidelines for determining if medical knowledge is needed: -- **Medical knowledge needed**: Queries about diseases, treatments, drugs, clinical trials, medical conditions, biomedical mechanisms, health outcomes, etc. - Examples: "Alzheimer's treatment", "metformin mechanism", "cancer clinical trials", "diabetes research" - -- **General knowledge sufficient**: Queries about technology, business, science (non-medical), history, current events, etc. - Examples: "quantum computing", "AI models", "market analysis", "historical events" - -Your output must be valid JSON matching the ParsedQuery schema. Always provide: -- original_query: The exact input query -- improved_query: A refined, clearer version of the query -- research_mode: Either "iterative" or "deep" -- key_entities: List of important entities (drugs, diseases, companies, technologies, etc.) -- research_questions: List of specific questions to answer - -Only output JSON. Do not output anything else. -""" - - -class InputParserAgent: - """ - Input parser agent that analyzes queries and determines research mode. - - Uses Pydantic AI to generate structured ParsedQuery output with research - mode detection, query improvement, and entity extraction. - """ - - def __init__(self, model: Any | None = None) -> None: - """ - Initialize the input parser agent. - - Args: - model: Optional Pydantic AI model. If None, uses config default. - """ - self.model = model or get_model() - self.logger = logger - - # Initialize Pydantic AI Agent - self.agent = Agent( - model=self.model, - output_type=ParsedQuery, - system_prompt=SYSTEM_PROMPT, - retries=3, - ) - - async def parse(self, query: str) -> ParsedQuery: - """ - Parse and analyze a user query. - - Args: - query: The user's research query - - Returns: - ParsedQuery with research mode, improved query, entities, and questions - - Raises: - JudgeError: If parsing fails after retries - ConfigurationError: If agent configuration is invalid - """ - self.logger.info("Parsing user query", query=query[:100]) - - user_message = f"QUERY: {query}" - - try: - # Run the agent - result = await self.agent.run(user_message) - parsed_query = result.output - - # Validate parsed query - if not parsed_query.original_query: - self.logger.warning("Parsed query missing original_query", query=query[:100]) - raise JudgeError("Parsed query must have original_query") - - if not parsed_query.improved_query: - self.logger.warning("Parsed query missing improved_query", query=query[:100]) - # Use original as fallback - parsed_query = ParsedQuery( - original_query=parsed_query.original_query, - improved_query=parsed_query.original_query, - research_mode=parsed_query.research_mode, - key_entities=parsed_query.key_entities, - research_questions=parsed_query.research_questions, - ) - - self.logger.info( - "Query parsed successfully", - mode=parsed_query.research_mode, - entities=len(parsed_query.key_entities), - questions=len(parsed_query.research_questions), - ) - - return parsed_query - - except Exception as e: - self.logger.error("Query parsing failed", error=str(e), query=query[:100]) - - # Fallback: return basic parsed query with heuristic mode detection - if isinstance(e, JudgeError | ConfigurationError): - raise - - # Heuristic fallback - query_lower = query.lower() - research_mode: Literal["iterative", "deep"] = "iterative" - if any( - keyword in query_lower - for keyword in [ - "comprehensive", - "report", - "sections", - "analyze", - "analysis", - "overview", - "market", - ] - ): - research_mode = "deep" - - return ParsedQuery( - original_query=query, - improved_query=query, - research_mode=research_mode, - key_entities=[], - research_questions=[], - ) - - -def create_input_parser_agent( - model: Any | None = None, oauth_token: str | None = None -) -> InputParserAgent: - """ - Factory function to create an input parser agent. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured InputParserAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - try: - # Get model from settings if not provided - if model is None: - model = get_model(oauth_token=oauth_token) - - # Create and return input parser agent - return InputParserAgent(model=model) - - except Exception as e: - logger.error("Failed to create input parser agent", error=str(e)) - raise ConfigurationError(f"Failed to create input parser agent: {e}") from e diff --git a/src/agents/judge_agent.py b/src/agents/judge_agent.py deleted file mode 100644 index 9bab6dbbfcf9850f33aac8c0a8b267fc7a16fc58..0000000000000000000000000000000000000000 --- a/src/agents/judge_agent.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Judge agent wrapper for Magentic integration.""" - -from collections.abc import AsyncIterable -from typing import Any - -from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - AgentThread, - BaseAgent, - ChatMessage, - Role, -) - -from src.legacy_orchestrator import JudgeHandlerProtocol -from src.utils.models import Evidence, JudgeAssessment - - -class JudgeAgent(BaseAgent): # type: ignore[misc] - """Wraps JudgeHandler as an AgentProtocol for Magentic.""" - - def __init__( - self, - judge_handler: JudgeHandlerProtocol, - evidence_store: dict[str, list[Evidence]], - ) -> None: - super().__init__( - name="JudgeAgent", - description="Evaluates evidence quality and determines if sufficient for synthesis", - ) - self._handler = judge_handler - self._evidence_store = evidence_store # Shared state for evidence - - async def run( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AgentRunResponse: - """Assess evidence quality.""" - # Extract original question from messages - question = "" - if isinstance(messages, list): - for msg in reversed(messages): - if isinstance(msg, ChatMessage) and msg.role == Role.USER and msg.text: - question = msg.text - break - elif isinstance(msg, str): - question = msg - break - elif isinstance(messages, str): - question = messages - elif isinstance(messages, ChatMessage) and messages.text: - question = messages.text - - # Get evidence from shared store - evidence = self._evidence_store.get("current", []) - - # Assess - assessment: JudgeAssessment = await self._handler.assess(question, evidence) - - # Format response - response_text = f"""## Assessment - -**Sufficient**: {assessment.sufficient} -**Confidence**: {assessment.confidence:.0%} -**Recommendation**: {assessment.recommendation} - -### Scores -- Mechanism: {assessment.details.mechanism_score}/10 -- Clinical: {assessment.details.clinical_evidence_score}/10 - -### Reasoning -{assessment.reasoning} -""" - - if assessment.next_search_queries: - response_text += "\n### Next Queries\n" + "\n".join( - f"- {q}" for q in assessment.next_search_queries - ) - - return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)], - response_id=f"judge-{assessment.recommendation}", - additional_properties={"assessment": assessment.model_dump()}, - ) - - async def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - """Streaming wrapper for judge.""" - result = await self.run(messages, thread=thread, **kwargs) - yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id) diff --git a/src/agents/judge_agent_llm.py b/src/agents/judge_agent_llm.py deleted file mode 100644 index 12453e09489d9379fdf2e61c213591dcc27a5a27..0000000000000000000000000000000000000000 --- a/src/agents/judge_agent_llm.py +++ /dev/null @@ -1,45 +0,0 @@ -"""LLM Judge for sub-iterations.""" - -from typing import Any - -import structlog -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.utils.models import JudgeAssessment - -logger = structlog.get_logger() - - -class LLMSubIterationJudge: - """Judge that uses an LLM to assess sub-iteration results.""" - - def __init__(self) -> None: - self.model = get_model() - self.agent = Agent( - model=self.model, - output_type=JudgeAssessment, - system_prompt="""You are a strict judge evaluating a research task. - -Evaluate if the result is sufficient to answer the task. -Provide scores and detailed reasoning. -If not sufficient, suggest next steps.""", - retries=3, - ) - - async def assess(self, task: str, result: Any, history: list[Any]) -> JudgeAssessment: - """Assess the result using LLM.""" - logger.info("LLM judge assessing result", task=task[:100], history_len=len(history)) - - prompt = f"""Task: {task} - -Current Result: -{str(result)[:4000]} - -History of previous attempts: {len(history)} - -Evaluate validity and sufficiency.""" - - run_result = await self.agent.run(prompt) - logger.info("LLM judge assessment complete", sufficient=run_result.output.sufficient) - return run_result.output diff --git a/src/agents/knowledge_gap.py b/src/agents/knowledge_gap.py deleted file mode 100644 index 4ceab6cc2740eb22df06212ca3dc29bdb94dbc62..0000000000000000000000000000000000000000 --- a/src/agents/knowledge_gap.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Knowledge gap agent for evaluating research completeness. - -Converts the folder/knowledge_gap_agent.py implementation to use Pydantic AI. -""" - -from datetime import datetime -from typing import Any - -import structlog -from pydantic_ai import Agent - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] - -from src.agent_factory.judges import get_model -from src.utils.exceptions import ConfigurationError -from src.utils.models import KnowledgeGapOutput - -logger = structlog.get_logger() - - -# System prompt for the knowledge gap agent -SYSTEM_PROMPT = f""" -You are a Research State Evaluator. Today's date is {datetime.now().strftime("%Y-%m-%d")}. -Your job is to critically analyze the current state of a research report, -identify what knowledge gaps still exist and determine the best next step to take. - -You will be given: -1. The original user query and any relevant background context to the query -2. A full history of the tasks, actions, findings and thoughts you've made up until this point in the research process - -Your task is to: -1. Carefully review the findings and thoughts, particularly from the latest iteration, and assess their completeness in answering the original query -2. Determine if the findings are sufficiently complete to end the research loop -3. If not, identify up to 3 knowledge gaps that need to be addressed in sequence in order to continue with research - these should be relevant to the original query - -Be specific in the gaps you identify and include relevant information as this will be passed onto another agent to process without additional context. - -Only output JSON. Follow the JSON schema for KnowledgeGapOutput. Do not output anything else. -""" - - -class KnowledgeGapAgent: - """ - Agent that evaluates research state and identifies knowledge gaps. - - Uses Pydantic AI to generate structured KnowledgeGapOutput indicating - whether research is complete and what gaps remain. - """ - - def __init__(self, model: Any | None = None) -> None: - """ - Initialize the knowledge gap agent. - - Args: - model: Optional Pydantic AI model. If None, uses config default. - """ - self.model = model or get_model() - self.logger = logger - - # Initialize Pydantic AI Agent - self.agent = Agent( - model=self.model, - output_type=KnowledgeGapOutput, - system_prompt=SYSTEM_PROMPT, - retries=3, - ) - - async def evaluate( - self, - query: str, - background_context: str = "", - conversation_history: str = "", - message_history: list[ModelMessage] | None = None, - iteration: int = 0, - time_elapsed_minutes: float = 0.0, - max_time_minutes: int = 10, - ) -> KnowledgeGapOutput: - """ - Evaluate research state and identify knowledge gaps. - - Args: - query: The original research query - background_context: Optional background context - conversation_history: History of actions, findings, and thoughts (backward compat) - message_history: Optional user conversation history (Pydantic AI format) - iteration: Current iteration number - time_elapsed_minutes: Time elapsed so far - max_time_minutes: Maximum time allowed - - Returns: - KnowledgeGapOutput with research completeness and outstanding gaps - - Raises: - JudgeError: If evaluation fails after retries - """ - self.logger.info( - "Evaluating knowledge gaps", - query=query[:100], - iteration=iteration, - ) - - background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else "" - - user_message = f""" -Current Iteration Number: {iteration} -Time Elapsed: {time_elapsed_minutes:.2f} minutes of maximum {max_time_minutes} minutes - -ORIGINAL QUERY: -{query} - -{background} - -HISTORY OF ACTIONS, FINDINGS AND THOUGHTS: -{conversation_history or "No previous actions, findings or thoughts available."} -""" - - try: - # Run the agent with message_history if provided - if message_history: - result = await self.agent.run(user_message, message_history=message_history) - else: - result = await self.agent.run(user_message) - evaluation = result.output - - self.logger.info( - "Knowledge gap evaluation complete", - research_complete=evaluation.research_complete, - gaps_count=len(evaluation.outstanding_gaps), - ) - - return evaluation - - except Exception as e: - self.logger.error("Knowledge gap evaluation failed", error=str(e)) - # Return fallback: research not complete, suggest continuing - return KnowledgeGapOutput( - research_complete=False, - outstanding_gaps=[f"Continue research on: {query}"], - ) - - -def create_knowledge_gap_agent( - model: Any | None = None, oauth_token: str | None = None -) -> KnowledgeGapAgent: - """ - Factory function to create a knowledge gap agent. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured KnowledgeGapAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - try: - if model is None: - model = get_model(oauth_token=oauth_token) - - return KnowledgeGapAgent(model=model) - - except Exception as e: - logger.error("Failed to create knowledge gap agent", error=str(e)) - raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e diff --git a/src/agents/long_writer.py b/src/agents/long_writer.py deleted file mode 100644 index 5b0553aae60d4daf89f7c6879d024a44a1ea1a42..0000000000000000000000000000000000000000 --- a/src/agents/long_writer.py +++ /dev/null @@ -1,466 +0,0 @@ -"""Long writer agent for iteratively writing report sections. - -Converts the folder/long_writer_agent.py implementation to use Pydantic AI. -""" - -import re -from datetime import datetime -from typing import Any - -import structlog -from pydantic import BaseModel, Field -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.utils.exceptions import ConfigurationError -from src.utils.models import ReportDraft - -logger = structlog.get_logger() - - -# LongWriterOutput model for structured output -class LongWriterOutput(BaseModel): - """Output from the long writer agent for a single section.""" - - next_section_markdown: str = Field( - description="The final draft of the next section in markdown format" - ) - references: list[str] = Field( - description="A list of URLs and their corresponding reference numbers for the section" - ) - - model_config = {"frozen": True} - - -# System prompt for the long writer agent -SYSTEM_PROMPT = f""" -You are an expert report writer tasked with iteratively writing each section of a report. -Today's date is {datetime.now().strftime("%Y-%m-%d")}. -You will be provided with: -1. The original research query -2. A final draft of the report containing the table of contents and all sections written up until this point (in the first iteration there will be no sections written yet) -3. A first draft of the next section of the report to be written - -OBJECTIVE: -1. Write a final draft of the next section of the report with numbered citations in square brackets in the body of the report -2. Produce a list of references to be appended to the end of the report - -CITATIONS/REFERENCES: -The citations should be in numerical order, written in numbered square brackets in the body of the report. -Separately, a list of all URLs and their corresponding reference numbers will be included at the end of the report. -Follow the example below for formatting. - -LongWriterOutput( - next_section_markdown="The company specializes in IT consulting [1]. It operates in the software services market which is expected to grow at 10% per year [2].", - references=["[1] https://example.com/first-source-url", "[2] https://example.com/second-source-url"] -) - -GUIDELINES: -- You can reformat and reorganize the flow of the content and headings within a section to flow logically, but DO NOT remove details that were included in the first draft -- Only remove text from the first draft if it is already mentioned earlier in the report, or if it should be covered in a later section per the table of contents -- Ensure the heading for the section matches the table of contents -- Format the final output and references section as markdown -- Do not include a title for the reference section, just a list of numbered references - -Only output JSON. Follow the JSON schema for LongWriterOutput. Do not output anything else. -""" - - -class LongWriterAgent: - """ - Agent that iteratively writes report sections with proper citations. - - Uses Pydantic AI to generate structured LongWriterOutput for each section. - """ - - def __init__(self, model: Any | None = None) -> None: - """ - Initialize the long writer agent. - - Args: - model: Optional Pydantic AI model. If None, uses config default. - """ - self.model = model or get_model() - self.logger = logger - - # Initialize Pydantic AI Agent - self.agent = Agent( - model=self.model, - output_type=LongWriterOutput, - system_prompt=SYSTEM_PROMPT, - retries=3, - ) - - async def write_next_section( - self, - original_query: str, - report_draft: str, - next_section_title: str, - next_section_draft: str, - ) -> LongWriterOutput: - """ - Write the next section of the report. - - Args: - original_query: The original research query - report_draft: Current report draft (all sections written so far) - next_section_title: Title of the section to write - next_section_draft: Draft content for the next section - - Returns: - LongWriterOutput with formatted section and references - - Raises: - ConfigurationError: If writing fails - """ - # Input validation - if not original_query or not original_query.strip(): - self.logger.warning("Empty query provided, using default") - original_query = "Research query" - - if not next_section_title or not next_section_title.strip(): - self.logger.warning("Empty section title provided, using default") - next_section_title = "Section" - - if next_section_draft is None: - next_section_draft = "" - - if report_draft is None: - report_draft = "" - - # Truncate very long inputs - max_draft_length = 30000 - if len(report_draft) > max_draft_length: - self.logger.warning( - "Report draft too long, truncating", - original_length=len(report_draft), - ) - report_draft = report_draft[:max_draft_length] + "\n\n[Content truncated]" - - if len(next_section_draft) > max_draft_length: - self.logger.warning( - "Section draft too long, truncating", - original_length=len(next_section_draft), - ) - next_section_draft = next_section_draft[:max_draft_length] + "\n\n[Content truncated]" - - self.logger.info( - "Writing next section", - section_title=next_section_title, - query=original_query[:100], - ) - - user_message = f""" - -{original_query} - - - -{report_draft or "No draft yet"} - - - -{next_section_title} - - - -{next_section_draft} - -""" - - # Retry logic for transient failures - max_retries = 3 - last_exception: Exception | None = None - - for attempt in range(max_retries): - try: - # Run the agent - result = await self.agent.run(user_message) - output = result.output - - # Validate output - if not output or not isinstance(output, LongWriterOutput): - raise ValueError("Invalid output format") - - if not output.next_section_markdown or not output.next_section_markdown.strip(): - self.logger.warning("Empty section generated, using fallback") - raise ValueError("Empty section generated") - - self.logger.info( - "Section written", - section_title=next_section_title, - references_count=len(output.references), - attempt=attempt + 1, - ) - - return output - - except (TimeoutError, ConnectionError) as e: - # Transient errors - retry - last_exception = e - if attempt < max_retries - 1: - self.logger.warning( - "Transient error, retrying", - error=str(e), - attempt=attempt + 1, - max_retries=max_retries, - ) - continue - else: - self.logger.error("Max retries exceeded for transient error", error=str(e)) - break - - except Exception as e: - # Non-transient errors - don't retry - last_exception = e - self.logger.error( - "Section writing failed", - error=str(e), - error_type=type(e).__name__, - ) - break - - # Return fallback section if all attempts failed - self.logger.error( - "Section writing failed after all attempts", - error=str(last_exception) if last_exception else "Unknown error", - ) - - # Try to enhance fallback with evidence if available - try: - from src.middleware.state_machine import get_workflow_state - - state = get_workflow_state() - if state and state.evidence: - # Include evidence citations in fallback - evidence_refs: list[str] = [] - for i, ev in enumerate(state.evidence[:10], 1): # Limit to 10 - authors = ( - ", ".join(ev.citation.authors[:2]) if ev.citation.authors else "Unknown" - ) - evidence_refs.append( - f"[{i}] {authors}. *{ev.citation.title}*. {ev.citation.url}" - ) - - enhanced_draft = f"## {next_section_title}\n\n{next_section_draft}" - if evidence_refs: - enhanced_draft += "\n\n### Sources\n\n" + "\n".join(evidence_refs) - - return LongWriterOutput( - next_section_markdown=enhanced_draft, - references=evidence_refs, - ) - except Exception as e: - self.logger.warning( - "Failed to enhance fallback with evidence", - error=str(e), - ) - - # Basic fallback - return LongWriterOutput( - next_section_markdown=f"## {next_section_title}\n\n{next_section_draft}", - references=[], - ) - - async def write_report( - self, - original_query: str, - report_title: str, - report_draft: ReportDraft, - ) -> str: - """ - Write the final report by iteratively writing each section. - - Args: - original_query: The original research query - report_title: Title of the report - report_draft: ReportDraft with all sections - - Returns: - Complete markdown report string - - Raises: - ConfigurationError: If writing fails - """ - # Input validation - if not original_query or not original_query.strip(): - self.logger.warning("Empty query provided, using default") - original_query = "Research query" - - if not report_title or not report_title.strip(): - self.logger.warning("Empty report title provided, using default") - report_title = "Research Report" - - if not report_draft or not report_draft.sections: - self.logger.warning("Empty report draft provided, returning minimal report") - return f"# {report_title}\n\n## Query\n{original_query}\n\n*No sections available.*" - - self.logger.info( - "Writing full report", - report_title=report_title, - sections_count=len(report_draft.sections), - ) - - # Initialize the final draft with title and table of contents - final_draft = ( - f"# {report_title}\n\n## Table of Contents\n\n" - + "\n".join( - [ - f"{i + 1}. {section.section_title}" - for i, section in enumerate(report_draft.sections) - ] - ) - + "\n\n" - ) - all_references: list[str] = [] - - for section in report_draft.sections: - # Write each section - next_section_output = await self.write_next_section( - original_query, - final_draft, - section.section_title, - section.section_content, - ) - - # Reformat references and update section markdown - section_markdown, all_references = self._reformat_references( - next_section_output.next_section_markdown, - next_section_output.references, - all_references, - ) - - # Reformat section headings - section_markdown = self._reformat_section_headings(section_markdown) - - # Add to final draft - final_draft += section_markdown + "\n\n" - - # Add final references - final_draft += "## References:\n\n" + " \n".join(all_references) - - self.logger.info("Full report written", length=len(final_draft)) - - return final_draft - - def _reformat_references( - self, - section_markdown: str, - section_references: list[str], - all_references: list[str], - ) -> tuple[str, list[str]]: - """ - Reformat references: re-number, de-duplicate, and update markdown. - - Args: - section_markdown: Markdown content with inline references [1], [2] - section_references: List of references for this section - all_references: Accumulated references from previous sections - - Returns: - Tuple of (updated markdown, updated all_references) - """ - - # Convert reference lists to maps (URL -> ref_num) - def convert_ref_list_to_map(ref_list: list[str]) -> dict[str, int]: - ref_map: dict[str, int] = {} - for ref in ref_list: - try: - # Parse "[1] https://example.com" format - parts = ref.split("]", 1) - if len(parts) == 2: - ref_num = int(parts[0].strip("[")) - url = parts[1].strip() - ref_map[url] = ref_num - except (ValueError, IndexError): - logger.warning("Invalid reference format", ref=ref) - continue - return ref_map - - section_ref_map = convert_ref_list_to_map(section_references) - report_ref_map = convert_ref_list_to_map(all_references) - section_to_report_ref_map: dict[int, int] = {} - - report_urls = set(report_ref_map.keys()) - ref_count = max(report_ref_map.values() or [0]) - - # Map section references to report references - for url, section_ref_num in section_ref_map.items(): - if url in report_urls: - # URL already exists - reuse its reference number - section_to_report_ref_map[section_ref_num] = report_ref_map[url] - else: - # New URL - assign next reference number - ref_count += 1 - section_to_report_ref_map[section_ref_num] = ref_count - all_references.append(f"[{ref_count}] {url}") - - # Replace reference numbers in markdown - def replace_reference(match: re.Match[str]) -> str: - ref_num = int(match.group(1)) - mapped_ref_num = section_to_report_ref_map.get(ref_num) - if mapped_ref_num: - return f"[{mapped_ref_num}]" - return "" - - updated_markdown = re.sub(r"\[(\d+)\]", replace_reference, section_markdown) - - return updated_markdown, all_references - - def _reformat_section_headings(self, section_markdown: str) -> str: - """ - Reformat section headings to be consistent (level-2 for main heading). - - Args: - section_markdown: Markdown content with headings - - Returns: - Updated markdown with adjusted heading levels - """ - if not section_markdown.strip(): - return section_markdown - - # Find first heading level - first_heading_match = re.search(r"^(#+)\s", section_markdown, re.MULTILINE) - if not first_heading_match: - return section_markdown - - # Calculate level adjustment needed (target is level 2) - first_heading_level = len(first_heading_match.group(1)) - level_adjustment = 2 - first_heading_level - - def adjust_heading_level(match: re.Match[str]) -> str: - hashes = match.group(1) - content = match.group(2) - new_level = max(2, len(hashes) + level_adjustment) - return "#" * new_level + " " + content - - # Apply heading adjustment - return re.sub(r"^(#+)\s(.+)$", adjust_heading_level, section_markdown, flags=re.MULTILINE) - - -def create_long_writer_agent( - model: Any | None = None, oauth_token: str | None = None -) -> LongWriterAgent: - """ - Factory function to create a long writer agent. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured LongWriterAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - try: - if model is None: - model = get_model(oauth_token=oauth_token) - - return LongWriterAgent(model=model) - - except Exception as e: - logger.error("Failed to create long writer agent", error=str(e)) - raise ConfigurationError(f"Failed to create long writer agent: {e}") from e diff --git a/src/agents/magentic_agents.py b/src/agents/magentic_agents.py deleted file mode 100644 index 926e84a9b3c4bf00cf36bfa99d643b637ece3d1f..0000000000000000000000000000000000000000 --- a/src/agents/magentic_agents.py +++ /dev/null @@ -1,177 +0,0 @@ -"""Magentic-compatible agents using ChatAgent pattern.""" - -from typing import Any - -from agent_framework import ChatAgent - -from src.agents.tools import ( - get_bibliography, - search_clinical_trials, - search_preprints, - search_pubmed, -) -from src.utils.llm_factory import get_chat_client_for_agent - - -def create_search_agent(chat_client: Any | None = None) -> ChatAgent: - """Create a search agent with internal LLM and search tools. - - Args: - chat_client: Optional custom chat client. If None, uses factory default - (HuggingFace preferred, OpenAI fallback). - - Returns: - ChatAgent configured for biomedical search - """ - client = chat_client or get_chat_client_for_agent() - - return ChatAgent( - name="SearchAgent", - description=( - "Searches biomedical databases (PubMed, ClinicalTrials.gov, Europe PMC) " - "for research evidence" - ), - instructions="""You are a biomedical search specialist. When asked to find evidence: - -1. Analyze the request to determine what to search for -2. Extract key search terms (drug names, disease names, mechanisms) -3. Use the appropriate search tools: - - search_pubmed for peer-reviewed papers - - search_clinical_trials for clinical studies - - search_preprints for cutting-edge findings -4. Summarize what you found and highlight key evidence - -Be thorough - search multiple databases when appropriate. -Focus on finding: mechanisms of action, clinical evidence, and specific drug candidates.""", - chat_client=client, - tools=[search_pubmed, search_clinical_trials, search_preprints], - temperature=0.3, # More deterministic for tool use - ) - - -def create_judge_agent(chat_client: Any | None = None) -> ChatAgent: - """Create a judge agent that evaluates evidence quality. - - Args: - chat_client: Optional custom chat client. If None, uses factory default - (HuggingFace preferred, OpenAI fallback). - - Returns: - ChatAgent configured for evidence assessment - """ - client = chat_client or get_chat_client_for_agent() - - return ChatAgent( - name="JudgeAgent", - description="Evaluates evidence quality and determines if sufficient for synthesis", - instructions="""You are an evidence quality assessor. When asked to evaluate: - -1. Review all evidence presented in the conversation -2. Score on two dimensions (0-10 each): - - Mechanism Score: How well is the biological mechanism explained? - - Clinical Score: How strong is the clinical/preclinical evidence? -3. Determine if evidence is SUFFICIENT for a final report: - - Sufficient: Clear mechanism + supporting clinical data - - Insufficient: Gaps in mechanism OR weak clinical evidence -4. If insufficient, suggest specific search queries to fill gaps - -Be rigorous but fair. Look for: -- Molecular targets and pathways -- Animal model studies -- Human clinical trials -- Safety data -- Drug-drug interactions""", - chat_client=client, - temperature=0.2, # Consistent judgments - ) - - -def create_hypothesis_agent(chat_client: Any | None = None) -> ChatAgent: - """Create a hypothesis generation agent. - - Args: - chat_client: Optional custom chat client. If None, uses factory default - (HuggingFace preferred, OpenAI fallback). - - Returns: - ChatAgent configured for hypothesis generation - """ - client = chat_client or get_chat_client_for_agent() - - return ChatAgent( - name="HypothesisAgent", - description="Generates mechanistic hypotheses for research investigation", - instructions="""You are a biomedical hypothesis generator. Based on evidence: - -1. Identify the key molecular targets involved -2. Map the biological pathways affected -3. Generate testable hypotheses in this format: - - DRUG -> TARGET -> PATHWAY -> THERAPEUTIC EFFECT - - Example: - Metformin -> AMPK activation -> mTOR inhibition -> Reduced tau phosphorylation - -4. Explain the rationale for each hypothesis -5. Suggest what additional evidence would support or refute it - -Focus on mechanistic plausibility and existing evidence.""", - chat_client=client, - temperature=0.5, # Some creativity for hypothesis generation - ) - - -def create_report_agent(chat_client: Any | None = None) -> ChatAgent: - """Create a report synthesis agent. - - Args: - chat_client: Optional custom chat client. If None, uses factory default - (HuggingFace preferred, OpenAI fallback). - - Returns: - ChatAgent configured for report generation - """ - client = chat_client or get_chat_client_for_agent() - - return ChatAgent( - name="ReportAgent", - description="Synthesizes research findings into structured reports", - instructions="""You are a scientific report writer. When asked to synthesize: - -Generate a structured report with these sections: - -## Executive Summary -Brief overview of findings and recommendation - -## Methodology -Databases searched, queries used, evidence reviewed - -## Key Findings -### Mechanism of Action -- Molecular targets -- Biological pathways -- Proposed mechanism - -### Clinical Evidence -- Preclinical studies -- Clinical trials -- Safety profile - -## Drug Candidates -List specific drugs with repurposing potential - -## Limitations -Gaps in evidence, conflicting data, caveats - -## Conclusion -Final recommendation with confidence level - -## References -Use the 'get_bibliography' tool to fetch the complete list of citations. -Format them as a numbered list. - -Be comprehensive but concise. Cite evidence for all claims.""", - chat_client=client, - tools=[get_bibliography], - temperature=0.3, - ) diff --git a/src/agents/proofreader.py b/src/agents/proofreader.py deleted file mode 100644 index 7a209c365ea04313dceee645e23ff0dd8b846817..0000000000000000000000000000000000000000 --- a/src/agents/proofreader.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Proofreader agent for finalizing report drafts. - -Converts the folder/proofreader_agent.py implementation to use Pydantic AI. -""" - -from datetime import datetime -from typing import Any - -import structlog -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.utils.exceptions import ConfigurationError -from src.utils.models import ReportDraft - -logger = structlog.get_logger() - - -# System prompt for the proofreader agent -SYSTEM_PROMPT = f""" -You are a research expert who proofreads and edits research reports. -Today's date is {datetime.now().strftime("%Y-%m-%d")}. - -You are given: -1. The original query topic for the report -2. A first draft of the report in ReportDraft format containing each section in sequence - -Your task is to: -1. **Combine sections:** Concatenate the sections into a single string -2. **Add section titles:** Add the section titles to the beginning of each section in markdown format, as well as a main title for the report -3. **De-duplicate:** Remove duplicate content across sections to avoid repetition -4. **Remove irrelevant sections:** If any sections or sub-sections are completely irrelevant to the query, remove them -5. **Refine wording:** Edit the wording of the report to be polished, concise and punchy, but **without eliminating any detail** or large chunks of text -6. **Add a summary:** Add a short report summary / outline to the beginning of the report to provide an overview of the sections and what is discussed -7. **Preserve sources:** Preserve all sources / references - move the long list of references to the end of the report -8. **Update reference numbers:** Continue to include reference numbers in square brackets ([1], [2], [3], etc.) in the main body of the report, but update the numbering to match the new order of references at the end of the report -9. **Output final report:** Output the final report in markdown format (do not wrap it in a code block) - -Guidelines: -- Do not add any new facts or data to the report -- Do not remove any content from the report unless it is very clearly wrong, contradictory or irrelevant -- Remove or reformat any redundant or excessive headings, and ensure that the final nesting of heading levels is correct -- Ensure that the final report flows well and has a logical structure -- Include all sources and references that are present in the final report -""" - - -class ProofreaderAgent: - """ - Agent that proofreads and finalizes report drafts. - - Uses Pydantic AI to generate polished markdown reports from draft sections. - """ - - def __init__(self, model: Any | None = None) -> None: - """ - Initialize the proofreader agent. - - Args: - model: Optional Pydantic AI model. If None, uses config default. - """ - self.model = model or get_model() - self.logger = logger - - # Initialize Pydantic AI Agent (no structured output - returns markdown text) - self.agent = Agent( - model=self.model, - system_prompt=SYSTEM_PROMPT, - retries=3, - ) - - async def proofread( - self, - query: str, - report_draft: ReportDraft, - ) -> str: - """ - Proofread and finalize a report draft. - - Args: - query: The original research query - report_draft: ReportDraft with all sections - - Returns: - Final polished markdown report string - - Raises: - ConfigurationError: If proofreading fails - """ - # Input validation - if not query or not query.strip(): - self.logger.warning("Empty query provided, using default") - query = "Research query" - - if not report_draft or not report_draft.sections: - self.logger.warning("Empty report draft provided, returning minimal report") - return f"# Research Report\n\n## Query\n{query}\n\n*No sections available.*" - - # Validate section structure - valid_sections = [] - for section in report_draft.sections: - if section.section_title and section.section_title.strip(): - valid_sections.append(section) - else: - self.logger.warning("Skipping section with empty title") - - if not valid_sections: - self.logger.warning("No valid sections in draft, returning minimal report") - return f"# Research Report\n\n## Query\n{query}\n\n*No valid sections available.*" - - self.logger.info( - "Proofreading report", - query=query[:100], - sections_count=len(valid_sections), - ) - - # Create validated draft - validated_draft = ReportDraft(sections=valid_sections) - - user_message = f""" -QUERY: -{query} - -REPORT DRAFT: -{validated_draft.model_dump_json()} -""" - - # Retry logic for transient failures - max_retries = 3 - last_exception: Exception | None = None - - for attempt in range(max_retries): - try: - # Run the agent - result = await self.agent.run(user_message) - final_report = result.output - - # Validate output - if not final_report or not final_report.strip(): - self.logger.warning("Empty report generated, using fallback") - raise ValueError("Empty report generated") - - self.logger.info("Report proofread", length=len(final_report), attempt=attempt + 1) - - return final_report - - except (TimeoutError, ConnectionError) as e: - # Transient errors - retry - last_exception = e - if attempt < max_retries - 1: - self.logger.warning( - "Transient error, retrying", - error=str(e), - attempt=attempt + 1, - max_retries=max_retries, - ) - continue - else: - self.logger.error("Max retries exceeded for transient error", error=str(e)) - break - - except Exception as e: - # Non-transient errors - don't retry - last_exception = e - self.logger.error( - "Proofreading failed", - error=str(e), - error_type=type(e).__name__, - ) - break - - # Return fallback: combine sections manually - self.logger.error( - "Proofreading failed after all attempts", - error=str(last_exception) if last_exception else "Unknown error", - ) - sections = [ - f"## {section.section_title}\n\n{section.section_content or 'Content unavailable.'}" - for section in valid_sections - ] - return f"# Research Report\n\n## Query\n{query}\n\n" + "\n\n".join(sections) - - -def create_proofreader_agent( - model: Any | None = None, oauth_token: str | None = None -) -> ProofreaderAgent: - """ - Factory function to create a proofreader agent. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured ProofreaderAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - try: - if model is None: - model = get_model(oauth_token=oauth_token) - - return ProofreaderAgent(model=model) - - except Exception as e: - logger.error("Failed to create proofreader agent", error=str(e)) - raise ConfigurationError(f"Failed to create proofreader agent: {e}") from e diff --git a/src/agents/report_agent.py b/src/agents/report_agent.py deleted file mode 100644 index 5454fed54f67356dca0674985c6c90a7f34517c6..0000000000000000000000000000000000000000 --- a/src/agents/report_agent.py +++ /dev/null @@ -1,140 +0,0 @@ -"""Report agent for generating structured research reports.""" - -from collections.abc import AsyncIterable -from typing import TYPE_CHECKING, Any - -from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - AgentThread, - BaseAgent, - ChatMessage, - Role, -) -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.prompts.report import SYSTEM_PROMPT, format_report_prompt -from src.utils.citation_validator import validate_references -from src.utils.models import Evidence, ResearchReport - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - - -class ReportAgent(BaseAgent): # type: ignore[misc] - """Generates structured scientific reports from evidence and hypotheses.""" - - def __init__( - self, - evidence_store: dict[str, Any], - embedding_service: "EmbeddingService | None" = None, # For diverse selection - ) -> None: - super().__init__( - name="ReportAgent", - description="Generates structured scientific research reports with citations", - ) - self._evidence_store = evidence_store - self._embeddings = embedding_service - self._agent: Agent[None, ResearchReport] | None = None # Lazy init - - def _get_agent(self) -> Agent[None, ResearchReport]: - """Lazy initialization of LLM agent to avoid requiring API keys at import.""" - if self._agent is None: - self._agent = Agent( - model=get_model(), - output_type=ResearchReport, - system_prompt=SYSTEM_PROMPT, - ) - return self._agent - - async def run( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AgentRunResponse: - """Generate research report.""" - query = self._extract_query(messages) - - # Gather all context - evidence: list[Evidence] = self._evidence_store.get("current", []) - hypotheses = self._evidence_store.get("hypotheses", []) - assessment = self._evidence_store.get("last_assessment", {}) - - if not evidence: - return AgentRunResponse( - messages=[ - ChatMessage( - role=Role.ASSISTANT, - text="Cannot generate report: No evidence collected.", - ) - ], - response_id="report-no-evidence", - ) - - # Build metadata - metadata = { - "sources": list(set(e.citation.source for e in evidence)), - "iterations": self._evidence_store.get("iteration_count", 0), - } - - # Generate report (format_report_prompt is now async) - prompt = await format_report_prompt( - query=query, - evidence=evidence, - hypotheses=hypotheses, - assessment=assessment, - metadata=metadata, - embeddings=self._embeddings, - ) - - result = await self._get_agent().run(prompt) - report = result.output - - # ═══════════════════════════════════════════════════════════════════ - # 🚨 CRITICAL: Validate citations to prevent hallucination - # ═══════════════════════════════════════════════════════════════════ - report = validate_references(report, evidence) - - # Store validated report - self._evidence_store["final_report"] = report - - # Return markdown version - return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text=report.to_markdown())], - response_id="report-complete", - additional_properties={"report": report.model_dump()}, - ) - - def _extract_query( - self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None - ) -> str: - """Extract query from messages.""" - if isinstance(messages, str): - return messages - elif isinstance(messages, ChatMessage): - return messages.text or "" - elif isinstance(messages, list): - for msg in reversed(messages): - if isinstance(msg, ChatMessage) and msg.role == Role.USER: - return msg.text or "" - elif isinstance(msg, str): - return msg - return "" - - async def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - """Streaming wrapper.""" - result = await self.run(messages, thread=thread, **kwargs) - yield AgentRunResponseUpdate( - messages=result.messages, - response_id=result.response_id, - additional_properties=result.additional_properties, - ) diff --git a/src/agents/retrieval_agent.py b/src/agents/retrieval_agent.py deleted file mode 100644 index 63a3a6bff9c86342e8f932a97cad51e8b9c5f58b..0000000000000000000000000000000000000000 --- a/src/agents/retrieval_agent.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Retrieval agent for web search and context management.""" - -from typing import Any - -import structlog -from agent_framework import ChatAgent, ai_function - -from src.agents.state import get_magentic_state -from src.tools.web_search import WebSearchTool -from src.utils.llm_factory import get_chat_client_for_agent - -logger = structlog.get_logger() - -_web_search = WebSearchTool() - - -@ai_function # type: ignore[arg-type, misc] -async def search_web(query: str, max_results: int = 10) -> str: - """Search the web using DuckDuckGo. - - Args: - query: Search keywords. - max_results: Maximum results to return (default 10). - - Returns: - Formatted search results. - """ - logger.info("Web search starting", query=query, max_results=max_results) - state = get_magentic_state() - - evidence = await _web_search.search(query, max_results) - if not evidence: - logger.info("Web search returned no results", query=query) - return f"No web results found for: {query}" - - # Update state - # We add *all* found results to state - new_count = state.add_evidence(evidence) - logger.info( - "Web search complete", - query=query, - results_found=len(evidence), - new_evidence=new_count, - ) - - # Use embedding service for deduplication/indexing if available - if state.embedding_service: - # This method also adds to vector DB as a side effect for unique items - await state.embedding_service.deduplicate(evidence) - - output = [f"Found {len(evidence)} web results ({new_count} new stored):\n"] - for i, r in enumerate(evidence[:max_results], 1): - output.append(f"{i}. **{r.citation.title}**") - output.append(f" Source: {r.citation.url}") - output.append(f" {r.content[:300]}...\n") - - return "\n".join(output) - - -def create_retrieval_agent(chat_client: Any | None = None) -> ChatAgent: - """Create a retrieval agent. - - Args: - chat_client: Optional custom chat client. If None, uses factory default - (HuggingFace preferred, OpenAI fallback). - - Returns: - ChatAgent configured for retrieval. - """ - client = chat_client or get_chat_client_for_agent() - - return ChatAgent( - name="RetrievalAgent", - description="Searches the web and manages context/evidence.", - instructions="""You are a retrieval specialist. -Use `search_web` to find information on the internet. -Your goal is to gather relevant evidence for the research task. -Always summarize what you found.""", - chat_client=client, - tools=[search_web], - ) diff --git a/src/agents/search_agent.py b/src/agents/search_agent.py deleted file mode 100644 index 9c28242458109bebc7824d122b54f9b4b1d1ea0a..0000000000000000000000000000000000000000 --- a/src/agents/search_agent.py +++ /dev/null @@ -1,154 +0,0 @@ -from collections.abc import AsyncIterable -from typing import TYPE_CHECKING, Any - -from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - AgentThread, - BaseAgent, - ChatMessage, - Role, -) - -from src.legacy_orchestrator import SearchHandlerProtocol -from src.utils.models import Citation, Evidence, SearchResult - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - - -class SearchAgent(BaseAgent): # type: ignore[misc] - """Wraps SearchHandler as an AgentProtocol for Magentic.""" - - def __init__( - self, - search_handler: SearchHandlerProtocol, - evidence_store: dict[str, list[Evidence]], - embedding_service: "EmbeddingService | None" = None, - ) -> None: - super().__init__( - name="SearchAgent", - description="Searches PubMed for biomedical research evidence", - ) - self._handler = search_handler - self._evidence_store = evidence_store - self._embeddings = embedding_service - - async def run( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AgentRunResponse: - """Execute search based on the last user message.""" - # Extract query from messages - query = "" - if isinstance(messages, list): - for msg in reversed(messages): - if isinstance(msg, ChatMessage) and msg.role == Role.USER and msg.text: - query = msg.text - break - elif isinstance(msg, str): - query = msg - break - elif isinstance(messages, str): - query = messages - elif isinstance(messages, ChatMessage) and messages.text: - query = messages.text - - if not query: - return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text="No query provided")], - response_id="search-no-query", - ) - - # Execute search - result: SearchResult = await self._handler.execute(query, max_results_per_tool=10) - - # Track what to show in response (initialized to search results as default) - evidence_to_show: list[Evidence] = result.evidence - total_new = 0 - - # Update shared evidence store - if self._embeddings: - # Deduplicate by semantic similarity (async-safe) - unique_evidence = await self._embeddings.deduplicate(result.evidence) - - # Also search for semantically related evidence (async-safe) - related = await self._embeddings.search_similar(query, n_results=5) - - # Merge related evidence not already in results - existing_urls = {e.citation.url for e in unique_evidence} - - # Reconstruct Evidence objects from stored vector DB data - related_evidence: list[Evidence] = [] - for item in related: - if item["id"] not in existing_urls: - meta = item.get("metadata", {}) - # Parse authors (stored as comma-separated string) - authors_str = meta.get("authors", "") - authors = [a.strip() for a in authors_str.split(",") if a.strip()] - - ev = Evidence( - content=item["content"], - citation=Citation( - title=meta.get("title", "Related Evidence"), - url=item["id"], - source="pubmed", - date=meta.get("date", "n.d."), - authors=authors, - ), - # Convert distance to relevance (lower distance = higher relevance) - relevance=max(0.0, 1.0 - item.get("distance", 0.5)), - ) - related_evidence.append(ev) - - # Combine unique from search + related from vector DB - final_new_evidence = unique_evidence + related_evidence - - # Add to global store (deduping against global store) - global_urls = {e.citation.url for e in self._evidence_store["current"]} - really_new = [e for e in final_new_evidence if e.citation.url not in global_urls] - self._evidence_store["current"].extend(really_new) - - total_new = len(really_new) - evidence_to_show = unique_evidence + related_evidence - - else: - # Fallback to URL-based deduplication (no embeddings) - existing_urls = {e.citation.url for e in self._evidence_store["current"]} - new_unique = [e for e in result.evidence if e.citation.url not in existing_urls] - self._evidence_store["current"].extend(new_unique) - total_new = len(new_unique) - evidence_to_show = result.evidence - - evidence_text = "\n".join( - [ - f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..." - for e in evidence_to_show[:5] - ] - ) - - response_text = ( - f"Found {result.total_found} sources ({total_new} new added to context):\n\n" - f"{evidence_text}" - ) - - return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)], - response_id=f"search-{result.total_found}", - additional_properties={"evidence": [e.model_dump() for e in evidence_to_show]}, - ) - - async def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - """Streaming wrapper for search (search itself isn't streaming).""" - result = await self.run(messages, thread=thread, **kwargs) - # Yield single update with full result - yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id) diff --git a/src/agents/state.py b/src/agents/state.py deleted file mode 100644 index 8bab5f7b9563b5c263c376780c2856c8a058c2f0..0000000000000000000000000000000000000000 --- a/src/agents/state.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Thread-safe state management for Magentic agents. - -DEPRECATED: This module is deprecated. Use src.middleware.state_machine instead. - -This file is kept for backward compatibility and will be removed in a future version. -""" - -import warnings -from contextvars import ContextVar -from typing import TYPE_CHECKING, Any - -from pydantic import BaseModel, Field - -from src.utils.models import Citation, Evidence - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - - -def _deprecation_warning() -> None: - """Emit deprecation warning for this module.""" - warnings.warn( - "src.agents.state is deprecated. Use src.middleware.state_machine instead.", - DeprecationWarning, - stacklevel=3, - ) - - -class MagenticState(BaseModel): - """Mutable state for a Magentic workflow session. - - DEPRECATED: Use WorkflowState from src.middleware.state_machine instead. - """ - - evidence: list[Evidence] = Field(default_factory=list) - # Type as Any to avoid circular imports/runtime resolution issues - # The actual object injected will be an EmbeddingService instance - embedding_service: Any = None - - model_config = {"arbitrary_types_allowed": True} - - def add_evidence(self, new_evidence: list[Evidence]) -> int: - """Add new evidence, deduplicating by URL. - - Returns: - Number of *new* items added. - """ - existing_urls = {e.citation.url for e in self.evidence} - count = 0 - for item in new_evidence: - if item.citation.url not in existing_urls: - self.evidence.append(item) - existing_urls.add(item.citation.url) - count += 1 - return count - - async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]: - """Search for semantically related evidence using the embedding service.""" - if not self.embedding_service: - return [] - - results = await self.embedding_service.search_similar(query, n_results=n_results) - - # Convert dict results back to Evidence objects - evidence_list = [] - for item in results: - meta = item.get("metadata", {}) - authors_str = meta.get("authors", "") - authors = [a.strip() for a in authors_str.split(",") if a.strip()] - - ev = Evidence( - content=item["content"], - citation=Citation( - title=meta.get("title", "Related Evidence"), - url=item["id"], - source="pubmed", # Defaulting to pubmed if unknown - date=meta.get("date", "n.d."), - authors=authors, - ), - relevance=max(0.0, 1.0 - item.get("distance", 0.5)), - ) - evidence_list.append(ev) - - return evidence_list - - -# The ContextVar holds the MagenticState for the current execution context -_magentic_state_var: ContextVar[MagenticState | None] = ContextVar("magentic_state", default=None) - - -def init_magentic_state(embedding_service: "EmbeddingService | None" = None) -> MagenticState: - """Initialize a new state for the current context. - - DEPRECATED: Use init_workflow_state from src.middleware.state_machine instead. - """ - _deprecation_warning() - state = MagenticState(embedding_service=embedding_service) - _magentic_state_var.set(state) - return state - - -def get_magentic_state() -> MagenticState: - """Get the current state. Raises RuntimeError if not initialized. - - DEPRECATED: Use get_workflow_state from src.middleware.state_machine instead. - """ - _deprecation_warning() - state = _magentic_state_var.get() - if state is None: - # Auto-initialize if missing (e.g. during tests or simple scripts) - return init_magentic_state() - return state diff --git a/src/agents/thinking.py b/src/agents/thinking.py deleted file mode 100644 index eff08e5111e7700ea2d4940d730ee1dfd3225324..0000000000000000000000000000000000000000 --- a/src/agents/thinking.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Thinking agent for generating observations and reflections. - -Converts the folder/thinking_agent.py implementation to use Pydantic AI. -""" - -from datetime import datetime -from typing import Any - -import structlog -from pydantic_ai import Agent - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] - -from src.agent_factory.judges import get_model -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger() - - -# System prompt for the thinking agent -SYSTEM_PROMPT = f""" -You are a research expert who is managing a research process in iterations. Today's date is {datetime.now().strftime("%Y-%m-%d")}. - -You are given: -1. The original research query along with some supporting background context -2. A history of the tasks, actions, findings and thoughts you've made up until this point in the research process (on iteration 1 you will be at the start of the research process, so this will be empty) - -Your objective is to reflect on the research process so far and share your latest thoughts. - -Specifically, your thoughts should include reflections on questions such as: -- What have you learned from the last iteration? -- What new areas would you like to explore next, or existing topics you'd like to go deeper into? -- Were you able to retrieve the information you were looking for in the last iteration? -- If not, should we change our approach or move to the next topic? -- Is there any info that is contradictory or conflicting? - -Guidelines: -- Share your stream of consciousness on the above questions as raw text -- Keep your response concise and informal -- Focus most of your thoughts on the most recent iteration and how that influences this next iteration -- Our aim is to do very deep and thorough research - bear this in mind when reflecting on the research process -- DO NOT produce a draft of the final report. This is not your job. -- If this is the first iteration (i.e. no data from prior iterations), provide thoughts on what info we need to gather in the first iteration to get started -""" - - -class ThinkingAgent: - """ - Agent that generates observations and reflections on the research process. - - Uses Pydantic AI to generate unstructured text observations about - the current state of research and next steps. - """ - - def __init__(self, model: Any | None = None) -> None: - """ - Initialize the thinking agent. - - Args: - model: Optional Pydantic AI model. If None, uses config default. - """ - self.model = model or get_model() - self.logger = logger - - # Initialize Pydantic AI Agent (no structured output - returns text) - self.agent = Agent( - model=self.model, - system_prompt=SYSTEM_PROMPT, - retries=3, - ) - - async def generate_observations( - self, - query: str, - background_context: str = "", - conversation_history: str = "", - message_history: list[ModelMessage] | None = None, - iteration: int = 1, - ) -> str: - """ - Generate observations about the research process. - - Args: - query: The original research query - background_context: Optional background context - conversation_history: History of actions, findings, and thoughts (backward compat) - message_history: Optional user conversation history (Pydantic AI format) - iteration: Current iteration number - - Returns: - String containing observations and reflections - - Raises: - ConfigurationError: If generation fails - """ - self.logger.info( - "Generating observations", - query=query[:100], - iteration=iteration, - ) - - background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else "" - - user_message = f""" -You are starting iteration {iteration} of your research process. - -ORIGINAL QUERY: -{query} - -{background} - -HISTORY OF ACTIONS, FINDINGS AND THOUGHTS: -{conversation_history or "No previous actions, findings or thoughts available."} -""" - - try: - # Run the agent with message_history if provided - if message_history: - result = await self.agent.run(user_message, message_history=message_history) - else: - result = await self.agent.run(user_message) - observations = result.output - - self.logger.info("Observations generated", length=len(observations)) - - return observations - - except Exception as e: - self.logger.error("Observation generation failed", error=str(e)) - # Return fallback observations - return f"Starting iteration {iteration}. Need to gather information about: {query}" - - -def create_thinking_agent( - model: Any | None = None, oauth_token: str | None = None -) -> ThinkingAgent: - """ - Factory function to create a thinking agent. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured ThinkingAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - try: - if model is None: - model = get_model(oauth_token=oauth_token) - - return ThinkingAgent(model=model) - - except Exception as e: - logger.error("Failed to create thinking agent", error=str(e)) - raise ConfigurationError(f"Failed to create thinking agent: {e}") from e diff --git a/src/agents/tool_selector.py b/src/agents/tool_selector.py deleted file mode 100644 index 0da35f43696306bac1093e1722ccc2fc981e598b..0000000000000000000000000000000000000000 --- a/src/agents/tool_selector.py +++ /dev/null @@ -1,181 +0,0 @@ -"""Tool selector agent for choosing which tools to use for knowledge gaps. - -Converts the folder/tool_selector_agent.py implementation to use Pydantic AI. -""" - -from datetime import datetime -from typing import Any - -import structlog -from pydantic_ai import Agent - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] - -from src.agent_factory.judges import get_model -from src.utils.exceptions import ConfigurationError -from src.utils.models import AgentSelectionPlan - -logger = structlog.get_logger() - - -# System prompt for the tool selector agent -SYSTEM_PROMPT = f""" -You are a Tool Selector responsible for determining which specialized agents should address a knowledge gap in a research project. -Today's date is {datetime.now().strftime("%Y-%m-%d")}. - -You will be given: -1. The original user query -2. A knowledge gap identified in the research -3. A full history of the tasks, actions, findings and thoughts you've made up until this point in the research process - -Your task is to decide: -1. Which specialized agents are best suited to address the gap -2. What specific queries should be given to the agents (keep this short - 3-6 words) - -Available specialized agents: -- WebSearchAgent: General web search for broad topics (can be called multiple times with different queries) -- SiteCrawlerAgent: Crawl the pages of a specific website to retrieve information about it - use this if you want to find out something about a particular company, entity or product -- RAGAgent: Semantic search within previously collected evidence - use when you need to find information from evidence already gathered in this research session. Best for finding connections, summarizing collected evidence, or retrieving specific details from earlier findings. - -Guidelines: -- Aim to call at most 3 agents at a time in your final output -- You can list the WebSearchAgent multiple times with different queries if needed to cover the full scope of the knowledge gap -- Be specific and concise (3-6 words) with the agent queries - they should target exactly what information is needed -- If you know the website or domain name of an entity being researched, always include it in the query -- Use RAGAgent when: (1) You need to search within evidence already collected, (2) You want to find connections between different findings, (3) You need to retrieve specific details from earlier research iterations -- Use WebSearchAgent or SiteCrawlerAgent when: (1) You need fresh information from the web, (2) You're starting a new research direction, (3) You need information not yet in the collected evidence -- If a gap doesn't clearly match any agent's capability, default to the WebSearchAgent -- Use the history of actions / tool calls as a guide - try not to repeat yourself if an approach didn't work previously - -Only output JSON. Follow the JSON schema for AgentSelectionPlan. Do not output anything else. -""" - - -class ToolSelectorAgent: - """ - Agent that selects appropriate tools to address knowledge gaps. - - Uses Pydantic AI to generate structured AgentSelectionPlan with - specific tasks for web search and crawl agents. - """ - - def __init__(self, model: Any | None = None) -> None: - """ - Initialize the tool selector agent. - - Args: - model: Optional Pydantic AI model. If None, uses config default. - """ - self.model = model or get_model() - self.logger = logger - - # Initialize Pydantic AI Agent - self.agent = Agent( - model=self.model, - output_type=AgentSelectionPlan, - system_prompt=SYSTEM_PROMPT, - retries=3, - ) - - async def select_tools( - self, - gap: str, - query: str, - background_context: str = "", - conversation_history: str = "", - message_history: list[ModelMessage] | None = None, - ) -> AgentSelectionPlan: - """ - Select tools to address a knowledge gap. - - Args: - gap: The knowledge gap to address - query: The original research query - background_context: Optional background context - conversation_history: History of actions, findings, and thoughts (backward compat) - message_history: Optional user conversation history (Pydantic AI format) - - Returns: - AgentSelectionPlan with tasks for selected agents - - Raises: - ConfigurationError: If selection fails - """ - self.logger.info("Selecting tools for gap", gap=gap[:100], query=query[:100]) - - background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else "" - - user_message = f""" -ORIGINAL QUERY: -{query} - -KNOWLEDGE GAP TO ADDRESS: -{gap} - -{background} - -HISTORY OF ACTIONS, FINDINGS AND THOUGHTS: -{conversation_history or "No previous actions, findings or thoughts available."} -""" - - try: - # Run the agent with message_history if provided - if message_history: - result = await self.agent.run(user_message, message_history=message_history) - else: - result = await self.agent.run(user_message) - selection_plan = result.output - - self.logger.info( - "Tool selection complete", - tasks_count=len(selection_plan.tasks), - agents=[task.agent for task in selection_plan.tasks], - ) - - return selection_plan - - except Exception as e: - self.logger.error("Tool selection failed", error=str(e)) - # Return fallback: use web search - from src.utils.models import AgentTask - - return AgentSelectionPlan( - tasks=[ - AgentTask( - gap=gap, - agent="WebSearchAgent", - query=gap[:50], # Use gap as query - entity_website=None, - ) - ] - ) - - -def create_tool_selector_agent( - model: Any | None = None, oauth_token: str | None = None -) -> ToolSelectorAgent: - """ - Factory function to create a tool selector agent. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured ToolSelectorAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - try: - if model is None: - model = get_model(oauth_token=oauth_token) - - return ToolSelectorAgent(model=model) - - except Exception as e: - logger.error("Failed to create tool selector agent", error=str(e)) - raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e diff --git a/src/agents/tools.py b/src/agents/tools.py deleted file mode 100644 index 9ce9f908eba4d07e539b971f00dab8a420697956..0000000000000000000000000000000000000000 --- a/src/agents/tools.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Tool functions for Magentic agents. - -These functions are decorated with @ai_function to be callable by the ChatAgent's internal LLM. -They also interact with the thread-safe MagenticState to persist evidence. -""" - -from agent_framework import ai_function - -from src.agents.state import get_magentic_state -from src.tools.clinicaltrials import ClinicalTrialsTool -from src.tools.europepmc import EuropePMCTool -from src.tools.pubmed import PubMedTool - -# Singleton tool instances (stateless wrappers) -_pubmed = PubMedTool() -_clinicaltrials = ClinicalTrialsTool() -_europepmc = EuropePMCTool() - - -@ai_function # type: ignore[arg-type, misc] -async def search_pubmed(query: str, max_results: int = 10) -> str: - """Search PubMed for biomedical research papers. - - Use this tool to find peer-reviewed scientific literature about - drugs, diseases, mechanisms of action, and clinical studies. - - Args: - query: Search keywords (e.g., "metformin alzheimer mechanism") - max_results: Maximum results to return (default 10) - - Returns: - Formatted list of papers with titles, abstracts, and citations - """ - state = get_magentic_state() - - # 1. Execute raw search - results = await _pubmed.search(query, max_results) - if not results: - return f"No PubMed results found for: {query}" - - # 2. Semantic Deduplication & Expansion (The "Digital Twin" Brain) - display_results = results - if state.embedding_service: - # Deduplicate against what we just found vs what's in the DB - unique_results = await state.embedding_service.deduplicate(results) - - # Search for related context in the vector DB (previous searches) - related = await state.search_related(query, n_results=3) - - # Combine unique new results + relevant historical results - display_results = unique_results + related - - # 3. Update State (Persist for ReportAgent) - # We add *all* found results to state, not just the displayed ones - new_count = state.add_evidence(results) - - # 4. Format Output for LLM - output = [f"Found {len(results)} results ({new_count} new stored):\n"] - - # Limit display to avoid context window overflow, but state has everything - limit = min(len(display_results), max_results) - - for i, r in enumerate(display_results[:limit], 1): - title = r.citation.title - date = r.citation.date - source = r.citation.source - content_clean = r.content[:300].replace("\n", " ") - url = r.citation.url - - output.append(f"{i}. **{title}** ({date})") - output.append(f" Source: {source} | {url}") - output.append(f" {content_clean}...") - output.append("") - - return "\n".join(output) - - -@ai_function # type: ignore[arg-type, misc] -async def search_clinical_trials(query: str, max_results: int = 10) -> str: - """Search ClinicalTrials.gov for clinical studies. - - Use this tool to find ongoing and completed clinical trials - for research investigation. - - Args: - query: Search terms (e.g., "metformin cancer phase 3") - max_results: Maximum results to return (default 10) - - Returns: - Formatted list of clinical trials with status and details - """ - state = get_magentic_state() - - results = await _clinicaltrials.search(query, max_results) - if not results: - return f"No clinical trials found for: {query}" - - # Update state - new_count = state.add_evidence(results) - - output = [f"Found {len(results)} clinical trials ({new_count} new stored):\n"] - for i, r in enumerate(results[:max_results], 1): - title = r.citation.title - date = r.citation.date - source = r.citation.source - content_clean = r.content[:300].replace("\n", " ") - url = r.citation.url - - output.append(f"{i}. **{title}**") - output.append(f" Status: {source} | Date: {date}") - output.append(f" {content_clean}...") - output.append(f" URL: {url}\n") - - return "\n".join(output) - - -@ai_function # type: ignore[arg-type, misc] -async def search_preprints(query: str, max_results: int = 10) -> str: - """Search Europe PMC for preprints and papers. - - Use this tool to find the latest research including preprints - from bioRxiv, medRxiv, and peer-reviewed papers. - - Args: - query: Search terms (e.g., "long covid treatment") - max_results: Maximum results to return (default 10) - - Returns: - Formatted list of papers with abstracts and links - """ - state = get_magentic_state() - - results = await _europepmc.search(query, max_results) - if not results: - return f"No papers found for: {query}" - - # Update state - new_count = state.add_evidence(results) - - output = [f"Found {len(results)} papers ({new_count} new stored):\n"] - for i, r in enumerate(results[:max_results], 1): - title = r.citation.title - date = r.citation.date - source = r.citation.source - content_clean = r.content[:300].replace("\n", " ") - url = r.citation.url - - output.append(f"{i}. **{title}**") - output.append(f" Source: {source} | Date: {date}") - output.append(f" {content_clean}...") - output.append(f" URL: {url}\n") - - return "\n".join(output) - - -@ai_function # type: ignore[arg-type, misc] -async def get_bibliography() -> str: - """Get the full list of collected evidence for the bibliography. - - Use this tool when generating the final report to get the complete - list of references. - - Returns: - Formatted bibliography string. - """ - state = get_magentic_state() - if not state.evidence: - return "No evidence collected." - - output = ["## References"] - for i, ev in enumerate(state.evidence, 1): - output.append(f"{i}. {ev.citation.formatted}") - output.append(f" URL: {ev.citation.url}") - - return "\n".join(output) diff --git a/src/agents/writer.py b/src/agents/writer.py deleted file mode 100644 index 9fc2edffc55ebf5420c920827745ba878ffe7842..0000000000000000000000000000000000000000 --- a/src/agents/writer.py +++ /dev/null @@ -1,234 +0,0 @@ -"""Writer agent for generating final reports from findings. - -Converts the folder/writer_agent.py implementation to use Pydantic AI. -""" - -from datetime import datetime -from typing import Any - -import structlog -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger() - - -# System prompt for the writer agent -SYSTEM_PROMPT = f""" -You are a senior researcher tasked with comprehensively answering a research query. -Today's date is {datetime.now().strftime("%Y-%m-%d")}. -You will be provided with the original query along with research findings put together by a research assistant. -Your objective is to generate the final response in markdown format. -The response should be as lengthy and detailed as possible with the information provided, focusing on answering the original query. -In your final output, include references to the source URLs for all information and data gathered. -This should be formatted in the form of a numbered square bracket next to the relevant information, -followed by a list of URLs at the end of the response, per the example below. - -EXAMPLE REFERENCE FORMAT: -The company has XYZ products [1]. It operates in the software services market which is expected to grow at 10% per year [2]. - -References: -[1] https://example.com/first-source-url -[2] https://example.com/second-source-url - -GUIDELINES: -* Answer the query directly, do not include unrelated or tangential information. -* Adhere to any instructions on the length of your final response if provided in the user prompt. -* If any additional guidelines are provided in the user prompt, follow them exactly and give them precedence over these system instructions. -""" - - -class WriterAgent: - """ - Agent that generates final reports from research findings. - - Uses Pydantic AI to generate markdown reports with citations. - """ - - def __init__(self, model: Any | None = None) -> None: - """ - Initialize the writer agent. - - Args: - model: Optional Pydantic AI model. If None, uses config default. - """ - self.model = model or get_model() - self.logger = logger - - # Initialize Pydantic AI Agent (no structured output - returns markdown text) - self.agent = Agent( - model=self.model, - system_prompt=SYSTEM_PROMPT, - retries=3, - ) - - async def write_report( - self, - query: str, - findings: str, - output_length: str = "", - output_instructions: str = "", - ) -> str: - """ - Write a final report from findings. - - Args: - query: The original research query - findings: All findings collected during research - output_length: Optional description of desired output length - output_instructions: Optional additional instructions - - Returns: - Markdown formatted report string - - Raises: - ConfigurationError: If writing fails - """ - # Input validation - if not query or not query.strip(): - self.logger.warning("Empty query provided, using default") - query = "Research query" - - if findings is None: - self.logger.warning("None findings provided, using empty string") - findings = "No findings available." - - # Truncate very long inputs to prevent context overflow - max_findings_length = 50000 # ~12k tokens - if len(findings) > max_findings_length: - self.logger.warning( - "Findings too long, truncating", - original_length=len(findings), - truncated_length=max_findings_length, - ) - findings = findings[:max_findings_length] + "\n\n[Content truncated due to length]" - - self.logger.info("Writing final report", query=query[:100], findings_length=len(findings)) - - length_str = ( - f"* The full response should be approximately {output_length}.\n" - if output_length - else "" - ) - instructions_str = f"* {output_instructions}" if output_instructions else "" - guidelines_str = ( - ("\n\nGUIDELINES:\n" + length_str + instructions_str).strip("\n") - if length_str or instructions_str - else "" - ) - - user_message = f""" -Provide a response based on the query and findings below with as much detail as possible. {guidelines_str} - -QUERY: {query} - -FINDINGS: -{findings} -""" - - # Retry logic for transient failures - max_retries = 3 - last_exception: Exception | None = None - - for attempt in range(max_retries): - try: - # Run the agent - result = await self.agent.run(user_message) - report = result.output - - # Validate output - if not report or not report.strip(): - self.logger.warning("Empty report generated, using fallback") - raise ValueError("Empty report generated") - - self.logger.info("Report written", length=len(report), attempt=attempt + 1) - - return report - - except (TimeoutError, ConnectionError) as e: - # Transient errors - retry - last_exception = e - if attempt < max_retries - 1: - self.logger.warning( - "Transient error, retrying", - error=str(e), - attempt=attempt + 1, - max_retries=max_retries, - ) - continue - else: - self.logger.error("Max retries exceeded for transient error", error=str(e)) - break - - except Exception as e: - # Non-transient errors - don't retry - last_exception = e - self.logger.error( - "Report writing failed", error=str(e), error_type=type(e).__name__ - ) - break - - # Return fallback report if all attempts failed - self.logger.error( - "Report writing failed after all attempts", - error=str(last_exception) if last_exception else "Unknown error", - ) - - # Try to use evidence-based report generator for better fallback - try: - from src.middleware.state_machine import get_workflow_state - from src.utils.report_generator import generate_report_from_evidence - - state = get_workflow_state() - if state and state.evidence: - self.logger.info( - "Using evidence-based report generator for fallback", - evidence_count=len(state.evidence), - ) - return generate_report_from_evidence( - query=query, - evidence=state.evidence, - findings=findings, - ) - except Exception as e: - self.logger.warning( - "Failed to use evidence-based report generator", - error=str(e), - ) - - # Fallback to simple report if evidence generator fails - # Truncate findings in fallback if too long - fallback_findings = findings[:500] + "..." if len(findings) > 500 else findings - return ( - f"# Research Report\n\n" - f"## Query\n{query}\n\n" - f"## Findings\n{fallback_findings}\n\n" - f"*Note: Report generation encountered an error. This is a fallback report.*" - ) - - -def create_writer_agent(model: Any | None = None, oauth_token: str | None = None) -> WriterAgent: - """ - Factory function to create a writer agent. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured WriterAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - try: - if model is None: - model = get_model(oauth_token=oauth_token) - - return WriterAgent(model=model) - - except Exception as e: - logger.error("Failed to create writer agent", error=str(e)) - raise ConfigurationError(f"Failed to create writer agent: {e}") from e diff --git a/src/app.py b/src/app.py index ba6795ef78ae066b31a6a96a44c6895802f93544..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/app.py +++ b/src/app.py @@ -1,1323 +0,0 @@ -"""Main Gradio application for DeepCritical research agent. - -This module provides the Gradio interface with: -- OAuth authentication via HuggingFace -- Multimodal input support (text, images, audio) -- Research agent orchestration -- Real-time event streaming -- MCP server integration -""" - -import os -from collections.abc import AsyncGenerator -from typing import Any - -import gradio as gr -import numpy as np -import structlog - -from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler -from src.orchestrator_factory import create_orchestrator -from src.services.multimodal_processing import get_multimodal_service -from src.utils.config import settings -from src.utils.models import AgentEvent, OrchestratorConfig - -# Import ModelMessage from pydantic_ai with fallback -try: - from pydantic_ai import ModelMessage -except ImportError: - from typing import Any - - ModelMessage = Any # type: ignore[assignment, misc] - -# Type alias for Gradio multimodal input -MultimodalPostprocess = dict[str, Any] | str - -# Import HuggingFace components with graceful fallback -try: - from pydantic_ai.models.huggingface import HuggingFaceModel - from pydantic_ai.providers.huggingface import HuggingFaceProvider - - _HUGGINGFACE_AVAILABLE = True -except ImportError: - _HUGGINGFACE_AVAILABLE = False - HuggingFaceModel = None # type: ignore[assignment, misc] - HuggingFaceProvider = None # type: ignore[assignment, misc] - -try: - from huggingface_hub import AsyncInferenceClient - - _ASYNC_INFERENCE_AVAILABLE = True -except ImportError: - _ASYNC_INFERENCE_AVAILABLE = False - AsyncInferenceClient = None # type: ignore[assignment, misc] - -logger = structlog.get_logger() - - -def configure_orchestrator( - use_mock: bool = False, - mode: str = "simple", - oauth_token: str | None = None, - hf_model: str | None = None, - hf_provider: str | None = None, - graph_mode: str | None = None, - use_graph: bool = True, - web_search_provider: str | None = None, -) -> tuple[Any, str]: - """ - Configure and create the research orchestrator. - - Args: - use_mock: Force mock judge handler (for testing) - mode: Orchestrator mode ("simple", "iterative", "deep", "auto", "advanced") - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - hf_model: Optional HuggingFace model ID (overrides settings) - hf_provider: Optional inference provider (currently not used by HuggingFaceProvider) - graph_mode: Optional graph execution mode - use_graph: Whether to use graph execution - web_search_provider: Optional web search provider ("auto", "serper", "duckduckgo") - - Returns: - Tuple of (orchestrator, backend_info_string) - """ - from src.tools.clinicaltrials import ClinicalTrialsTool - from src.tools.europepmc import EuropePMCTool - from src.tools.neo4j_search import Neo4jSearchTool - from src.tools.pubmed import PubMedTool - from src.tools.search_handler import SearchHandler - from src.tools.web_search_factory import create_web_search_tool - - # Create search handler with tools - tools = [] - - # Add biomedical search tools (always available, no API keys required) - tools.append(PubMedTool()) - logger.info("PubMed tool added to search handler") - - tools.append(ClinicalTrialsTool()) - logger.info("ClinicalTrials tool added to search handler") - - tools.append(EuropePMCTool()) - logger.info("EuropePMC tool added to search handler") - - # Add Neo4j knowledge graph search tool (if Neo4j is configured) - neo4j_tool = Neo4jSearchTool() - tools.append(neo4j_tool) - logger.info("Neo4j search tool added to search handler") - - # Add web search tool - web_search_tool = create_web_search_tool(provider=web_search_provider or "auto") - if web_search_tool: - tools.append(web_search_tool) - logger.info("Web search tool added to search handler", provider=web_search_tool.name) - - # Create config if not provided - config = OrchestratorConfig() - - search_handler = SearchHandler( - tools=tools, - timeout=config.search_timeout, - include_rag=True, - auto_ingest_to_rag=True, - oauth_token=oauth_token, - ) - - # Create judge (mock, real, or free tier) - judge_handler: JudgeHandler | MockJudgeHandler | HFInferenceJudgeHandler - backend_info = "Unknown" - - # 1. Forced Mock (Unit Testing) - if use_mock: - judge_handler = MockJudgeHandler() - backend_info = "Mock (Testing)" - - # 2. API Key (OAuth or Env) - HuggingFace only (OAuth provides HF token) - # Priority: oauth_token > env vars - # On HuggingFace Spaces, OAuth token is available via request.oauth_token - # - # OAuth Scope Requirements: - # - 'inference-api': Required for HuggingFace Inference API access - # This scope grants access to: - # * HuggingFace's own Inference API - # * All third-party inference providers (nebius, together, scaleway, hyperbolic, novita, nscale, sambanova, ovh, fireworks, etc.) - # * All models available through the Inference Providers API - # See: https://huggingface.co/docs/hub/oauth#currently-supported-scopes - # - # Note: The hf_provider parameter is accepted but not used here because HuggingFaceProvider - # from pydantic-ai doesn't support provider selection. Provider selection happens at the - # InferenceClient level (used in HuggingFaceChatClient for advanced mode). - effective_api_key = oauth_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") - - # Log which authentication source is being used - if effective_api_key: - auth_source = ( - "OAuth token" - if oauth_token - else ("HF_TOKEN env var" if os.getenv("HF_TOKEN") else "HUGGINGFACE_API_KEY env var") - ) - logger.info( - "Using HuggingFace authentication", - source=auth_source, - has_token=bool(effective_api_key), - ) - - if effective_api_key: - # We have an API key (OAuth or env) - use pydantic-ai with JudgeHandler - # This uses HuggingFace Inference API, which includes access to all third-party providers - # via the Inference Providers API (router.huggingface.co) - model: Any | None = None - # Use selected model or fall back to env var/settings - model_name = ( - hf_model - or os.getenv("HF_MODEL") - or settings.huggingface_model - or "Qwen/Qwen3-Next-80B-A3B-Thinking" - ) - if not _HUGGINGFACE_AVAILABLE: - raise ImportError( - "HuggingFace models are not available in this version of pydantic-ai. " - "Please install with: uv add 'pydantic-ai[huggingface]' to use HuggingFace inference providers." - ) - # Inference API - uses HuggingFace Inference API - # Per https://ai.pydantic.dev/models/huggingface/#configure-the-provider - # HuggingFaceProvider accepts api_key parameter directly - # This is consistent with usage in src/utils/llm_factory.py and src/agent_factory/judges.py - # The OAuth token with 'inference-api' scope provides access to all inference providers - provider = HuggingFaceProvider(api_key=effective_api_key) # type: ignore[misc] - model = HuggingFaceModel(model_name, provider=provider) # type: ignore[misc] - backend_info = "API (HuggingFace OAuth)" if oauth_token else "API (Env Config)" - - judge_handler = JudgeHandler(model=model) - - # 3. Free Tier (HuggingFace Inference) - NO API KEY AVAILABLE - else: - # No API key available - use HFInferenceJudgeHandler with public models - # HFInferenceJudgeHandler will use HF_TOKEN from env if available, otherwise public models - # Note: OAuth token should have been caught in effective_api_key check above - # If we reach here, we truly have no API key, so use public models - judge_handler = HFInferenceJudgeHandler( - model_id=hf_model if hf_model else None, - api_key=None, # Will use HF_TOKEN from env if available, otherwise public models - ) - model_display = hf_model.split("/")[-1] if hf_model else "Default (Public Models)" - backend_info = f"Free Tier ({model_display} - Public Models Only)" - - # Determine effective mode - # If mode is already iterative/deep/auto, use it directly - # If mode is "graph" or "simple", use graph_mode if provided - effective_mode = mode - if mode in ("graph", "simple") and graph_mode: - effective_mode = graph_mode - elif mode == "graph" and not graph_mode: - effective_mode = "auto" # Default to auto if graph mode but no graph_mode specified - - orchestrator = create_orchestrator( - search_handler=search_handler, - judge_handler=judge_handler, - config=config, - mode=effective_mode, # type: ignore - oauth_token=oauth_token, - ) - - return orchestrator, backend_info - - -def _is_file_path(text: str) -> bool: - """Check if text appears to be a file path. - - Args: - text: Text to check - - Returns: - True if text looks like a file path - """ - return ("/" in text or "\\" in text) and ( - "." in text.split("/")[-1] or "." in text.split("\\")[-1] - ) - - -def event_to_chat_message(event: AgentEvent) -> dict[str, Any]: - """Convert AgentEvent to Gradio chat message format. - - Args: - event: AgentEvent to convert - - Returns: - Dictionary with 'role' and 'content' keys for Gradio Chatbot - """ - result: dict[str, Any] = { - "role": "assistant", - "content": event.to_markdown(), - } - - # Add metadata if available - if event.data: - metadata: dict[str, Any] = {} - - # Extract file path if present - if isinstance(event.data, dict): - file_path = event.data.get("file_path") - if file_path: - metadata["file_path"] = file_path - - if metadata: - result["metadata"] = metadata - return result - - -def extract_oauth_info(request: gr.Request | None) -> tuple[str | None, str | None]: - """ - Extract OAuth token and username from Gradio request. - - Args: - request: Gradio request object containing OAuth information - - Returns: - Tuple of (oauth_token, oauth_username) - """ - oauth_token: str | None = None - oauth_username: str | None = None - - if request is None: - return oauth_token, oauth_username - - # Try multiple ways to access OAuth token (Gradio API may vary) - # Pattern 1: request.oauth_token.token - if hasattr(request, "oauth_token") and request.oauth_token is not None: - if hasattr(request.oauth_token, "token"): - oauth_token = request.oauth_token.token - elif isinstance(request.oauth_token, str): - oauth_token = request.oauth_token - # Pattern 2: request.headers (fallback) - elif hasattr(request, "headers"): - # OAuth token might be in headers - auth_header = request.headers.get("authorization") or request.headers.get("Authorization") - if auth_header and auth_header.startswith("Bearer "): - oauth_token = auth_header.replace("Bearer ", "") - - # Access username from request - if hasattr(request, "username") and request.username: - oauth_username = request.username - # Also try accessing via oauth_profile if available - elif hasattr(request, "oauth_profile") and request.oauth_profile is not None: - if hasattr(request.oauth_profile, "username") and request.oauth_profile.username: - oauth_username = request.oauth_profile.username - elif hasattr(request.oauth_profile, "name") and request.oauth_profile.name: - oauth_username = request.oauth_profile.name - - return oauth_token, oauth_username - - -async def yield_auth_messages( - oauth_username: str | None, - oauth_token: str | None, - has_huggingface: bool, - mode: str, -) -> AsyncGenerator[dict[str, Any], None]: - """ - Yield authentication status messages. - - Args: - oauth_username: OAuth username if available - oauth_token: OAuth token if available - has_huggingface: Whether HuggingFace authentication is available - mode: Research mode - - Yields: - Chat message dictionaries - """ - if oauth_username: - yield { - "role": "assistant", - "content": f"👋 **Welcome, {oauth_username}!**\n\nAuthenticated via HuggingFace OAuth.", - } - - if oauth_token: - yield { - "role": "assistant", - "content": ( - "🔐 **Authentication Status**: ✅ Authenticated\n\n" - "Your OAuth token has been validated. You can now use all AI models and research tools." - ), - } - elif has_huggingface: - yield { - "role": "assistant", - "content": ( - "🔐 **Authentication Status**: ✅ Using environment token\n\n" - "Using HF_TOKEN from environment variables." - ), - } - else: - yield { - "role": "assistant", - "content": ( - "⚠️ **Authentication Status**: ❌ No authentication\n\n" - "Please sign in with HuggingFace or set HF_TOKEN environment variable." - ), - } - - yield { - "role": "assistant", - "content": f"🚀 **Mode**: {mode.upper()}\n\nStarting research agent...", - } - - -def _extract_oauth_token(oauth_token: gr.OAuthToken | None) -> str | None: - """Extract token value from OAuth token object.""" - if oauth_token is None: - return None - - if hasattr(oauth_token, "token"): - token_value: str | None = getattr(oauth_token, "token", None) # type: ignore[assignment] - if token_value is None: - return None - logger.debug("OAuth token extracted from oauth_token.token attribute") - - # Validate token format - from src.utils.hf_error_handler import log_token_info, validate_hf_token - - log_token_info(token_value, context="research_agent") - is_valid, error_msg = validate_hf_token(token_value) - if not is_valid: - logger.warning( - "OAuth token validation failed", - error=error_msg, - oauth_token_type=type(oauth_token).__name__, - ) - return token_value - - if isinstance(oauth_token, str): - logger.debug("OAuth token extracted as string") - - # Validate token format - from src.utils.hf_error_handler import log_token_info, validate_hf_token - - log_token_info(oauth_token, context="research_agent") - return oauth_token - - logger.warning( - "OAuth token object present but token extraction failed", - oauth_token_type=type(oauth_token).__name__, - ) - return None - - -def _extract_username(oauth_profile: gr.OAuthProfile | None) -> str | None: - """Extract username from OAuth profile.""" - if oauth_profile is None: - return None - - username: str | None = None - if hasattr(oauth_profile, "username") and oauth_profile.username: - username = str(oauth_profile.username) - elif hasattr(oauth_profile, "name") and oauth_profile.name: - username = str(oauth_profile.name) - - if username: - logger.info("OAuth user authenticated", username=username) - return username - - -async def _process_multimodal_input( - message: str | MultimodalPostprocess, - enable_image_input: bool, - enable_audio_input: bool, - token_value: str | None, -) -> tuple[str, tuple[int, np.ndarray[Any, Any]] | None]: # type: ignore[type-arg] - """Process multimodal input and return processed text and audio data.""" - processed_text = "" - audio_input_data: tuple[int, np.ndarray[Any, Any]] | None = None # type: ignore[type-arg] - - if isinstance(message, dict): - processed_text = message.get("text", "") or "" - files = message.get("files", []) or [] - audio_input_data = message.get("audio") or None - - if (files and enable_image_input) or (audio_input_data is not None and enable_audio_input): - try: - multimodal_service = get_multimodal_service() - processed_text = await multimodal_service.process_multimodal_input( - processed_text, - files=files if enable_image_input else [], - audio_input=audio_input_data if enable_audio_input else None, - hf_token=token_value, - prepend_multimodal=True, - ) - except Exception as e: - logger.warning("multimodal_processing_failed", error=str(e)) - else: - processed_text = str(message) if message else "" - - return processed_text, audio_input_data - - -async def research_agent( - message: str | MultimodalPostprocess, - history: list[dict[str, Any]], - mode: str = "simple", - hf_model: str | None = None, - hf_provider: str | None = None, - graph_mode: str = "auto", - use_graph: bool = True, - enable_image_input: bool = True, - enable_audio_input: bool = True, - web_search_provider: str = "auto", - oauth_token: gr.OAuthToken | None = None, - oauth_profile: gr.OAuthProfile | None = None, -) -> AsyncGenerator[dict[str, Any], None]: - """ - Main research agent function that processes queries and streams results. - - Args: - message: User message (text, image, or audio) - history: Conversation history - mode: Orchestrator mode - hf_model: Optional HuggingFace model ID - hf_provider: Optional inference provider - graph_mode: Graph execution mode - use_graph: Whether to use graph execution - enable_image_input: Whether to process image inputs - enable_audio_input: Whether to process audio inputs - web_search_provider: Web search provider selection - oauth_token: Gradio OAuth token (None if user not logged in) - oauth_profile: Gradio OAuth profile (None if user not logged in) - - Yields: - Chat message dictionaries - """ - # Extract OAuth token and username - token_value = _extract_oauth_token(oauth_token) - username = _extract_username(oauth_profile) - - # Check if user is logged in (OAuth token or env var) - # Fallback to env vars for local development or Spaces with HF_TOKEN secret - has_authentication = bool( - token_value or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") - ) - - if not has_authentication: - yield { - "role": "assistant", - "content": ( - "🔐 **Authentication Required**\n\n" - "Please **sign in with HuggingFace** using the login button at the top of the page " - "before using this application.\n\n" - "The login button is required to access the AI models and research tools." - ), - } - return - - # Process multimodal input - processed_text, audio_input_data = await _process_multimodal_input( - message, enable_image_input, enable_audio_input, token_value - ) - - if not processed_text.strip(): - yield { - "role": "assistant", - "content": "Please enter a research question or provide an image/audio input.", - } - return - - # Check available keys (use token_value instead of oauth_token) - has_huggingface = bool(os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") or token_value) - - # Adjust mode if needed - effective_mode = mode - if mode == "advanced": - effective_mode = "simple" - - # Yield authentication and mode status messages - async for msg in yield_auth_messages(username, token_value, has_huggingface, mode): - yield msg - - # Run the agent and stream events - try: - # use_mock=False - let configure_orchestrator decide based on available keys - # It will use: OAuth token > Env vars > HF Inference (free tier) - # Convert empty strings from Textbox to None for defaults - model_id = hf_model if hf_model and hf_model.strip() else None - provider_name = hf_provider if hf_provider and hf_provider.strip() else None - - # Log authentication source for debugging - auth_source = ( - "OAuth" - if token_value - else ( - "Env (HF_TOKEN)" - if os.getenv("HF_TOKEN") - else ("Env (HUGGINGFACE_API_KEY)" if os.getenv("HUGGINGFACE_API_KEY") else "None") - ) - ) - logger.info( - "Configuring orchestrator", - mode=effective_mode, - auth_source=auth_source, - has_oauth_token=bool(token_value), - model=model_id or "default", - provider=provider_name or "auto", - ) - - # Convert empty string to None for web_search_provider - web_search_provider_value = ( - web_search_provider if web_search_provider and web_search_provider.strip() else None - ) - - orchestrator, backend_name = configure_orchestrator( - use_mock=False, # Never use mock in production - HF Inference is the free fallback - mode=effective_mode, - oauth_token=token_value, # Use extracted token value - passed to all agents and services - hf_model=model_id, # None will use defaults in configure_orchestrator - hf_provider=provider_name, # None will use defaults in configure_orchestrator - graph_mode=graph_mode if graph_mode else None, - use_graph=use_graph, - web_search_provider=web_search_provider_value, # None will use settings default - ) - - yield { - "role": "assistant", - "content": f"🔧 **Backend**: {backend_name}\n\nProcessing your query...", - } - - # Convert history to ModelMessage format if needed - message_history: list[ModelMessage] = [] - if history: - for msg in history: - role = msg.get("role", "user") - content = msg.get("content", "") - if isinstance(content, str) and content.strip(): - message_history.append(ModelMessage(role=role, content=content)) # type: ignore[operator] - - # Run orchestrator and stream events - async for event in orchestrator.run( - processed_text, message_history=message_history if message_history else None - ): - chat_msg = event_to_chat_message(event) - yield chat_msg - - # Note: Audio output is now handled via on-demand TTS button - # Users click "Generate Audio" button to create TTS for the last response - - except Exception as e: - # Return error message without metadata to avoid issues during example caching - # Metadata can cause validation errors when Gradio caches examples - # Gradio Chatbot requires plain text - remove all markdown and special characters - error_msg = str(e).replace("**", "").replace("*", "").replace("`", "") - # Ensure content is a simple string without any special formatting - yield { - "role": "assistant", - "content": f"Error: {error_msg}. Please check your configuration and try again.", - } - - -async def update_model_provider_dropdowns( - oauth_token: gr.OAuthToken | None = None, - oauth_profile: gr.OAuthProfile | None = None, -) -> tuple[dict[str, Any], dict[str, Any], str]: - """Update model and provider dropdowns based on OAuth token. - - This function is called when OAuth token/profile changes (user logs in/out). - It queries HuggingFace API to get available models and providers. - - Args: - oauth_token: Gradio OAuth token - oauth_profile: Gradio OAuth profile - - Returns: - Tuple of (model_dropdown_update, provider_dropdown_update, status_message) - """ - from src.utils.hf_model_validator import ( - get_available_models, - get_available_providers, - validate_oauth_token, - ) - - # Extract token value - token_value: str | None = None - if oauth_token is not None: - if hasattr(oauth_token, "token"): - token_value = oauth_token.token - elif isinstance(oauth_token, str): - token_value = oauth_token - - # Default values (empty = use default) - default_models = [""] - default_providers = [""] - status_msg = "⚠️ Not authenticated - using default models" - - if not token_value: - # No token - return defaults - return ( - gr.update(choices=default_models, value=""), - gr.update(choices=default_providers, value=""), - status_msg, - ) - - try: - # Validate token and get available resources - validation_result = await validate_oauth_token(token_value) - - if not validation_result["is_valid"]: - status_msg = ( - f"❌ Token validation failed: {validation_result.get('error', 'Unknown error')}" - ) - return ( - gr.update(choices=default_models, value=""), - gr.update(choices=default_providers, value=""), - status_msg, - ) - - # Get available models and providers - models = await get_available_models(token=token_value, limit=50) - providers = await get_available_providers(token=token_value) - - # Combine with defaults - model_choices = ["", *models[:49]] # Keep first 49 + empty option - provider_choices = providers # Already includes "auto" - - username = validation_result.get("username", "User") - - # Build status message with warning if scope is missing - scope_warning = "" - if not validation_result["has_inference_api_scope"]: - scope_warning = ( - "⚠️ Token may not have 'inference-api' scope - some models may not work\n\n" - ) - - status_msg = ( - f"{scope_warning}✅ Authenticated as {username}\n\n" - f"📊 Found {len(models)} available models\n" - f"🔧 Found {len(providers)} available providers" - ) - - logger.info( - "Updated model/provider dropdowns", - model_count=len(model_choices), - provider_count=len(provider_choices), - username=username, - ) - - return ( - gr.update(choices=model_choices, value=""), - gr.update(choices=provider_choices, value=""), - status_msg, - ) - - except Exception as e: - logger.error("Failed to update dropdowns", error=str(e)) - status_msg = f"⚠️ Failed to load models: {e!s}" - return ( - gr.update(choices=default_models, value=""), - gr.update(choices=default_providers, value=""), - status_msg, - ) - - -def create_demo() -> gr.Blocks: - """ - Create the Gradio demo interface with MCP support and OAuth login. - - Returns: - Configured Gradio Blocks interface with MCP server and OAuth enabled - """ - with gr.Blocks(title="🔬 The DETERMINATOR", fill_height=True) as demo: - # Add sidebar with login button and information - # Reference: Working implementation pattern from Gradio docs - with gr.Sidebar(): - gr.Markdown("# 🔐 Authentication") - gr.Markdown( - "**Sign in with Hugging Face** to access AI models and research tools.\n\n" - "This application requires authentication to use the inference API." - ) - gr.LoginButton("Sign in with Hugging Face") - gr.Markdown("---") - - # About Section - Collapsible with details - with gr.Accordion("ℹ️ About", open=False): - gr.Markdown( - "**The DETERMINATOR** - Generalist Deep Research Agent\n\n" - "Stops at nothing until finding precise answers to complex questions.\n\n" - "**How It Works**:\n" - "- 🔍 Multi-source search (Web, PubMed, ClinicalTrials.gov, Europe PMC, RAG)\n" - "- 🧠 Automatic medical knowledge detection\n" - "- 🔄 Iterative refinement with search-judge loops\n" - "- ⏹️ Continues until budget/time/iteration limits\n" - "- 📊 Evidence synthesis with citations\n\n" - "**Multimodal Input**:\n" - "- 📷 **Images**: Click image icon in textbox (OCR)\n" - "- 🎤 **Audio**: Click microphone icon (speech-to-text)\n" - "- 📄 **Files**: Drag & drop or click to upload\n\n" - "**MCP Server**: Connect Claude Desktop to `/gradio_api/mcp/`\n\n" - "⚠️ **Research tool only** - Synthesizes evidence but cannot provide medical advice." - ) - - gr.Markdown("---") - - # Settings Section - Organized in Accordions - gr.Markdown("## ⚙️ Settings") - - # Research Configuration Accordion - with gr.Accordion("🔬 Research Configuration", open=True): - mode_radio = gr.Radio( - choices=["simple", "advanced", "iterative", "deep", "auto"], - value="simple", - label="Orchestrator Mode", - info=( - "Simple: Linear search-judge loop | " - "Advanced: Multi-agent (OpenAI) | " - "Iterative: Knowledge-gap driven | " - "Deep: Parallel sections | " - "Auto: Smart routing" - ), - ) - - graph_mode_radio = gr.Radio( - choices=["iterative", "deep", "auto"], - value="auto", - label="Graph Research Mode", - info="Iterative: Single loop | Deep: Parallel sections | Auto: Detect from query", - ) - - use_graph_checkbox = gr.Checkbox( - value=True, - label="Use Graph Execution", - info="Enable graph-based workflow execution", - ) - - # Model and Provider selection - gr.Markdown("### 🤖 Model & Provider") - - # Status message for model/provider loading - model_provider_status = gr.Markdown( - value="⚠️ Sign in to see available models and providers", - visible=True, - ) - - # Popular models list (will be updated by validator) - popular_models = [ - "", # Empty = use default - "Qwen/Qwen3-Next-80B-A3B-Thinking", - "Qwen/Qwen3-235B-A22B-Instruct-2507", - "zai-org/GLM-4.5-Air", - "meta-llama/Llama-3.1-8B-Instruct", - "meta-llama/Llama-3.1-70B-Instruct", - "mistralai/Mistral-7B-Instruct-v0.2", - "google/gemma-2-9b-it", - ] - - hf_model_dropdown = gr.Dropdown( - choices=popular_models, - value="", # Empty string - will be converted to None in research_agent - label="Reasoning Model", - info="Select a HuggingFace model (leave empty for default). Sign in to see all available models.", - allow_custom_value=True, # Allow users to type custom model IDs - ) - - # Provider list from README (will be updated by validator) - providers = [ - "", # Empty string = auto-select - "nebius", - "together", - "scaleway", - "hyperbolic", - "novita", - "nscale", - "sambanova", - "ovh", - "fireworks", - ] - - hf_provider_dropdown = gr.Dropdown( - choices=providers, - value="", # Empty string - will be converted to None in research_agent - label="Inference Provider", - info="Select inference provider (leave empty for auto-select). Sign in to see all available providers.", - ) - - # Refresh button for updating models/providers after login - def refresh_models_and_providers( - request: gr.Request, - ) -> tuple[dict[str, Any], dict[str, Any], str]: - """Handle refresh button click and update dropdowns.""" - import asyncio - - # Extract OAuth token and profile from request - oauth_token: gr.OAuthToken | None = None - oauth_profile: gr.OAuthProfile | None = None - - if request is not None: - # Try to get OAuth token from request - if hasattr(request, "oauth_token"): - oauth_token = request.oauth_token - if hasattr(request, "oauth_profile"): - oauth_profile = request.oauth_profile - - # Run async function in sync context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - update_model_provider_dropdowns(oauth_token, oauth_profile) - ) - return result - finally: - loop.close() - - refresh_models_btn = gr.Button( - value="🔄 Refresh Available Models", - visible=True, - size="sm", - ) - - # Pass request to get OAuth token from Gradio context - refresh_models_btn.click( - fn=refresh_models_and_providers, - inputs=[], # Request is automatically available in Gradio context - outputs=[hf_model_dropdown, hf_provider_dropdown, model_provider_status], - ) - - # Web Search Provider selection - gr.Markdown("### 🔍 Web Search Provider") - - # Available providers with labels indicating availability - # Format: (display_label, value) - Gradio Dropdown supports tuples - web_search_provider_options = [ - ("Auto-detect (Recommended)", "auto"), - ("Serper (Google Search + Full Content)", "serper"), - ("DuckDuckGo (Free, Snippets Only)", "duckduckgo"), - ("SearchXNG (Self-hosted) - Coming Soon", "searchxng"), # Not fully implemented - ("Brave - Coming Soon", "brave"), # Not implemented - ("Tavily - Coming Soon", "tavily"), # Not implemented - ] - - # Create Dropdown with label-value pairs - # Gradio will display labels but return values - # Disabled options are marked with "Coming Soon" in the label - # The factory will handle "not implemented" cases gracefully - web_search_provider_dropdown = gr.Dropdown( - choices=web_search_provider_options, - value="auto", - label="Web Search Provider", - info="Select web search provider. 'Auto' detects best available.", - ) - - # Multimodal Input Configuration - gr.Markdown("### 📷🎤 Multimodal Input") - - enable_image_input_checkbox = gr.Checkbox( - value=settings.enable_image_input, - label="Enable Image Input (OCR)", - info="Process uploaded images with OCR", - ) - - enable_audio_input_checkbox = gr.Checkbox( - value=settings.enable_audio_input, - label="Enable Audio Input (STT)", - info="Process uploaded/recorded audio with speech-to-text", - ) - - # Audio Output Configuration - Collapsible - with gr.Accordion("🔊 Audio Output (TTS)", open=False): - gr.Markdown( - "**Generate audio for research responses on-demand.**\n\n" - "Enter Modal keys below or set `MODAL_TOKEN_ID`/`MODAL_TOKEN_SECRET` in `.env` for local development." - ) - - with gr.Accordion("🔑 Modal Credentials (Optional)", open=False): - modal_token_id_input = gr.Textbox( - label="Modal Token ID", - placeholder="ak-... (leave empty to use .env)", - type="password", - value="", - ) - - modal_token_secret_input = gr.Textbox( - label="Modal Token Secret", - placeholder="as-... (leave empty to use .env)", - type="password", - value="", - ) - - with gr.Accordion("🎚️ Voice & Quality Settings", open=False): - tts_voice_dropdown = gr.Dropdown( - choices=[ - "af_heart", - "af_bella", - "af_sarah", - "af_sky", - "af_nova", - "af_shimmer", - "af_echo", - "af_fable", - "af_onyx", - "af_angel", - "af_asteria", - "af_jessica", - "af_elli", - "af_domi", - "af_gigi", - "af_freya", - "af_glinda", - "af_cora", - "af_serena", - "af_liv", - "af_naomi", - "af_rachel", - "af_antoni", - "af_thomas", - "af_charlie", - "af_emily", - "af_george", - "af_arnold", - "af_adam", - "af_sam", - "af_paul", - "af_josh", - "af_daniel", - "af_liam", - "af_dave", - "af_fin", - "af_sarah", - "af_glinda", - "af_grace", - "af_dorothy", - "af_michael", - "af_james", - "af_joseph", - "af_jeremy", - "af_ryan", - "af_oliver", - "af_harry", - "af_kyle", - "af_leo", - "af_otto", - "af_owen", - "af_pepper", - "af_phil", - "af_raven", - "af_rocky", - "af_rusty", - "af_serena", - "af_sky", - "af_spark", - "af_stella", - "af_storm", - "af_taylor", - "af_vera", - "af_will", - "af_aria", - "af_ash", - "af_ballad", - "af_bella", - "af_breeze", - "af_cove", - "af_dusk", - "af_ember", - "af_flash", - "af_flow", - "af_glow", - "af_harmony", - "af_journey", - "af_lullaby", - "af_lyra", - "af_melody", - "af_midnight", - "af_moon", - "af_muse", - "af_music", - "af_narrator", - "af_nightingale", - "af_poet", - "af_rain", - "af_redwood", - "af_rewind", - "af_river", - "af_sage", - "af_seashore", - "af_shadow", - "af_silver", - "af_song", - "af_starshine", - "af_story", - "af_summer", - "af_sun", - "af_thunder", - "af_tide", - "af_time", - "af_valentino", - "af_verdant", - "af_verse", - "af_vibrant", - "af_vivid", - "af_warmth", - "af_whisper", - "af_wilderness", - "af_willow", - "af_winter", - "af_wit", - "af_witness", - "af_wren", - "af_writer", - "af_zara", - "af_zeus", - "af_ziggy", - "af_zoom", - "af_river", - "am_michael", - "am_fenrir", - "am_puck", - "am_echo", - "am_eric", - "am_liam", - "am_onyx", - "am_santa", - "am_adam", - ], - value=settings.tts_voice, - label="TTS Voice", - info="Select TTS voice (American English voices: af_*, am_*)", - ) - - tts_speed_slider = gr.Slider( - minimum=0.5, - maximum=2.0, - value=settings.tts_speed, - step=0.1, - label="TTS Speech Speed", - info="Adjust TTS speech speed (0.5x to 2.0x)", - ) - - gr.Dropdown( - choices=["T4", "A10", "A100", "L4", "L40S"], - value=settings.tts_gpu or "T4", - label="TTS GPU Type", - info="Modal GPU type for TTS (T4 is cheapest, A100 is fastest). Note: GPU changes require app restart.", - visible=settings.modal_available, - interactive=False, # GPU type set at function definition time, requires restart - ) - - tts_use_llm_polish_checkbox = gr.Checkbox( - value=settings.tts_use_llm_polish, - label="Use LLM Polish for Audio", - info="Apply LLM-based final polish to remove remaining formatting artifacts (costs API calls)", - ) - - tts_generate_button = gr.Button( - "🎵 Generate Audio for Last Response", - variant="primary", - size="lg", - ) - - tts_status_text = gr.Markdown( - "Click the button above to generate audio for the last research response.", - elem_classes="tts-status", - ) - - # Audio output component (for TTS response) - audio_output = gr.Audio( - label="🔊 Audio Output", - visible=True, - ) - - # TTS on-demand generation handler - async def handle_tts_generation( - history: list[dict[str, Any]], - modal_token_id: str, - modal_token_secret: str, - voice: str, - speed: float, - use_llm_polish: bool, - ) -> tuple[Any | None, str]: - """Generate audio on-demand for the last response. - - Args: - history: Chat history - modal_token_id: Modal token ID from UI - modal_token_secret: Modal token secret from UI - voice: TTS voice selection - speed: TTS speed - use_llm_polish: Enable LLM polish - - Returns: - Tuple of (audio_output, status_message) - """ - from src.services.tts_modal import generate_audio_on_demand - - # Get last assistant message from history - # History is a list of tuples: [(user_msg, assistant_msg), ...] - if not history: - logger.warning("tts_no_history", history=history) - return None, "❌ No messages in history to generate audio for" - - # Debug: Log history format - logger.info( - "tts_history_debug", - history_type=type(history).__name__, - history_length=len(history) if isinstance(history, list) else 0, - first_entry_type=type(history[0]).__name__ - if isinstance(history, list) and len(history) > 0 - else None, - first_entry_sample=str(history[0])[:200] - if isinstance(history, list) and len(history) > 0 - else None, - ) - - # Get the last assistant message (second element of last tuple) - last_message = None - if isinstance(history, list) and len(history) > 0: - last_entry = history[-1] - # ChatInterface format: (user_message, assistant_message) - if isinstance(last_entry, (tuple, list)) and len(last_entry) >= 2: - last_message = last_entry[1] - logger.info( - "tts_extracted_from_tuple", message_type=type(last_message).__name__ - ) - # Dict format: {"role": "assistant", "content": "..."} - elif isinstance(last_entry, dict): - if last_entry.get("role") == "assistant": - content = last_entry.get("content", "") - # Content might be a list (multimodal) or string - if isinstance(content, list): - # Extract text from multimodal content list - last_message = " ".join(str(item) for item in content if item) - else: - last_message = content - logger.info( - "tts_extracted_from_dict", - message_type=type(content).__name__, - message_length=len(last_message) - if isinstance(last_message, str) - else 0, - ) - else: - logger.warning( - "tts_unknown_format", - entry_type=type(last_entry).__name__, - entry=str(last_entry)[:200], - ) - - # Also handle if last_message itself is a list - if isinstance(last_message, list): - last_message = " ".join(str(item) for item in last_message if item) - - if not last_message or not isinstance(last_message, str) or not last_message.strip(): - logger.error( - "tts_no_message_found", - last_message_type=type(last_message).__name__ if last_message else None, - last_message_value=str(last_message)[:100] if last_message else None, - ) - return None, "❌ No assistant response found in history" - - # Generate audio - audio_output, status_message = await generate_audio_on_demand( - text=last_message, - modal_token_id=modal_token_id, - modal_token_secret=modal_token_secret, - voice=voice, - speed=speed, - use_llm_polish=use_llm_polish, - ) - - return audio_output, status_message - - # Chat interface with multimodal support - # Examples are provided but will NOT run at startup (cache_examples=False) - # Users must log in first before using examples or submitting queries - chat_interface = gr.ChatInterface( - fn=research_agent, - multimodal=True, # Enable multimodal input (text + images + audio) - title="🔬 The DETERMINATOR", - description=( - "*Generalist Deep Research Agent — stops at nothing until finding precise answers*\n\n" - "💡 **Quick Start**: Type your research question below. Use 📷 for images, 🎤 for audio.\n\n" - "⚠️ **Sign in with HuggingFace** (sidebar) before starting." - ), - examples=[ - # When additional_inputs are provided, examples must be lists of lists - # Each inner list: [message, mode, hf_model, hf_provider, graph_mode, multimodal_enabled] - # Using actual model IDs and provider names from inference_models.py - # Note: Provider is optional - if empty, HF will auto-select - # These examples will NOT run at startup - users must click them after logging in - # All examples require deep iterative search and information retrieval across multiple sources - [ - # Medical research example (only one medical example) - "Create a comprehensive report on Long COVID treatments including clinical trials, mechanisms, and safety.", - "deep", - "zai-org/GLM-4.5-Air", - "nebius", - "deep", - True, - ], - [ - # Technical/Engineering example requiring deep research - "Analyze the current state of quantum computing architectures: compare different qubit technologies, error correction methods, and scalability challenges across major platforms including IBM, Google, and IonQ.", - "deep", - "Qwen/Qwen3-Next-80B-A3B-Thinking", - "nebius", - "deep", - True, - ], - [ - # Historical/Social Science example - "Research and synthesize information about the economic impact of the Industrial Revolution on European social structures, including changes in class dynamics, urbanization patterns, and labor movements from 1750-1900.", - "deep", - "meta-llama/Llama-3.1-70B-Instruct", - "together", - "deep", - True, - ], - [ - # Scientific/Physics example - "Investigate the latest developments in fusion energy research: compare ITER, SPARC, and other major projects, analyze recent breakthroughs in plasma confinement, and assess the timeline to commercial fusion power.", - "deep", - "Qwen/Qwen3-235B-A22B-Instruct-2507", - "hyperbolic", - "deep", - True, - ], - [ - # Technology/Business example - "Research the competitive landscape of AI chip manufacturers: analyze NVIDIA, AMD, Intel, and emerging players, compare architectures (GPU vs. TPU vs. NPU), and assess market positioning and future trends.", - "deep", - "zai-org/GLM-4.5-Air", - "fireworks", - "deep", - True, - ], - ], - additional_inputs=[ - mode_radio, - hf_model_dropdown, - hf_provider_dropdown, - graph_mode_radio, - use_graph_checkbox, - enable_image_input_checkbox, - enable_audio_input_checkbox, - web_search_provider_dropdown, - # Note: gr.OAuthToken and gr.OAuthProfile are automatically passed as function parameters - ], - cache_examples=False, # Don't cache examples - requires authentication - ) - - # Wire up TTS generation button - tts_generate_button.click( - fn=handle_tts_generation, - inputs=[ - chat_interface.chatbot, # Get chat history from ChatInterface - modal_token_id_input, - modal_token_secret_input, - tts_voice_dropdown, - tts_speed_slider, - tts_use_llm_polish_checkbox, - ], - outputs=[audio_output, tts_status_text], - ) - - return demo # type: ignore[no-any-return] - - -if __name__ == "__main__": - demo = create_demo() - demo.launch(server_name="0.0.0.0", server_port=7860) diff --git a/src/legacy_orchestrator.py b/src/legacy_orchestrator.py deleted file mode 100644 index b41ba8aad566bee202c37ddbc142e677a7a9fce4..0000000000000000000000000000000000000000 --- a/src/legacy_orchestrator.py +++ /dev/null @@ -1,453 +0,0 @@ -"""Orchestrator - the agent loop connecting Search and Judge.""" - -import asyncio -from collections.abc import AsyncGenerator -from typing import Any, Protocol - -import structlog - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] - -from src.utils.config import settings -from src.utils.models import ( - AgentEvent, - Evidence, - JudgeAssessment, - OrchestratorConfig, - SearchResult, -) - -logger = structlog.get_logger() - - -class SearchHandlerProtocol(Protocol): - """Protocol for search handler.""" - - async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult: ... - - -class JudgeHandlerProtocol(Protocol): - """Protocol for judge handler.""" - - async def assess(self, question: str, evidence: list[Evidence]) -> JudgeAssessment: ... - - -class Orchestrator: - """ - The agent orchestrator - runs the Search -> Judge -> Loop cycle. - - This is a generator-based design that yields events for real-time UI updates. - """ - - def __init__( - self, - search_handler: SearchHandlerProtocol, - judge_handler: JudgeHandlerProtocol, - config: OrchestratorConfig | None = None, - enable_analysis: bool = False, - enable_embeddings: bool = True, - ): - """ - Initialize the orchestrator. - - Args: - search_handler: Handler for executing searches - judge_handler: Handler for assessing evidence - config: Optional configuration (uses defaults if not provided) - enable_analysis: Whether to perform statistical analysis (if Modal available) - enable_embeddings: Whether to use semantic search for ranking/dedup - """ - self.search = search_handler - self.judge = judge_handler - self.config = config or OrchestratorConfig() - self.history: list[dict[str, Any]] = [] - self._enable_analysis = enable_analysis and settings.modal_available - self._enable_embeddings = enable_embeddings - - # Lazy-load services - self._analyzer: Any = None - self._embeddings: Any = None - - def _get_analyzer(self) -> Any: - """Lazy initialization of StatisticalAnalyzer. - - Note: This imports from src.services, NOT src.agents, - so it works without the magentic optional dependency. - """ - if self._analyzer is None: - from src.services.statistical_analyzer import get_statistical_analyzer - - self._analyzer = get_statistical_analyzer() - return self._analyzer - - def _get_embeddings(self) -> Any: - """Lazy initialization of EmbeddingService. - - Uses local sentence-transformers - NO API key required. - """ - if self._embeddings is None and self._enable_embeddings: - try: - from src.services.embeddings import get_embedding_service - - self._embeddings = get_embedding_service() - logger.info("Embedding service enabled for semantic ranking") - except Exception as e: - logger.warning("Embeddings unavailable, using basic ranking", error=str(e)) - self._enable_embeddings = False - return self._embeddings - - async def _deduplicate_and_rank(self, evidence: list[Evidence], query: str) -> list[Evidence]: - """Use embeddings to deduplicate and rank evidence by relevance.""" - embeddings = self._get_embeddings() - if not embeddings or not evidence: - return evidence - - try: - # Deduplicate using semantic similarity - unique_evidence: list[Evidence] = await embeddings.deduplicate(evidence, threshold=0.85) - logger.info( - "Deduplicated evidence", - before=len(evidence), - after=len(unique_evidence), - ) - return unique_evidence - except Exception as e: - logger.warning("Deduplication failed, using original", error=str(e)) - return evidence - - async def _run_analysis_phase( - self, query: str, evidence: list[Evidence], iteration: int - ) -> AsyncGenerator[AgentEvent, None]: - """Run the optional analysis phase.""" - if not self._enable_analysis: - return - - yield AgentEvent( - type="analyzing", - message="Running statistical analysis in Modal sandbox...", - data={}, - iteration=iteration, - ) - - try: - analyzer = self._get_analyzer() - - # Run Modal analysis (no agent_framework needed!) - analysis_result = await analyzer.analyze( - query=query, - evidence=evidence, - hypothesis=None, # Could add hypothesis generation later - ) - - yield AgentEvent( - type="analysis_complete", - message=f"Analysis verdict: {analysis_result.verdict}", - data=analysis_result.model_dump(), - iteration=iteration, - ) - - except Exception as e: - logger.error("Modal analysis failed", error=str(e)) - yield AgentEvent( - type="error", - message=f"Modal analysis failed: {e}", - data={"error": str(e)}, - iteration=iteration, - ) - - async def run( - self, query: str, message_history: list[ModelMessage] | None = None - ) -> AsyncGenerator[AgentEvent, None]: # noqa: PLR0915 - """ - Run the agent loop for a query. - - Yields AgentEvent objects for each step, allowing real-time UI updates. - - Args: - query: The user's research question - message_history: Optional user conversation history (for compatibility) - - Yields: - AgentEvent objects for each step of the process - """ - logger.info( - "Starting orchestrator", - query=query, - has_history=bool(message_history), - ) - - yield AgentEvent( - type="started", - message=f"Starting research for: {query}", - iteration=0, - ) - - all_evidence: list[Evidence] = [] - current_queries = [query] - iteration = 0 - - while iteration < self.config.max_iterations: - iteration += 1 - logger.info("Iteration", iteration=iteration, queries=current_queries) - - # === SEARCH PHASE === - yield AgentEvent( - type="searching", - message=f"Searching for: {', '.join(current_queries[:3])}...", - iteration=iteration, - ) - - try: - # Execute searches for all current queries - search_tasks = [ - self.search.execute(q, self.config.max_results_per_tool) - for q in current_queries[:3] # Limit to 3 queries per iteration - ] - search_results = await asyncio.gather(*search_tasks, return_exceptions=True) - - # Collect evidence from successful searches - new_evidence: list[Evidence] = [] - errors: list[str] = [] - - for q, result in zip(current_queries[:3], search_results, strict=False): - if isinstance(result, Exception): - errors.append(f"Search for '{q}' failed: {result!s}") - elif isinstance(result, SearchResult): - new_evidence.extend(result.evidence) - errors.extend(result.errors) - else: - # Should not happen with return_exceptions=True but safe fallback - errors.append(f"Unknown result type for '{q}': {type(result)}") - - # Deduplicate evidence by URL (fast, basic) - seen_urls = {e.citation.url for e in all_evidence} - unique_new = [e for e in new_evidence if e.citation.url not in seen_urls] - all_evidence.extend(unique_new) - - # Semantic deduplication and ranking (if embeddings available) - all_evidence = await self._deduplicate_and_rank(all_evidence, query) - - yield AgentEvent( - type="search_complete", - message=f"Found {len(unique_new)} new sources ({len(all_evidence)} total)", - data={ - "new_count": len(unique_new), - "total_count": len(all_evidence), - }, - iteration=iteration, - ) - - if errors: - logger.warning("Search errors", errors=errors) - - except Exception as e: - logger.error("Search phase failed", error=str(e)) - yield AgentEvent( - type="error", - message=f"Search failed: {e!s}", - iteration=iteration, - ) - continue - - # === JUDGE PHASE === - yield AgentEvent( - type="judging", - message=f"Evaluating {len(all_evidence)} sources...", - iteration=iteration, - ) - - try: - assessment = await self.judge.assess(query, all_evidence) - - yield AgentEvent( - type="judge_complete", - message=( - f"Assessment: {assessment.recommendation} " - f"(confidence: {assessment.confidence:.0%})" - ), - data={ - "sufficient": assessment.sufficient, - "confidence": assessment.confidence, - "mechanism_score": assessment.details.mechanism_score, - "clinical_score": assessment.details.clinical_evidence_score, - }, - iteration=iteration, - ) - - # Record this iteration in history - self.history.append( - { - "iteration": iteration, - "queries": current_queries, - "evidence_count": len(all_evidence), - "assessment": assessment.model_dump(), - } - ) - - # === DECISION PHASE === - if assessment.sufficient and assessment.recommendation == "synthesize": - # Optional Analysis Phase - async for event in self._run_analysis_phase(query, all_evidence, iteration): - yield event - - yield AgentEvent( - type="synthesizing", - message="Evidence sufficient! Preparing synthesis...", - iteration=iteration, - ) - - # Generate final response - final_response = self._generate_synthesis(query, all_evidence, assessment) - - yield AgentEvent( - type="complete", - message=final_response, - data={ - "evidence_count": len(all_evidence), - "iterations": iteration, - "drug_candidates": assessment.details.drug_candidates, - "key_findings": assessment.details.key_findings, - }, - iteration=iteration, - ) - return - - else: - # Need more evidence - prepare next queries - current_queries = assessment.next_search_queries or [ - f"{query} mechanism of action", - f"{query} clinical evidence", - ] - - yield AgentEvent( - type="looping", - message=( - f"Need more evidence. " - f"Next searches: {', '.join(current_queries[:2])}..." - ), - data={"next_queries": current_queries}, - iteration=iteration, - ) - - except Exception as e: - logger.error("Judge phase failed", error=str(e)) - yield AgentEvent( - type="error", - message=f"Assessment failed: {e!s}", - iteration=iteration, - ) - continue - - # Max iterations reached - yield AgentEvent( - type="complete", - message=self._generate_partial_synthesis(query, all_evidence), - data={ - "evidence_count": len(all_evidence), - "iterations": iteration, - "max_reached": True, - }, - iteration=iteration, - ) - - def _generate_synthesis( - self, - query: str, - evidence: list[Evidence], - assessment: JudgeAssessment, - ) -> str: - """ - Generate the final synthesis response. - - Args: - query: The original question - evidence: All collected evidence - assessment: The final assessment - - Returns: - Formatted synthesis as markdown - """ - drug_list = ( - "\n".join([f"- **{d}**" for d in assessment.details.drug_candidates]) - or "- No specific candidates identified" - ) - findings_list = ( - "\n".join([f"- {f}" for f in assessment.details.key_findings]) or "- See evidence below" - ) - - citations = "\n".join( - [ - f"{i + 1}. [{e.citation.title}]({e.citation.url}) " - f"({e.citation.source.upper()}, {e.citation.date})" - for i, e in enumerate(evidence[:10]) # Limit to 10 citations - ] - ) - - return f"""## Research Analysis - -### Question -{query} - -### Drug Candidates -{drug_list} - -### Key Findings -{findings_list} - -### Assessment -- **Mechanism Score**: {assessment.details.mechanism_score}/10 -- **Clinical Evidence Score**: {assessment.details.clinical_evidence_score}/10 -- **Confidence**: {assessment.confidence:.0%} - -### Reasoning -{assessment.reasoning} - -### Citations ({len(evidence)} sources) -{citations} - ---- -*Analysis based on {len(evidence)} sources across {len(self.history)} iterations.* -""" - - def _generate_partial_synthesis( - self, - query: str, - evidence: list[Evidence], - ) -> str: - """ - Generate a partial synthesis when max iterations reached. - - Args: - query: The original question - evidence: All collected evidence - - Returns: - Formatted partial synthesis as markdown - """ - citations = "\n".join( - [ - f"{i + 1}. [{e.citation.title}]({e.citation.url}) ({e.citation.source.upper()})" - for i, e in enumerate(evidence[:10]) - ] - ) - - return f"""## Partial Analysis (Max Iterations Reached) - -### Question -{query} - -### Status -Maximum search iterations reached. The evidence gathered may be incomplete. - -### Evidence Collected -Found {len(evidence)} sources. Consider refining your query for more specific results. - -### Citations -{citations} - ---- -*Consider searching with more specific terms or drug names.* -""" diff --git a/src/mcp_tools.py b/src/mcp_tools.py deleted file mode 100644 index 7b59f9be0c47a3c9ae4ebd29f2d633202d21c8ef..0000000000000000000000000000000000000000 --- a/src/mcp_tools.py +++ /dev/null @@ -1,301 +0,0 @@ -"""MCP tool wrappers for The DETERMINATOR search tools. - -These functions expose our search tools via MCP protocol. -Each function follows the MCP tool contract: -- Full type hints -- Google-style docstrings with Args section -- Formatted string returns -""" - -from src.tools.clinicaltrials import ClinicalTrialsTool -from src.tools.europepmc import EuropePMCTool -from src.tools.pubmed import PubMedTool - -# Singleton instances (avoid recreating on each call) -_pubmed = PubMedTool() -_trials = ClinicalTrialsTool() -_europepmc = EuropePMCTool() - - -async def search_pubmed(query: str, max_results: int = 10) -> str: - """Search PubMed for peer-reviewed biomedical literature. - - Searches NCBI PubMed database for scientific papers matching your query. - Returns titles, authors, abstracts, and citation information. - - Args: - query: Search query (e.g., "metformin alzheimer", "cancer treatment mechanisms") - max_results: Maximum results to return (1-50, default 10) - - Returns: - Formatted search results with paper titles, authors, dates, and abstracts - """ - max_results = max(1, min(50, max_results)) # Clamp to valid range - - results = await _pubmed.search(query, max_results) - - if not results: - return f"No PubMed results found for: {query}" - - formatted = [f"## PubMed Results for: {query}\n"] - for i, evidence in enumerate(results, 1): - formatted.append(f"### {i}. {evidence.citation.title}") - formatted.append(f"**Authors**: {', '.join(evidence.citation.authors[:3])}") - formatted.append(f"**Date**: {evidence.citation.date}") - formatted.append(f"**URL**: {evidence.citation.url}") - formatted.append(f"\n{evidence.content}\n") - - return "\n".join(formatted) - - -async def search_clinical_trials(query: str, max_results: int = 10) -> str: - """Search ClinicalTrials.gov for clinical trial data. - - Searches the ClinicalTrials.gov database for trials matching your query. - Returns trial titles, phases, status, conditions, and interventions. - - Args: - query: Search query (e.g., "metformin alzheimer", "diabetes phase 3") - max_results: Maximum results to return (1-50, default 10) - - Returns: - Formatted clinical trial information with NCT IDs, phases, and status - """ - max_results = max(1, min(50, max_results)) - - results = await _trials.search(query, max_results) - - if not results: - return f"No clinical trials found for: {query}" - - formatted = [f"## Clinical Trials for: {query}\n"] - for i, evidence in enumerate(results, 1): - formatted.append(f"### {i}. {evidence.citation.title}") - formatted.append(f"**URL**: {evidence.citation.url}") - formatted.append(f"**Date**: {evidence.citation.date}") - formatted.append(f"\n{evidence.content}\n") - - return "\n".join(formatted) - - -async def search_europepmc(query: str, max_results: int = 10) -> str: - """Search Europe PMC for preprints and papers. - - Searches Europe PMC, which includes bioRxiv, medRxiv, and peer-reviewed content. - Useful for finding cutting-edge preprints and open access papers. - - Args: - query: Search query (e.g., "metformin neuroprotection", "long covid treatment") - max_results: Maximum results to return (1-50, default 10) - - Returns: - Formatted results with titles, authors, and abstracts - """ - max_results = max(1, min(50, max_results)) - - results = await _europepmc.search(query, max_results) - - if not results: - return f"No Europe PMC results found for: {query}" - - formatted = [f"## Europe PMC Results for: {query}\n"] - for i, evidence in enumerate(results, 1): - formatted.append(f"### {i}. {evidence.citation.title}") - formatted.append(f"**Authors**: {', '.join(evidence.citation.authors[:3])}") - formatted.append(f"**Date**: {evidence.citation.date}") - formatted.append(f"**URL**: {evidence.citation.url}") - formatted.append(f"\n{evidence.content}\n") - - return "\n".join(formatted) - - -async def search_all_sources(query: str, max_per_source: int = 5) -> str: - """Search all biomedical sources simultaneously. - - Performs parallel search across PubMed, ClinicalTrials.gov, and Europe PMC. - This is the most comprehensive search option for deep medical research inquiry. - - Args: - query: Search query (e.g., "metformin alzheimer", "aspirin cancer prevention") - max_per_source: Maximum results per source (1-20, default 5) - - Returns: - Combined results from all sources with source labels - """ - import asyncio - - max_per_source = max(1, min(20, max_per_source)) - - # Run all searches in parallel - pubmed_task = search_pubmed(query, max_per_source) - trials_task = search_clinical_trials(query, max_per_source) - europepmc_task = search_europepmc(query, max_per_source) - - pubmed_results, trials_results, europepmc_results = await asyncio.gather( - pubmed_task, trials_task, europepmc_task, return_exceptions=True - ) - - formatted = [f"# Comprehensive Search: {query}\n"] - - # Add each result section (handle exceptions gracefully) - if isinstance(pubmed_results, str): - formatted.append(pubmed_results) - else: - formatted.append(f"## PubMed\n*Error: {pubmed_results}*\n") - - if isinstance(trials_results, str): - formatted.append(trials_results) - else: - formatted.append(f"## Clinical Trials\n*Error: {trials_results}*\n") - - if isinstance(europepmc_results, str): - formatted.append(europepmc_results) - else: - formatted.append(f"## Europe PMC\n*Error: {europepmc_results}*\n") - - return "\n---\n".join(formatted) - - -async def analyze_hypothesis( - drug: str, - condition: str, - evidence_summary: str, -) -> str: - """Perform statistical analysis of research hypothesis using Modal. - - Executes AI-generated Python code in a secure Modal sandbox to analyze - the statistical evidence for a research hypothesis. - - Args: - drug: The drug being evaluated (e.g., "metformin") - condition: The target condition (e.g., "Alzheimer's disease") - evidence_summary: Summary of evidence to analyze - - Returns: - Analysis result with verdict (SUPPORTED/REFUTED/INCONCLUSIVE) and statistics - """ - from src.services.statistical_analyzer import get_statistical_analyzer - from src.utils.config import settings - from src.utils.models import Citation, Evidence - - if not settings.modal_available: - return "Error: Modal credentials not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET." - - # Create evidence from summary - evidence = [ - Evidence( - content=evidence_summary, - citation=Citation( - source="pubmed", - title=f"Evidence for {drug} in {condition}", - url="https://example.com", - date="2024-01-01", - authors=["User Provided"], - ), - relevance=0.9, - ) - ] - - analyzer = get_statistical_analyzer() - result = await analyzer.analyze( - query=f"Can {drug} treat {condition}?", - evidence=evidence, - hypothesis={"drug": drug, "target": "unknown", "pathway": "unknown", "effect": condition}, - ) - - return f"""## Statistical Analysis: {drug} for {condition} - -### Verdict: **{result.verdict}** -**Confidence**: {result.confidence:.0%} - -### Key Findings -{chr(10).join(f"- {f}" for f in result.key_findings) or "- No specific findings extracted"} - -### Execution Output -``` -{result.execution_output} -``` - -### Generated Code -```python -{result.code_generated} -``` - -**Executed in Modal Sandbox** - Isolated, secure, reproducible. -""" - - -async def extract_text_from_image( - image_path: str, model: str | None = None, hf_token: str | None = None -) -> str: - """Extract text from an image using OCR. - - Uses the Multimodal-OCR3 Gradio Space to extract text from images. - Supports various image formats (PNG, JPG, etc.) and can extract text - from scanned documents, screenshots, and other image types. - - Args: - image_path: Path to image file - model: Optional model selection (default: None, uses API default) - - Returns: - Extracted text from the image - """ - from src.services.image_ocr import get_image_ocr_service - from src.utils.config import settings - - try: - ocr_service = get_image_ocr_service() - # Use provided token or fallback to env vars - token = hf_token or settings.hf_token or settings.huggingface_api_key - extracted_text = await ocr_service.extract_text(image_path, model=model, hf_token=token) - - if not extracted_text: - return f"No text found in image: {image_path}" - - return f"## Extracted Text from Image\n\n{extracted_text}" - - except Exception as e: - return f"Error extracting text from image: {e}" - - -async def transcribe_audio_file( - audio_path: str, - source_lang: str | None = None, - target_lang: str | None = None, - hf_token: str | None = None, -) -> str: - """Transcribe audio file to text using speech-to-text. - - Uses the NVIDIA Canary Gradio Space to transcribe audio files. - Supports various audio formats (WAV, MP3, etc.) and multiple languages. - - Args: - audio_path: Path to audio file - source_lang: Source language (default: "English") - target_lang: Target language (default: "English") - - Returns: - Transcribed text from the audio file - """ - from src.services.stt_gradio import get_stt_service - from src.utils.config import settings - - try: - stt_service = get_stt_service() - # Use provided token or fallback to env vars - token = hf_token or settings.hf_token or settings.huggingface_api_key - transcribed_text = await stt_service.transcribe_file( - audio_path, - source_lang=source_lang, - target_lang=target_lang, - hf_token=token, - ) - - if not transcribed_text: - return f"No transcription found in audio: {audio_path}" - - return f"## Audio Transcription\n\n{transcribed_text}" - - except Exception as e: - return f"Error transcribing audio: {e}" diff --git a/src/middleware/__init__.py b/src/middleware/__init__.py deleted file mode 100644 index 2d296a27b28bd09955d82289397d4d889ece5b62..0000000000000000000000000000000000000000 --- a/src/middleware/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Middleware for workflow state management, parallel loop coordination, and budget tracking. - -This module provides: -- WorkflowState: Thread-safe state management using ContextVar -- WorkflowManager: Coordination of parallel research loops -- BudgetTracker: Token, time, and iteration budget tracking -""" - -from src.middleware.budget_tracker import BudgetStatus, BudgetTracker -from src.middleware.state_machine import ( - WorkflowState, - get_workflow_state, - init_workflow_state, -) -from src.middleware.workflow_manager import ( - LoopStatus, - ResearchLoop, - WorkflowManager, -) - -__all__ = [ - "BudgetStatus", - "BudgetTracker", - "LoopStatus", - "ResearchLoop", - "WorkflowManager", - "WorkflowState", - "get_workflow_state", - "init_workflow_state", -] diff --git a/src/middleware/budget_tracker.py b/src/middleware/budget_tracker.py deleted file mode 100644 index 7f7c89ece57b73c3c1a526a43238480f5e65fed5..0000000000000000000000000000000000000000 --- a/src/middleware/budget_tracker.py +++ /dev/null @@ -1,390 +0,0 @@ -"""Budget tracking for research loops. - -Tracks token usage, time elapsed, and iteration counts per loop and globally. -Enforces budget constraints to prevent infinite loops and excessive resource usage. -""" - -import time - -import structlog -from pydantic import BaseModel, Field - -logger = structlog.get_logger() - - -class BudgetStatus(BaseModel): - """Status of a budget (tokens, time, iterations).""" - - tokens_used: int = Field(default=0, description="Total tokens used") - tokens_limit: int = Field(default=100000, description="Token budget limit", ge=0) - time_elapsed_seconds: float = Field(default=0.0, description="Time elapsed", ge=0.0) - time_limit_seconds: float = Field( - default=600.0, description="Time budget limit (10 min default)", ge=0.0 - ) - iterations: int = Field(default=0, description="Number of iterations completed", ge=0) - iterations_limit: int = Field(default=10, description="Maximum iterations", ge=1) - iteration_tokens: dict[int, int] = Field( - default_factory=dict, - description="Tokens used per iteration (iteration number -> token count)", - ) - - def is_exceeded(self) -> bool: - """Check if any budget limit has been exceeded. - - Returns: - True if any limit is exceeded, False otherwise. - """ - return ( - self.tokens_used >= self.tokens_limit - or self.time_elapsed_seconds >= self.time_limit_seconds - or self.iterations >= self.iterations_limit - ) - - def remaining_tokens(self) -> int: - """Get remaining token budget. - - Returns: - Remaining tokens (may be negative if exceeded). - """ - return self.tokens_limit - self.tokens_used - - def remaining_time_seconds(self) -> float: - """Get remaining time budget. - - Returns: - Remaining time in seconds (may be negative if exceeded). - """ - return self.time_limit_seconds - self.time_elapsed_seconds - - def remaining_iterations(self) -> int: - """Get remaining iteration budget. - - Returns: - Remaining iterations (may be negative if exceeded). - """ - return self.iterations_limit - self.iterations - - def add_iteration_tokens(self, iteration: int, tokens: int) -> None: - """Add tokens for a specific iteration. - - Args: - iteration: Iteration number (1-indexed). - tokens: Number of tokens to add. - """ - if iteration not in self.iteration_tokens: - self.iteration_tokens[iteration] = 0 - self.iteration_tokens[iteration] += tokens - # Also add to total tokens - self.tokens_used += tokens - - def get_iteration_tokens(self, iteration: int) -> int: - """Get tokens used for a specific iteration. - - Args: - iteration: Iteration number. - - Returns: - Token count for the iteration, or 0 if not found. - """ - return self.iteration_tokens.get(iteration, 0) - - -class BudgetTracker: - """Tracks budgets per loop and globally.""" - - def __init__(self) -> None: - """Initialize the budget tracker.""" - self._budgets: dict[str, BudgetStatus] = {} - self._start_times: dict[str, float] = {} - self._global_budget: BudgetStatus | None = None - - def create_budget( - self, - loop_id: str, - tokens_limit: int = 100000, - time_limit_seconds: float = 600.0, - iterations_limit: int = 10, - ) -> BudgetStatus: - """Create a budget for a specific loop. - - Args: - loop_id: Unique identifier for the loop. - tokens_limit: Maximum tokens allowed. - time_limit_seconds: Maximum time allowed in seconds. - iterations_limit: Maximum iterations allowed. - - Returns: - The created BudgetStatus instance. - """ - budget = BudgetStatus( - tokens_limit=tokens_limit, - time_limit_seconds=time_limit_seconds, - iterations_limit=iterations_limit, - ) - self._budgets[loop_id] = budget - logger.debug( - "Budget created", - loop_id=loop_id, - tokens_limit=tokens_limit, - time_limit=time_limit_seconds, - iterations_limit=iterations_limit, - ) - return budget - - def get_budget(self, loop_id: str) -> BudgetStatus | None: - """Get the budget for a specific loop. - - Args: - loop_id: Unique identifier for the loop. - - Returns: - The BudgetStatus instance, or None if not found. - """ - return self._budgets.get(loop_id) - - def add_tokens(self, loop_id: str, tokens: int) -> None: - """Add tokens to a loop's budget. - - Args: - loop_id: Unique identifier for the loop. - tokens: Number of tokens to add (can be negative). - """ - if loop_id not in self._budgets: - logger.warning("Budget not found for loop", loop_id=loop_id) - return - self._budgets[loop_id].tokens_used += tokens - logger.debug("Tokens added", loop_id=loop_id, tokens=tokens) - - def add_iteration_tokens(self, loop_id: str, iteration: int, tokens: int) -> None: - """Add tokens for a specific iteration. - - Args: - loop_id: Loop identifier. - iteration: Iteration number (1-indexed). - tokens: Number of tokens to add. - """ - if loop_id not in self._budgets: - logger.warning("Budget not found for loop", loop_id=loop_id) - return - - budget = self._budgets[loop_id] - budget.add_iteration_tokens(iteration, tokens) - - logger.debug( - "Iteration tokens added", - loop_id=loop_id, - iteration=iteration, - tokens=tokens, - total_iteration=budget.get_iteration_tokens(iteration), - ) - - def get_iteration_tokens(self, loop_id: str, iteration: int) -> int: - """Get tokens used for a specific iteration. - - Args: - loop_id: Loop identifier. - iteration: Iteration number. - - Returns: - Token count for the iteration, or 0 if not found. - """ - if loop_id not in self._budgets: - return 0 - - return self._budgets[loop_id].get_iteration_tokens(iteration) - - def start_timer(self, loop_id: str) -> None: - """Start the timer for a loop. - - Args: - loop_id: Unique identifier for the loop. - """ - self._start_times[loop_id] = time.time() - logger.debug("Timer started", loop_id=loop_id) - - def update_timer(self, loop_id: str) -> None: - """Update the elapsed time for a loop. - - Args: - loop_id: Unique identifier for the loop. - """ - if loop_id not in self._start_times: - logger.warning("Timer not started for loop", loop_id=loop_id) - return - if loop_id not in self._budgets: - logger.warning("Budget not found for loop", loop_id=loop_id) - return - - elapsed = time.time() - self._start_times[loop_id] - self._budgets[loop_id].time_elapsed_seconds = elapsed - logger.debug("Timer updated", loop_id=loop_id, elapsed=elapsed) - - def increment_iteration(self, loop_id: str) -> None: - """Increment the iteration count for a loop. - - Args: - loop_id: Unique identifier for the loop. - """ - if loop_id not in self._budgets: - logger.warning("Budget not found for loop", loop_id=loop_id) - return - self._budgets[loop_id].iterations += 1 - logger.debug( - "Iteration incremented", - loop_id=loop_id, - iterations=self._budgets[loop_id].iterations, - ) - - def check_budget(self, loop_id: str) -> tuple[bool, str]: - """Check if a loop's budget has been exceeded. - - Args: - loop_id: Unique identifier for the loop. - - Returns: - Tuple of (exceeded: bool, reason: str). Reason is empty if not exceeded. - """ - if loop_id not in self._budgets: - return False, "" - - budget = self._budgets[loop_id] - self.update_timer(loop_id) # Update time before checking - - if budget.is_exceeded(): - reasons = [] - if budget.tokens_used >= budget.tokens_limit: - reasons.append("tokens") - if budget.time_elapsed_seconds >= budget.time_limit_seconds: - reasons.append("time") - if budget.iterations >= budget.iterations_limit: - reasons.append("iterations") - reason = f"Budget exceeded: {', '.join(reasons)}" - logger.warning("Budget exceeded", loop_id=loop_id, reason=reason) - return True, reason - - return False, "" - - def can_continue(self, loop_id: str) -> bool: - """Check if a loop can continue based on budget. - - Args: - loop_id: Unique identifier for the loop. - - Returns: - True if the loop can continue, False if budget is exceeded. - """ - exceeded, _ = self.check_budget(loop_id) - return not exceeded - - def get_budget_summary(self, loop_id: str) -> str: - """Get a formatted summary of a loop's budget status. - - Args: - loop_id: Unique identifier for the loop. - - Returns: - Formatted string summary. - """ - if loop_id not in self._budgets: - return f"Budget not found for loop: {loop_id}" - - budget = self._budgets[loop_id] - self.update_timer(loop_id) - - return ( - f"Loop {loop_id}: " - f"Tokens: {budget.tokens_used}/{budget.tokens_limit} " - f"({budget.remaining_tokens()} remaining), " - f"Time: {budget.time_elapsed_seconds:.1f}/{budget.time_limit_seconds:.1f}s " - f"({budget.remaining_time_seconds():.1f}s remaining), " - f"Iterations: {budget.iterations}/{budget.iterations_limit} " - f"({budget.remaining_iterations()} remaining)" - ) - - def reset_budget(self, loop_id: str) -> None: - """Reset the budget for a loop. - - Args: - loop_id: Unique identifier for the loop. - """ - if loop_id in self._budgets: - old_budget = self._budgets[loop_id] - # Preserve iteration_tokens when resetting - old_iteration_tokens = old_budget.iteration_tokens - self._budgets[loop_id] = BudgetStatus( - tokens_limit=old_budget.tokens_limit, - time_limit_seconds=old_budget.time_limit_seconds, - iterations_limit=old_budget.iterations_limit, - iteration_tokens=old_iteration_tokens, # Restore old iteration tokens - ) - if loop_id in self._start_times: - self._start_times[loop_id] = time.time() - logger.debug("Budget reset", loop_id=loop_id) - - def set_global_budget( - self, - tokens_limit: int = 100000, - time_limit_seconds: float = 600.0, - iterations_limit: int = 10, - ) -> None: - """Set a global budget that applies to all loops. - - Args: - tokens_limit: Maximum tokens allowed globally. - time_limit_seconds: Maximum time allowed in seconds. - iterations_limit: Maximum iterations allowed globally. - """ - self._global_budget = BudgetStatus( - tokens_limit=tokens_limit, - time_limit_seconds=time_limit_seconds, - iterations_limit=iterations_limit, - ) - logger.debug( - "Global budget set", - tokens_limit=tokens_limit, - time_limit=time_limit_seconds, - iterations_limit=iterations_limit, - ) - - def get_global_budget(self) -> BudgetStatus | None: - """Get the global budget. - - Returns: - The global BudgetStatus instance, or None if not set. - """ - return self._global_budget - - def add_global_tokens(self, tokens: int) -> None: - """Add tokens to the global budget. - - Args: - tokens: Number of tokens to add (can be negative). - """ - if self._global_budget is None: - logger.warning("Global budget not set") - return - self._global_budget.tokens_used += tokens - logger.debug("Global tokens added", tokens=tokens) - - def estimate_tokens(self, text: str) -> int: - """Estimate token count from text (rough estimate: ~4 chars per token). - - Args: - text: Text to estimate tokens for. - - Returns: - Estimated token count. - """ - return len(text) // 4 - - def estimate_llm_call_tokens(self, prompt: str, response: str) -> int: - """Estimate token count for an LLM call. - - Args: - prompt: The prompt text. - response: The response text. - - Returns: - Estimated total token count (prompt + response). - """ - return self.estimate_tokens(prompt) + self.estimate_tokens(response) diff --git a/src/middleware/state_machine.py b/src/middleware/state_machine.py deleted file mode 100644 index 8fbdf793b5ebc2d64a98501f543de4a3979a8132..0000000000000000000000000000000000000000 --- a/src/middleware/state_machine.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Thread-safe state management for workflow agents. - -Uses contextvars to ensure isolation between concurrent requests (e.g., multiple users -searching simultaneously via Gradio). Refactored from MagenticState to support both -iterative and deep research patterns. -""" - -from contextvars import ContextVar -from typing import TYPE_CHECKING, Any - -import structlog -from pydantic import BaseModel, Field - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] - -from src.utils.models import Citation, Conversation, Evidence - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - -logger = structlog.get_logger() - - -class WorkflowState(BaseModel): - """Mutable state for a workflow session. - - Supports both iterative and deep research patterns by tracking evidence, - conversation history, and providing semantic search capabilities. - """ - - evidence: list[Evidence] = Field(default_factory=list) - conversation: Conversation = Field(default_factory=Conversation) - user_message_history: list[ModelMessage] = Field( - default_factory=list, - description="User conversation history (multi-turn interactions)", - ) - # Type as Any to avoid circular imports/runtime resolution issues - # The actual object injected will be an EmbeddingService instance - embedding_service: Any = Field(default=None) - - model_config = {"arbitrary_types_allowed": True} - - def add_evidence(self, new_evidence: list[Evidence]) -> int: - """Add new evidence, deduplicating by URL. - - Args: - new_evidence: List of Evidence objects to add. - - Returns: - Number of *new* items added (excluding duplicates). - """ - existing_urls = {e.citation.url for e in self.evidence} - count = 0 - for item in new_evidence: - if item.citation.url not in existing_urls: - self.evidence.append(item) - existing_urls.add(item.citation.url) - count += 1 - return count - - async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]: - """Search for semantically related evidence using the embedding service. - - Args: - query: Search query string. - n_results: Maximum number of results to return. - - Returns: - List of Evidence objects, ordered by relevance. - """ - if not self.embedding_service: - logger.warning("Embedding service not available, returning empty results") - return [] - - results = await self.embedding_service.search_similar(query, n_results=n_results) - - # Convert dict results back to Evidence objects - evidence_list = [] - for item in results: - meta = item.get("metadata", {}) - authors_str = meta.get("authors", "") - authors = [a.strip() for a in authors_str.split(",") if a.strip()] - - ev = Evidence( - content=item["content"], - citation=Citation( - title=meta.get("title", "Related Evidence"), - url=item["id"], - source="pubmed", # Defaulting to pubmed if unknown - date=meta.get("date", "n.d."), - authors=authors, - ), - relevance=max(0.0, 1.0 - item.get("distance", 0.5)), - ) - evidence_list.append(ev) - - return evidence_list - - def add_user_message(self, message: ModelMessage) -> None: - """Add a user message to conversation history. - - Args: - message: Message to add - """ - self.user_message_history.append(message) - - def get_user_history(self, max_messages: int | None = None) -> list[ModelMessage]: - """Get user conversation history. - - Args: - max_messages: Maximum messages to return (None for all) - - Returns: - List of messages - """ - if max_messages is None: - return self.user_message_history.copy() - return ( - self.user_message_history[-max_messages:] - if len(self.user_message_history) > max_messages - else self.user_message_history.copy() - ) - - -# The ContextVar holds the WorkflowState for the current execution context -_workflow_state_var: ContextVar[WorkflowState | None] = ContextVar("workflow_state", default=None) - - -def init_workflow_state( - embedding_service: "EmbeddingService | None" = None, - message_history: list[ModelMessage] | None = None, -) -> WorkflowState: - """Initialize a new state for the current context. - - Args: - embedding_service: Optional embedding service for semantic search. - message_history: Optional user conversation history. - - Returns: - The initialized WorkflowState instance. - """ - state = WorkflowState(embedding_service=embedding_service) - if message_history: - state.user_message_history = message_history.copy() - _workflow_state_var.set(state) - logger.debug( - "Workflow state initialized", - has_embeddings=embedding_service is not None, - has_history=bool(message_history), - ) - return state - - -def get_workflow_state() -> WorkflowState: - """Get the current state. Auto-initializes if not set. - - Returns: - The current WorkflowState instance. - - Raises: - RuntimeError: If state is not initialized and auto-initialization fails. - """ - state = _workflow_state_var.get() - if state is None: - # Auto-initialize if missing (e.g. during tests or simple scripts) - logger.debug("Workflow state not found, auto-initializing") - return init_workflow_state() - return state \ No newline at end of file diff --git a/src/middleware/sub_iteration.py b/src/middleware/sub_iteration.py deleted file mode 100644 index 801a3686a6d023c39615d01548766e4c24098c66..0000000000000000000000000000000000000000 --- a/src/middleware/sub_iteration.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Middleware for orchestrating sub-iterations with research teams and judges.""" - -from typing import Any, Protocol - -import structlog - -from src.utils.models import AgentEvent, JudgeAssessment - -logger = structlog.get_logger() - - -class SubIterationTeam(Protocol): - """Protocol for a research team that executes a sub-task.""" - - async def execute(self, task: str) -> Any: - """Execute the sub-task and return a result.""" - ... - - -class SubIterationJudge(Protocol): - """Protocol for a judge that evaluates the sub-task result.""" - - async def assess(self, task: str, result: Any, history: list[Any]) -> JudgeAssessment: - """Assess the quality of the result.""" - ... - - -class SubIterationMiddleware: - """ - Middleware that manages a sub-iteration loop: - 1. Orchestrator delegates to a Research Team. - 2. Research Team produces a result. - 3. Judge evaluates the result. - 4. Loop continues until Judge approves or max iterations reached. - """ - - def __init__( - self, - team: SubIterationTeam, - judge: SubIterationJudge, - max_iterations: int = 3, - ): - self.team = team - self.judge = judge - self.max_iterations = max_iterations - - async def run( - self, - task: str, - event_callback: Any = None, # Optional callback for streaming events - ) -> tuple[Any, JudgeAssessment | None]: - """ - Run the sub-iteration loop. - - Args: - task: The research task or question. - event_callback: Async callable to report events (e.g. to UI). - - Returns: - Tuple of (best_result, final_assessment). - """ - history: list[Any] = [] - best_result: Any = None - final_assessment: JudgeAssessment | None = None - - for i in range(1, self.max_iterations + 1): - logger.info("Sub-iteration starting", iteration=i, task=task) - - if event_callback: - await event_callback( - AgentEvent( - type="looping", - message=f"Sub-iteration {i}: Executing task...", - iteration=i, - ) - ) - - # 1. Team Execution - try: - result = await self.team.execute(task) - history.append(result) - best_result = result # Assume latest is best for now - except Exception as e: - logger.error("Sub-iteration execution failed", error=str(e)) - if event_callback: - await event_callback( - AgentEvent( - type="error", - message=f"Sub-iteration execution failed: {e}", - iteration=i, - ) - ) - return best_result, final_assessment - - # 2. Judge Assessment - try: - assessment = await self.judge.assess(task, result, history) - final_assessment = assessment - except Exception as e: - logger.error("Sub-iteration judge failed", error=str(e)) - if event_callback: - await event_callback( - AgentEvent( - type="error", - message=f"Sub-iteration judge failed: {e}", - iteration=i, - ) - ) - return best_result, final_assessment - - # 3. Decision - if assessment.sufficient: - logger.info("Sub-iteration sufficient", iteration=i) - return best_result, assessment - - # If not sufficient, we might refine the task for the next iteration - # For this implementation, we assume the team is smart enough or the task stays same - # but we could append feedback to the task. - - feedback = assessment.reasoning - logger.info("Sub-iteration insufficient", feedback=feedback) - - if event_callback: - await event_callback( - AgentEvent( - type="looping", - message=( - f"Sub-iteration {i} result insufficient. Feedback: {feedback[:100]}..." - ), - iteration=i, - ) - ) - - logger.warning("Sub-iteration max iterations reached", task=task) - return best_result, final_assessment diff --git a/src/middleware/workflow_manager.py b/src/middleware/workflow_manager.py deleted file mode 100644 index b817cbf27a49fd19b4ad682416eae61cf4d90a1d..0000000000000000000000000000000000000000 --- a/src/middleware/workflow_manager.py +++ /dev/null @@ -1,322 +0,0 @@ -"""Workflow manager for coordinating parallel research loops. - -Manages multiple research loops running in parallel, tracks their status, -and synchronizes evidence between loops and the global state. -""" - -import asyncio -from collections.abc import Callable -from typing import Any, Literal - -import structlog -from pydantic import BaseModel, Field - -from src.middleware.state_machine import get_workflow_state -from src.utils.models import Evidence - -logger = structlog.get_logger() - -LoopStatus = Literal["pending", "running", "completed", "failed", "cancelled"] - - -class ResearchLoop(BaseModel): - """Represents a single research loop.""" - - loop_id: str = Field(description="Unique identifier for the loop") - query: str = Field(description="The research query for this loop") - status: LoopStatus = Field(default="pending") - evidence: list[Evidence] = Field(default_factory=list) - iteration_count: int = Field(default=0, ge=0) - error: str | None = Field(default=None) - - model_config = {"frozen": False} # Mutable for status updates - - -class WorkflowManager: - """Manages parallel research loops and state synchronization.""" - - def __init__(self) -> None: - """Initialize the workflow manager.""" - self._loops: dict[str, ResearchLoop] = {} - - async def add_loop(self, loop_id: str, query: str) -> ResearchLoop: - """Add a new research loop. - - Args: - loop_id: Unique identifier for the loop. - query: The research query for this loop. - - Returns: - The created ResearchLoop instance. - """ - loop = ResearchLoop(loop_id=loop_id, query=query, status="pending") - self._loops[loop_id] = loop - logger.info("Loop added", loop_id=loop_id, query=query) - return loop - - async def get_loop(self, loop_id: str) -> ResearchLoop | None: - """Get a research loop by ID. - - Args: - loop_id: Unique identifier for the loop. - - Returns: - The ResearchLoop instance, or None if not found. - """ - return self._loops.get(loop_id) - - async def update_loop_status( - self, loop_id: str, status: LoopStatus, error: str | None = None - ) -> None: - """Update the status of a research loop. - - Args: - loop_id: Unique identifier for the loop. - status: New status for the loop. - error: Optional error message if status is "failed". - """ - if loop_id not in self._loops: - logger.warning("Loop not found", loop_id=loop_id) - return - - self._loops[loop_id].status = status - if error: - self._loops[loop_id].error = error - logger.info("Loop status updated", loop_id=loop_id, status=status) - - async def add_loop_evidence(self, loop_id: str, evidence: list[Evidence]) -> None: - """Add evidence to a research loop. - - Args: - loop_id: Unique identifier for the loop. - evidence: List of Evidence objects to add. - """ - if loop_id not in self._loops: - logger.warning("Loop not found", loop_id=loop_id) - return - - self._loops[loop_id].evidence.extend(evidence) - logger.debug( - "Evidence added to loop", - loop_id=loop_id, - evidence_count=len(evidence), - ) - - async def increment_loop_iteration(self, loop_id: str) -> None: - """Increment the iteration count for a research loop. - - Args: - loop_id: Unique identifier for the loop. - """ - if loop_id not in self._loops: - logger.warning("Loop not found", loop_id=loop_id) - return - - self._loops[loop_id].iteration_count += 1 - logger.debug( - "Iteration incremented", - loop_id=loop_id, - iteration=self._loops[loop_id].iteration_count, - ) - - async def run_loops_parallel( - self, - loop_configs: list[dict[str, Any]], - loop_func: Callable[[dict[str, Any]], Any], - judge_handler: Any | None = None, - budget_tracker: Any | None = None, - ) -> list[Any]: - """Run multiple research loops in parallel. - - Args: - loop_configs: List of configuration dicts, each must contain 'loop_id' and 'query'. - loop_func: Async function that takes a config dict and returns loop results. - judge_handler: Optional JudgeHandler for early termination based on evidence sufficiency. - budget_tracker: Optional BudgetTracker for budget enforcement. - - Returns: - List of results from each loop (in order of completion, not original order). - """ - logger.info("Starting parallel loops", loop_count=len(loop_configs)) - - # Create loops - for config in loop_configs: - loop_id = config.get("loop_id") - query = config.get("query", "") - if loop_id: - await self.add_loop(loop_id, query) - await self.update_loop_status(loop_id, "running") - - # Run loops in parallel - async def run_single_loop(config: dict[str, Any]) -> Any: - loop_id = config.get("loop_id", "unknown") - query = config.get("query", "") - try: - # Check budget before starting - if budget_tracker: - exceeded, reason = budget_tracker.check_budget(loop_id) - if exceeded: - await self.update_loop_status(loop_id, "cancelled", error=reason) - logger.warning( - "Loop cancelled due to budget", loop_id=loop_id, reason=reason - ) - return None - - # If loop_func supports periodic checkpoints, we could check judge here - # For now, the loop_func itself handles judge checks internally - result = await loop_func(config) - - # Final check with judge if available - if judge_handler and query: - should_complete, reason = await self.check_loop_completion( - loop_id, query, judge_handler - ) - if should_complete: - logger.info( - "Loop completed early based on judge assessment", - loop_id=loop_id, - reason=reason, - ) - - await self.update_loop_status(loop_id, "completed") - return result - except Exception as e: - error_msg = str(e) - await self.update_loop_status(loop_id, "failed", error=error_msg) - logger.error("Loop failed", loop_id=loop_id, error=error_msg) - raise - - results = await asyncio.gather( - *(run_single_loop(config) for config in loop_configs), - return_exceptions=True, - ) - - # Log completion - completed = sum(1 for r in results if not isinstance(r, Exception)) - failed = len(results) - completed - logger.info( - "Parallel loops completed", - total=len(loop_configs), - completed=completed, - failed=failed, - ) - - return results - - async def wait_for_loops( - self, loop_ids: list[str], timeout: float | None = None - ) -> list[ResearchLoop]: - """Wait for loops to complete. - - Args: - loop_ids: List of loop IDs to wait for. - timeout: Optional timeout in seconds. - - Returns: - List of ResearchLoop instances (may be incomplete if timeout occurs). - """ - start_time = asyncio.get_event_loop().time() - - while True: - loops = [self._loops.get(loop_id) for loop_id in loop_ids] - all_complete = all( - loop and loop.status in ("completed", "failed", "cancelled") for loop in loops - ) - - if all_complete: - return [loop for loop in loops if loop is not None] - - if timeout is not None: - elapsed = asyncio.get_event_loop().time() - start_time - if elapsed >= timeout: - logger.warning("Timeout waiting for loops", timeout=timeout) - return [loop for loop in loops if loop is not None] - - await asyncio.sleep(0.1) # Small delay to avoid busy waiting - - async def cancel_loop(self, loop_id: str) -> None: - """Cancel a research loop. - - Args: - loop_id: Unique identifier for the loop. - """ - await self.update_loop_status(loop_id, "cancelled") - logger.info("Loop cancelled", loop_id=loop_id) - - async def get_all_loops(self) -> list[ResearchLoop]: - """Get all research loops. - - Returns: - List of all ResearchLoop instances. - """ - return list(self._loops.values()) - - async def sync_loop_evidence_to_state(self, loop_id: str) -> None: - """Synchronize evidence from a loop to the global state. - - Args: - loop_id: Unique identifier for the loop. - """ - if loop_id not in self._loops: - logger.warning("Loop not found", loop_id=loop_id) - return - - loop = self._loops[loop_id] - state = get_workflow_state() - added_count = state.add_evidence(loop.evidence) - logger.debug( - "Loop evidence synced to state", - loop_id=loop_id, - evidence_count=len(loop.evidence), - added_count=added_count, - ) - - async def get_shared_evidence(self) -> list[Evidence]: - """Get evidence from the global state. - - Returns: - List of Evidence objects from the global state. - """ - state = get_workflow_state() - return state.evidence - - async def get_loop_evidence(self, loop_id: str) -> list[Evidence]: - """Get evidence collected by a specific loop. - - Args: - loop_id: Loop identifier. - - Returns: - List of Evidence objects from the loop. - """ - if loop_id not in self._loops: - return [] - - return self._loops[loop_id].evidence - - async def check_loop_completion( - self, loop_id: str, query: str, judge_handler: Any - ) -> tuple[bool, str]: - """Check if a loop should complete using judge assessment. - - Args: - loop_id: Loop identifier. - query: Research query. - judge_handler: JudgeHandler instance. - - Returns: - Tuple of (should_complete: bool, reason: str). - """ - evidence = await self.get_loop_evidence(loop_id) - - if not evidence: - return False, "No evidence collected yet" - - try: - assessment = await judge_handler.assess(query, evidence) - if assessment.sufficient: - return True, f"Judge assessment: {assessment.reasoning}" - return False, f"Judge assessment: {assessment.reasoning}" - except Exception as e: - logger.error("Judge assessment failed", error=str(e), loop_id=loop_id) - return False, f"Judge assessment failed: {e!s}" diff --git a/src/__init__.py b/src/orchestrator.py similarity index 100% rename from src/__init__.py rename to src/orchestrator.py diff --git a/src/orchestrator/__init__.py b/src/orchestrator/__init__.py deleted file mode 100644 index e71f737ab54ccc3391d909d170f6938ac0a9836b..0000000000000000000000000000000000000000 --- a/src/orchestrator/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Orchestrator module for research flows and planner agent. - -This module provides: -- PlannerAgent: Creates report plans with sections -- IterativeResearchFlow: Single research loop pattern -- DeepResearchFlow: Parallel research loops pattern -- GraphOrchestrator: Stub for Phase 4 (uses agent chains for now) -- Protocols: SearchHandlerProtocol, JudgeHandlerProtocol (re-exported from legacy_orchestrator) -- Orchestrator: Legacy orchestrator class (re-exported from legacy_orchestrator) -""" - -from typing import TYPE_CHECKING - -# Re-export protocols and Orchestrator from legacy_orchestrator for backward compatibility -from src.legacy_orchestrator import ( - JudgeHandlerProtocol, - Orchestrator, - SearchHandlerProtocol, -) - -# Lazy imports to avoid circular dependencies -if TYPE_CHECKING: - from src.orchestrator.graph_orchestrator import GraphOrchestrator - from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent - from src.orchestrator.research_flow import ( - DeepResearchFlow, - IterativeResearchFlow, - ) - -# Public exports -from src.orchestrator.graph_orchestrator import ( - GraphOrchestrator, - create_graph_orchestrator, -) -from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent -from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow - -__all__ = [ - "DeepResearchFlow", - "GraphOrchestrator", - "IterativeResearchFlow", - "JudgeHandlerProtocol", - "Orchestrator", - "PlannerAgent", - "SearchHandlerProtocol", - "create_graph_orchestrator", - "create_planner_agent", -] diff --git a/src/orchestrator/graph_orchestrator.py b/src/orchestrator/graph_orchestrator.py deleted file mode 100644 index d71a337dde1929301a70217f99e560c5413e34fc..0000000000000000000000000000000000000000 --- a/src/orchestrator/graph_orchestrator.py +++ /dev/null @@ -1,1751 +0,0 @@ -"""Graph orchestrator for Phase 4. - -Implements graph-based orchestration using Pydantic AI agents as nodes. -Supports both iterative and deep research patterns with parallel execution. -""" - -import asyncio -from collections.abc import AsyncGenerator, Callable -from typing import TYPE_CHECKING, Any, Literal - -import structlog - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] - -from src.agent_factory.agents import ( - create_input_parser_agent, - create_knowledge_gap_agent, - create_long_writer_agent, - create_planner_agent, - create_thinking_agent, - create_tool_selector_agent, - create_writer_agent, -) -from src.agent_factory.graph_builder import ( - AgentNode, - DecisionNode, - ParallelNode, - ResearchGraph, - StateNode, - create_deep_graph, - create_iterative_graph, -) -from src.legacy_orchestrator import JudgeHandlerProtocol, SearchHandlerProtocol -from src.middleware.budget_tracker import BudgetTracker -from src.middleware.state_machine import WorkflowState, init_workflow_state -from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow -from src.services.report_file_service import ReportFileService, get_report_file_service -from src.utils.models import AgentEvent - -if TYPE_CHECKING: - pass - -logger = structlog.get_logger() - - -class GraphExecutionContext: - """Context for managing graph execution state.""" - - def __init__( - self, - state: WorkflowState, - budget_tracker: BudgetTracker, - message_history: list[ModelMessage] | None = None, - ) -> None: - """Initialize execution context. - - Args: - state: Current workflow state - budget_tracker: Budget tracker instance - message_history: Optional user conversation history - """ - self.current_node: str = "" - self.visited_nodes: set[str] = set() - self.node_results: dict[str, Any] = {} - self.state = state - self.budget_tracker = budget_tracker - self.iteration_count = 0 - self.message_history: list[ModelMessage] = message_history or [] - - def set_node_result(self, node_id: str, result: Any) -> None: - """Store result from node execution. - - Args: - node_id: The node ID - result: The execution result - """ - self.node_results[node_id] = result - - def get_node_result(self, node_id: str) -> Any: - """Get result from node execution. - - Args: - node_id: The node ID - - Returns: - The stored result, or None if not found - """ - return self.node_results.get(node_id) - - def has_visited(self, node_id: str) -> bool: - """Check if node was visited. - - Args: - node_id: The node ID - - Returns: - True if visited, False otherwise - """ - return node_id in self.visited_nodes - - def mark_visited(self, node_id: str) -> None: - """Mark node as visited. - - Args: - node_id: The node ID - """ - self.visited_nodes.add(node_id) - - def update_state( - self, updater: Callable[[WorkflowState, Any], WorkflowState], data: Any - ) -> None: - """Update workflow state. - - Args: - updater: Function to update state - data: Data to pass to updater - """ - self.state = updater(self.state, data) - - def add_message(self, message: ModelMessage) -> None: - """Add a message to the history. - - Args: - message: Message to add - """ - self.message_history.append(message) - - def get_message_history(self, max_messages: int | None = None) -> list[ModelMessage]: - """Get message history, optionally truncated. - - Args: - max_messages: Maximum messages to return (None for all) - - Returns: - List of messages - """ - if max_messages is None: - return self.message_history.copy() - return ( - self.message_history[-max_messages:] - if len(self.message_history) > max_messages - else self.message_history.copy() - ) - - -class GraphOrchestrator: - """ - Graph orchestrator using Pydantic AI Graphs. - - Executes research workflows as graphs with nodes (agents) and edges (transitions). - Supports parallel execution, conditional routing, and state management. - """ - - def __init__( - self, - mode: Literal["iterative", "deep", "auto"] = "auto", - max_iterations: int = 5, - max_time_minutes: int = 10, - use_graph: bool = True, - search_handler: SearchHandlerProtocol | None = None, - judge_handler: JudgeHandlerProtocol | None = None, - oauth_token: str | None = None, - ) -> None: - """ - Initialize graph orchestrator. - - Args: - mode: Research mode ("iterative", "deep", or "auto" to detect) - max_iterations: Maximum iterations per loop - max_time_minutes: Maximum time per loop - use_graph: Whether to use graph execution (True) or agent chains (False) - search_handler: Optional search handler for tool execution - judge_handler: Optional judge handler for evidence assessment - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - """ - self.mode = mode - self.max_iterations = max_iterations - self.max_time_minutes = max_time_minutes - self.use_graph = use_graph - self.search_handler = search_handler - self.judge_handler = judge_handler - self.oauth_token = oauth_token - self.logger = logger - - # Initialize file service (lazy if not provided) - self._file_service: ReportFileService | None = None - - # Initialize flows (for backward compatibility) - self._iterative_flow: IterativeResearchFlow | None = None - self._deep_flow: DeepResearchFlow | None = None - - # Graph execution components (lazy initialization) - self._graph: ResearchGraph | None = None - self._budget_tracker: BudgetTracker | None = None - - def _get_file_service(self) -> ReportFileService | None: - """ - Get file service instance (lazy initialization). - - Returns: - ReportFileService instance or None if disabled - """ - if self._file_service is None: - try: - self._file_service = get_report_file_service() - except Exception as e: - self.logger.warning("Failed to initialize file service", error=str(e)) - return None - return self._file_service - - async def run( - self, query: str, message_history: list[ModelMessage] | None = None - ) -> AsyncGenerator[AgentEvent, None]: - """ - Run the research workflow. - - Args: - query: The user's research query - message_history: Optional user conversation history - - Yields: - AgentEvent objects for real-time UI updates - """ - self.logger.info( - "Starting graph orchestrator", - query=query[:100], - mode=self.mode, - use_graph=self.use_graph, - has_history=bool(message_history), - ) - - yield AgentEvent( - type="started", - message=f"Starting research ({self.mode} mode): {query}", - iteration=0, - ) - - try: - # Determine research mode - research_mode = self.mode - if research_mode == "auto": - research_mode = await self._detect_research_mode(query) - - # Use graph execution if enabled, otherwise fall back to agent chains - if self.use_graph: - async for event in self._run_with_graph(query, research_mode, message_history): - yield event - else: - async for event in self._run_with_chains(query, research_mode, message_history): - yield event - - except Exception as e: - self.logger.error("Graph orchestrator failed", error=str(e), exc_info=True) - yield AgentEvent( - type="error", - message=f"Research failed: {e!s}", - iteration=0, - ) - - async def _run_with_graph( - self, - query: str, - research_mode: Literal["iterative", "deep"], - message_history: list[ModelMessage] | None = None, - ) -> AsyncGenerator[AgentEvent, None]: - """Run workflow using graph execution. - - Args: - query: The research query - research_mode: The research mode - message_history: Optional user conversation history - - Yields: - AgentEvent objects - """ - # Initialize state and budget tracker - from src.services.embeddings import get_embedding_service - - embedding_service = get_embedding_service() - state = init_workflow_state( - embedding_service=embedding_service, - message_history=message_history, - ) - budget_tracker = BudgetTracker() - budget_tracker.create_budget( - loop_id="graph_execution", - tokens_limit=100000, - time_limit_seconds=self.max_time_minutes * 60, - iterations_limit=self.max_iterations, - ) - budget_tracker.start_timer("graph_execution") - - context = GraphExecutionContext( - state, - budget_tracker, - message_history=message_history or [], - ) - - # Build graph - self._graph = await self._build_graph(research_mode) - - # Execute graph - async for event in self._execute_graph(query, context): - yield event - - async def _run_with_chains( - self, - query: str, - research_mode: Literal["iterative", "deep"], - message_history: list[ModelMessage] | None = None, - ) -> AsyncGenerator[AgentEvent, None]: - """Run workflow using agent chains (backward compatibility). - - Args: - query: The research query - research_mode: The research mode - message_history: Optional user conversation history - - Yields: - AgentEvent objects - """ - if research_mode == "iterative": - yield AgentEvent( - type="searching", - message="Running iterative research flow...", - iteration=1, - ) - - if self._iterative_flow is None: - self._iterative_flow = IterativeResearchFlow( - max_iterations=self.max_iterations, - max_time_minutes=self.max_time_minutes, - judge_handler=self.judge_handler, - oauth_token=self.oauth_token, - ) - - try: - final_report = await self._iterative_flow.run( - query, message_history=message_history - ) - except Exception as e: - self.logger.error("Iterative flow failed", error=str(e), exc_info=True) - # Yield error event - outer handler will also catch and yield error event - yield AgentEvent( - type="error", - message=f"Iterative research failed: {e!s}", - iteration=1, - ) - # Re-raise so outer handler can also yield error event for consistency - raise - - yield AgentEvent( - type="complete", - message=final_report, - data={"mode": "iterative"}, - iteration=1, - ) - - elif research_mode == "deep": - yield AgentEvent( - type="searching", - message="Running deep research flow...", - iteration=1, - ) - - if self._deep_flow is None: - # DeepResearchFlow creates its own judge_handler internally - # The judge_handler is passed to IterativeResearchFlow in parallel loops - self._deep_flow = DeepResearchFlow( - max_iterations=self.max_iterations, - max_time_minutes=self.max_time_minutes, - oauth_token=self.oauth_token, - ) - - try: - final_report = await self._deep_flow.run(query, message_history=message_history) - except Exception as e: - self.logger.error("Deep flow failed", error=str(e), exc_info=True) - # Yield error event before re-raising so test can capture it - yield AgentEvent( - type="error", - message=f"Deep research failed: {e!s}", - iteration=1, - ) - raise - - yield AgentEvent( - type="complete", - message=final_report, - data={"mode": "deep"}, - iteration=1, - ) - - async def _build_graph(self, mode: Literal["iterative", "deep"]) -> ResearchGraph: - """Build graph for the specified mode. - - Args: - mode: Research mode - - Returns: - Constructed ResearchGraph - """ - if mode == "iterative": - # Get agents - pass OAuth token for HuggingFace authentication - knowledge_gap_agent = create_knowledge_gap_agent(oauth_token=self.oauth_token) - tool_selector_agent = create_tool_selector_agent(oauth_token=self.oauth_token) - thinking_agent = create_thinking_agent(oauth_token=self.oauth_token) - writer_agent = create_writer_agent(oauth_token=self.oauth_token) - - # Create graph - graph = create_iterative_graph( - knowledge_gap_agent=knowledge_gap_agent.agent, - tool_selector_agent=tool_selector_agent.agent, - thinking_agent=thinking_agent.agent, - writer_agent=writer_agent.agent, - ) - else: # deep - # Get agents - pass OAuth token for HuggingFace authentication - planner_agent = create_planner_agent(oauth_token=self.oauth_token) - knowledge_gap_agent = create_knowledge_gap_agent(oauth_token=self.oauth_token) - tool_selector_agent = create_tool_selector_agent(oauth_token=self.oauth_token) - thinking_agent = create_thinking_agent(oauth_token=self.oauth_token) - writer_agent = create_writer_agent(oauth_token=self.oauth_token) - long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token) - - # Create graph - graph = create_deep_graph( - planner_agent=planner_agent.agent, - knowledge_gap_agent=knowledge_gap_agent.agent, - tool_selector_agent=tool_selector_agent.agent, - thinking_agent=thinking_agent.agent, - writer_agent=writer_agent.agent, - long_writer_agent=long_writer_agent.agent, - ) - - return graph - - def _emit_start_event( - self, node: Any, current_node_id: str, iteration: int, context: GraphExecutionContext - ) -> AgentEvent: - """Emit start event for a node. - - Args: - node: The node being executed - current_node_id: Current node ID - iteration: Current iteration number - context: Execution context - - Returns: - AgentEvent for the start of node execution - """ - if node and node.node_id == "planner": - return AgentEvent( - type="searching", - message="Creating report plan...", - iteration=iteration, - ) - elif node and node.node_id == "parallel_loops": - # Get report plan to show section count - report_plan = context.get_node_result("planner") - if report_plan and hasattr(report_plan, "report_outline"): - section_count = len(report_plan.report_outline) - return AgentEvent( - type="looping", - message=f"Running parallel research loops for {section_count} sections...", - iteration=iteration, - data={"sections": section_count}, - ) - return AgentEvent( - type="looping", - message="Running parallel research loops...", - iteration=iteration, - ) - elif node and node.node_id == "synthesizer": - return AgentEvent( - type="synthesizing", - message="Synthesizing final report from section drafts...", - iteration=iteration, - ) - return AgentEvent( - type="looping", - message=f"Executing node: {current_node_id}", - iteration=iteration, - ) - - def _emit_completion_event( - self, node: Any, current_node_id: str, result: Any, iteration: int - ) -> AgentEvent: - """Emit completion event for a node. - - Args: - node: The node that was executed - current_node_id: Current node ID - result: Node execution result - iteration: Current iteration number - - Returns: - AgentEvent for the completion of node execution - """ - if not node: - return AgentEvent( - type="looping", - message=f"Completed node: {current_node_id}", - iteration=iteration, - ) - - if node.node_id == "planner": - if isinstance(result, dict) and "report_outline" in result: - section_count = len(result["report_outline"]) - return AgentEvent( - type="search_complete", - message=f"Report plan created with {section_count} sections", - iteration=iteration, - data={"sections": section_count}, - ) - return AgentEvent( - type="search_complete", - message="Report plan created", - iteration=iteration, - ) - elif node.node_id == "parallel_loops": - if isinstance(result, list): - return AgentEvent( - type="search_complete", - message=f"Completed parallel research for {len(result)} sections", - iteration=iteration, - data={"sections_completed": len(result)}, - ) - return AgentEvent( - type="search_complete", - message="Parallel research loops completed", - iteration=iteration, - ) - elif node.node_id == "synthesizer": - return AgentEvent( - type="synthesizing", - message="Final report synthesis completed", - iteration=iteration, - ) - return AgentEvent( - type="searching" if node.node_type == "agent" else "looping", - message=f"Completed {node.node_type} node: {current_node_id}", - iteration=iteration, - ) - - def _get_final_result_from_exit_nodes( - self, context: GraphExecutionContext, current_node_id: str | None - ) -> tuple[Any, str | None]: - """Get final result from exit nodes, prioritizing synthesizer/writer.""" - if not self._graph: - return None, current_node_id - - final_result = None - result_node_id = current_node_id - - # First try to get result from current node (if it's an exit node) - if current_node_id and current_node_id in self._graph.exit_nodes: - final_result = context.get_node_result(current_node_id) - self.logger.debug( - "Final result from current exit node", - node_id=current_node_id, - has_result=final_result is not None, - result_type=type(final_result).__name__ if final_result else None, - ) - - # If no result from current node, check all exit nodes for results - # Prioritize synthesizer (deep research) or writer (iterative research) - if not final_result: - exit_node_priority = ["synthesizer", "writer"] - for exit_node_id in exit_node_priority: - if exit_node_id in self._graph.exit_nodes: - result = context.get_node_result(exit_node_id) - if result: - final_result = result - result_node_id = exit_node_id - self.logger.debug( - "Final result from priority exit node", - node_id=exit_node_id, - result_type=type(final_result).__name__, - ) - break - - # If still no result, check all exit nodes - if not final_result: - for exit_node_id in self._graph.exit_nodes: - result = context.get_node_result(exit_node_id) - if result: - final_result = result - result_node_id = exit_node_id - self.logger.debug( - "Final result from any exit node", - node_id=exit_node_id, - result_type=type(final_result).__name__, - ) - break - - # Log warning if no result found - if not final_result: - self.logger.warning( - "No final result found in exit nodes", - exit_nodes=list(self._graph.exit_nodes), - visited_nodes=list(context.visited_nodes), - all_node_results=list(context.node_results.keys()), - ) - - return final_result, result_node_id - - def _extract_final_message_and_files(self, final_result: Any) -> tuple[str, dict[str, Any]]: - """Extract message and file information from final result.""" - event_data: dict[str, Any] = {"mode": self.mode} - message: str = "Research completed" - - if isinstance(final_result, str): - message = final_result - self.logger.debug("Final message extracted from string result", length=len(message)) - elif isinstance(final_result, dict): - # First check for message key (most important) - if "message" in final_result: - message = final_result["message"] - self.logger.debug( - "Final message extracted from dict 'message' key", - length=len(message) if isinstance(message, str) else 0, - ) - - # Then check for file paths - if "file" in final_result: - file_path = final_result["file"] - if isinstance(file_path, str): - event_data["file"] = file_path - # Only override message if not already set from "message" key - if "message" not in final_result: - message = "Report generated. Download available." - self.logger.debug("File path added to event data", file_path=file_path) - - # Check for multiple files - if "files" in final_result: - files = final_result["files"] - if isinstance(files, list): - event_data["files"] = files - self.logger.debug("Multiple files added to event data", count=len(files)) - - return message, event_data - - async def _execute_graph( - self, query: str, context: GraphExecutionContext - ) -> AsyncGenerator[AgentEvent, None]: - """Execute the graph from entry node. - - Args: - query: The research query - context: Execution context - - Yields: - AgentEvent objects - """ - if not self._graph: - raise ValueError("Graph not built") - - current_node_id = self._graph.entry_node - iteration = 0 - - # Execute nodes until we reach an exit node - while current_node_id: - # Check budget - if not context.budget_tracker.can_continue("graph_execution"): - self.logger.warning("Budget exceeded, exiting graph execution") - break - - # Execute current node - iteration += 1 - context.current_node = current_node_id - node = self._graph.get_node(current_node_id) - - # Emit start event - yield self._emit_start_event(node, current_node_id, iteration, context) - - try: - result = await self._execute_node(current_node_id, query, context) - context.set_node_result(current_node_id, result) - context.mark_visited(current_node_id) - - # Yield completion event - yield self._emit_completion_event(node, current_node_id, result, iteration) - - except Exception as e: - self.logger.error("Node execution failed", node_id=current_node_id, error=str(e)) - yield AgentEvent( - type="error", - message=f"Node {current_node_id} failed: {e!s}", - iteration=iteration, - ) - break - - # Check if current node is an exit node - if so, we're done - if current_node_id in self._graph.exit_nodes: - break - - # Get next node(s) - next_nodes = self._get_next_node(current_node_id, context) - - if not next_nodes: - # No more nodes, we've reached a dead end - self.logger.warning("Reached dead end in graph", node_id=current_node_id) - break - - current_node_id = next_nodes[0] # For now, take first next node (handle parallel later) - - # Final event - get result from exit nodes (prioritize synthesizer/writer nodes) - final_result, result_node_id = self._get_final_result_from_exit_nodes( - context, current_node_id - ) - - # Check if final result contains file information - event_data: dict[str, Any] = {"mode": self.mode, "iterations": iteration} - message, file_event_data = self._extract_final_message_and_files(final_result) - event_data.update(file_event_data) - - yield AgentEvent( - type="complete", - message=message, - data=event_data, - iteration=iteration, - ) - - async def _execute_node(self, node_id: str, query: str, context: GraphExecutionContext) -> Any: - """Execute a single node. - - Args: - node_id: The node ID - query: The research query - context: Execution context - - Returns: - Node execution result - """ - if not self._graph: - raise ValueError("Graph not built") - - node = self._graph.get_node(node_id) - if not node: - raise ValueError(f"Node {node_id} not found") - - if isinstance(node, AgentNode): - return await self._execute_agent_node(node, query, context) - elif isinstance(node, StateNode): - return await self._execute_state_node(node, query, context) - elif isinstance(node, DecisionNode): - return await self._execute_decision_node(node, query, context) - elif isinstance(node, ParallelNode): - return await self._execute_parallel_node(node, query, context) - else: - raise ValueError(f"Unknown node type: {type(node)}") - - async def _execute_synthesizer_node(self, query: str, context: GraphExecutionContext) -> Any: - """Execute synthesizer node for deep research.""" - from src.agent_factory.agents import create_long_writer_agent - from src.utils.models import ReportDraft, ReportDraftSection, ReportPlan - - report_plan = context.get_node_result("planner") - section_drafts = context.get_node_result("parallel_loops") or [] - - if not isinstance(report_plan, ReportPlan): - raise ValueError("ReportPlan not found for synthesizer") - - if not section_drafts: - raise ValueError("Section drafts not found for synthesizer") - - # Create ReportDraft from section drafts - report_draft = ReportDraft( - sections=[ - ReportDraftSection( - section_title=section.title, - section_content=draft, - ) - for section, draft in zip(report_plan.report_outline, section_drafts, strict=False) - ] - ) - - # Get LongWriterAgent instance and call write_report directly - long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token) - final_report = await long_writer_agent.write_report( - original_query=query, - report_title=report_plan.report_title, - report_draft=report_draft, - ) - - # Estimate tokens (rough estimate) - estimated_tokens = len(final_report) // 4 # Rough token estimate - context.budget_tracker.add_tokens("graph_execution", estimated_tokens) - - # Save report to file if enabled (may generate multiple formats) - return self._save_report_and_return_result(final_report, query) - - def _save_report_and_return_result(self, final_report: str, query: str) -> dict[str, Any] | str: - """Save report to file and return result with file paths if available.""" - file_path: str | None = None - pdf_path: str | None = None - try: - file_service = self._get_file_service() - if file_service: - # Use save_report_multiple_formats to get both MD and PDF if enabled - saved_files = file_service.save_report_multiple_formats( - report_content=final_report, - query=query, - ) - file_path = saved_files.get("md") - pdf_path = saved_files.get("pdf") - self.logger.info( - "Report saved to file", - md_path=file_path, - pdf_path=pdf_path, - ) - except Exception as e: - # Don't fail the entire operation if file saving fails - self.logger.warning("Failed to save report to file", error=str(e)) - file_path = None - pdf_path = None - - # Return dict with file paths if available, otherwise return string (backward compatible) - if file_path: - result: dict[str, Any] = { - "message": final_report, - "file": file_path, - } - # Add PDF path if generated - if pdf_path: - result["files"] = [file_path, pdf_path] - return result - return final_report - - async def _execute_writer_node(self, query: str, context: GraphExecutionContext) -> Any: - """Execute writer node for iterative research.""" - from src.agent_factory.agents import create_writer_agent - - # Get all evidence from workflow state and convert to findings string - evidence = context.state.evidence - if evidence: - # Convert evidence to findings format (similar to conversation.get_all_findings()) - findings_parts: list[str] = [] - for ev in evidence: - finding = f"**{ev.citation.title}**\n{ev.content}" - if ev.citation.url: - finding += f"\nSource: {ev.citation.url}" - findings_parts.append(finding) - all_findings = "\n\n".join(findings_parts) - else: - all_findings = "No findings available yet." - - # Get WriterAgent instance and call write_report directly - writer_agent = create_writer_agent(oauth_token=self.oauth_token) - final_report = await writer_agent.write_report( - query=query, - findings=all_findings, - output_length="", - output_instructions="", - ) - - # Estimate tokens (rough estimate) - estimated_tokens = len(final_report) // 4 # Rough token estimate - context.budget_tracker.add_tokens("graph_execution", estimated_tokens) - - # Save report to file if enabled (may generate multiple formats) - return self._save_report_and_return_result(final_report, query) - - def _prepare_agent_input( - self, node: AgentNode, query: str, context: GraphExecutionContext - ) -> Any: - """Prepare input data for agent execution.""" - if node.node_id == "planner": - # Planner takes the original query - input_data = query - else: - # Standard: use previous node result or query - prev_result = context.get_node_result(context.current_node) - input_data = prev_result if prev_result is not None else query - - # Apply input transformer if provided - if node.input_transformer: - input_data = node.input_transformer(input_data) - - return input_data - - async def _execute_standard_agent( - self, node: AgentNode, input_data: Any, query: str, context: GraphExecutionContext - ) -> Any: - """Execute standard agent with error handling and fallback models.""" - # Get message history from context (limit to most recent 10 messages for token efficiency) - message_history = context.get_message_history(max_messages=10) - - # Try with the original agent first - try: - # Pass message_history if available (Pydantic AI agents support this) - if message_history: - result = await node.agent.run(input_data, message_history=message_history) - else: - result = await node.agent.run(input_data) - - # Accumulate new messages from agent result if available - if hasattr(result, "new_messages"): - try: - new_messages = result.new_messages() - for msg in new_messages: - context.add_message(msg) - except Exception as e: - # Don't fail if message accumulation fails - self.logger.debug( - "Failed to accumulate messages from agent result", error=str(e) - ) - return result - except Exception as e: - # Check if we should retry with fallback models - from src.utils.hf_error_handler import ( - extract_error_details, - should_retry_with_fallback, - ) - - error_details = extract_error_details(e) - should_retry = should_retry_with_fallback(e) - - # Handle validation errors and API errors for planner node (with fallback) - if node.node_id == "planner": - if should_retry: - self.logger.warning( - "Planner failed, trying fallback models", - original_error=str(e), - status_code=error_details.get("status_code"), - ) - # Try fallback models for planner - fallback_result = await self._try_fallback_models( - node, input_data, message_history, query, context, e - ) - if fallback_result is not None: - return fallback_result - # If fallback failed or not applicable, use fallback plan - return self._create_fallback_plan(query, input_data) - - # For other nodes, try fallback models if applicable - if should_retry: - self.logger.warning( - "Agent node failed, trying fallback models", - node_id=node.node_id, - original_error=str(e), - status_code=error_details.get("status_code"), - ) - fallback_result = await self._try_fallback_models( - node, input_data, message_history, query, context, e - ) - if fallback_result is not None: - return fallback_result - - # If fallback didn't work or wasn't applicable, re-raise the exception - raise - - async def _try_fallback_models( - self, - node: AgentNode, - input_data: Any, - message_history: list[Any], - query: str, - context: GraphExecutionContext, - original_error: Exception, - ) -> Any | None: - """Try executing agent with fallback models. - - Args: - node: The agent node that failed - input_data: Input data for the agent - message_history: Message history for the agent - query: The research query - context: Execution context - original_error: The original error that triggered fallback - - Returns: - Agent result if successful, None if all fallbacks failed - """ - from src.utils.hf_error_handler import extract_error_details, get_fallback_models - - error_details = extract_error_details(original_error) - original_model = error_details.get("model_name") - fallback_models = get_fallback_models(original_model) - - # Also try models from settings fallback list - from src.utils.config import settings - - settings_fallbacks = settings.get_hf_fallback_models_list() - for model in settings_fallbacks: - if model not in fallback_models: - fallback_models.append(model) - - self.logger.info( - "Trying fallback models", - node_id=node.node_id, - original_model=original_model, - fallback_count=len(fallback_models), - ) - - # Try each fallback model - for fallback_model in fallback_models: - try: - # Recreate agent with fallback model - fallback_agent = self._recreate_agent_with_model(node.node_id, fallback_model) - if fallback_agent is None: - continue - - # Try running with fallback agent - if message_history: - result = await fallback_agent.run(input_data, message_history=message_history) - else: - result = await fallback_agent.run(input_data) - - self.logger.info( - "Fallback model succeeded", - node_id=node.node_id, - fallback_model=fallback_model, - ) - - # Accumulate new messages from agent result if available - if hasattr(result, "new_messages"): - try: - new_messages = result.new_messages() - for msg in new_messages: - context.add_message(msg) - except Exception as e: - self.logger.debug( - "Failed to accumulate messages from fallback agent result", error=str(e) - ) - - return result - - except Exception as e: - self.logger.warning( - "Fallback model failed", - node_id=node.node_id, - fallback_model=fallback_model, - error=str(e), - ) - continue - - # All fallback models failed - self.logger.error( - "All fallback models failed", - node_id=node.node_id, - fallback_count=len(fallback_models), - ) - return None - - def _recreate_agent_with_model(self, node_id: str, model_name: str) -> Any | None: - """Recreate an agent with a specific model. - - Args: - node_id: The node ID (e.g., "thinking", "knowledge_gap") - model_name: The model name to use - - Returns: - Agent instance or None if recreation failed - """ - try: - from pydantic_ai.models.huggingface import HuggingFaceModel - from pydantic_ai.providers.huggingface import HuggingFaceProvider - - # Create model with fallback model name - hf_provider = HuggingFaceProvider(api_key=self.oauth_token) - model = HuggingFaceModel(model_name, provider=hf_provider) - - # Recreate agent based on node_id - if node_id == "thinking": - from src.agent_factory.agents import create_thinking_agent - - agent_wrapper = create_thinking_agent(model=model, oauth_token=self.oauth_token) - return agent_wrapper.agent - elif node_id == "knowledge_gap": - from src.agent_factory.agents import create_knowledge_gap_agent - - agent_wrapper = create_knowledge_gap_agent( # type: ignore[assignment] - model=model, oauth_token=self.oauth_token - ) - return agent_wrapper.agent - elif node_id == "tool_selector": - from src.agent_factory.agents import create_tool_selector_agent - - agent_wrapper = create_tool_selector_agent( # type: ignore[assignment] - model=model, oauth_token=self.oauth_token - ) - return agent_wrapper.agent - elif node_id == "planner": - from src.agent_factory.agents import create_planner_agent - - agent_wrapper = create_planner_agent(model=model, oauth_token=self.oauth_token) # type: ignore[assignment] - return agent_wrapper.agent - elif node_id == "writer": - from src.agent_factory.agents import create_writer_agent - - agent_wrapper = create_writer_agent(model=model, oauth_token=self.oauth_token) # type: ignore[assignment] - return agent_wrapper.agent - else: - self.logger.warning("Unknown node_id for agent recreation", node_id=node_id) - return None - - except Exception as e: - self.logger.error( - "Failed to recreate agent with fallback model", - node_id=node_id, - model_name=model_name, - error=str(e), - ) - return None - - def _create_fallback_plan(self, query: str, input_data: Any) -> Any: - """Create fallback ReportPlan when planner fails.""" - from src.utils.models import ReportPlan, ReportPlanSection - - self.logger.error( - "Planner agent execution failed, using fallback plan", - error_type=type(input_data).__name__, - ) - - # Extract query from input_data if possible - fallback_query = query - if isinstance(input_data, str): - # Try to extract query from input string - if "QUERY:" in input_data: - fallback_query = input_data.split("QUERY:")[-1].strip() - - return ReportPlan( - background_context="", - report_outline=[ - ReportPlanSection( - title="Research Findings", - key_question=fallback_query, - ) - ], - report_title=f"Research Report: {fallback_query[:50]}", - ) - - def _extract_agent_output(self, node: AgentNode, result: Any) -> Any: - """Extract and transform output from agent result.""" - # Defensively extract output - handle various result formats - output = result.output if hasattr(result, "output") else result - - # Handle case where output might be a tuple (from pydantic-ai validation errors) - if isinstance(output, tuple): - output = self._handle_tuple_output(node, output, result) - return output - - def _handle_tuple_output(self, node: AgentNode, output: tuple[Any, ...], result: Any) -> Any: - """Handle tuple output from agent (validation errors).""" - # If tuple contains a dict-like structure, try to reconstruct the object - if len(output) == 2 and isinstance(output[0], str) and output[0] == "research_complete": - # This is likely a validation error format: ('research_complete', False) - # Try to get the actual output from result - self.logger.warning( - "Agent result output is a tuple, attempting to extract actual output", - node_id=node.node_id, - tuple_value=output, - ) - # Try to get output from result attributes - if hasattr(result, "data"): - return result.data - if hasattr(result, "response"): - return result.response - # Last resort: try to reconstruct from tuple - # This shouldn't happen, but handle gracefully - from src.utils.models import KnowledgeGapOutput - - if node.node_id == "knowledge_gap": - # Reconstruct KnowledgeGapOutput from validation error tuple - reconstructed = KnowledgeGapOutput( - research_complete=output[1] if len(output) > 1 else False, - outstanding_gaps=[], - ) - self.logger.info( - "Reconstructed KnowledgeGapOutput from validation error tuple", - node_id=node.node_id, - research_complete=reconstructed.research_complete, - ) - return reconstructed - - # For other nodes, try to extract meaningful output or use fallback - self.logger.warning( - "Agent node output is tuple format, attempting extraction", - node_id=node.node_id, - tuple_value=output, - ) - # Try to extract first meaningful element - if len(output) > 0: - # If first element is a string or dict, might be the actual output - if isinstance(output[0], str | dict): - return output[0] - # Last resort: use first element - return output[0] - # Empty tuple - use None and let downstream handle it - return None - - async def _execute_agent_node( - self, node: AgentNode, query: str, context: GraphExecutionContext - ) -> Any: - """Execute an agent node. - - Special handling for deep research nodes: - - "planner": Takes query string, returns ReportPlan - - "synthesizer": Takes query + ReportPlan + section drafts, returns final report - - Args: - node: The agent node - query: The research query - context: Execution context - - Returns: - Agent execution result - """ - # Special handling for synthesizer node (deep research) - if node.node_id == "synthesizer": - return await self._execute_synthesizer_node(query, context) - - # Special handling for writer node (iterative research) - if node.node_id == "writer": - return await self._execute_writer_node(query, context) - - # Standard agent execution - input_data = self._prepare_agent_input(node, query, context) - result = await self._execute_standard_agent(node, input_data, query, context) - output = self._extract_agent_output(node, result) - - if node.output_transformer: - output = node.output_transformer(output) - - # Estimate and track tokens - if hasattr(result, "usage") and result.usage: - tokens = result.usage.total_tokens if hasattr(result.usage, "total_tokens") else 0 - context.budget_tracker.add_tokens("graph_execution", tokens) - - # Special handling for knowledge_gap node: optionally call judge_handler - if node.node_id == "knowledge_gap" and self.judge_handler: - # Get evidence from workflow state - evidence = context.state.evidence - if evidence: - try: - from src.utils.models import JudgeAssessment - - # Call judge handler to assess evidence - judge_assessment: JudgeAssessment = await self.judge_handler.assess( - question=query, evidence=evidence - ) - # Store assessment in context for decision node to use - context.set_node_result("judge_assessment", judge_assessment) - self.logger.info( - "Judge assessment completed", - sufficient=judge_assessment.sufficient, - confidence=judge_assessment.confidence, - recommendation=judge_assessment.recommendation, - ) - except Exception as e: - self.logger.warning( - "Judge handler assessment failed", - error=str(e), - node_id=node.node_id, - ) - # Continue without judge assessment - - return output - - async def _execute_state_node( - self, node: StateNode, query: str, context: GraphExecutionContext - ) -> Any: - """Execute a state node. - - Special handling for deep research state nodes: - - "store_plan": Stores ReportPlan in context for parallel loops - - "collect_drafts": Stores section drafts in context for synthesizer - - "execute_tools": Executes search using search_handler - - Args: - node: The state node - query: The research query - context: Execution context - - Returns: - State update result - """ - # Special handling for execute_tools node - if node.node_id == "execute_tools": - # Get AgentSelectionPlan from tool_selector node result - tool_selector_result = context.get_node_result("tool_selector") - from src.utils.models import AgentSelectionPlan, SearchResult - - # Extract query from context or use original query - search_query = query - if tool_selector_result and isinstance(tool_selector_result, AgentSelectionPlan): - # Use the gap or query from the selection plan - if tool_selector_result.tasks: - # Use the first task's query if available - first_task = tool_selector_result.tasks[0] - if hasattr(first_task, "query") and first_task.query: - search_query = first_task.query - elif hasattr(first_task, "tool_input") and isinstance( - first_task.tool_input, str - ): - search_query = first_task.tool_input - - # Execute search using search_handler - if self.search_handler: - try: - search_result: SearchResult = await self.search_handler.execute( - query=search_query, max_results_per_tool=10 - ) - # Add evidence to workflow state (add_evidence expects a list) - context.state.add_evidence(search_result.evidence) - # Store evidence list in context for next nodes - context.set_node_result(node.node_id, search_result.evidence) - self.logger.info( - "Tools executed via search_handler", - query=search_query[:100], - evidence_count=len(search_result.evidence), - ) - return search_result.evidence - except Exception as e: - self.logger.error( - "Search handler execution failed", - error=str(e), - query=search_query[:100], - ) - # Return empty list on error to allow graph to continue - return [] - else: - # Fallback: log warning and return empty list - self.logger.warning( - "Search handler not available for execute_tools node", - node_id=node.node_id, - ) - return [] - - # Get previous result for state update - # For "store_plan", get from planner node - # For "collect_drafts", get from parallel_loops node - if node.node_id == "store_plan": - prev_result = context.get_node_result("planner") - elif node.node_id == "collect_drafts": - prev_result = context.get_node_result("parallel_loops") - else: - prev_result = context.get_node_result(context.current_node) - - # Update state - updated_state = node.state_updater(context.state, prev_result) - context.state = updated_state - - # Store result in context for next nodes to access - context.set_node_result(node.node_id, prev_result) - - # Read state if needed - if node.state_reader: - return node.state_reader(context.state) - - return prev_result # Return the stored result for next nodes - - async def _execute_decision_node( - self, node: DecisionNode, query: str, context: GraphExecutionContext - ) -> str: - """Execute a decision node. - - Args: - node: The decision node - query: The research query - context: Execution context - - Returns: - Next node ID - """ - # Get previous result for decision - # The decision node needs the result from the node that connects to it - # Find the previous node by searching edges - prev_node_id: str | None = None - if self._graph: - # Find which node connects to this decision node - for from_node, edge_list in self._graph.edges.items(): - for edge in edge_list: - if edge.to_node == node.node_id: - prev_node_id = from_node - break - if prev_node_id: - break - - # Fallback: For continue_decision, it always comes from knowledge_gap - if not prev_node_id and node.node_id == "continue_decision": - prev_node_id = "knowledge_gap" - - # Get result from previous node (or current node if no previous found) - if prev_node_id: - prev_result = context.get_node_result(prev_node_id) - else: - # Fallback: try to get from visited nodes (last visited before current) - visited_list = list(context.visited_nodes) - if len(visited_list) > 0: - prev_node_id = visited_list[-1] - prev_result = context.get_node_result(prev_node_id) - else: - prev_result = context.get_node_result(context.current_node) - - # Handle case where result might be a tuple (from pydantic-ai validation errors) - # Extract the actual result object if it's a tuple - if isinstance(prev_result, tuple) and len(prev_result) > 0: - # Check if first element is a KnowledgeGapOutput-like object - if hasattr(prev_result[0], "research_complete"): - prev_result = prev_result[0] - elif len(prev_result) > 1 and hasattr(prev_result[1], "research_complete"): - prev_result = prev_result[1] - elif ( - len(prev_result) == 2 - and isinstance(prev_result[0], str) - and prev_result[0] == "research_complete" - ): - # Handle validation error format: ('research_complete', False) - # Reconstruct KnowledgeGapOutput from tuple - from src.utils.models import KnowledgeGapOutput - - self.logger.warning( - "Decision node received validation error tuple, reconstructing KnowledgeGapOutput", - node_id=node.node_id, - tuple_value=prev_result, - ) - prev_result = KnowledgeGapOutput( - research_complete=prev_result[1] if len(prev_result) > 1 else False, - outstanding_gaps=[], - ) - else: - # If tuple doesn't contain the object, try to reconstruct or use fallback - self.logger.warning( - "Decision node received unexpected tuple format, attempting reconstruction", - node_id=node.node_id, - tuple_length=len(prev_result), - tuple_types=[type(x).__name__ for x in prev_result], - ) - # Try to reconstruct KnowledgeGapOutput if this is from knowledge_gap node - if prev_node_id == "knowledge_gap": - from src.utils.models import KnowledgeGapOutput - - # Try to extract research_complete from tuple - research_complete = False - for item in prev_result: - if isinstance(item, bool): - research_complete = item - break - elif isinstance(item, dict) and "research_complete" in item: - research_complete = item["research_complete"] - break - prev_result = KnowledgeGapOutput( - research_complete=research_complete, - outstanding_gaps=[], - ) - else: - # For other nodes, use first element as fallback - prev_result = prev_result[0] - - # Make decision - try: - next_node_id = node.decision_function(prev_result) - except Exception as e: - self.logger.error( - "Decision function failed", - node_id=node.node_id, - error=str(e), - prev_result_type=type(prev_result).__name__, - ) - # Default to first option on error - next_node_id = node.options[0] - - # Validate decision - if next_node_id not in node.options: - self.logger.warning( - "Decision function returned invalid node", - node_id=node.node_id, - returned=next_node_id, - options=node.options, - ) - # Default to first option - next_node_id = node.options[0] - - return next_node_id - - async def _execute_parallel_node( - self, node: ParallelNode, query: str, context: GraphExecutionContext - ) -> list[Any]: - """Execute a parallel node. - - Special handling for deep research "parallel_loops" node: - - Extracts report plan from previous node result - - Creates IterativeResearchFlow instances for each section - - Executes them in parallel - - Returns section drafts - - Args: - node: The parallel node - query: The research query - context: Execution context - - Returns: - List of results from parallel nodes - """ - # Special handling for deep research parallel_loops node - if node.node_id == "parallel_loops": - return await self._execute_deep_research_parallel_loops(node, query, context) - - # Standard parallel node execution - # Execute all parallel nodes concurrently - tasks = [ - self._execute_node(parallel_node_id, query, context) - for parallel_node_id in node.parallel_nodes - ] - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Handle exceptions - for i, result in enumerate(results): - if isinstance(result, Exception): - self.logger.error( - "Parallel node execution failed", - node_id=node.parallel_nodes[i] if i < len(node.parallel_nodes) else "unknown", - error=str(result), - ) - results[i] = None - - # Aggregate if needed - if node.aggregator: - aggregated = node.aggregator(results) - # Type cast: aggregator returns Any, but we expect list[Any] - return list(aggregated) if isinstance(aggregated, list) else [aggregated] - - return results - - async def _execute_deep_research_parallel_loops( - self, node: ParallelNode, query: str, context: GraphExecutionContext - ) -> list[str]: - """Execute parallel iterative research loops for deep research. - - Args: - node: The parallel node (should be "parallel_loops") - query: The research query - context: Execution context - - Returns: - List of section draft strings - """ - from src.agent_factory.judges import create_judge_handler - from src.orchestrator.research_flow import IterativeResearchFlow - from src.utils.models import ReportPlan - - # Get report plan from previous node (store_plan) - # The plan should be stored in context.node_results from the planner node - planner_result = context.get_node_result("planner") - if not isinstance(planner_result, ReportPlan): - self.logger.error( - "Planner result is not a ReportPlan", - type=type(planner_result), - ) - raise ValueError("Planner must return ReportPlan for deep research") - - report_plan: ReportPlan = planner_result - self.logger.info( - "Executing parallel loops for deep research", - sections=len(report_plan.report_outline), - ) - - # Use judge handler from GraphOrchestrator if available, otherwise create new one - judge_handler = self.judge_handler - if judge_handler is None: - judge_handler = create_judge_handler() - - # Create and execute iterative research flows for each section - async def run_section_research(section_index: int) -> str: - """Run iterative research for a single section.""" - section = report_plan.report_outline[section_index] - - try: - # Create iterative research flow - flow = IterativeResearchFlow( - max_iterations=self.max_iterations, - max_time_minutes=self.max_time_minutes, - verbose=False, # Less verbose in parallel execution - use_graph=False, # Use agent chains for section research - judge_handler=self.judge_handler or judge_handler, - oauth_token=self.oauth_token, - ) - - # Run research for this section - section_draft = await flow.run( - query=section.key_question, - background_context=report_plan.background_context, - ) - - self.logger.info( - "Section research completed", - section_index=section_index, - section_title=section.title, - draft_length=len(section_draft), - ) - - return section_draft - - except Exception as e: - self.logger.error( - "Section research failed", - section_index=section_index, - section_title=section.title, - error=str(e), - ) - # Return empty string for failed sections - return f"# {section.title}\n\n[Research failed: {e!s}]" - - # Execute all sections in parallel - section_drafts = await asyncio.gather( - *(run_section_research(i) for i in range(len(report_plan.report_outline))), - return_exceptions=True, - ) - - # Handle exceptions and filter None results - filtered_drafts: list[str] = [] - for i, draft in enumerate(section_drafts): - if isinstance(draft, Exception): - self.logger.error( - "Section research exception", - section_index=i, - error=str(draft), - ) - filtered_drafts.append( - f"# {report_plan.report_outline[i].title}\n\n[Research failed: {draft!s}]" - ) - elif draft is not None: - # Type narrowing: after Exception check, draft is str | None - assert isinstance(draft, str), "Expected str after Exception check" - filtered_drafts.append(draft) - - self.logger.info( - "Parallel loops completed", - sections=len(filtered_drafts), - total_sections=len(report_plan.report_outline), - ) - - return filtered_drafts - - def _get_next_node(self, node_id: str, context: GraphExecutionContext) -> list[str]: - """Get next node(s) from current node. - - Args: - node_id: Current node ID - context: Execution context - - Returns: - List of next node IDs - """ - if not self._graph: - return [] - - # Get node result for condition evaluation - node_result = context.get_node_result(node_id) - - # Get next nodes - next_nodes = self._graph.get_next_nodes(node_id, context=node_result) - - # If this was a decision node, use its result - node = self._graph.get_node(node_id) - if isinstance(node, DecisionNode): - decision_result = node_result - if isinstance(decision_result, str): - return [decision_result] - - # Return next node IDs - return [next_node_id for next_node_id, _ in next_nodes] - - async def _detect_research_mode(self, query: str) -> Literal["iterative", "deep"]: - """ - Detect research mode from query using input parser agent. - - Uses input parser agent to analyze query and determine research mode. - Falls back to heuristic if parser fails. - - Args: - query: The research query - - Returns: - Detected research mode - """ - try: - # Use input parser agent for intelligent mode detection - input_parser = create_input_parser_agent(oauth_token=self.oauth_token) - parsed_query = await input_parser.parse(query) - self.logger.info( - "Research mode detected by input parser", - mode=parsed_query.research_mode, - query=query[:100], - ) - return parsed_query.research_mode - except Exception as e: - # Fallback to heuristic if parser fails - self.logger.warning( - "Input parser failed, using heuristic", - error=str(e), - query=query[:100], - ) - query_lower = query.lower() - if any( - keyword in query_lower - for keyword in [ - "section", - "sections", - "report", - "outline", - "structure", - "comprehensive", - "analyze", - "analysis", - ] - ): - return "deep" - return "iterative" - - -def create_graph_orchestrator( - mode: Literal["iterative", "deep", "auto"] = "auto", - max_iterations: int = 5, - max_time_minutes: int = 10, - use_graph: bool = True, - search_handler: SearchHandlerProtocol | None = None, - judge_handler: JudgeHandlerProtocol | None = None, - oauth_token: str | None = None, -) -> GraphOrchestrator: - """ - Factory function to create a graph orchestrator. - - Args: - mode: Research mode - max_iterations: Maximum iterations per loop - max_time_minutes: Maximum time per loop - use_graph: Whether to use graph execution (True) or agent chains (False) - search_handler: Optional search handler for tool execution - judge_handler: Optional judge handler for evidence assessment - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured GraphOrchestrator instance - """ - return GraphOrchestrator( - mode=mode, - max_iterations=max_iterations, - max_time_minutes=max_time_minutes, - use_graph=use_graph, - search_handler=search_handler, - judge_handler=judge_handler, - oauth_token=oauth_token, - ) diff --git a/src/orchestrator/planner_agent.py b/src/orchestrator/planner_agent.py deleted file mode 100644 index 410c19ea682d150d17738c15605768e2484b24e6..0000000000000000000000000000000000000000 --- a/src/orchestrator/planner_agent.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Planner agent for creating report plans with sections and background context. - -Converts the folder/planner_agent.py implementation to use Pydantic AI. -""" - -from datetime import datetime -from typing import Any - -import structlog -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.tools.crawl_adapter import crawl_website -from src.tools.web_search_adapter import web_search -from src.utils.exceptions import ConfigurationError, JudgeError -from src.utils.models import ReportPlan, ReportPlanSection - -logger = structlog.get_logger() - - -# System prompt for the planner agent -SYSTEM_PROMPT = f""" -You are a research manager, managing a team of research agents. Today's date is {datetime.now().strftime("%Y-%m-%d")}. -Given a research query, your job is to produce an initial outline of the report (section titles and key questions), -as well as some background context. Each section will be assigned to a different researcher in your team who will then -carry out research on the section. - -You will be given: -- An initial research query - -Your task is to: -1. Produce 1-2 paragraphs of initial background context (if needed) on the query by running web searches or crawling websites -2. Produce an outline of the report that includes a list of section titles and the key question to be addressed in each section -3. Provide a title for the report that will be used as the main heading - -Guidelines: -- Each section should cover a single topic/question that is independent of other sections -- The key question for each section should include both the NAME and DOMAIN NAME / WEBSITE (if available and applicable) if it is related to a company, product or similar -- The background_context should not be more than 2 paragraphs -- The background_context should be very specific to the query and include any information that is relevant for researchers across all sections of the report -- The background_context should be drawn only from web search or crawl results rather than prior knowledge (i.e. it should only be included if you have called tools) -- For example, if the query is about a company, the background context should include some basic information about what the company does -- DO NOT do more than 2 tool calls - -Only output JSON. Follow the JSON schema for ReportPlan. Do not output anything else. -""" - - -class PlannerAgent: - """ - Planner agent that creates report plans with sections and background context. - - Uses Pydantic AI to generate structured ReportPlan output with optional - web search and crawl tool usage for background context. - """ - - def __init__( - self, - model: Any | None = None, - web_search_tool: Any | None = None, - crawl_tool: Any | None = None, - ) -> None: - """ - Initialize the planner agent. - - Args: - model: Optional Pydantic AI model. If None, uses config default. - web_search_tool: Optional web search tool function. If None, uses default. - crawl_tool: Optional crawl tool function. If None, uses default. - """ - self.model = model or get_model() - self.web_search_tool = web_search_tool or web_search - self.crawl_tool = crawl_tool or crawl_website - self.logger = logger - - # Validate tools are callable - if not callable(self.web_search_tool): - raise ConfigurationError("web_search_tool must be callable") - if not callable(self.crawl_tool): - raise ConfigurationError("crawl_tool must be callable") - - # Initialize Pydantic AI Agent - self.agent = Agent( - model=self.model, - output_type=ReportPlan, - system_prompt=SYSTEM_PROMPT, - tools=[self.web_search_tool, self.crawl_tool], - retries=3, - ) - - async def run(self, query: str) -> ReportPlan: - """ - Run the planner agent to generate a report plan. - - Args: - query: The user's research query - - Returns: - ReportPlan with sections, background context, and report title - - Raises: - JudgeError: If planning fails after retries - ConfigurationError: If agent configuration is invalid - """ - self.logger.info("Starting report planning", query=query[:100]) - - user_message = f"QUERY: {query}" - - try: - # Run the agent - result = await self.agent.run(user_message) - report_plan = result.output - - # Validate report plan - if not report_plan.report_outline: - self.logger.warning("Report plan has no sections", query=query[:100]) - # Return fallback plan instead of raising error - return ReportPlan( - background_context=report_plan.background_context or "", - report_outline=[ - ReportPlanSection( - title="Overview", - key_question=query, - ) - ], - report_title=report_plan.report_title or f"Research Report: {query[:50]}", - ) - - if not report_plan.report_title: - self.logger.warning("Report plan has no title", query=query[:100]) - raise JudgeError("Report plan must have a title") - - self.logger.info( - "Report plan created", - sections=len(report_plan.report_outline), - has_background=bool(report_plan.background_context), - ) - - return report_plan - - except Exception as e: - self.logger.error("Planning failed", error=str(e), query=query[:100]) - - # Fallback: return minimal report plan - if isinstance(e, JudgeError | ConfigurationError): - raise - - # For other errors, return a minimal plan - return ReportPlan( - background_context="", - report_outline=[ - ReportPlanSection( - title="Research Findings", - key_question=query, - ) - ], - report_title=f"Research Report: {query[:50]}", - ) - - -def create_planner_agent(model: Any | None = None, oauth_token: str | None = None) -> PlannerAgent: - """ - Factory function to create a planner agent. - - Args: - model: Optional Pydantic AI model. If None, uses settings default. - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured PlannerAgent instance - - Raises: - ConfigurationError: If required API keys are missing - """ - try: - # Get model from settings if not provided - if model is None: - model = get_model(oauth_token=oauth_token) - - # Create and return planner agent - return PlannerAgent(model=model) - - except Exception as e: - logger.error("Failed to create planner agent", error=str(e)) - raise ConfigurationError(f"Failed to create planner agent: {e}") from e diff --git a/src/orchestrator/research_flow.py b/src/orchestrator/research_flow.py deleted file mode 100644 index 3ae58be138e570aa16bd10c04f39ec380a18e288..0000000000000000000000000000000000000000 --- a/src/orchestrator/research_flow.py +++ /dev/null @@ -1,1111 +0,0 @@ -"""Research flow implementations for iterative and deep research patterns. - -Converts the folder/iterative_research.py and folder/deep_research.py -implementations to use Pydantic AI agents. -""" - -import asyncio -import time -from typing import Any - -import structlog - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] - -from src.agent_factory.agents import ( - create_graph_orchestrator, - create_knowledge_gap_agent, - create_long_writer_agent, - create_planner_agent, - create_proofreader_agent, - create_thinking_agent, - create_tool_selector_agent, - create_writer_agent, -) -from src.agent_factory.judges import create_judge_handler -from src.middleware.budget_tracker import BudgetTracker -from src.middleware.state_machine import get_workflow_state, init_workflow_state -from src.middleware.workflow_manager import WorkflowManager -from src.services.llamaindex_rag import LlamaIndexRAGService, get_rag_service -from src.services.report_file_service import ReportFileService, get_report_file_service -from src.tools.tool_executor import execute_tool_tasks -from src.utils.exceptions import ConfigurationError -from src.utils.models import ( - AgentSelectionPlan, - AgentTask, - Citation, - Conversation, - Evidence, - JudgeAssessment, - KnowledgeGapOutput, - ReportDraft, - ReportDraftSection, - ReportPlan, - SourceName, - ToolAgentOutput, -) - -logger = structlog.get_logger() - - -class IterativeResearchFlow: - """ - Iterative research flow that runs a single research loop. - - Pattern: Generate observations → Evaluate gaps → Select tools → Execute → Repeat - until research is complete or constraints are met. - """ - - def __init__( - self, - max_iterations: int = 5, - max_time_minutes: int = 10, - verbose: bool = True, - use_graph: bool = False, - judge_handler: Any | None = None, - oauth_token: str | None = None, - ) -> None: - """ - Initialize iterative research flow. - - Args: - max_iterations: Maximum number of iterations - max_time_minutes: Maximum time in minutes - verbose: Whether to log progress - use_graph: Whether to use graph-based execution (True) or agent chains (False) - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - """ - self.max_iterations = max_iterations - self.max_time_minutes = max_time_minutes - self.verbose = verbose - self.use_graph = use_graph - self.oauth_token = oauth_token - self.logger = logger - - # Initialize agents (only needed for agent chain execution) - if not use_graph: - self.knowledge_gap_agent = create_knowledge_gap_agent(oauth_token=self.oauth_token) - self.tool_selector_agent = create_tool_selector_agent(oauth_token=self.oauth_token) - self.thinking_agent = create_thinking_agent(oauth_token=self.oauth_token) - self.writer_agent = create_writer_agent(oauth_token=self.oauth_token) - # Initialize judge handler (use provided or create new) - self.judge_handler = judge_handler or create_judge_handler() - - # Initialize state (only needed for agent chain execution) - if not use_graph: - self.conversation = Conversation() - self.iteration = 0 - self.start_time: float | None = None - self.should_continue = True - - # Initialize budget tracker - self.budget_tracker = BudgetTracker() - self.loop_id = "iterative_flow" - self.budget_tracker.create_budget( - loop_id=self.loop_id, - tokens_limit=100000, - time_limit_seconds=max_time_minutes * 60, - iterations_limit=max_iterations, - ) - self.budget_tracker.start_timer(self.loop_id) - - # Initialize RAG service (lazy, may be None if unavailable) - self._rag_service: LlamaIndexRAGService | None = None - - # Graph orchestrator (lazy initialization) - self._graph_orchestrator: Any = None - - # File service (lazy initialization) - self._file_service: ReportFileService | None = None - - def _get_file_service(self) -> ReportFileService | None: - """ - Get file service instance (lazy initialization). - - Returns: - ReportFileService instance or None if disabled - """ - if self._file_service is None: - try: - self._file_service = get_report_file_service() - except Exception as e: - self.logger.warning("Failed to initialize file service", error=str(e)) - return None - return self._file_service - - async def run( - self, - query: str, - background_context: str = "", - output_length: str = "", - output_instructions: str = "", - message_history: list[ModelMessage] | None = None, - ) -> str: - """ - Run the iterative research flow. - - Args: - query: The research query - background_context: Optional background context - output_length: Optional description of desired output length - output_instructions: Optional additional instructions - message_history: Optional user conversation history - - Returns: - Final report string - """ - if self.use_graph: - return await self._run_with_graph( - query, background_context, output_length, output_instructions, message_history - ) - else: - return await self._run_with_chains( - query, background_context, output_length, output_instructions, message_history - ) - - async def _run_with_chains( - self, - query: str, - background_context: str = "", - output_length: str = "", - output_instructions: str = "", - message_history: list[ModelMessage] | None = None, - ) -> str: - """ - Run the iterative research flow using agent chains. - - Args: - query: The research query - background_context: Optional background context - output_length: Optional description of desired output length - output_instructions: Optional additional instructions - message_history: Optional user conversation history - - Returns: - Final report string - """ - self.start_time = time.time() - self.logger.info("Starting iterative research (agent chains)", query=query[:100]) - - # Initialize conversation with first iteration - self.conversation.add_iteration() - - # Main research loop - while self.should_continue and self._check_constraints(): - self.iteration += 1 - self.logger.info("Starting iteration", iteration=self.iteration) - - # Add new iteration to conversation - self.conversation.add_iteration() - - # 1. Generate observations - await self._generate_observations(query, background_context, message_history) - - # 2. Evaluate gaps - evaluation = await self._evaluate_gaps(query, background_context, message_history) - - # 3. Assess with judge (after tools execute, we'll assess again) - # For now, check knowledge gap evaluation - # After tool execution, we'll do a full judge assessment - - # Check if research is complete (knowledge gap agent says complete) - if evaluation.research_complete: - self.should_continue = False - self.logger.info("Research marked as complete by knowledge gap agent") - break - - # 4. Select tools for next gap - next_gap = evaluation.outstanding_gaps[0] if evaluation.outstanding_gaps else query - selection_plan = await self._select_agents( - next_gap, query, background_context, message_history - ) - - # 5. Execute tools - await self._execute_tools(selection_plan.tasks) - - # 6. Assess evidence sufficiency with judge - judge_assessment = await self._assess_with_judge(query) - - # Check if judge says evidence is sufficient - if judge_assessment.sufficient: - self.should_continue = False - self.logger.info( - "Research marked as complete by judge", - confidence=judge_assessment.confidence, - reasoning=judge_assessment.reasoning[:100], - ) - break - - # Update budget tracker - self.budget_tracker.increment_iteration(self.loop_id) - self.budget_tracker.update_timer(self.loop_id) - - # Create final report - report = await self._create_final_report(query, output_length, output_instructions) - - elapsed = time.time() - (self.start_time or time.time()) - self.logger.info( - "Iterative research completed", - iterations=self.iteration, - elapsed_minutes=elapsed / 60, - ) - - return report - - async def _run_with_graph( - self, - query: str, - background_context: str = "", - output_length: str = "", - output_instructions: str = "", - message_history: list[ModelMessage] | None = None, - ) -> str: - """ - Run the iterative research flow using graph execution. - - Args: - query: The research query - background_context: Optional background context (currently ignored in graph execution) - output_length: Optional description of desired output length (currently ignored in graph execution) - output_instructions: Optional additional instructions (currently ignored in graph execution) - - Returns: - Final report string - """ - self.logger.info("Starting iterative research (graph execution)", query=query[:100]) - - # Create graph orchestrator (lazy initialization) - if self._graph_orchestrator is None: - self._graph_orchestrator = create_graph_orchestrator( - mode="iterative", - max_iterations=self.max_iterations, - max_time_minutes=self.max_time_minutes, - use_graph=True, - ) - - # Run orchestrator and collect events - final_report = "" - async for event in self._graph_orchestrator.run(query): - if event.type == "complete": - final_report = event.message - break - elif event.type == "error": - self.logger.error("Graph execution error", error=event.message) - raise RuntimeError(f"Graph execution failed: {event.message}") - - if not final_report: - self.logger.warning("No complete event received from graph orchestrator") - final_report = "Research completed but no report was generated." - - self.logger.info("Iterative research completed (graph execution)") - - return final_report - - def _check_constraints(self) -> bool: - """Check if we've exceeded constraints.""" - if self.iteration >= self.max_iterations: - self.logger.info("Max iterations reached", max=self.max_iterations) - return False - - if self.start_time: - elapsed_minutes = (time.time() - self.start_time) / 60 - if elapsed_minutes >= self.max_time_minutes: - self.logger.info("Max time reached", max=self.max_time_minutes) - return False - - # Check budget tracker - self.budget_tracker.update_timer(self.loop_id) - exceeded, reason = self.budget_tracker.check_budget(self.loop_id) - if exceeded: - self.logger.info("Budget exceeded", reason=reason) - return False - - return True - - async def _generate_observations( - self, - query: str, - background_context: str = "", - message_history: list[ModelMessage] | None = None, - ) -> str: - """Generate observations from current research state.""" - # Build input prompt for token estimation - conversation_history = self.conversation.compile_conversation_history() - # Build background context section separately to avoid backslash in f-string - background_section = ( - f"BACKGROUND CONTEXT:\n{background_context}\n\n" if background_context else "" - ) - input_prompt = f""" -You are starting iteration {self.iteration} of your research process. - -ORIGINAL QUERY: -{query} - -{background_section}HISTORY OF ACTIONS, FINDINGS AND THOUGHTS: -{conversation_history or "No previous actions, findings or thoughts available."} -""" - - observations = await self.thinking_agent.generate_observations( - query=query, - background_context=background_context, - conversation_history=conversation_history, - message_history=message_history, - iteration=self.iteration, - ) - - # Track tokens for this iteration - estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, observations) - self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens) - self.logger.debug( - "Tokens tracked for thinking agent", - iteration=self.iteration, - tokens=estimated_tokens, - ) - - self.conversation.set_latest_thought(observations) - return observations - - async def _evaluate_gaps( - self, - query: str, - background_context: str = "", - message_history: list[ModelMessage] | None = None, - ) -> KnowledgeGapOutput: - """Evaluate knowledge gaps in current research.""" - if self.start_time: - elapsed_minutes = (time.time() - self.start_time) / 60 - else: - elapsed_minutes = 0.0 - - # Build input prompt for token estimation - conversation_history = self.conversation.compile_conversation_history() - background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else "" - input_prompt = f""" -Current Iteration Number: {self.iteration} -Time Elapsed: {elapsed_minutes:.2f} minutes of maximum {self.max_time_minutes} minutes - -ORIGINAL QUERY: -{query} - -{background} - -HISTORY OF ACTIONS, FINDINGS AND THOUGHTS: -{conversation_history or "No previous actions, findings or thoughts available."} -""" - - evaluation = await self.knowledge_gap_agent.evaluate( - query=query, - background_context=background_context, - conversation_history=conversation_history, - message_history=message_history, - iteration=self.iteration, - time_elapsed_minutes=elapsed_minutes, - max_time_minutes=self.max_time_minutes, - ) - - # Track tokens for this iteration - evaluation_text = f"research_complete={evaluation.research_complete}, gaps={len(evaluation.outstanding_gaps)}" - estimated_tokens = self.budget_tracker.estimate_llm_call_tokens( - input_prompt, evaluation_text - ) - self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens) - self.logger.debug( - "Tokens tracked for knowledge gap agent", - iteration=self.iteration, - tokens=estimated_tokens, - ) - - if not evaluation.research_complete and evaluation.outstanding_gaps: - self.conversation.set_latest_gap(evaluation.outstanding_gaps[0]) - - return evaluation - - async def _assess_with_judge(self, query: str) -> JudgeAssessment: - """Assess evidence sufficiency using JudgeHandler. - - Args: - query: The research query - - Returns: - JudgeAssessment with sufficiency evaluation - """ - state = get_workflow_state() - evidence = state.evidence # Get all collected evidence - - self.logger.info( - "Assessing evidence with judge", - query=query[:100], - evidence_count=len(evidence), - ) - - assessment = await self.judge_handler.assess(query, evidence) - - # Track tokens for judge call - # Estimate tokens from query + evidence + assessment - evidence_text = "\n".join([e.content[:500] for e in evidence[:10]]) # Sample - estimated_tokens = self.budget_tracker.estimate_llm_call_tokens( - query + evidence_text, str(assessment.reasoning) - ) - self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens) - - self.logger.info( - "Judge assessment complete", - sufficient=assessment.sufficient, - confidence=assessment.confidence, - recommendation=assessment.recommendation, - ) - - return assessment - - async def _select_agents( - self, - gap: str, - query: str, - background_context: str = "", - message_history: list[ModelMessage] | None = None, - ) -> AgentSelectionPlan: - """Select tools to address knowledge gap.""" - # Build input prompt for token estimation - conversation_history = self.conversation.compile_conversation_history() - background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else "" - input_prompt = f""" -ORIGINAL QUERY: -{query} - -KNOWLEDGE GAP TO ADDRESS: -{gap} - -{background} - -HISTORY OF ACTIONS, FINDINGS AND THOUGHTS: -{conversation_history or "No previous actions, findings or thoughts available."} -""" - - selection_plan = await self.tool_selector_agent.select_tools( - gap=gap, - query=query, - background_context=background_context, - conversation_history=conversation_history, - message_history=message_history, - ) - - # Track tokens for this iteration - selection_text = f"tasks={len(selection_plan.tasks)}, agents={[task.agent for task in selection_plan.tasks]}" - estimated_tokens = self.budget_tracker.estimate_llm_call_tokens( - input_prompt, selection_text - ) - self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens) - self.logger.debug( - "Tokens tracked for tool selector agent", - iteration=self.iteration, - tokens=estimated_tokens, - ) - - # Store tool calls in conversation - tool_calls = [ - f"[Agent] {task.agent} [Query] {task.query} [Entity] {task.entity_website or 'null'}" - for task in selection_plan.tasks - ] - self.conversation.set_latest_tool_calls(tool_calls) - - return selection_plan - - def _get_rag_service(self) -> LlamaIndexRAGService | None: - """ - Get or create RAG service instance. - - Returns: - RAG service instance, or None if unavailable - """ - if self._rag_service is None: - try: - self._rag_service = get_rag_service(oauth_token=self.oauth_token) - self.logger.info("RAG service initialized for research flow") - except (ConfigurationError, ImportError) as e: - self.logger.warning( - "RAG service unavailable", error=str(e), hint="OPENAI_API_KEY required" - ) - return None - return self._rag_service - - async def _execute_tools(self, tasks: list[AgentTask]) -> dict[str, ToolAgentOutput]: - """Execute selected tools concurrently.""" - try: - results = await execute_tool_tasks(tasks) - except Exception as e: - # Handle tool execution errors gracefully - self.logger.error( - "Tool execution failed", - error=str(e), - task_count=len(tasks), - exc_info=True, - ) - # Return empty results to allow research flow to continue - # The flow can still generate a report based on previous iterations - results = {} - - # Store findings in conversation (only if we have results) - evidence_list: list[Evidence] = [] - if results: - findings = [result.output for result in results.values()] - self.conversation.set_latest_findings(findings) - - # Convert tool outputs to Evidence objects and store in workflow state - evidence_list = self._convert_tool_outputs_to_evidence(results) - - if evidence_list: - state = get_workflow_state() - added_count = state.add_evidence(evidence_list) - self.logger.info( - "Evidence added to workflow state", - count=added_count, - total_evidence=len(state.evidence), - ) - - # Ingest evidence into RAG if available (Phase 6 requirement) - rag_service = self._get_rag_service() - if rag_service is not None: - try: - # ingest_evidence is synchronous, run in executor to avoid blocking - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, rag_service.ingest_evidence, evidence_list) - self.logger.info( - "Evidence ingested into RAG", - count=len(evidence_list), - ) - except Exception as e: - # Don't fail the research loop if RAG ingestion fails - self.logger.warning( - "Failed to ingest evidence into RAG", - error=str(e), - count=len(evidence_list), - ) - - return results - - def _convert_tool_outputs_to_evidence( - self, tool_results: dict[str, ToolAgentOutput] - ) -> list[Evidence]: - """Convert ToolAgentOutput to Evidence objects. - - Args: - tool_results: Dictionary of tool execution results - - Returns: - List of Evidence objects - """ - evidence_list = [] - for key, result in tool_results.items(): - # Extract URLs from sources - if result.sources: - # Create one Evidence object per source URL - for url in result.sources: - # Determine source type from URL or tool name - # Default to "web" for unknown web sources - source_type: SourceName = "web" - if "pubmed" in url.lower() or "ncbi" in url.lower(): - source_type = "pubmed" - elif "clinicaltrials" in url.lower(): - source_type = "clinicaltrials" - elif "europepmc" in url.lower(): - source_type = "europepmc" - elif "biorxiv" in url.lower(): - source_type = "biorxiv" - elif "arxiv" in url.lower() or "preprint" in url.lower(): - source_type = "preprint" - # Note: "web" is now a valid SourceName for general web sources - - citation = Citation( - title=f"Tool Result: {key}", - url=url, - source=source_type, - date="n.d.", - authors=[], - ) - # Truncate content to reasonable length for judge (1500 chars) - content = result.output[:1500] - if len(result.output) > 1500: - content += "... [truncated]" - - evidence = Evidence( - content=content, - citation=citation, - relevance=0.5, # Default relevance - ) - evidence_list.append(evidence) - else: - # No URLs, create a single Evidence object with tool output - # Use a placeholder URL based on the tool name - # Determine source type from tool name - tool_source_type: SourceName = "web" # Default for unknown sources - if "RAG" in key: - tool_source_type = "rag" - elif "WebSearch" in key or "SiteCrawler" in key: - tool_source_type = "web" - # "web" is now a valid SourceName for general web sources - - citation = Citation( - title=f"Tool Result: {key}", - url=f"tool://{key}", - source=tool_source_type, - date="n.d.", - authors=[], - ) - content = result.output[:1500] - if len(result.output) > 1500: - content += "... [truncated]" - - evidence = Evidence( - content=content, - citation=citation, - relevance=0.5, - ) - evidence_list.append(evidence) - - return evidence_list - - async def _create_final_report( - self, query: str, length: str = "", instructions: str = "" - ) -> str: - """Create final report from all findings.""" - all_findings = "\n\n".join(self.conversation.get_all_findings()) - if not all_findings: - all_findings = "No findings available yet." - - # Build input prompt for token estimation - length_str = f"* The full response should be approximately {length}.\n" if length else "" - instructions_str = f"* {instructions}" if instructions else "" - guidelines_str = ( - ("\n\nGUIDELINES:\n" + length_str + instructions_str).strip("\n") - if length or instructions - else "" - ) - input_prompt = f""" -Provide a response based on the query and findings below with as much detail as possible. {guidelines_str} - -QUERY: {query} - -FINDINGS: -{all_findings} -""" - - report = await self.writer_agent.write_report( - query=query, - findings=all_findings, - output_length=length, - output_instructions=instructions, - ) - - # Track tokens for final report (not per iteration, just total) - estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, report) - self.budget_tracker.add_tokens(self.loop_id, estimated_tokens) - self.logger.debug( - "Tokens tracked for writer agent (final report)", - tokens=estimated_tokens, - ) - - # Save report to file if enabled - try: - file_service = self._get_file_service() - if file_service: - file_path = file_service.save_report( - report_content=report, - query=query, - ) - self.logger.info("Report saved to file", file_path=file_path) - except Exception as e: - # Don't fail the entire operation if file saving fails - self.logger.warning("Failed to save report to file", error=str(e)) - - # Note: Citation validation for markdown reports would require Evidence objects - # Currently, findings are strings, not Evidence objects. For full validation, - # consider using ResearchReport format or passing Evidence objects separately. - # See src/utils/citation_validator.py for markdown citation validation utilities. - - return report - - -class DeepResearchFlow: - """ - Deep research flow that runs parallel iterative loops per section. - - Pattern: Plan → Parallel Iterative Loops (one per section) → Synthesis - """ - - def __init__( - self, - max_iterations: int = 5, - max_time_minutes: int = 10, - verbose: bool = True, - use_long_writer: bool = True, - use_graph: bool = False, - oauth_token: str | None = None, - ) -> None: - """ - Initialize deep research flow. - - Args: - max_iterations: Maximum iterations per section - max_time_minutes: Maximum time per section - verbose: Whether to log progress - use_long_writer: Whether to use long writer (True) or proofreader (False) - use_graph: Whether to use graph-based execution (True) or agent chains (False) - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - """ - self.max_iterations = max_iterations - self.max_time_minutes = max_time_minutes - self.verbose = verbose - self.use_long_writer = use_long_writer - self.use_graph = use_graph - self.oauth_token = oauth_token - self.logger = logger - - # Initialize agents (only needed for agent chain execution) - if not use_graph: - self.planner_agent = create_planner_agent(oauth_token=self.oauth_token) - self.long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token) - self.proofreader_agent = create_proofreader_agent(oauth_token=self.oauth_token) - # Initialize judge handler for section loop completion - self.judge_handler = create_judge_handler() - # Initialize budget tracker for token tracking - self.budget_tracker = BudgetTracker() - self.loop_id = "deep_research_flow" - self.budget_tracker.create_budget( - loop_id=self.loop_id, - tokens_limit=200000, # Higher limit for deep research - time_limit_seconds=max_time_minutes - * 60 - * 2, # Allow more time for parallel sections - iterations_limit=max_iterations * 10, # Allow for multiple sections - ) - self.budget_tracker.start_timer(self.loop_id) - - # Graph orchestrator (lazy initialization) - self._graph_orchestrator: Any = None - - # File service (lazy initialization) - self._file_service: ReportFileService | None = None - - def _get_file_service(self) -> ReportFileService | None: - """ - Get file service instance (lazy initialization). - - Returns: - ReportFileService instance or None if disabled - """ - if self._file_service is None: - try: - self._file_service = get_report_file_service() - except Exception as e: - self.logger.warning("Failed to initialize file service", error=str(e)) - return None - return self._file_service - - async def run(self, query: str, message_history: list[ModelMessage] | None = None) -> str: - """ - Run the deep research flow. - - Args: - query: The research query - message_history: Optional user conversation history - - Returns: - Final report string - """ - if self.use_graph: - return await self._run_with_graph(query, message_history) - else: - return await self._run_with_chains(query, message_history) - - async def _run_with_chains( - self, query: str, message_history: list[ModelMessage] | None = None - ) -> str: - """ - Run the deep research flow using agent chains. - - Args: - query: The research query - message_history: Optional user conversation history - - Returns: - Final report string - """ - self.logger.info("Starting deep research (agent chains)", query=query[:100]) - - # Initialize workflow state for deep research - try: - from src.services.embeddings import get_embedding_service - - embedding_service = get_embedding_service() - except (ImportError, Exception): - # If embedding service is unavailable, initialize without it - embedding_service = None - self.logger.debug("Embedding service unavailable, initializing state without it") - - init_workflow_state(embedding_service=embedding_service, message_history=message_history) - self.logger.debug("Workflow state initialized for deep research") - - # 1. Build report plan - report_plan = await self._build_report_plan(query, message_history) - self.logger.info( - "Report plan created", - sections=len(report_plan.report_outline), - title=report_plan.report_title, - ) - - # 2. Run parallel research loops with state synchronization - section_drafts = await self._run_research_loops(report_plan, message_history) - - # Verify state synchronization - log evidence count - state = get_workflow_state() - self.logger.info( - "State synchronization complete", - total_evidence=len(state.evidence), - sections_completed=len(section_drafts), - ) - - # 3. Create final report - final_report = await self._create_final_report(query, report_plan, section_drafts) - - self.logger.info( - "Deep research completed", - sections=len(section_drafts), - final_report_length=len(final_report), - ) - - return final_report - - async def _run_with_graph( - self, query: str, message_history: list[ModelMessage] | None = None - ) -> str: - """ - Run the deep research flow using graph execution. - - Args: - query: The research query - message_history: Optional user conversation history - - Returns: - Final report string - """ - self.logger.info("Starting deep research (graph execution)", query=query[:100]) - - # Create graph orchestrator (lazy initialization) - if self._graph_orchestrator is None: - self._graph_orchestrator = create_graph_orchestrator( - mode="deep", - max_iterations=self.max_iterations, - max_time_minutes=self.max_time_minutes, - use_graph=True, - ) - - # Run orchestrator and collect events - final_report = "" - async for event in self._graph_orchestrator.run(query, message_history=message_history): - if event.type == "complete": - final_report = event.message - break - elif event.type == "error": - self.logger.error("Graph execution error", error=event.message) - raise RuntimeError(f"Graph execution failed: {event.message}") - - if not final_report: - self.logger.warning("No complete event received from graph orchestrator") - final_report = "Research completed but no report was generated." - - self.logger.info("Deep research completed (graph execution)") - - return final_report - - async def _build_report_plan( - self, query: str, message_history: list[ModelMessage] | None = None - ) -> ReportPlan: - """Build the initial report plan.""" - self.logger.info("Building report plan") - - # Build input prompt for token estimation - input_prompt = f"QUERY: {query}" - - # Planner agent may not support message_history yet, so we'll pass it if available - # For now, just use the standard run() call - report_plan = await self.planner_agent.run(query) - - # Track tokens for planner agent - if not self.use_graph and hasattr(self, "budget_tracker"): - plan_text = ( - f"title={report_plan.report_title}, sections={len(report_plan.report_outline)}" - ) - estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, plan_text) - self.budget_tracker.add_tokens(self.loop_id, estimated_tokens) - self.logger.debug( - "Tokens tracked for planner agent", - tokens=estimated_tokens, - ) - - self.logger.info( - "Report plan created", - sections=len(report_plan.report_outline), - has_background=bool(report_plan.background_context), - ) - - return report_plan - - async def _run_research_loops( - self, report_plan: ReportPlan, message_history: list[ModelMessage] | None = None - ) -> list[str]: - """Run parallel iterative research loops for each section.""" - self.logger.info("Running research loops", sections=len(report_plan.report_outline)) - - # Create workflow manager for parallel execution - workflow_manager = WorkflowManager() - - # Create loop configurations - loop_configs = [ - { - "loop_id": f"section_{i}", - "query": section.key_question, - "section_title": section.title, - "background_context": report_plan.background_context, - } - for i, section in enumerate(report_plan.report_outline) - ] - - async def run_research_for_section(config: dict[str, Any]) -> str: - """Run iterative research for a single section.""" - loop_id = config.get("loop_id", "unknown") - query = config.get("query", "") - background_context = config.get("background_context", "") - - try: - # Update loop status - await workflow_manager.update_loop_status(loop_id, "running") - - # Create iterative research flow - flow = IterativeResearchFlow( - max_iterations=self.max_iterations, - max_time_minutes=self.max_time_minutes, - verbose=self.verbose, - use_graph=self.use_graph, - judge_handler=self.judge_handler if not self.use_graph else None, - ) - - # Run research with message_history - result = await flow.run( - query=query, - background_context=background_context, - message_history=message_history, - ) - - # Sync evidence from flow to loop - state = get_workflow_state() - if state.evidence: - await workflow_manager.add_loop_evidence(loop_id, state.evidence) - - # Update loop status - await workflow_manager.update_loop_status(loop_id, "completed") - - return result - - except Exception as e: - error_msg = str(e) - await workflow_manager.update_loop_status(loop_id, "failed", error=error_msg) - self.logger.error( - "Section research failed", - loop_id=loop_id, - error=error_msg, - ) - raise - - # Run all sections in parallel using workflow manager - section_drafts = await workflow_manager.run_loops_parallel( - loop_configs=loop_configs, - loop_func=run_research_for_section, - judge_handler=self.judge_handler if not self.use_graph else None, - budget_tracker=self.budget_tracker if not self.use_graph else None, - ) - - # Sync evidence from all loops to global state - for config in loop_configs: - loop_id = config.get("loop_id") - if loop_id: - await workflow_manager.sync_loop_evidence_to_state(loop_id) - - # Filter out None results (failed loops) - section_drafts = [draft for draft in section_drafts if draft is not None] - - self.logger.info( - "Research loops completed", - drafts=len(section_drafts), - total_sections=len(report_plan.report_outline), - ) - - return section_drafts - - async def _create_final_report( - self, query: str, report_plan: ReportPlan, section_drafts: list[str] - ) -> str: - """Create final report from section drafts.""" - self.logger.info("Creating final report") - - # Create ReportDraft from section drafts - report_draft = ReportDraft( - sections=[ - ReportDraftSection( - section_title=section.title, - section_content=draft, - ) - for section, draft in zip(report_plan.report_outline, section_drafts, strict=False) - ] - ) - - # Build input prompt for token estimation - draft_text = "\n".join( - [s.section_content[:500] for s in report_draft.sections[:5]] - ) # Sample - input_prompt = f"QUERY: {query}\nTITLE: {report_plan.report_title}\nDRAFT: {draft_text}" - - if self.use_long_writer: - # Use long writer agent - final_report = await self.long_writer_agent.write_report( - original_query=query, - report_title=report_plan.report_title, - report_draft=report_draft, - ) - else: - # Use proofreader agent - final_report = await self.proofreader_agent.proofread( - query=query, - report_draft=report_draft, - ) - - # Track tokens for final report synthesis - if not self.use_graph and hasattr(self, "budget_tracker"): - estimated_tokens = self.budget_tracker.estimate_llm_call_tokens( - input_prompt, final_report - ) - self.budget_tracker.add_tokens(self.loop_id, estimated_tokens) - self.logger.debug( - "Tokens tracked for final report synthesis", - tokens=estimated_tokens, - agent="long_writer" if self.use_long_writer else "proofreader", - ) - - # Save report to file if enabled - try: - file_service = self._get_file_service() - if file_service: - file_path = file_service.save_report( - report_content=final_report, - query=query, - ) - self.logger.info("Report saved to file", file_path=file_path) - except Exception as e: - # Don't fail the entire operation if file saving fails - self.logger.warning("Failed to save report to file", error=str(e)) - - self.logger.info("Final report created", length=len(final_report)) - - return final_report diff --git a/src/orchestrator_factory.py b/src/orchestrator_factory.py deleted file mode 100644 index 44f8b93a5cd75a600a040258ab3da24f0ea9d2f6..0000000000000000000000000000000000000000 --- a/src/orchestrator_factory.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Factory for creating orchestrators.""" - -from typing import Any, Literal - -import structlog - -from src.legacy_orchestrator import ( - JudgeHandlerProtocol, - Orchestrator, - SearchHandlerProtocol, -) -from src.utils.config import settings -from src.utils.models import OrchestratorConfig - -logger = structlog.get_logger() - - -def _get_magentic_orchestrator_class() -> Any: - """Import MagenticOrchestrator lazily to avoid hard dependency.""" - try: - from src.orchestrator_magentic import MagenticOrchestrator - - return MagenticOrchestrator - except ImportError as e: - logger.error("Failed to import MagenticOrchestrator", error=str(e)) - raise ValueError( - "Advanced mode requires agent-framework-core. Please install it or use mode='simple'." - ) from e - - -def _get_graph_orchestrator_factory() -> Any: - """Import create_graph_orchestrator lazily to avoid circular dependencies.""" - try: - from src.orchestrator.graph_orchestrator import create_graph_orchestrator - - return create_graph_orchestrator - except ImportError as e: - logger.error("Failed to import create_graph_orchestrator", error=str(e)) - raise ValueError( - "Graph orchestrators require Pydantic Graph. Please check dependencies." - ) from e - - -def create_orchestrator( - search_handler: SearchHandlerProtocol | None = None, - judge_handler: JudgeHandlerProtocol | None = None, - config: OrchestratorConfig | None = None, - mode: Literal["simple", "magentic", "advanced", "iterative", "deep", "auto"] | None = None, - oauth_token: str | None = None, -) -> Any: - """ - Create an orchestrator instance. - - Args: - search_handler: The search handler (required for simple mode) - judge_handler: The judge handler (required for simple mode) - config: Optional configuration - mode: Orchestrator mode - "simple", "advanced", "iterative", "deep", "auto", or None (auto-detect) - - "simple": Linear search-judge loop (Free Tier) - - "advanced": Multi-agent coordination (Requires OpenAI) - - "iterative": Knowledge-gap-driven research (Free Tier) - - "deep": Parallel section-based research (Free Tier) - - "auto": Intelligent mode detection (Free Tier) - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Orchestrator instance - """ - effective_mode = _determine_mode(mode) - logger.info("Creating orchestrator", mode=effective_mode) - - if effective_mode == "advanced": - orchestrator_cls = _get_magentic_orchestrator_class() - return orchestrator_cls( - max_rounds=config.max_iterations if config else 10, - ) - - # Graph-based orchestrators (iterative, deep, auto) - if effective_mode in ("iterative", "deep", "auto"): - create_graph_orchestrator = _get_graph_orchestrator_factory() - return create_graph_orchestrator( - mode=effective_mode, # type: ignore[arg-type] - max_iterations=config.max_iterations if config else 5, - max_time_minutes=10, - use_graph=True, - search_handler=search_handler, - judge_handler=judge_handler, - oauth_token=oauth_token, - ) - - # Simple mode requires handlers - if search_handler is None or judge_handler is None: - raise ValueError("Simple mode requires search_handler and judge_handler") - - return Orchestrator( - search_handler=search_handler, - judge_handler=judge_handler, - config=config, - ) - - -def _determine_mode(explicit_mode: str | None) -> str: - """Determine which mode to use.""" - if explicit_mode: - if explicit_mode in ("magentic", "advanced"): - return "advanced" - if explicit_mode in ("iterative", "deep", "auto"): - return explicit_mode - return "simple" - - # Auto-detect: advanced if paid API key available, otherwise simple - if settings.has_openai_key: - return "advanced" - - return "simple" diff --git a/src/orchestrator_hierarchical.py b/src/orchestrator_hierarchical.py deleted file mode 100644 index a7bfb85adc874c60bd7ab84c49e0990cc2f1a620..0000000000000000000000000000000000000000 --- a/src/orchestrator_hierarchical.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Hierarchical orchestrator using middleware and sub-teams.""" - -import asyncio -from collections.abc import AsyncGenerator -from typing import Any - -import structlog - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] - -from src.agents.judge_agent_llm import LLMSubIterationJudge -from src.agents.magentic_agents import create_search_agent -from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam -from src.services.embeddings import get_embedding_service -from src.state import init_magentic_state -from src.utils.models import AgentEvent - -logger = structlog.get_logger() - - -class ResearchTeam(SubIterationTeam): - """Adapts Magentic ChatAgent to SubIterationTeam protocol.""" - - def __init__(self) -> None: - self.agent = create_search_agent() - - async def execute(self, task: str) -> str: - response = await self.agent.run(task) - if response.messages: - for msg in reversed(response.messages): - if msg.role == "assistant" and msg.text: - return str(msg.text) - return "No response from agent." - - -class HierarchicalOrchestrator: - """Orchestrator that uses hierarchical teams and sub-iterations.""" - - def __init__(self) -> None: - self.team = ResearchTeam() - self.judge = LLMSubIterationJudge() - self.middleware = SubIterationMiddleware(self.team, self.judge, max_iterations=5) - - async def run( - self, query: str, message_history: list[ModelMessage] | None = None - ) -> AsyncGenerator[AgentEvent, None]: - logger.info( - "Starting hierarchical orchestrator", - query=query, - has_history=bool(message_history), - ) - - try: - service = get_embedding_service() - init_magentic_state(service) - except Exception as e: - logger.warning( - "Embedding service initialization failed, using default state", - error=str(e), - ) - init_magentic_state() - - yield AgentEvent(type="started", message=f"Starting research: {query}") - - queue: asyncio.Queue[AgentEvent | None] = asyncio.Queue() - - async def event_callback(event: AgentEvent) -> None: - await queue.put(event) - - # Note: middleware.run() may not support message_history yet - # Pass query for now, message_history can be added to middleware later if needed - task_future = asyncio.create_task(self.middleware.run(query, event_callback)) - - while not task_future.done(): - get_event = asyncio.create_task(queue.get()) - done, _ = await asyncio.wait( - {task_future, get_event}, return_when=asyncio.FIRST_COMPLETED - ) - - if get_event in done: - event = get_event.result() - if event: - yield event - else: - get_event.cancel() - - # Process remaining events - while not queue.empty(): - ev = queue.get_nowait() - if ev: - yield ev - - try: - result, assessment = await task_future - - assessment_text = assessment.reasoning if assessment else "None" - yield AgentEvent( - type="complete", - message=( - f"Research complete.\n\nResult:\n{result}\n\nAssessment:\n{assessment_text}" - ), - data={"assessment": assessment.model_dump() if assessment else None}, - ) - except Exception as e: - logger.error("Orchestrator failed", error=str(e)) - yield AgentEvent(type="error", message=f"Orchestrator failed: {e}") diff --git a/src/orchestrator_magentic.py b/src/orchestrator_magentic.py deleted file mode 100644 index 416d896830c42c7e1ff5a1304db3eb06cb00dc4b..0000000000000000000000000000000000000000 --- a/src/orchestrator_magentic.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Magentic-based orchestrator using ChatAgent pattern.""" - -from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any - -import structlog - -try: - from pydantic_ai import ModelMessage -except ImportError: - ModelMessage = Any # type: ignore[assignment, misc] -from agent_framework import ( - MagenticAgentDeltaEvent, - MagenticAgentMessageEvent, - MagenticBuilder, - MagenticFinalResultEvent, - MagenticOrchestratorMessageEvent, - WorkflowOutputEvent, -) - -from src.agents.magentic_agents import ( - create_hypothesis_agent, - create_judge_agent, - create_report_agent, - create_search_agent, -) -from src.agents.state import init_magentic_state -from src.utils.llm_factory import check_magentic_requirements, get_chat_client_for_agent -from src.utils.models import AgentEvent - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - -logger = structlog.get_logger() - - -class MagenticOrchestrator: - """ - Magentic-based orchestrator using ChatAgent pattern. - - Each agent has an internal LLM that understands natural language - instructions from the manager and can call tools appropriately. - """ - - def __init__( - self, - max_rounds: int = 10, - chat_client: Any | None = None, - ) -> None: - """Initialize orchestrator. - - Args: - max_rounds: Maximum coordination rounds - chat_client: Optional shared chat client for agents. - If None, uses factory default (HuggingFace preferred, OpenAI fallback) - """ - # Validate requirements via centralized factory - check_magentic_requirements() - - self._max_rounds = max_rounds - self._chat_client = chat_client - - def _init_embedding_service(self) -> "EmbeddingService | None": - """Initialize embedding service if available.""" - try: - from src.services.embeddings import get_embedding_service - - service = get_embedding_service() - logger.info("Embedding service enabled") - return service - except ImportError: - logger.info("Embedding service not available (dependencies missing)") - except Exception as e: - logger.warning("Failed to initialize embedding service", error=str(e)) - return None - - def _build_workflow(self) -> Any: - """Build the Magentic workflow with ChatAgent participants.""" - # Create agents with internal LLMs - search_agent = create_search_agent(self._chat_client) - judge_agent = create_judge_agent(self._chat_client) - hypothesis_agent = create_hypothesis_agent(self._chat_client) - report_agent = create_report_agent(self._chat_client) - - # Manager chat client (orchestrates the agents) - # Use same client type as agents for consistency - manager_client = self._chat_client or get_chat_client_for_agent() - - return ( - MagenticBuilder() - .participants( - searcher=search_agent, - hypothesizer=hypothesis_agent, - judge=judge_agent, - reporter=report_agent, - ) - .with_standard_manager( - chat_client=manager_client, - max_round_count=self._max_rounds, - max_stall_count=3, - max_reset_count=2, - ) - .build() - ) - - async def run( - self, query: str, message_history: list[ModelMessage] | None = None - ) -> AsyncGenerator[AgentEvent, None]: - """ - Run the Magentic workflow. - - Args: - query: User's research question - message_history: Optional user conversation history (for compatibility) - - Yields: - AgentEvent objects for real-time UI updates - """ - logger.info( - "Starting Magentic orchestrator", - query=query, - has_history=bool(message_history), - ) - - yield AgentEvent( - type="started", - message=f"Starting research (Magentic mode): {query}", - iteration=0, - ) - - # Initialize context state - embedding_service = self._init_embedding_service() - init_magentic_state(embedding_service) - - workflow = self._build_workflow() - - # Include conversation history context if provided - history_context = "" - if message_history: - # Convert message history to string context for task - from src.utils.message_history import message_history_to_string - - history_str = message_history_to_string(message_history, max_messages=5) - if history_str: - history_context = f"\n\nPrevious conversation context:\n{history_str}" - - task = f"""Research query: {query}{history_context} - -Workflow: -1. SearchAgent: Find evidence from available sources (automatically selects: web search, PubMed, ClinicalTrials.gov, Europe PMC, or RAG based on query) -2. HypothesisAgent: Generate research hypotheses and questions based on evidence -3. JudgeAgent: Evaluate if evidence is sufficient to answer the query precisely -4. If insufficient -> SearchAgent refines search based on identified gaps -5. If sufficient -> ReportAgent synthesizes final comprehensive report - -Focus on: -- Finding precise answers to the research question -- Identifying all relevant evidence from appropriate sources -- Understanding mechanisms, relationships, and key findings -- Synthesizing comprehensive findings with proper citations - -The DETERMINATOR stops at nothing until finding precise answers, only stopping at configured limits (budget, time, iterations). - -The final output should be a structured research report with comprehensive evidence synthesis.""" - - iteration = 0 - try: - async for event in workflow.run_stream(task): - agent_event = self._process_event(event, iteration) - if agent_event: - if isinstance(event, MagenticAgentMessageEvent): - iteration += 1 - yield agent_event - - except Exception as e: - logger.error("Magentic workflow failed", error=str(e)) - yield AgentEvent( - type="error", - message=f"Workflow error: {e!s}", - iteration=iteration, - ) - - def _extract_text(self, message: Any) -> str: - """ - Defensively extract text from a message object. - - Fixes bug where message.text might return the object itself or its repr. - """ - if not message: - return "" - - # Priority 1: .content (often the raw string or list of content) - if hasattr(message, "content") and message.content: - content = message.content - # If it's a list (e.g., Multi-modal), join text parts - if isinstance(content, list): - return " ".join([str(c.text) for c in content if hasattr(c, "text")]) - return str(content) - - # Priority 2: .text (standard, but sometimes buggy/missing) - if hasattr(message, "text") and message.text: - # Verify it's not the object itself or a repr string - text = str(message.text) - if text.startswith("<") and "object at" in text: - # Likely a repr string, ignore if possible - pass - else: - return text - - # Fallback: If we can't find clean text, return str(message) - # taking care to avoid infinite recursion if str() calls .text - return str(message) - - def _process_event(self, event: Any, iteration: int) -> AgentEvent | None: - """Process workflow event into AgentEvent.""" - if isinstance(event, MagenticOrchestratorMessageEvent): - text = self._extract_text(event.message) - if text: - return AgentEvent( - type="judging", - message=f"Manager ({event.kind}): {text[:200]}...", - iteration=iteration, - ) - - elif isinstance(event, MagenticAgentMessageEvent): - agent_name = event.agent_id or "unknown" - text = self._extract_text(event.message) - - event_type = "judging" - if "search" in agent_name.lower(): - event_type = "search_complete" - elif "judge" in agent_name.lower(): - event_type = "judge_complete" - elif "hypothes" in agent_name.lower(): - event_type = "hypothesizing" - elif "report" in agent_name.lower(): - event_type = "synthesizing" - - return AgentEvent( - type=event_type, # type: ignore[arg-type] - message=f"{agent_name}: {text[:200]}...", - iteration=iteration + 1, - ) - - elif isinstance(event, MagenticFinalResultEvent): - text = self._extract_text(event.message) if event.message else "No result" - return AgentEvent( - type="complete", - message=text, - data={"iterations": iteration}, - iteration=iteration, - ) - - elif isinstance(event, MagenticAgentDeltaEvent): - if event.text: - return AgentEvent( - type="streaming", - message=event.text, - data={"agent_id": event.agent_id}, - iteration=iteration, - ) - - elif isinstance(event, WorkflowOutputEvent): - if event.data: - return AgentEvent( - type="complete", - message=str(event.data), - iteration=iteration, - ) - - return None diff --git a/src/prompts/__init__.py b/src/prompts/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/prompts/hypothesis.py b/src/prompts/hypothesis.py deleted file mode 100644 index becfc0a300639d81f538c8a001111a2fdde62574..0000000000000000000000000000000000000000 --- a/src/prompts/hypothesis.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Prompts for Hypothesis Agent.""" - -from typing import TYPE_CHECKING - -from src.utils.text_utils import select_diverse_evidence, truncate_at_sentence - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - from src.utils.models import Evidence - -SYSTEM_PROMPT = """You are an expert research scientist functioning as a generalist research assistant. - -Your role is to generate research hypotheses, questions, and investigation paths based on evidence from any domain. - -IMPORTANT: You are a research assistant. You cannot provide medical advice or answer medical questions directly. Your hypotheses are for research investigation purposes only. - -A good hypothesis: -1. Proposes a MECHANISM or RELATIONSHIP: Explains how things work or relate - - For medical: Drug -> Target -> Pathway -> Effect - - For technical: Technology -> Mechanism -> Outcome - - For business: Strategy -> Market -> Result -2. Is TESTABLE: Can be supported or refuted by further research -3. Is SPECIFIC: Names actual entities, processes, or mechanisms -4. Generates SEARCH QUERIES: Helps find more evidence - -Example hypothesis formats: -- Medical: "Metformin -> AMPK activation -> mTOR inhibition -> autophagy -> amyloid clearance" -- Technical: "Transformer architecture -> attention mechanism -> improved NLP performance" -- Business: "Subscription model -> recurring revenue -> higher valuation" - -Be specific. Use actual names, technical terms, and precise language when possible.""" - - -async def format_hypothesis_prompt( - query: str, evidence: list["Evidence"], embeddings: "EmbeddingService | None" = None -) -> str: - """Format prompt for hypothesis generation. - - Uses smart evidence selection instead of arbitrary truncation. - - Args: - query: The research query - evidence: All collected evidence - embeddings: Optional EmbeddingService for diverse selection - """ - # Select diverse, relevant evidence (not arbitrary first 10) - # We use n=10 as a reasonable context window limit - selected = await select_diverse_evidence(evidence, n=10, query=query, embeddings=embeddings) - - # Format with sentence-aware truncation - evidence_text = "\n".join( - [ - f"- **{e.citation.title}** ({e.citation.source}): " - f"{truncate_at_sentence(e.content, 300)}" - for e in selected - ] - ) - - return f"""Based on the following evidence about "{query}", generate research hypotheses and investigation paths. - -## Evidence ({len(selected)} sources selected for diversity) -{evidence_text} - -## Task -1. Identify key mechanisms, relationships, or processes mentioned in the evidence -2. Propose testable hypotheses explaining how things work or relate -3. Rate confidence based on evidence strength -4. Suggest specific search queries to test each hypothesis - -Generate 2-4 hypotheses, prioritized by confidence. Adapt the hypothesis format to the domain of the query (medical, technical, business, etc.).""" diff --git a/src/prompts/judge.py b/src/prompts/judge.py deleted file mode 100644 index 9f1ad78480c6c9ea53fdb8b8a6e18c2cd7c1d6bf..0000000000000000000000000000000000000000 --- a/src/prompts/judge.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Judge prompts for evidence assessment.""" - -from src.utils.models import Evidence - -SYSTEM_PROMPT = """You are an expert research evidence evaluator for a generalist deep research agent. - -Your task is to evaluate evidence from any domain (medical, scientific, technical, business, etc.) and determine if sufficient evidence has been gathered to provide a precise answer to the research question. - -IMPORTANT: You are a research assistant. You cannot provide medical advice or answer medical questions directly. Your role is to assess whether enough high-quality evidence has been collected to synthesize comprehensive findings. - -## Evaluation Criteria - -1. **Mechanism/Explanation Score (0-10)**: How well does the evidence explain the underlying mechanism, process, or concept? - - For medical queries: biological mechanisms, pathways, drug actions - - For technical queries: how systems work, algorithms, processes - - For business queries: market dynamics, business models, strategies - - 0-3: No clear explanation, speculative - - 4-6: Some insight, but gaps exist - - 7-10: Clear, well-supported explanation - -2. **Evidence Quality Score (0-10)**: Strength and reliability of the evidence? - - For medical: clinical trials, peer-reviewed studies, meta-analyses - - For technical: peer-reviewed papers, authoritative sources, verified implementations - - For business: market reports, financial data, expert analysis - - 0-3: Weak or theoretical evidence only - - 4-6: Moderate quality evidence - - 7-10: Strong, authoritative evidence - -3. **Sufficiency**: Evidence is sufficient when: - - Combined scores >= 12 AND - - Key questions from the research query are addressed AND - - Evidence is comprehensive enough to provide a precise answer - -## Output Rules - -- Always output valid JSON matching the schema -- Be conservative: only recommend "synthesize" when truly confident the answer is precise -- If continuing, suggest specific, actionable search queries to fill gaps -- Never hallucinate findings, names, or facts not in the evidence -- Adapt evaluation criteria to the domain of the query (medical vs technical vs business) -""" - - -def format_user_prompt(question: str, evidence: list[Evidence]) -> str: - """ - Format the user prompt with question and evidence. - - Args: - question: The user's research question - evidence: List of Evidence objects from search - - Returns: - Formatted prompt string - """ - max_content_len = 1500 - - def format_single_evidence(i: int, e: Evidence) -> str: - content = e.content - if len(content) > max_content_len: - content = content[:max_content_len] + "..." - - return ( - f"### Evidence {i + 1}\n" - f"**Source**: {e.citation.source.upper()} - {e.citation.title}\n" - f"**URL**: {e.citation.url}\n" - f"**Date**: {e.citation.date}\n" - f"**Content**:\n{content}" - ) - - evidence_text = "\n\n".join([format_single_evidence(i, e) for i, e in enumerate(evidence)]) - - return f"""## Research Question -{question} - -## Available Evidence ({len(evidence)} sources) - -{evidence_text} - -## Your Task - -Evaluate this evidence and determine if it's sufficient to synthesize research findings. Consider the quality, quantity, and relevance of the evidence collected. -Respond with a JSON object matching the JudgeAssessment schema. -""" - - -def format_empty_evidence_prompt(question: str) -> str: - """ - Format prompt when no evidence was found. - - Args: - question: The user's research question - - Returns: - Formatted prompt string - """ - return f"""## Research Question -{question} - -## Available Evidence - -No evidence was found from the search. - -## Your Task - -Since no evidence was found, recommend search queries that might yield better results. -Set sufficient=False and recommendation=\"continue\". -Suggest 3-5 specific search queries. -""" diff --git a/src/prompts/report.py b/src/prompts/report.py deleted file mode 100644 index 41257d65f9d7daa0681bd41a6a304580f731c994..0000000000000000000000000000000000000000 --- a/src/prompts/report.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Prompts for Report Agent.""" - -from typing import TYPE_CHECKING, Any - -from src.utils.text_utils import select_diverse_evidence, truncate_at_sentence - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - from src.utils.models import Evidence, MechanismHypothesis - -SYSTEM_PROMPT = """You are a scientific writer functioning as a medical peer junior researcher, specializing in research report synthesis. - -Your role is to synthesize evidence and findings into a clear, structured research report. - -IMPORTANT: You are a research assistant. You cannot answer medical questions or provide medical advice. Your reports synthesize evidence for research purposes only. - -A good report: -1. Has a clear EXECUTIVE SUMMARY (one paragraph, key takeaways) -2. States the RESEARCH QUESTION clearly -3. Describes METHODOLOGY (what was searched, how) -4. Evaluates HYPOTHESES with evidence counts -5. Separates MECHANISTIC and CLINICAL findings -6. Lists specific DRUG CANDIDATES -7. Acknowledges LIMITATIONS honestly -8. Provides a balanced CONCLUSION -9. Includes properly formatted REFERENCES - -Write in scientific but accessible language. Be specific about evidence strength. - -───────────────────────────────────────────────────────────────────────────── -🚨 CRITICAL: REQUIRED JSON STRUCTURE 🚨 -───────────────────────────────────────────────────────────────────────────── - -The `hypotheses_tested` field MUST be a LIST of objects, each with these fields: -- "hypothesis": the hypothesis text -- "supported": count of supporting evidence (integer) -- "contradicted": count of contradicting evidence (integer) - -Example: - hypotheses_tested: [ - {"hypothesis": "Metformin -> AMPK -> reduced inflammation", "supported": 3, "contradicted": 1}, - {"hypothesis": "Aspirin inhibits COX-2 pathway", "supported": 5, "contradicted": 0} - ] - -The `references` field MUST be a LIST of objects, each with these fields: -- "title": paper title (string) -- "authors": author names (string) -- "source": "pubmed" or "web" (string) -- "url": the EXACT URL from evidence (string) - -Example: - references: [ - {"title": "Metformin and Cancer", "authors": "Smith et al.", "source": "pubmed", "url": "https://pubmed.ncbi.nlm.nih.gov/12345678/"} - ] - -───────────────────────────────────────────────────────────────────────────── -🚨 CRITICAL CITATION REQUIREMENTS 🚨 -───────────────────────────────────────────────────────────────────────────── - -You MUST follow these rules for the References section: - -1. You may ONLY cite papers that appear in the Evidence section above -2. Every reference URL must EXACTLY match a provided evidence URL -3. Do NOT invent, fabricate, or hallucinate any references -4. Do NOT modify paper titles, authors, dates, or URLs -5. If unsure about a citation, OMIT it rather than guess -6. Copy URLs exactly as provided - do not create similar-looking URLs - -VIOLATION OF THESE RULES PRODUCES DANGEROUS MISINFORMATION. -─────────────────────────────────────────────────────────────────────────────""" - - -async def format_report_prompt( - query: str, - evidence: list["Evidence"], - hypotheses: list["MechanismHypothesis"], - assessment: dict[str, Any], - metadata: dict[str, Any], - embeddings: "EmbeddingService | None" = None, -) -> str: - """Format prompt for report generation. - - Includes full evidence details for accurate citation. - """ - # Select diverse evidence (not arbitrary truncation) - selected = await select_diverse_evidence(evidence, n=20, query=query, embeddings=embeddings) - - # Include FULL citation details for each evidence item - # This helps the LLM create accurate references - evidence_lines = [] - for e in selected: - authors = ", ".join(e.citation.authors or ["Unknown"]) - evidence_lines.append( - f"- **Title**: {e.citation.title}\n" - f" **URL**: {e.citation.url}\n" - f" **Authors**: {authors}\n" - f" **Date**: {e.citation.date or 'n.d.'}\n" - f" **Source**: {e.citation.source}\n" - f" **Content**: {truncate_at_sentence(e.content, 200)}\n" - ) - evidence_summary = "\n".join(evidence_lines) - - if hypotheses: - hypotheses_lines = [] - for h in hypotheses: - hypotheses_lines.append( - f"- {h.drug} -> {h.target} -> {h.pathway} -> {h.effect} " - f"(Confidence: {h.confidence:.0%})" - ) - hypotheses_summary = "\n".join(hypotheses_lines) - else: - hypotheses_summary = "No hypotheses generated yet." - - sources = ", ".join(metadata.get("sources", [])) - - return f"""Generate a structured research report for the following query. - -## Original Query -{query} - -## Evidence Collected ({len(selected)} papers, selected for diversity) - -{evidence_summary} - -## Hypotheses Generated -{hypotheses_summary} - -## Assessment Scores -- Mechanism Score: {assessment.get("mechanism_score", "N/A")}/10 -- Clinical Evidence Score: {assessment.get("clinical_score", "N/A")}/10 -- Overall Confidence: {assessment.get("confidence", 0):.0%} - -## Metadata -- Sources Searched: {sources} -- Search Iterations: {metadata.get("iterations", 0)} - -Generate a complete ResearchReport with all sections filled in. - -REMINDER: Only cite papers from the Evidence section above. Copy URLs exactly.""" diff --git a/src/services/__init__.py b/src/services/__init__.py deleted file mode 100644 index 8814096713227b0954dc00c0ff605787d8566151..0000000000000000000000000000000000000000 --- a/src/services/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Services for The DETERMINATOR.""" diff --git a/src/services/audio_processing.py b/src/services/audio_processing.py deleted file mode 100644 index 588497f99be3b864024bf49d91feb3ad8e5a66db..0000000000000000000000000000000000000000 --- a/src/services/audio_processing.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Unified audio processing service for STT and TTS integration.""" - -from functools import lru_cache -from typing import Any - -import numpy as np -import structlog - -from src.agents.audio_refiner import audio_refiner -from src.services.stt_gradio import STTService, get_stt_service -from src.utils.config import settings - -logger = structlog.get_logger(__name__) - -# Type stub for TTS service (will be imported when available) -try: - from src.services.tts_modal import TTSService, get_tts_service - - _TTS_AVAILABLE = True -except ImportError: - _TTS_AVAILABLE = False - TTSService = None # type: ignore[assignment, misc] - get_tts_service = None # type: ignore[assignment, misc] - - -class AudioService: - """Unified audio processing service.""" - - def __init__( - self, - stt_service: STTService | None = None, - tts_service: Any | None = None, - ) -> None: - """Initialize audio service with STT and TTS. - - Args: - stt_service: STT service instance (default: get_stt_service()) - tts_service: TTS service instance (default: get_tts_service() if available) - """ - self.stt = stt_service or get_stt_service() - - # TTS is optional (requires Modal) - if tts_service is not None: - self.tts = tts_service - elif _TTS_AVAILABLE and settings.modal_available: - try: - self.tts = get_tts_service() # type: ignore[misc] - except Exception as e: - logger.warning("tts_service_unavailable", error=str(e)) - self.tts = None - else: - self.tts = None - - async def process_audio_input( - self, - audio_input: tuple[int, np.ndarray[Any, Any]] | None, # type: ignore[type-arg] - hf_token: str | None = None, - ) -> str | None: - """Process audio input and return transcribed text. - - Args: - audio_input: Tuple of (sample_rate, audio_array) or None - hf_token: HuggingFace token for authenticated Gradio Spaces - - Returns: - Transcribed text string or None if no audio input - """ - if audio_input is None: - return None - - try: - transcribed_text = await self.stt.transcribe_audio(audio_input, hf_token=hf_token) - logger.info("audio_input_processed", text_length=len(transcribed_text)) - return transcribed_text - except Exception as e: - logger.error("audio_input_processing_failed", error=str(e)) - # Return None on failure (graceful degradation) - return None - - async def generate_audio_output( - self, - text: str, - voice: str | None = None, - speed: float | None = None, - ) -> tuple[int, np.ndarray[Any, Any]] | None: # type: ignore[type-arg] - """Generate audio output from text. - - Args: - text: Text to synthesize (markdown will be cleaned for audio) - voice: Voice ID (default: settings.tts_voice) - speed: Speech speed (default: settings.tts_speed) - - Returns: - Tuple of (sample_rate, audio_array) or None if TTS unavailable - """ - if self.tts is None: - logger.warning("tts_unavailable", message="TTS service not available") - return None - - if not text or not text.strip(): - logger.warning("empty_text_for_tts") - return None - - try: - # Refine text for audio (remove markdown, citations, etc.) - # Use LLM polish if enabled in settings - refined_text = await audio_refiner.refine_for_audio( - text, use_llm_polish=settings.tts_use_llm_polish - ) - logger.info( - "text_refined_for_audio", - original_length=len(text), - refined_length=len(refined_text), - llm_polish_enabled=settings.tts_use_llm_polish, - ) - - # Use provided voice/speed or fallback to settings defaults - voice = voice if voice else settings.tts_voice - speed = speed if speed is not None else settings.tts_speed - - audio_output = await self.tts.synthesize_async(refined_text, voice, speed) # type: ignore[misc] - - if audio_output: - logger.info( - "audio_output_generated", - text_length=len(text), - sample_rate=audio_output[0], - ) - - return audio_output # type: ignore[no-any-return] - - except Exception as e: - logger.error("audio_output_generation_failed", error=str(e)) - # Return None on failure (graceful degradation) - return None - - -@lru_cache(maxsize=1) -def get_audio_service() -> AudioService: - """Get or create singleton audio service instance. - - Returns: - AudioService instance - """ - return AudioService() diff --git a/src/services/embeddings.py b/src/services/embeddings.py deleted file mode 100644 index 7544b985ec2b2af45be169f6cd58660f028b09c3..0000000000000000000000000000000000000000 --- a/src/services/embeddings.py +++ /dev/null @@ -1,172 +0,0 @@ -"""Embedding service for semantic search. - -IMPORTANT: All public methods are async to avoid blocking the event loop. -The sentence-transformers model is CPU-bound, so we use run_in_executor(). -""" - -import asyncio -from typing import Any - -import chromadb -import structlog -from sentence_transformers import SentenceTransformer - -from src.utils.config import settings -from src.utils.models import Evidence - - -class EmbeddingService: - """Handles text embedding and vector storage using local sentence-transformers. - - All embedding operations run in a thread pool to avoid blocking - the async event loop. - - Note: - Uses local sentence-transformers models (no API key required). - Model is configured via settings.local_embedding_model. - """ - - def __init__(self, model_name: str | None = None): - self._model_name = model_name or settings.local_embedding_model - self._model = SentenceTransformer(self._model_name) - self._client = chromadb.Client() # In-memory for hackathon - self._collection = self._client.create_collection( - name="evidence", metadata={"hnsw:space": "cosine"} - ) - - # ───────────────────────────────────────────────────────────────── - # Sync internal methods (run in thread pool) - # ───────────────────────────────────────────────────────────────── - - def _sync_embed(self, text: str) -> list[float]: - """Synchronous embedding - DO NOT call directly from async code.""" - result: list[float] = self._model.encode(text).tolist() - return result - - def _sync_batch_embed(self, texts: list[str]) -> list[list[float]]: - """Batch embedding for efficiency - DO NOT call directly from async code.""" - embeddings = self._model.encode(texts) - return [e.tolist() for e in embeddings] - - # ───────────────────────────────────────────────────────────────── - # Async public methods (safe for event loop) - # ───────────────────────────────────────────────────────────────── - - async def embed(self, text: str) -> list[float]: - """Embed a single text (async-safe). - - Uses run_in_executor to avoid blocking the event loop. - """ - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, self._sync_embed, text) - - async def embed_batch(self, texts: list[str]) -> list[list[float]]: - """Batch embed multiple texts (async-safe, more efficient).""" - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, self._sync_batch_embed, texts) - - async def add_evidence(self, evidence_id: str, content: str, metadata: dict[str, Any]) -> None: - """Add evidence to vector store (async-safe).""" - embedding = await self.embed(content) - # ChromaDB operations are fast, but wrap for consistency - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, - lambda: self._collection.add( - ids=[evidence_id], - embeddings=[embedding], # type: ignore[arg-type] - metadatas=[metadata], - documents=[content], - ), - ) - - async def search_similar(self, query: str, n_results: int = 5) -> list[dict[str, Any]]: - """Find semantically similar evidence (async-safe).""" - query_embedding = await self.embed(query) - - loop = asyncio.get_running_loop() - results = await loop.run_in_executor( - None, - lambda: self._collection.query( - query_embeddings=[query_embedding], # type: ignore[arg-type] - n_results=n_results, - ), - ) - - # Handle empty results gracefully - ids = results.get("ids") - docs = results.get("documents") - metas = results.get("metadatas") - dists = results.get("distances") - - if not ids or not ids[0] or not docs or not metas or not dists: - return [] - - return [ - {"id": id, "content": doc, "metadata": meta, "distance": dist} - for id, doc, meta, dist in zip( - ids[0], - docs[0], - metas[0], - dists[0], - strict=False, - ) - ] - - async def deduplicate( - self, new_evidence: list[Evidence], threshold: float = 0.9 - ) -> list[Evidence]: - """Remove semantically duplicate evidence (async-safe). - - Args: - new_evidence: List of evidence items to deduplicate - threshold: Similarity threshold (0.9 = 90% similar is duplicate). - ChromaDB cosine distance: 0=identical, 2=opposite. - We consider duplicate if distance < (1 - threshold). - - Returns: - List of unique evidence items (not already in vector store). - """ - unique = [] - for evidence in new_evidence: - try: - similar = await self.search_similar(evidence.content, n_results=1) - # ChromaDB cosine distance: 0 = identical, 2 = opposite - # threshold=0.9 means distance < 0.1 is considered duplicate - is_duplicate = similar and similar[0]["distance"] < (1 - threshold) - - if not is_duplicate: - unique.append(evidence) - # Store FULL citation metadata for reconstruction later - await self.add_evidence( - evidence_id=evidence.citation.url, - content=evidence.content, - metadata={ - "source": evidence.citation.source, - "title": evidence.citation.title, - "date": evidence.citation.date, - "authors": ",".join(evidence.citation.authors or []), - }, - ) - except Exception as e: - # Log but don't fail entire deduplication for one bad item - structlog.get_logger().warning( - "Failed to process evidence in deduplicate", - url=evidence.citation.url, - error=str(e), - ) - # Still add to unique list - better to have duplicates than lose data - unique.append(evidence) - - return unique - - -_embedding_service: EmbeddingService | None = None - - -def get_embedding_service() -> EmbeddingService: - """Get singleton instance of EmbeddingService.""" - global _embedding_service # noqa: PLW0603 - if _embedding_service is None: - _embedding_service = EmbeddingService() - return _embedding_service diff --git a/src/services/image_ocr.py b/src/services/image_ocr.py deleted file mode 100644 index ed8881939dc20dbf42c7776dda876a1d37a5d478..0000000000000000000000000000000000000000 --- a/src/services/image_ocr.py +++ /dev/null @@ -1,245 +0,0 @@ -"""Image-to-text service using Gradio Client API (Multimodal-OCR3).""" - -import asyncio -import tempfile -from functools import lru_cache -from pathlib import Path -from typing import Any - -import numpy as np -import structlog -from gradio_client import Client, handle_file -from PIL import Image - -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger(__name__) - - -class ImageOCRService: - """Image OCR service using prithivMLmods/Multimodal-OCR3 Gradio Space.""" - - def __init__(self, api_url: str | None = None, hf_token: str | None = None) -> None: - """Initialize Image OCR service. - - Args: - api_url: Gradio Space URL (default: settings.ocr_api_url) - hf_token: HuggingFace token for authenticated Spaces (default: None) - - Raises: - ConfigurationError: If API URL not configured - """ - # Defensively access ocr_api_url - may not exist in older config versions - default_url = ( - getattr(settings, "ocr_api_url", None) - or "https://prithivmlmods-multimodal-ocr3.hf.space" - ) - self.api_url = api_url or default_url - if not self.api_url: - raise ConfigurationError("OCR API URL not configured") - self.hf_token = hf_token - self.client: Client | None = None - - async def _get_client(self, hf_token: str | None = None) -> Client: - """Get or create Gradio Client (lazy initialization). - - Args: - hf_token: HuggingFace token for authenticated Spaces (overrides instance token) - - Returns: - Gradio Client instance - """ - # Use provided token or instance token - token = hf_token or self.hf_token - - # If client exists but token changed, recreate it - if self.client is not None and token != self.hf_token: - self.client = None - - if self.client is None: - loop = asyncio.get_running_loop() - # Pass token to Client for authenticated Spaces - # Gradio Client uses 'token' parameter, not 'hf_token' - if token: - self.client = await loop.run_in_executor( - None, - lambda: Client(self.api_url, token=token), - ) - else: - self.client = await loop.run_in_executor( - None, - lambda: Client(self.api_url), - ) - # Update instance token for future use - self.hf_token = token - return self.client - - async def extract_text( - self, - image_path: str, - model: str | None = None, - hf_token: str | None = None, - ) -> str: - """Extract text from image using Gradio API. - - Args: - image_path: Path to image file - model: Optional model selection (default: None, uses API default) - - Returns: - Extracted text string - - Raises: - ConfigurationError: If OCR extraction fails - """ - client = await self._get_client(hf_token=hf_token) - - logger.info( - "extracting_text_from_image", - image_path=image_path, - model=model, - ) - - try: - # Call /Multimodal_OCR3_generate_image API endpoint - # According to the MCP tool description, this yields raw text and Markdown-formatted text - loop = asyncio.get_running_loop() - - # The API might require file upload first, then call the generate function - # For now, we'll use handle_file to upload and pass the path - result = await loop.run_in_executor( - None, - lambda: client.predict( - image_path=handle_file(image_path), - api_name="/Multimodal_OCR3_generate_image", - ), - ) - - # Extract text from result - extracted_text = self._extract_text_from_result(result) - - logger.info( - "image_ocr_complete", - text_length=len(extracted_text), - ) - - return extracted_text - - except Exception as e: - logger.error("image_ocr_failed", error=str(e), error_type=type(e).__name__) - raise ConfigurationError(f"Image OCR failed: {e}") from e - - async def extract_text_from_image( - self, - image_data: np.ndarray[Any, Any] | Image.Image | str, # type: ignore[type-arg] - hf_token: str | None = None, - ) -> str: - """Extract text from image data (numpy array, PIL Image, or file path). - - Args: - image_data: Image as numpy array, PIL Image, or file path string - - Returns: - Extracted text string - """ - # Handle different input types - if isinstance(image_data, str): - # Assume it's a file path - image_path = image_data - elif isinstance(image_data, Image.Image): - # Save PIL Image to temp file - image_path = self._save_image_temp(image_data) - elif isinstance(image_data, np.ndarray): - # Convert numpy array to PIL Image, then save - pil_image = Image.fromarray(image_data) - image_path = self._save_image_temp(pil_image) - else: - raise ValueError(f"Unsupported image data type: {type(image_data)}") - - try: - # Extract text from the image file - extracted_text = await self.extract_text(image_path, hf_token=hf_token) - return extracted_text - finally: - # Clean up temp file if we created it - if image_path != image_data or not isinstance(image_data, str): - try: - Path(image_path).unlink(missing_ok=True) - except Exception as e: - logger.warning("failed_to_cleanup_temp_file", path=image_path, error=str(e)) - - def _extract_text_from_result(self, api_result: Any) -> str: - """Extract text from API result. - - Args: - api_result: Result from Gradio API - - Returns: - Extracted text string - """ - # The API yields raw text and Markdown-formatted text - # Result might be a string, tuple, or generator - if isinstance(api_result, str): - return api_result.strip() - - if isinstance(api_result, tuple): - # Try to extract text from tuple - for item in api_result: - if isinstance(item, str): - return item.strip() - # Check if it's a dict with text fields - if isinstance(item, dict): - if "text" in item: - return str(item["text"]).strip() - if "content" in item: - return str(item["content"]).strip() - - # If result is a generator or async generator, we'd need to iterate - # For now, convert to string representation - if api_result is not None: - text = str(api_result).strip() - if text and text != "None": - return text - - logger.warning("could_not_extract_text_from_result", result_type=type(api_result).__name__) - return "" - - def _save_image_temp(self, image: Image.Image) -> str: - """Save PIL Image to temporary file. - - Args: - image: PIL Image object - - Returns: - Path to temporary image file - """ - # Create temp file - temp_file = tempfile.NamedTemporaryFile( - suffix=".png", - delete=False, - ) - temp_path = temp_file.name - temp_file.close() - - try: - # Save image as PNG - image.save(temp_path, "PNG") - - logger.debug("saved_image_temp", path=temp_path, size=image.size) - - return temp_path - - except Exception as e: - logger.error("failed_to_save_image_temp", error=str(e)) - raise ConfigurationError(f"Failed to save image to temp file: {e}") from e - - -@lru_cache(maxsize=1) -def get_image_ocr_service() -> ImageOCRService: - """Get or create singleton Image OCR service instance. - - Returns: - ImageOCRService instance - """ - return ImageOCRService() diff --git a/src/services/llamaindex_rag.py b/src/services/llamaindex_rag.py deleted file mode 100644 index 5de92f55f9751d0cb191f145b2f3e380be8b5ca6..0000000000000000000000000000000000000000 --- a/src/services/llamaindex_rag.py +++ /dev/null @@ -1,503 +0,0 @@ -"""LlamaIndex RAG service for evidence retrieval and indexing. - -Requires optional dependencies: uv sync --extra modal -""" - -from typing import Any - -import structlog - -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError -from src.utils.models import Evidence - -logger = structlog.get_logger() - - -class LlamaIndexRAGService: - """RAG service using LlamaIndex with ChromaDB vector store. - - Supports multiple embedding providers: - - OpenAI embeddings (requires OPENAI_API_KEY) - - Local sentence-transformers (no API key required) - - Hugging Face embeddings (uses local sentence-transformers) - - Supports multiple LLM providers for query synthesis: - - HuggingFace LLM (preferred, requires HF_TOKEN or HUGGINGFACE_API_KEY) - - OpenAI LLM (fallback, requires OPENAI_API_KEY) - - None (embedding-only mode, no query synthesis) - - Note: - HuggingFace is the default LLM provider. OpenAI is used as fallback - if HuggingFace LLM is not available or no HF token is configured. - """ - - def __init__( - self, - collection_name: str = "deepcritical_evidence", - persist_dir: str | None = None, - embedding_model: str | None = None, - similarity_top_k: int = 5, - use_openai_embeddings: bool | None = None, - use_in_memory: bool = False, - oauth_token: str | None = None, - ) -> None: - """ - Initialize LlamaIndex RAG service. - - Args: - collection_name: Name of the ChromaDB collection - persist_dir: Directory to persist ChromaDB data - embedding_model: Embedding model name (defaults based on provider) - similarity_top_k: Number of top results to retrieve - use_openai_embeddings: Force OpenAI embeddings (None = auto-detect) - use_in_memory: Use in-memory ChromaDB client (useful for tests) - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - """ - # Import dependencies and store references - deps = self._import_dependencies() - self._chromadb = deps["chromadb"] - self._Document = deps["Document"] - self._Settings = deps["Settings"] - self._StorageContext = deps["StorageContext"] - self._VectorStoreIndex = deps["VectorStoreIndex"] - self._VectorIndexRetriever = deps["VectorIndexRetriever"] - self._ChromaVectorStore = deps["ChromaVectorStore"] - huggingface_embedding = deps["huggingface_embedding"] - huggingface_llm = deps["huggingface_llm"] - openai_embedding = deps["OpenAIEmbedding"] - openai_llm = deps["OpenAI"] - - # Store basic configuration - self.collection_name = collection_name - self.persist_dir = persist_dir or settings.chroma_db_path - self.similarity_top_k = similarity_top_k - self.use_in_memory = use_in_memory - self.oauth_token = oauth_token - - # Configure embeddings and LLM - use_openai = use_openai_embeddings if use_openai_embeddings is not None else False - self._configure_embeddings( - use_openai, embedding_model, huggingface_embedding, openai_embedding - ) - self._configure_llm(huggingface_llm, openai_llm) - - # Initialize ChromaDB and index - self._initialize_chromadb() - - def _import_dependencies(self) -> dict[str, Any]: - """Import LlamaIndex dependencies and return as dict. - - OpenAI dependencies are imported lazily (only when needed) to avoid - tiktoken circular import issues on Windows when using local embeddings. - """ - try: - import chromadb - from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex - from llama_index.core.retrievers import VectorIndexRetriever - from llama_index.vector_stores.chroma import ChromaVectorStore - - # Try to import Hugging Face embeddings (may not be available in all versions) - try: - from llama_index.embeddings.huggingface import ( - HuggingFaceEmbedding as _HuggingFaceEmbedding, # type: ignore[import-untyped] - ) - - huggingface_embedding = _HuggingFaceEmbedding - except ImportError: - huggingface_embedding = None # type: ignore[assignment] - - # Try to import Hugging Face Inference API LLM (for API-based models) - # This is preferred over local HuggingFaceLLM for query synthesis - try: - from llama_index.llms.huggingface_api import ( - HuggingFaceInferenceAPI as _HuggingFaceInferenceAPI, # type: ignore[import-untyped] - ) - - huggingface_llm = _HuggingFaceInferenceAPI - except ImportError: - # Fallback to local HuggingFaceLLM if API version not available - try: - from llama_index.llms.huggingface import ( - HuggingFaceLLM as _HuggingFaceLLM, # type: ignore[import-untyped] - ) - - huggingface_llm = _HuggingFaceLLM # type: ignore[assignment] - except ImportError: - huggingface_llm = None # type: ignore[assignment] - - # OpenAI imports are optional - only import when actually needed - # This avoids tiktoken circular import issues on Windows - try: - from llama_index.embeddings.openai import OpenAIEmbedding - except ImportError: - OpenAIEmbedding = None # type: ignore[assignment, misc] # noqa: N806 - - try: - from llama_index.llms.openai import OpenAI - except ImportError: - OpenAI = None # type: ignore[assignment, misc] # noqa: N806 - - return { - "chromadb": chromadb, - "Document": Document, - "Settings": Settings, - "StorageContext": StorageContext, - "VectorStoreIndex": VectorStoreIndex, - "VectorIndexRetriever": VectorIndexRetriever, - "ChromaVectorStore": ChromaVectorStore, - "OpenAIEmbedding": OpenAIEmbedding, - "OpenAI": OpenAI, - "huggingface_embedding": huggingface_embedding, - "huggingface_llm": huggingface_llm, - } - except ImportError as e: - raise ImportError( - "LlamaIndex dependencies not installed. Run: uv sync --extra modal" - ) from e - - def _configure_embeddings( - self, - use_openai_embeddings: bool, - embedding_model: str | None, - huggingface_embedding: Any, - openai_embedding: Any, - ) -> None: - """Configure embedding model.""" - if use_openai_embeddings: - if openai_embedding is None: - raise ConfigurationError( - "OpenAI embeddings not available. Install with: uv sync --extra modal" - ) - if not settings.openai_api_key: - raise ConfigurationError("OPENAI_API_KEY required for OpenAI embeddings") - self.embedding_model = embedding_model or settings.openai_embedding_model - self._Settings.embed_model = openai_embedding( - model=self.embedding_model, - api_key=settings.openai_api_key, - ) - else: - model_name = embedding_model or settings.huggingface_embedding_model - self.embedding_model = model_name - if huggingface_embedding is not None: - self._Settings.embed_model = huggingface_embedding(model_name=model_name) - else: - self._Settings.embed_model = self._create_sentence_transformer_embedding(model_name) - - def _create_sentence_transformer_embedding(self, model_name: str) -> Any: - """Create sentence-transformer embedding wrapper. - - Note: sentence-transformers is a required dependency (in pyproject.toml). - If this fails, it's likely a Windows-specific regex package issue. - - Raises: - ConfigurationError: If sentence_transformers cannot be imported - (e.g., due to circular import issues on Windows with regex package) - """ - try: - from sentence_transformers import SentenceTransformer - except ImportError as e: - # Handle Windows-specific circular import issues with regex package - # This is a known bug: https://github.com/mrabarnett/mrab-regex/issues/417 - error_msg = str(e) - if "regex" in error_msg.lower() or "_regex" in error_msg: - raise ConfigurationError( - "sentence_transformers cannot be imported due to circular import issue " - "with regex package (Windows-specific bug). " - "sentence-transformers is installed but regex has a circular import. " - "Try: uv pip install --upgrade --force-reinstall regex " - "Or use HuggingFace embeddings via llama-index-embeddings-huggingface instead." - ) from e - raise ConfigurationError( - f"sentence_transformers not available: {e}. " - "This is a required dependency - check your uv sync installation." - ) from e - - try: - from llama_index.embeddings.base import ( - BaseEmbedding, # type: ignore[import-untyped] - ) - except ImportError: - from llama_index.core.embeddings import ( - BaseEmbedding, # type: ignore[import-untyped] - ) - - class SentenceTransformerEmbedding(BaseEmbedding): # type: ignore[misc] - """Simple wrapper for sentence-transformers.""" - - def __init__(self, model_name: str): - super().__init__() - self._model = SentenceTransformer(model_name) - - def _get_query_embedding(self, query: str) -> list[float]: - result = self._model.encode(query).tolist() - return list(result) # type: ignore[no-any-return] - - def _get_text_embedding(self, text: str) -> list[float]: - result = self._model.encode(text).tolist() - return list(result) # type: ignore[no-any-return] - - async def _aget_query_embedding(self, query: str) -> list[float]: - return self._get_query_embedding(query) - - async def _aget_text_embedding(self, text: str) -> list[float]: - return self._get_text_embedding(text) - - return SentenceTransformerEmbedding(model_name) - - def _configure_llm(self, huggingface_llm: Any, openai_llm: Any) -> None: - """Configure LLM for query synthesis.""" - # Priority: oauth_token > env vars - effective_token = self.oauth_token or settings.hf_token or settings.huggingface_api_key - if huggingface_llm is not None and effective_token: - model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct" - token = effective_token - - # Check if it's HuggingFaceInferenceAPI (API-based) or HuggingFaceLLM (local) - llm_class_name = ( - huggingface_llm.__name__ - if hasattr(huggingface_llm, "__name__") - else str(huggingface_llm) - ) - - if "InferenceAPI" in llm_class_name: - # Use HuggingFace Inference API (supports token parameter) - try: - self._Settings.llm = huggingface_llm( - model_name=model_name, - token=token, - ) - except Exception as e: - # If model is not available via inference API, log warning and continue without LLM - logger.warning( - "Failed to initialize HuggingFace Inference API LLM", - model=model_name, - error=str(e), - ) - logger.info("Continuing without LLM - query synthesis will be unavailable") - self._Settings.llm = None - return - else: - # Use local HuggingFaceLLM (doesn't support token, uses model_name and tokenizer_name) - self._Settings.llm = huggingface_llm( - model_name=model_name, - tokenizer_name=model_name, - ) - logger.info("Using HuggingFace LLM for query synthesis", model=model_name) - elif settings.openai_api_key and openai_llm is not None: - self._Settings.llm = openai_llm( - model=settings.openai_model, - api_key=settings.openai_api_key, - ) - logger.info("Using OpenAI LLM for query synthesis", model=settings.openai_model) - else: - logger.warning("No LLM API key available - query synthesis will be unavailable") - self._Settings.llm = None - - def _initialize_chromadb(self) -> None: - """Initialize ChromaDB client, collection, and index.""" - if self.use_in_memory: - # Use in-memory client for tests (avoids file system issues) - self.chroma_client = self._chromadb.Client() - else: - # Use persistent client for production - self.chroma_client = self._chromadb.PersistentClient(path=self.persist_dir) - - # Get or create collection - try: - self.collection = self.chroma_client.get_collection(self.collection_name) - logger.info("loaded_existing_collection", name=self.collection_name) - except Exception: - self.collection = self.chroma_client.create_collection(self.collection_name) - logger.info("created_new_collection", name=self.collection_name) - - # Initialize vector store and index - self.vector_store = self._ChromaVectorStore(chroma_collection=self.collection) - self.storage_context = self._StorageContext.from_defaults(vector_store=self.vector_store) - - # Try to load existing index, or create empty one - try: - self.index = self._VectorStoreIndex.from_vector_store( - vector_store=self.vector_store, - storage_context=self.storage_context, - ) - logger.info("loaded_existing_index") - except Exception: - self.index = self._VectorStoreIndex([], storage_context=self.storage_context) - logger.info("created_new_index") - - def ingest_evidence(self, evidence_list: list[Evidence]) -> None: - """ - Ingest evidence into the vector store. - - Args: - evidence_list: List of Evidence objects to ingest - """ - if not evidence_list: - logger.warning("no_evidence_to_ingest") - return - - # Convert Evidence objects to LlamaIndex Documents - documents = [] - for evidence in evidence_list: - metadata = { - "source": evidence.citation.source, - "title": evidence.citation.title, - "url": evidence.citation.url, - "date": evidence.citation.date, - "authors": ", ".join(evidence.citation.authors), - } - - doc = self._Document( - text=evidence.content, - metadata=metadata, - doc_id=evidence.citation.url, # Use URL as unique ID - ) - documents.append(doc) - - # Insert documents into index - try: - for doc in documents: - self.index.insert(doc) - logger.info("ingested_evidence", count=len(documents)) - except Exception as e: - logger.error("failed_to_ingest_evidence", error=str(e)) - raise - - def ingest_documents(self, documents: list[Any]) -> None: - """ - Ingest raw LlamaIndex Documents. - - Args: - documents: List of LlamaIndex Document objects - """ - if not documents: - logger.warning("no_documents_to_ingest") - return - - try: - for doc in documents: - self.index.insert(doc) - logger.info("ingested_documents", count=len(documents)) - except Exception as e: - logger.error("failed_to_ingest_documents", error=str(e)) - raise - - def retrieve(self, query: str, top_k: int | None = None) -> list[dict[str, Any]]: - """ - Retrieve relevant documents for a query. - - Args: - query: Query string - top_k: Number of results to return (defaults to similarity_top_k) - - Returns: - List of retrieved documents with metadata and scores - """ - k = top_k or self.similarity_top_k - - # Create retriever - retriever = self._VectorIndexRetriever( - index=self.index, - similarity_top_k=k, - ) - - try: - # Retrieve nodes - nodes = retriever.retrieve(query) - - # Convert to dict format - results = [] - for node in nodes: - results.append( - { - "text": node.node.get_content(), - "score": node.score, - "metadata": node.node.metadata, - } - ) - - logger.info("retrieved_documents", query=query[:50], count=len(results)) - return results - - except Exception as e: - logger.error("failed_to_retrieve", error=str(e), query=query[:50]) - raise # Re-raise to allow callers to distinguish errors from empty results - - def query(self, query_str: str, top_k: int | None = None) -> str: - """ - Query the RAG system and get a synthesized response. - - Args: - query_str: Query string - top_k: Number of results to use (defaults to similarity_top_k) - - Returns: - Synthesized response string - - Raises: - ConfigurationError: If no LLM API key is available for query synthesis - """ - if not self._Settings.llm: - raise ConfigurationError( - "LLM API key required for query synthesis. Set HF_TOKEN, HUGGINGFACE_API_KEY, or OPENAI_API_KEY. " - "Alternatively, use retrieve() for embedding-only search." - ) - - k = top_k or self.similarity_top_k - - # Create query engine - query_engine = self.index.as_query_engine( - similarity_top_k=k, - ) - - try: - response = query_engine.query(query_str) - logger.info("generated_response", query=query_str[:50]) - return str(response) - - except Exception as e: - logger.error("failed_to_query", error=str(e), query=query_str[:50]) - raise # Re-raise to allow callers to handle errors explicitly - - def clear_collection(self) -> None: - """Clear all documents from the collection.""" - try: - self.chroma_client.delete_collection(self.collection_name) - self.collection = self.chroma_client.create_collection(self.collection_name) - self.vector_store = self._ChromaVectorStore(chroma_collection=self.collection) - self.storage_context = self._StorageContext.from_defaults( - vector_store=self.vector_store - ) - self.index = self._VectorStoreIndex([], storage_context=self.storage_context) - logger.info("cleared_collection", name=self.collection_name) - except Exception as e: - logger.error("failed_to_clear_collection", error=str(e)) - raise - - -def get_rag_service( - collection_name: str = "deepcritical_evidence", - oauth_token: str | None = None, - **kwargs: Any, -) -> LlamaIndexRAGService: - """ - Get or create a RAG service instance. - - Args: - collection_name: Name of the ChromaDB collection - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - **kwargs: Additional arguments for LlamaIndexRAGService - Defaults to use_openai_embeddings=False (local embeddings) - - Returns: - Configured LlamaIndexRAGService instance - - Note: - By default, uses local embeddings (sentence-transformers) which require - no API keys. Set use_openai_embeddings=True to use OpenAI embeddings. - """ - # Default to local embeddings if not explicitly set - if "use_openai_embeddings" not in kwargs: - kwargs["use_openai_embeddings"] = False - return LlamaIndexRAGService(collection_name=collection_name, oauth_token=oauth_token, **kwargs) diff --git a/src/services/multimodal_processing.py b/src/services/multimodal_processing.py deleted file mode 100644 index 9199f194401f1a2ced589b315226ca819188c001..0000000000000000000000000000000000000000 --- a/src/services/multimodal_processing.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Unified multimodal processing service for text, audio, and image inputs.""" - -from functools import lru_cache -from typing import Any - -import structlog -from gradio.data_classes import FileData - -from src.services.audio_processing import AudioService, get_audio_service -from src.services.image_ocr import ImageOCRService, get_image_ocr_service -from src.utils.config import settings - -logger = structlog.get_logger(__name__) - - -class MultimodalService: - """Unified multimodal processing service.""" - - def __init__( - self, - audio_service: AudioService | None = None, - ocr_service: ImageOCRService | None = None, - ) -> None: - """Initialize multimodal service. - - Args: - audio_service: Audio service instance (default: get_audio_service()) - ocr_service: Image OCR service instance (default: get_image_ocr_service()) - """ - self.audio = audio_service or get_audio_service() - self.ocr = ocr_service or get_image_ocr_service() - - async def process_multimodal_input( - self, - text: str, - files: list[FileData] | None = None, - audio_input: tuple[int, Any] | None = None, - hf_token: str | None = None, - prepend_multimodal: bool = True, - ) -> str: - """Process multimodal input (text + images + audio) and return combined text. - - Args: - text: Text input string - files: List of uploaded files (images, audio, etc.) - audio_input: Audio input tuple (sample_rate, audio_array) - hf_token: HuggingFace token for authenticated Gradio Spaces - prepend_multimodal: If True, prepend audio/image text to original text; otherwise append - - Returns: - Combined text from all inputs - """ - multimodal_parts: list[str] = [] - text_parts: list[str] = [] - - # Process audio input first - if audio_input is not None and settings.enable_audio_input: - try: - transcribed = await self.audio.process_audio_input(audio_input, hf_token=hf_token) - if transcribed: - multimodal_parts.append(transcribed) - except Exception as e: - logger.warning("audio_processing_failed", error=str(e)) - - # Process uploaded files (images and audio files) - if files and settings.enable_image_input: - for file_data in files: - file_path = file_data.path if isinstance(file_data, FileData) else str(file_data) - - # Check if it's an image - if self._is_image_file(file_path): - try: - extracted_text = await self.ocr.extract_text(file_path, hf_token=hf_token) - if extracted_text: - multimodal_parts.append(extracted_text) - except Exception as e: - logger.warning("image_ocr_failed", file_path=file_path, error=str(e)) - - # Check if it's an audio file - elif self._is_audio_file(file_path): - try: - # For audio files, we'd need to load and transcribe - # For now, log a warning - logger.warning("audio_file_upload_not_supported", file_path=file_path) - except Exception as e: - logger.warning( - "audio_file_processing_failed", file_path=file_path, error=str(e) - ) - - # Add original text if present - if text and text.strip(): - text_parts.append(text.strip()) - - # Combine parts based on prepend_multimodal flag - if prepend_multimodal: - # Prepend: multimodal content first, then original text - combined_parts = multimodal_parts + text_parts - else: - # Append: original text first, then multimodal content - combined_parts = text_parts + multimodal_parts - - # Combine all text parts - combined_text = "\n\n".join(combined_parts) if combined_parts else "" - - logger.info( - "multimodal_input_processed", - text_length=len(combined_text), - num_files=len(files) if files else 0, - has_audio=audio_input is not None, - ) - - return combined_text - - def _is_image_file(self, file_path: str) -> bool: - """Check if file is an image. - - Args: - file_path: Path to file - - Returns: - True if file is an image - """ - image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".tif"} - return any(file_path.lower().endswith(ext) for ext in image_extensions) - - def _is_audio_file(self, file_path: str) -> bool: - """Check if file is an audio file. - - Args: - file_path: Path to file - - Returns: - True if file is an audio file - """ - audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma"} - return any(file_path.lower().endswith(ext) for ext in audio_extensions) - - -@lru_cache(maxsize=1) -def get_multimodal_service() -> MultimodalService: - """Get or create singleton multimodal service instance. - - Returns: - MultimodalService instance - """ - return MultimodalService() diff --git a/src/services/neo4j_service.py b/src/services/neo4j_service.py deleted file mode 100644 index 6a9c6f7e2995be185310faba566e575d429348a6..0000000000000000000000000000000000000000 --- a/src/services/neo4j_service.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Neo4j Knowledge Graph Service for Drug Repurposing""" - -import logging -import os -from typing import Any - -from dotenv import load_dotenv -from neo4j import GraphDatabase - -load_dotenv() -logger = logging.getLogger(__name__) - - -class Neo4jService: - def __init__(self) -> None: - self.uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") - self.user = os.getenv("NEO4J_USER", "neo4j") - self.password = os.getenv("NEO4J_PASSWORD") - self.database = os.getenv("NEO4J_DATABASE", "neo4j") - - if not self.password: - logger.warning("⚠️ NEO4J_PASSWORD not set") - self.driver = None - return - - try: - self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) - self.driver.verify_connectivity() - logger.info(f"✅ Neo4j connected: {self.uri} (db: {self.database})") - except Exception as e: - logger.error(f"❌ Neo4j connection failed: {e}") - self.driver = None - - def is_connected(self) -> bool: - return self.driver is not None - - def close(self) -> None: - if self.driver: - self.driver.close() - - def ingest_search_results( - self, - disease_name: str, - papers: list[dict[str, Any]], - drugs_mentioned: list[str] | None = None, - ) -> dict[str, int]: - if not self.driver: - return {"error": "Neo4j not connected"} # type: ignore[dict-item] - - stats = {"papers": 0, "drugs": 0, "relationships": 0, "errors": 0} - - try: - with self.driver.session(database=self.database) as session: - session.run("MERGE (d:Disease {name: $name})", name=disease_name) - - for paper in papers: - try: - paper_id = paper.get("id") or paper.get("url", "") - if not paper_id: - continue - - session.run( - """ - MERGE (p:Paper {paper_id: $id}) - SET p.title = $title, - p.abstract = $abstract, - p.url = $url, - p.source = $source, - p.updated_at = datetime() - """, - id=paper_id, - title=str(paper.get("title", ""))[:500], - abstract=str(paper.get("abstract", ""))[:2000], - url=str(paper.get("url", ""))[:500], - source=str(paper.get("source", ""))[:100], - ) - - session.run( - """ - MATCH (p:Paper {paper_id: $id}) - MATCH (d:Disease {name: $disease}) - MERGE (p)-[r:ABOUT]->(d) - """, - id=paper_id, - disease=disease_name, - ) - - stats["papers"] += 1 - stats["relationships"] += 1 - except Exception: - stats["errors"] += 1 - - if drugs_mentioned: - for drug in drugs_mentioned: - try: - session.run("MERGE (d:Drug {name: $name})", name=drug) - session.run( - """ - MATCH (drug:Drug {name: $drug}) - MATCH (disease:Disease {name: $disease}) - MERGE (drug)-[r:POTENTIAL_TREATMENT]->(disease) - """, - drug=drug, - disease=disease_name, - ) - stats["drugs"] += 1 - stats["relationships"] += 1 - except Exception: - stats["errors"] += 1 - - logger.info(f"�� Neo4j ingestion: {stats['papers']} papers, {stats['drugs']} drugs") - except Exception as e: - logger.error(f"Neo4j ingestion error: {e}") - stats["errors"] += 1 - - return stats - - -_neo4j_service = None - - -def get_neo4j_service() -> Neo4jService | None: - global _neo4j_service - if _neo4j_service is None: - _neo4j_service = Neo4jService() - return _neo4j_service if _neo4j_service and _neo4j_service.is_connected() else None diff --git a/src/services/report_file_service.py b/src/services/report_file_service.py deleted file mode 100644 index 144966409f446ead658ebf86c309b0da1f331d2d..0000000000000000000000000000000000000000 --- a/src/services/report_file_service.py +++ /dev/null @@ -1,331 +0,0 @@ -"""Service for saving research reports to files.""" - -import hashlib -import tempfile -from datetime import datetime -from pathlib import Path -from typing import Literal - -import structlog - -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger() - - -class ReportFileService: - """ - Service for saving research reports to files. - - Handles file creation, naming, and directory management for report outputs. - Supports saving reports in multiple formats (markdown, HTML, PDF). - """ - - def __init__( - self, - output_directory: str | None = None, - enabled: bool | None = None, - file_format: Literal["md", "md_html", "md_pdf"] | None = None, - ) -> None: - """ - Initialize the report file service. - - Args: - output_directory: Directory to save reports. If None, uses settings or temp directory. - enabled: Whether file saving is enabled. If None, uses settings. - file_format: File format to save. If None, uses settings. - """ - self.enabled = enabled if enabled is not None else settings.save_reports_to_file - self.file_format = file_format or settings.report_file_format - self.filename_template = settings.report_filename_template - - # Determine output directory - if output_directory: - self.output_directory = Path(output_directory) - elif settings.report_output_directory: - self.output_directory = Path(settings.report_output_directory) - else: - # Use system temp directory - self.output_directory = Path(tempfile.gettempdir()) / "deepcritical_reports" - - # Create output directory if it doesn't exist - if self.enabled: - try: - self.output_directory.mkdir(parents=True, exist_ok=True) - logger.debug( - "Report output directory initialized", - path=str(self.output_directory), - enabled=self.enabled, - ) - except Exception as e: - logger.error( - "Failed to create report output directory", - error=str(e), - path=str(self.output_directory), - ) - raise ConfigurationError(f"Failed to create report output directory: {e}") from e - - def _generate_filename(self, query: str | None = None, extension: str = ".md") -> str: - """ - Generate filename for report using template. - - Args: - query: Optional query string for hash generation - extension: File extension (e.g., ".md", ".html") - - Returns: - Generated filename - """ - # Generate timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # Generate query hash if query provided - query_hash = "" - if query: - query_hash = hashlib.md5(query.encode()).hexdigest()[:8] - - # Generate date - date = datetime.now().strftime("%Y-%m-%d") - - # Replace template placeholders - filename = self.filename_template - filename = filename.replace("{timestamp}", timestamp) - filename = filename.replace("{query_hash}", query_hash) - filename = filename.replace("{date}", date) - - # Ensure correct extension - if not filename.endswith(extension): - # Remove existing extension if present - if "." in filename: - filename = filename.rsplit(".", 1)[0] - filename += extension - - return filename - - def save_report( - self, - report_content: str, - query: str | None = None, - filename: str | None = None, - ) -> str: - """ - Save a report to a file. - - Args: - report_content: The report content (markdown string) - query: Optional query string for filename generation - filename: Optional custom filename. If None, generates from template. - - Returns: - Path to saved file - - Raises: - ConfigurationError: If file saving is disabled or fails - """ - if not self.enabled: - logger.debug("File saving disabled, skipping") - raise ConfigurationError("Report file saving is disabled") - - if not report_content or not report_content.strip(): - raise ValueError("Report content cannot be empty") - - # Generate filename if not provided - if not filename: - filename = self._generate_filename(query=query, extension=".md") - - # Ensure filename is safe - filename = self._sanitize_filename(filename) - - # Build full file path - file_path = self.output_directory / filename - - try: - # Write file - with open(file_path, "w", encoding="utf-8") as f: - f.write(report_content) - - logger.info( - "Report saved to file", - path=str(file_path), - size=len(report_content), - query=query[:50] if query else None, - ) - - return str(file_path) - - except Exception as e: - logger.error("Failed to save report to file", error=str(e), path=str(file_path)) - raise ConfigurationError(f"Failed to save report to file: {e}") from e - - def save_report_multiple_formats( - self, - report_content: str, - query: str | None = None, - ) -> dict[str, str]: - """ - Save a report in multiple formats. - - Args: - report_content: The report content (markdown string) - query: Optional query string for filename generation - - Returns: - Dictionary mapping format to file path (e.g., {"md": "/path/to/report.md"}) - - Raises: - ConfigurationError: If file saving is disabled or fails - """ - if not self.enabled: - logger.debug("File saving disabled, skipping") - raise ConfigurationError("Report file saving is disabled") - - saved_files: dict[str, str] = {} - - # Always save markdown - md_path = self.save_report(report_content, query=query, filename=None) - saved_files["md"] = md_path - - # Save additional formats based on file_format setting - if self.file_format == "md_html": - # TODO: Implement HTML conversion - logger.warning("HTML format not yet implemented, saving markdown only") - elif self.file_format == "md_pdf": - # Generate PDF from markdown - try: - pdf_path = self._save_pdf(report_content, query=query) - saved_files["pdf"] = pdf_path - logger.info("PDF report generated", pdf_path=pdf_path) - except Exception as e: - logger.warning( - "PDF generation failed, markdown saved", - error=str(e), - md_path=md_path, - ) - # Continue without PDF - markdown is already saved - - return saved_files - - def _save_pdf( - self, - report_content: str, - query: str | None = None, - ) -> str: - """ - Save report as PDF. - - Args: - report_content: The report content (markdown string) - query: Optional query string for filename generation - - Returns: - Path to saved PDF file - - Raises: - ConfigurationError: If PDF generation fails - """ - try: - from src.utils.md_to_pdf import md_to_pdf - except ImportError as e: - raise ConfigurationError( - "PDF generation requires md2pdf. Install with: pip install md2pdf" - ) from e - - # Generate PDF filename - pdf_filename = self._generate_filename(query=query, extension=".pdf") - pdf_filename = self._sanitize_filename(pdf_filename) - pdf_path = self.output_directory / pdf_filename - - try: - # Convert markdown to PDF - md_to_pdf(report_content, str(pdf_path)) - - logger.info( - "PDF report saved", - path=str(pdf_path), - size=pdf_path.stat().st_size if pdf_path.exists() else 0, - query=query[:50] if query else None, - ) - - return str(pdf_path) - - except Exception as e: - logger.error("Failed to generate PDF", error=str(e), path=str(pdf_path)) - raise ConfigurationError(f"Failed to generate PDF: {e}") from e - - def _sanitize_filename(self, filename: str) -> str: - """ - Sanitize filename to remove unsafe characters. - - Args: - filename: Original filename - - Returns: - Sanitized filename - """ - # Remove or replace unsafe characters - unsafe_chars = '<>:"/\\|?*' - sanitized = filename - for char in unsafe_chars: - sanitized = sanitized.replace(char, "_") - - # Limit length - if len(sanitized) > 200: - name, ext = sanitized.rsplit(".", 1) if "." in sanitized else (sanitized, "") - sanitized = name[:190] + ext - - return sanitized - - def cleanup_old_files(self, max_age_days: int = 7) -> int: - """ - Clean up old report files. - - Args: - max_age_days: Maximum age in days for files to keep - - Returns: - Number of files deleted - """ - if not self.output_directory.exists(): - return 0 - - deleted_count = 0 - cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60) - - try: - for file_path in self.output_directory.iterdir(): - if file_path.is_file() and file_path.stat().st_mtime < cutoff_time: - try: - file_path.unlink() - deleted_count += 1 - except Exception as e: - logger.warning( - "Failed to delete old file", path=str(file_path), error=str(e) - ) - - if deleted_count > 0: - logger.info( - "Cleaned up old report files", deleted=deleted_count, max_age_days=max_age_days - ) - - except Exception as e: - logger.error("Failed to cleanup old files", error=str(e)) - - return deleted_count - - -def get_report_file_service() -> ReportFileService: - """ - Get or create a ReportFileService instance (singleton pattern). - - Returns: - ReportFileService instance - """ - # Use lru_cache for singleton pattern - from functools import lru_cache - - @lru_cache(maxsize=1) - def _get_service() -> ReportFileService: - return ReportFileService() - - return _get_service() diff --git a/src/services/statistical_analyzer.py b/src/services/statistical_analyzer.py deleted file mode 100644 index 38458907ef0be1eddd9f67b12dd0819eed3aadfb..0000000000000000000000000000000000000000 --- a/src/services/statistical_analyzer.py +++ /dev/null @@ -1,255 +0,0 @@ -"""Statistical analysis service using Modal code execution. - -This module provides Modal-based statistical analysis WITHOUT depending on -agent_framework. This allows it to be used in the simple orchestrator mode -without requiring the magentic optional dependency. - -The AnalysisAgent (in src/agents/) wraps this service for magentic mode. -""" - -import asyncio -import re -from functools import lru_cache, partial -from typing import Any, Literal - -# Type alias for verdict values -VerdictType = Literal["SUPPORTED", "REFUTED", "INCONCLUSIVE"] - -from pydantic import BaseModel, Field -from pydantic_ai import Agent - -from src.agent_factory.judges import get_model -from src.tools.code_execution import ( - CodeExecutionError, - get_code_executor, - get_sandbox_library_prompt, -) -from src.utils.models import Evidence - - -class AnalysisResult(BaseModel): - """Result of statistical analysis.""" - - verdict: VerdictType = Field( - description="SUPPORTED, REFUTED, or INCONCLUSIVE", - ) - confidence: float = Field(ge=0.0, le=1.0, description="Confidence in verdict (0-1)") - statistical_evidence: str = Field( - description="Summary of statistical findings from code execution" - ) - code_generated: str = Field(description="Python code that was executed") - execution_output: str = Field(description="Output from code execution") - key_findings: list[str] = Field(default_factory=list, description="Key takeaways") - limitations: list[str] = Field(default_factory=list, description="Limitations") - - -class StatisticalAnalyzer: - """Performs statistical analysis using Modal code execution. - - This service: - 1. Generates Python code for statistical analysis using LLM - 2. Executes code in Modal sandbox - 3. Interprets results - 4. Returns verdict (SUPPORTED/REFUTED/INCONCLUSIVE) - - Note: This class has NO agent_framework dependency, making it safe - to use in the simple orchestrator without the magentic extra. - """ - - def __init__(self) -> None: - """Initialize the analyzer.""" - self._code_executor: Any = None - self._agent: Agent[None, str] | None = None - - def _get_code_executor(self) -> Any: - """Lazy initialization of code executor.""" - if self._code_executor is None: - self._code_executor = get_code_executor() - return self._code_executor - - def _get_agent(self) -> Agent[None, str]: - """Lazy initialization of LLM agent for code generation.""" - if self._agent is None: - library_versions = get_sandbox_library_prompt() - self._agent = Agent( - model=get_model(), - output_type=str, - system_prompt=f"""You are a biomedical data scientist. - -Generate Python code to analyze research evidence and test hypotheses. - -Guidelines: -1. Use pandas, numpy, scipy.stats for analysis -2. Print clear, interpretable results -3. Include statistical tests (t-tests, chi-square, etc.) -4. Calculate effect sizes and confidence intervals -5. Keep code concise (<50 lines) -6. Set 'result' variable to SUPPORTED, REFUTED, or INCONCLUSIVE - -Available libraries: -{library_versions} - -Output format: Return ONLY executable Python code, no explanations.""", - ) - return self._agent - - async def analyze( - self, - query: str, - evidence: list[Evidence], - hypothesis: dict[str, Any] | None = None, - ) -> AnalysisResult: - """Run statistical analysis on evidence. - - Args: - query: The research question - evidence: List of Evidence objects to analyze - hypothesis: Optional hypothesis dict with drug, target, pathway, effect - - Returns: - AnalysisResult with verdict and statistics - """ - # Build analysis prompt (method handles slicing internally) - evidence_summary = self._summarize_evidence(evidence) - hypothesis_text = "" - if hypothesis: - hypothesis_text = ( - f"\nHypothesis: {hypothesis.get('drug', 'Unknown')} → " - f"{hypothesis.get('target', '?')} → " - f"{hypothesis.get('pathway', '?')} → " - f"{hypothesis.get('effect', '?')}\n" - f"Confidence: {hypothesis.get('confidence', 0.5):.0%}\n" - ) - - prompt = f"""Generate Python code to statistically analyze: - -**Research Question**: {query} -{hypothesis_text} - -**Evidence Summary**: -{evidence_summary} - -Generate executable Python code to analyze this evidence.""" - - try: - # Generate code - agent = self._get_agent() - code_result = await agent.run(prompt) - generated_code = code_result.output - - # Execute in Modal sandbox - loop = asyncio.get_running_loop() - executor = self._get_code_executor() - execution = await loop.run_in_executor( - None, partial(executor.execute, generated_code, timeout=120) - ) - - if not execution["success"]: - return AnalysisResult( - verdict="INCONCLUSIVE", - confidence=0.0, - statistical_evidence=( - f"Execution failed: {execution.get('error', 'Unknown error')}" - ), - code_generated=generated_code, - execution_output=execution.get("stderr", ""), - key_findings=[], - limitations=["Code execution failed"], - ) - - # Interpret results - return self._interpret_results(generated_code, execution) - - except CodeExecutionError as e: - return AnalysisResult( - verdict="INCONCLUSIVE", - confidence=0.0, - statistical_evidence=str(e), - code_generated="", - execution_output="", - key_findings=[], - limitations=[f"Analysis error: {e}"], - ) - - def _summarize_evidence(self, evidence: list[Evidence]) -> str: - """Summarize evidence for code generation prompt.""" - if not evidence: - return "No evidence available." - - lines = [] - for i, ev in enumerate(evidence[:5], 1): - content = ev.content - truncated = content[:200] + ("..." if len(content) > 200 else "") - lines.append(f"{i}. {truncated}") - lines.append(f" Source: {ev.citation.title}") - lines.append(f" Relevance: {ev.relevance:.0%}\n") - - return "\n".join(lines) - - def _interpret_results( - self, - code: str, - execution: dict[str, Any], - ) -> AnalysisResult: - """Interpret code execution results.""" - stdout = execution["stdout"] - stdout_upper = stdout.upper() - - # Extract verdict with robust word-boundary matching - verdict: VerdictType = "INCONCLUSIVE" - if re.search(r"\bSUPPORTED\b", stdout_upper) and not re.search( - r"\b(?:NOT|UN)SUPPORTED\b", stdout_upper - ): - verdict = "SUPPORTED" - elif re.search(r"\bREFUTED\b", stdout_upper): - verdict = "REFUTED" - - # Extract key findings - key_findings = [] - for line in stdout.split("\n"): - line_lower = line.lower() - if any(kw in line_lower for kw in ["p-value", "significant", "effect", "mean"]): - key_findings.append(line.strip()) - - # Calculate confidence from p-values - confidence = self._calculate_confidence(stdout) - - return AnalysisResult( - verdict=verdict, - confidence=confidence, - statistical_evidence=stdout.strip(), - code_generated=code, - execution_output=stdout, - key_findings=key_findings[:5], - limitations=[ - "Analysis based on summary data only", - "Limited to available evidence", - "Statistical tests assume data independence", - ], - ) - - def _calculate_confidence(self, output: str) -> float: - """Calculate confidence based on statistical results.""" - p_values = re.findall(r"p[-\s]?value[:\s]+(\d+\.?\d*)", output.lower()) - - if p_values: - try: - min_p = min(float(p) for p in p_values) - if min_p < 0.001: - return 0.95 - elif min_p < 0.01: - return 0.90 - elif min_p < 0.05: - return 0.80 - else: - return 0.60 - except ValueError: - pass - - return 0.70 # Default - - -@lru_cache(maxsize=1) -def get_statistical_analyzer() -> StatisticalAnalyzer: - """Get or create singleton StatisticalAnalyzer instance (thread-safe via lru_cache).""" - return StatisticalAnalyzer() diff --git a/src/services/stt_gradio.py b/src/services/stt_gradio.py deleted file mode 100644 index 44b993b429f39aa3e15d7c83ce7ebabc77d619ba..0000000000000000000000000000000000000000 --- a/src/services/stt_gradio.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Speech-to-Text service using Gradio Client API.""" - -import asyncio -import tempfile -from functools import lru_cache -from pathlib import Path -from typing import Any - -import numpy as np -import structlog -from gradio_client import Client, handle_file - -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger(__name__) - - -class STTService: - """STT service using nvidia/canary-1b-v2 Gradio Space.""" - - def __init__(self, api_url: str | None = None, hf_token: str | None = None) -> None: - """Initialize STT service. - - Args: - api_url: Gradio Space URL (default: settings.stt_api_url or nvidia/canary-1b-v2) - hf_token: HuggingFace token for authenticated Spaces (default: None) - - Raises: - ConfigurationError: If API URL not configured - """ - self.api_url = api_url or settings.stt_api_url or "https://nvidia-canary-1b-v2.hf.space" - if not self.api_url: - raise ConfigurationError("STT API URL not configured") - self.hf_token = hf_token - self.client: Client | None = None - - async def _get_client(self, hf_token: str | None = None) -> Client: - """Get or create Gradio Client (lazy initialization). - - Args: - hf_token: HuggingFace token for authenticated Spaces (overrides instance token) - - Returns: - Gradio Client instance - """ - # Use provided token or instance token - token = hf_token or self.hf_token - - # If client exists but token changed, recreate it - if self.client is not None and token != self.hf_token: - self.client = None - - if self.client is None: - loop = asyncio.get_running_loop() - # Pass token to Client for authenticated Spaces - # Gradio Client uses 'token' parameter, not 'hf_token' - if token: - self.client = await loop.run_in_executor( - None, - lambda: Client(self.api_url, token=token), - ) - else: - self.client = await loop.run_in_executor( - None, - lambda: Client(self.api_url), - ) - # Update instance token for future use - self.hf_token = token - return self.client - - async def transcribe_file( - self, - audio_path: str, - source_lang: str | None = None, - target_lang: str | None = None, - hf_token: str | None = None, - ) -> str: - """Transcribe audio file using Gradio API. - - Args: - audio_path: Path to audio file - source_lang: Source language (default: settings.stt_source_lang) - target_lang: Target language (default: settings.stt_target_lang) - - Returns: - Transcribed text string - - Raises: - ConfigurationError: If transcription fails - """ - client = await self._get_client(hf_token=hf_token) - source_lang = source_lang or settings.stt_source_lang - target_lang = target_lang or settings.stt_target_lang - - logger.info( - "transcribing_audio_file", - audio_path=audio_path, - source_lang=source_lang, - target_lang=target_lang, - ) - - try: - # Call /transcribe_file API endpoint - # API returns: (dataframe, csv_path, srt_path) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor( - None, - lambda: client.predict( - audio_path=handle_file(audio_path), - source_lang=source_lang, - target_lang=target_lang, - api_name="/transcribe_file", - ), - ) - - # Extract transcription from result - transcribed_text = self._extract_transcription(result) - - logger.info( - "audio_transcription_complete", - text_length=len(transcribed_text), - ) - - return transcribed_text - - except Exception as e: - logger.error("audio_transcription_failed", error=str(e), error_type=type(e).__name__) - raise ConfigurationError(f"Audio transcription failed: {e}") from e - - async def transcribe_audio( - self, - audio_data: tuple[int, np.ndarray[Any, Any]], # type: ignore[type-arg] - hf_token: str | None = None, - ) -> str: - """Transcribe audio numpy array to text. - - Args: - audio_data: Tuple of (sample_rate, audio_array) - - Returns: - Transcribed text string - """ - sample_rate, audio_array = audio_data - - logger.info( - "transcribing_audio_array", - sample_rate=sample_rate, - audio_shape=audio_array.shape, - ) - - # Save audio to temp file - temp_path = self._save_audio_temp(audio_data) - - try: - # Transcribe the temp file - transcribed_text = await self.transcribe_file(temp_path, hf_token=hf_token) - return transcribed_text - finally: - # Clean up temp file - try: - Path(temp_path).unlink(missing_ok=True) - except Exception as e: - logger.warning("failed_to_cleanup_temp_file", path=temp_path, error=str(e)) - - def _extract_transcription(self, api_result: tuple[Any, ...]) -> str: - """Extract transcription text from API result. - - Args: - api_result: Tuple from Gradio API (dataframe, csv_path, srt_path) - - Returns: - Extracted transcription text - """ - # API returns: (dataframe, csv_path, srt_path) - # Try to extract from dataframe first - if isinstance(api_result, tuple) and len(api_result) >= 1: - dataframe = api_result[0] - if isinstance(dataframe, dict) and "data" in dataframe: - # Extract text from dataframe rows - rows = dataframe.get("data", []) - if rows: - # Combine all text segments - text_segments = [] - for row in rows: - if isinstance(row, list) and len(row) > 0: - # First column is usually the text - text_segments.append(str(row[0])) - if text_segments: - return " ".join(text_segments) - - # Fallback: try to read CSV file if available - if len(api_result) >= 2 and api_result[1]: - csv_path = api_result[1] - try: - import pandas as pd - - df = pd.read_csv(csv_path) - if "text" in df.columns: - return " ".join(df["text"].astype(str).tolist()) - elif len(df.columns) > 0: - # Use first column - return " ".join(df.iloc[:, 0].astype(str).tolist()) - except Exception as e: - logger.warning("failed_to_read_csv", csv_path=csv_path, error=str(e)) - - # Last resort: return empty string - logger.warning("could_not_extract_transcription", result_type=type(api_result).__name__) - return "" - - def _save_audio_temp( - self, - audio_data: tuple[int, np.ndarray[Any, Any]], # type: ignore[type-arg] - ) -> str: - """Save audio numpy array to temporary WAV file. - - Args: - audio_data: Tuple of (sample_rate, audio_array) - - Returns: - Path to temporary WAV file - """ - sample_rate, audio_array = audio_data - - # Create temp file - temp_file = tempfile.NamedTemporaryFile( - suffix=".wav", - delete=False, - ) - temp_path = temp_file.name - temp_file.close() - - # Save audio using soundfile - try: - import soundfile as sf - - # Ensure audio is float32 and mono - if audio_array.dtype != np.float32: - audio_array = audio_array.astype(np.float32) - - # Handle stereo -> mono conversion - if len(audio_array.shape) > 1: - audio_array = np.mean(audio_array, axis=1) - - # Normalize to [-1, 1] range - if audio_array.max() > 1.0 or audio_array.min() < -1.0: - audio_array = audio_array / np.max(np.abs(audio_array)) - - sf.write(temp_path, audio_array, sample_rate) - - logger.debug("saved_audio_temp", path=temp_path, sample_rate=sample_rate) - - return temp_path - - except ImportError: - raise ConfigurationError( - "soundfile not installed. Install with: uv add soundfile" - ) from None - except Exception as e: - logger.error("failed_to_save_audio_temp", error=str(e)) - raise ConfigurationError(f"Failed to save audio to temp file: {e}") from e - - -@lru_cache(maxsize=1) -def get_stt_service() -> STTService: - """Get or create singleton STT service instance. - - Returns: - STTService instance - """ - return STTService() diff --git a/src/services/tts_modal.py b/src/services/tts_modal.py deleted file mode 100644 index ce55c49aad24f1f2f1e37ddccbdbbd79df208305..0000000000000000000000000000000000000000 --- a/src/services/tts_modal.py +++ /dev/null @@ -1,482 +0,0 @@ -"""Text-to-Speech service using Kokoro 82M via Modal GPU.""" - -import asyncio -import os -from collections.abc import Iterator -from contextlib import contextmanager -from functools import lru_cache -from typing import Any, cast - -import numpy as np -from numpy.typing import NDArray -import structlog - -# Load .env file BEFORE importing Modal SDK -# Modal SDK reads MODAL_TOKEN_ID and MODAL_TOKEN_SECRET from environment on import -from dotenv import load_dotenv - -load_dotenv() - -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger(__name__) - -# Kokoro TTS dependencies for Modal image -KOKORO_DEPENDENCIES = [ - "torch>=2.0.0", - "transformers>=4.30.0", - "numpy<2.0", - # kokoro-82M can be installed from source: - # git+https://github.com/hexgrad/kokoro.git -] - -# Modal app and function definitions (module-level for Modal) -_modal_app: Any | None = None -_tts_function: Any | None = None -_tts_image: Any | None = None - - -@contextmanager -def modal_credentials_override(token_id: str | None, token_secret: str | None) -> Iterator[None]: - """Context manager to temporarily override Modal credentials. - - Args: - token_id: Modal token ID (overrides env if provided) - token_secret: Modal token secret (overrides env if provided) - - Yields: - None - - Note: - Resets global Modal state to force re-initialization with new credentials. - """ - global _modal_app, _tts_function - - # Save original credentials - original_token_id = os.environ.get("MODAL_TOKEN_ID") - original_token_secret = os.environ.get("MODAL_TOKEN_SECRET") - - # Save original Modal state - original_app = _modal_app - original_function = _tts_function - - try: - # Override environment variables if provided - if token_id: - os.environ["MODAL_TOKEN_ID"] = token_id - if token_secret: - os.environ["MODAL_TOKEN_SECRET"] = token_secret - - # Reset Modal state to force re-initialization - _modal_app = None - _tts_function = None - - yield - - finally: - # Restore original credentials - if original_token_id is not None: - os.environ["MODAL_TOKEN_ID"] = original_token_id - elif "MODAL_TOKEN_ID" in os.environ: - del os.environ["MODAL_TOKEN_ID"] - - if original_token_secret is not None: - os.environ["MODAL_TOKEN_SECRET"] = original_token_secret - elif "MODAL_TOKEN_SECRET" in os.environ: - del os.environ["MODAL_TOKEN_SECRET"] - - # Restore original Modal state - _modal_app = original_app - _tts_function = original_function - - -def _get_modal_app() -> Any: - """Get or create Modal app instance. - - Retrieves Modal credentials directly from environment variables (.env file) - instead of relying on settings configuration. - """ - global _modal_app - if _modal_app is None: - try: - import modal - - # Get credentials directly from environment variables - token_id = os.getenv("MODAL_TOKEN_ID") - token_secret = os.getenv("MODAL_TOKEN_SECRET") - - # Validate Modal credentials - if not token_id or not token_secret: - raise ConfigurationError( - "Modal credentials not found in environment. " - "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file." - ) - - # Validate token ID format (Modal token IDs are typically UUIDs or specific formats) - if len(token_id.strip()) < 10: - raise ConfigurationError( - f"Modal token ID appears malformed (too short: {len(token_id)} chars). " - "Token ID should be a valid Modal token identifier." - ) - - logger.info( - "modal_credentials_loaded", - token_id_prefix=token_id[:8] + "...", # Log prefix for debugging - has_secret=bool(token_secret), - ) - - try: - # Use lookup with create_if_missing for inline function fallback - _modal_app = modal.App.lookup("deepcritical-tts", create_if_missing=True) - except Exception as e: - error_msg = str(e).lower() - if "token" in error_msg or "malformed" in error_msg or "invalid" in error_msg: - raise ConfigurationError( - f"Modal token validation failed: {e}. " - "Please check that MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env are correctly set." - ) from e - raise - except ImportError as e: - raise ConfigurationError( - "Modal SDK not installed. Run: uv sync or pip install modal>=0.63.0" - ) from e - return _modal_app - - -# Define Modal image with Kokoro dependencies (module-level) -def _get_tts_image() -> Any: - """Get Modal image with Kokoro dependencies.""" - global _tts_image - if _tts_image is not None: - return _tts_image - - try: - import modal - - _tts_image = ( - modal.Image.debian_slim(python_version="3.11") - .pip_install(*KOKORO_DEPENDENCIES) - .pip_install("git+https://github.com/hexgrad/kokoro.git") - ) - return _tts_image - except ImportError: - return None - - -# Modal TTS function - Using serialized=True to allow dynamic creation -# This will be initialized lazily when _setup_modal_function() is called -def _create_tts_function() -> Any: - """Create the Modal TTS function using serialized=True. - - The serialized=True parameter allows the function to be defined outside - of global scope, which is necessary for dynamic initialization. - """ - app = _get_modal_app() - tts_image = _get_tts_image() - - if tts_image is None: - raise ConfigurationError("Modal image setup failed") - - # Get GPU and timeout from settings (with defaults) - gpu_type = getattr(settings, "tts_gpu", None) or "T4" - timeout_seconds = getattr(settings, "tts_timeout", None) or 120 # 2 minutes for cold starts - - @app.function( - image=tts_image, - gpu=gpu_type, - timeout=timeout_seconds, - serialized=True, # Allow function to be defined outside global scope - ) - def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, NDArray[np.float32]]: - """Modal GPU function for Kokoro TTS. - - This function runs on Modal's GPU infrastructure. - Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS - Reference: https://huggingface.co/spaces/hexgrad/Kokoro-TTS/raw/main/app.py - """ - import numpy as np - - # Import Kokoro inside function (lazy load) - try: - import torch - from kokoro import KModel, KPipeline - - # Initialize model (cached on GPU) - model = KModel().to("cuda").eval() - pipeline = KPipeline(lang_code=voice[0]) - pack = pipeline.load_voice(voice) - - # Generate audio - for _, ps, _ in pipeline(text, voice, speed): - ref_s = pack[len(ps) - 1] - audio = model(ps, ref_s, speed) - return (24000, audio.numpy()) - - # If no audio generated, return empty - return (24000, np.zeros(1, dtype=np.float32)) - - except ImportError as e: - raise ConfigurationError( - "Kokoro not installed. Install with: pip install git+https://github.com/hexgrad/kokoro.git" - ) from e - except Exception as e: - raise ConfigurationError(f"TTS synthesis failed: {e}") from e - - return kokoro_tts_function - - -def _setup_modal_function() -> None: - """Setup Modal GPU function for TTS (called once, lazy initialization). - - Hybrid approach: - 1. Try to lookup pre-deployed function (fast path for advanced users) - 2. If lookup fails, create function inline (fallback for casual users) - - This allows both workflows: - - Advanced: Deploy with `modal deploy deployments/modal_tts.py` for best performance - - Casual: Just add Modal keys and it auto-creates function on first use - """ - global _tts_function - - if _tts_function is not None: - return # Already set up - - try: - import modal - - # Try path 1: Lookup pre-deployed function (fast path) - try: - _tts_function = modal.Function.from_name("deepcritical-tts", "kokoro_tts_function") - logger.info( - "modal_tts_function_lookup_success", - app_name="deepcritical-tts", - function_name="kokoro_tts_function", - method="lookup", - ) - return - except Exception as lookup_error: - logger.info( - "modal_tts_function_lookup_failed", - error=str(lookup_error), - fallback="Creating function inline", - ) - - # Try path 2: Create function inline (fallback for casual users) - logger.info("modal_tts_creating_inline_function") - _tts_function = _create_tts_function() - logger.info( - "modal_tts_function_setup_complete", - app_name="deepcritical-tts", - function_name="kokoro_tts_function", - method="inline", - ) - - except Exception as e: - logger.error("modal_tts_function_setup_failed", error=str(e)) - raise ConfigurationError( - f"Failed to setup Modal TTS function: {e}. " - "Ensure Modal credentials (MODAL_TOKEN_ID, MODAL_TOKEN_SECRET) are valid." - ) from e - - -class ModalTTSExecutor: - """Execute Kokoro TTS synthesis on Modal GPU. - - This class provides TTS synthesis using Kokoro 82M model on Modal's GPU infrastructure. - Follows the same pattern as ModalCodeExecutor but uses GPU functions for TTS. - """ - - def __init__(self) -> None: - """Initialize Modal TTS executor. - - Note: - Logs a warning if Modal credentials are not configured in environment. - Execution will fail at runtime without valid credentials in .env file. - """ - # Check for Modal credentials directly from environment - token_id = os.getenv("MODAL_TOKEN_ID") - token_secret = os.getenv("MODAL_TOKEN_SECRET") - - if not token_id or not token_secret: - logger.warning( - "Modal credentials not found in environment. " - "TTS will not be available. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file." - ) - - def synthesize( - self, - text: str, - voice: str = "af_heart", - speed: float = 1.0, - timeout: int = 120, - ) -> tuple[int, NDArray[np.float32]]: - """Synthesize text to speech using Kokoro on Modal GPU. - - Args: - text: Text to synthesize (max 5000 chars for free tier) - voice: Voice ID from Kokoro (e.g., af_heart, af_bella, am_michael) - speed: Speech speed multiplier (0.5-2.0) - timeout: Maximum execution time (not used, Modal function has its own timeout) - - Returns: - Tuple of (sample_rate, audio_array) - - Raises: - ConfigurationError: If synthesis fails - """ - # Setup Modal function if not already done - _setup_modal_function() - - if _tts_function is None: - raise ConfigurationError("Modal TTS function not initialized") - - logger.info("synthesizing_tts", text_length=len(text), voice=voice, speed=speed) - - try: - # Call the GPU function remotely - result = cast(tuple[int, NDArray[np.float32]], _tts_function.remote(text, voice, speed)) - - logger.info( - "tts_synthesis_complete", sample_rate=result[0], audio_shape=result[1].shape - ) - - return result - - except Exception as e: - logger.error("tts_synthesis_failed", error=str(e), error_type=type(e).__name__) - raise ConfigurationError(f"TTS synthesis failed: {e}") from e - - -class TTSService: - """TTS service wrapper for async usage.""" - - def __init__(self) -> None: - """Initialize TTS service. - - Validates Modal credentials from environment variables (.env file). - """ - # Check credentials directly from environment - token_id = os.getenv("MODAL_TOKEN_ID") - token_secret = os.getenv("MODAL_TOKEN_SECRET") - - if not token_id or not token_secret: - raise ConfigurationError( - "Modal credentials required for TTS. " - "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file." - ) - self.executor = ModalTTSExecutor() - - async def synthesize_async( - self, - text: str, - voice: str = "af_heart", - speed: float = 1.0, - ) -> tuple[int, NDArray[np.float32]] | None: - """Async wrapper for TTS synthesis. - - Args: - text: Text to synthesize - voice: Voice ID (default: settings.tts_voice) - speed: Speech speed (default: settings.tts_speed) - - Returns: - Tuple of (sample_rate, audio_array) or None if error - """ - voice = voice or settings.tts_voice - speed = speed or settings.tts_speed - - loop = asyncio.get_running_loop() - - try: - result = await loop.run_in_executor( - None, - lambda: self.executor.synthesize(text, voice, speed), - ) - return result - except Exception as e: - logger.error("tts_synthesis_async_failed", error=str(e)) - return None - - -@lru_cache(maxsize=1) -def get_tts_service() -> TTSService: - """Get or create singleton TTS service instance. - - Returns: - TTSService instance - - Raises: - ConfigurationError: If Modal credentials not configured - """ - return TTSService() - - -async def generate_audio_on_demand( - text: str, - modal_token_id: str | None = None, - modal_token_secret: str | None = None, - voice: str = "af_heart", - speed: float = 1.0, - use_llm_polish: bool = False, -) -> tuple[tuple[int, NDArray[np.float32]] | None, str]: - """Generate audio on-demand with optional runtime credentials. - - Args: - text: Text to synthesize - modal_token_id: Modal token ID (UI input, overrides .env) - modal_token_secret: Modal token secret (UI input, overrides .env) - voice: Voice ID (default: af_heart) - speed: Speech speed (default: 1.0) - use_llm_polish: Apply LLM polish to text (default: False) - - Returns: - Tuple of (audio_output, status_message) - - audio_output: (sample_rate, audio_array) or None if failed - - status_message: Status/error message for user - - Priority: UI credentials > .env credentials - """ - # Priority: UI keys > .env keys - token_id = (modal_token_id or "").strip() or os.getenv("MODAL_TOKEN_ID") - token_secret = (modal_token_secret or "").strip() or os.getenv("MODAL_TOKEN_SECRET") - - if not token_id or not token_secret: - return ( - None, - "❌ Modal credentials required. Enter keys above or set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env", - ) - - try: - # Use credentials override context - with modal_credentials_override(token_id, token_secret): - # Import audio_processing here to avoid circular import - from src.services.audio_processing import AudioService - - # Temporarily override LLM polish setting - original_llm_polish = settings.tts_use_llm_polish - try: - settings.tts_use_llm_polish = use_llm_polish - - # Create fresh AudioService instance (bypass cache to pick up new credentials) - audio_service = AudioService() - audio_output = await audio_service.generate_audio_output( - text=text, - voice=voice, - speed=speed, - ) - - if audio_output: - return audio_output, "✅ Audio generated successfully" - else: - return None, "⚠️ Audio generation returned no output" - - finally: - settings.tts_use_llm_polish = original_llm_polish - - except ConfigurationError as e: - logger.error("audio_generation_config_error", error=str(e)) - return None, f"❌ Configuration error: {e}" - except Exception as e: - logger.error("audio_generation_failed", error=str(e), exc_info=True) - return None, f"❌ Audio generation failed: {e}" diff --git a/src/state/__init__.py b/src/state/__init__.py deleted file mode 100644 index a2323db724ea3ee5902b1121cdb2a30406319fe5..0000000000000000000000000000000000000000 --- a/src/state/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""State package - re-exports from agents.state for compatibility.""" - -from src.agents.state import ( - MagenticState, - get_magentic_state, - init_magentic_state, -) - -__all__ = ["MagenticState", "get_magentic_state", "init_magentic_state"] diff --git a/src/tools/__init__.py b/src/tools/__init__.py deleted file mode 100644 index 19f5c3ecd4c646af301ab73eab0286698ce99822..0000000000000000000000000000000000000000 --- a/src/tools/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Search tools package.""" - -from src.tools.base import SearchTool -from src.tools.pubmed import PubMedTool -from src.tools.rag_tool import RAGTool, create_rag_tool -from src.tools.search_handler import SearchHandler - -# Re-export -__all__ = [ - "PubMedTool", - "RAGTool", - "SearchHandler", - "SearchTool", - "create_rag_tool", -] diff --git a/src/tools/base.py b/src/tools/base.py deleted file mode 100644 index 0533c851973e3063d735d55f618c44c4ad8418ba..0000000000000000000000000000000000000000 --- a/src/tools/base.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Base classes and protocols for search tools.""" - -from typing import Protocol - -from src.utils.models import Evidence - - -class SearchTool(Protocol): - """Protocol defining the interface for all search tools.""" - - @property - def name(self) -> str: - """Human-readable name of this tool.""" - ... - - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """ - Execute a search and return evidence. - - Args: - query: The search query string - max_results: Maximum number of results to return - - Returns: - List of Evidence objects - - Raises: - SearchError: If the search fails - RateLimitError: If we hit rate limits - """ - ... diff --git a/src/tools/clinicaltrials.py b/src/tools/clinicaltrials.py deleted file mode 100644 index 06ef61158096e0ab2a5ff5a52f77e428de13bb4c..0000000000000000000000000000000000000000 --- a/src/tools/clinicaltrials.py +++ /dev/null @@ -1,143 +0,0 @@ -"""ClinicalTrials.gov search tool using API v2.""" - -import asyncio -from typing import Any, ClassVar - -import requests -from tenacity import retry, stop_after_attempt, wait_exponential - -from src.utils.exceptions import SearchError -from src.utils.models import Citation, Evidence - - -class ClinicalTrialsTool: - """Search tool for ClinicalTrials.gov. - - Note: Uses `requests` library instead of `httpx` because ClinicalTrials.gov's - WAF blocks httpx's TLS fingerprint. The `requests` library is not blocked. - See: https://clinicaltrials.gov/data-api/api - """ - - BASE_URL = "https://clinicaltrials.gov/api/v2/studies" - - # Fields to retrieve - FIELDS: ClassVar[list[str]] = [ - "NCTId", - "BriefTitle", - "Phase", - "OverallStatus", - "Condition", - "InterventionName", - "StartDate", - "BriefSummary", - ] - - # Status filter: Only active/completed studies with potential data - STATUS_FILTER = "COMPLETED,ACTIVE_NOT_RECRUITING,RECRUITING,ENROLLING_BY_INVITATION" - - # Study type filter: Only interventional (drug/treatment studies) - STUDY_TYPE_FILTER = "INTERVENTIONAL" - - @property - def name(self) -> str: - return "clinicaltrials" - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=1, max=10), - reraise=True, - ) - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """Search ClinicalTrials.gov for interventional studies. - - Args: - query: Search query (e.g., "metformin alzheimer") - max_results: Maximum results to return (max 100) - - Returns: - List of Evidence objects from clinical trials - """ - # Add study type filter to query string (parameter is not supported) - # AREA[StudyType]INTERVENTIONAL restricts to interventional studies - final_query = f"{query} AND AREA[StudyType]INTERVENTIONAL" - - params: dict[str, Any] = { - "query.term": final_query, - "pageSize": min(max_results, 100), - "fields": ",".join(self.FIELDS), - # FILTERS - Only active/completed studies - "filter.overallStatus": self.STATUS_FILTER, - } - - try: - # Run blocking requests.get in a separate thread for async compatibility - response = await asyncio.to_thread( - requests.get, - self.BASE_URL, - params=params, - headers={"User-Agent": "DETERMINATOR-Research-Agent/1.0"}, - timeout=30, - ) - response.raise_for_status() - - data = response.json() - studies = data.get("studies", []) - return [self._study_to_evidence(study) for study in studies[:max_results]] - - except requests.HTTPError as e: - raise SearchError(f"ClinicalTrials.gov API error: {e}") from e - except requests.RequestException as e: - raise SearchError(f"ClinicalTrials.gov request failed: {e}") from e - - def _study_to_evidence(self, study: dict[str, Any]) -> Evidence: - """Convert a clinical trial study to Evidence.""" - # Navigate nested structure - protocol = study.get("protocolSection", {}) - id_module = protocol.get("identificationModule", {}) - status_module = protocol.get("statusModule", {}) - desc_module = protocol.get("descriptionModule", {}) - design_module = protocol.get("designModule", {}) - conditions_module = protocol.get("conditionsModule", {}) - arms_module = protocol.get("armsInterventionsModule", {}) - - nct_id = id_module.get("nctId", "Unknown") - title = id_module.get("briefTitle", "Untitled Study") - status = status_module.get("overallStatus", "Unknown") - start_date = status_module.get("startDateStruct", {}).get("date", "Unknown") - - # Get phase (might be a list) - phases = design_module.get("phases", []) - phase = phases[0] if phases else "Not Applicable" - - # Get conditions - conditions = conditions_module.get("conditions", []) - conditions_str = ", ".join(conditions[:3]) if conditions else "Unknown" - - # Get interventions - interventions = arms_module.get("interventions", []) - intervention_names = [i.get("name", "") for i in interventions[:3]] - interventions_str = ", ".join(intervention_names) if intervention_names else "Unknown" - - # Get summary - summary = desc_module.get("briefSummary", "No summary available.") - - # Build content with key trial info - content = ( - f"{summary[:500]}... " - f"Trial Phase: {phase}. " - f"Status: {status}. " - f"Conditions: {conditions_str}. " - f"Interventions: {interventions_str}." - ) - - return Evidence( - content=content[:2000], - citation=Citation( - source="clinicaltrials", - title=title[:500], - url=f"https://clinicaltrials.gov/study/{nct_id}", - date=start_date, - authors=[], # Trials don't have traditional authors - ), - relevance=0.85, # Trials are highly relevant for repurposing - ) diff --git a/src/tools/code_execution.py b/src/tools/code_execution.py index da22fa9b1f864b0ab0e3778706c150ae60c714fd..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/tools/code_execution.py +++ b/src/tools/code_execution.py @@ -1,256 +0,0 @@ -"""Modal-based secure code execution tool for statistical analysis. - -This module provides sandboxed Python code execution using Modal's serverless infrastructure. -It's designed for running LLM-generated statistical analysis code safely. -""" - -import os -from functools import lru_cache -from typing import Any - -import structlog - -logger = structlog.get_logger(__name__) - -# Shared library versions for Modal sandbox - used by both executor and LLM prompts -# Keep these in sync to avoid version mismatch between generated code and execution -SANDBOX_LIBRARIES: dict[str, str] = { - "pandas": "2.2.0", - "numpy": "1.26.4", - "scipy": "1.11.4", - "matplotlib": "3.8.2", - "scikit-learn": "1.4.0", - "statsmodels": "0.14.1", -} - - -def get_sandbox_library_list() -> list[str]: - """Get list of library==version strings for Modal image.""" - return [f"{lib}=={ver}" for lib, ver in SANDBOX_LIBRARIES.items()] - - -def get_sandbox_library_prompt() -> str: - """Get formatted library versions for LLM prompts.""" - return "\n".join(f"- {lib}=={ver}" for lib, ver in SANDBOX_LIBRARIES.items()) - - -class CodeExecutionError(Exception): - """Raised when code execution fails.""" - - pass - - -class ModalCodeExecutor: - """Execute Python code securely using Modal sandboxes. - - This class provides a safe environment for executing LLM-generated code, - particularly for scientific computing and statistical analysis tasks. - - Features: - - Sandboxed execution (isolated from host system) - - Pre-installed scientific libraries (numpy, scipy, pandas, matplotlib) - - Network isolation for security - - Timeout protection - - Stdout/stderr capture - - Example: - >>> executor = ModalCodeExecutor() - >>> result = executor.execute(''' - ... import pandas as pd - ... df = pd.DataFrame({'a': [1, 2, 3]}) - ... result = df['a'].sum() - ... ''') - >>> print(result['stdout']) - 6 - """ - - def __init__(self) -> None: - """Initialize Modal code executor. - - Note: - Logs a warning if Modal credentials are not configured. - Execution will fail at runtime without valid credentials. - """ - # Check for Modal credentials - self.modal_token_id = os.getenv("MODAL_TOKEN_ID") - self.modal_token_secret = os.getenv("MODAL_TOKEN_SECRET") - - if not self.modal_token_id or not self.modal_token_secret: - logger.warning( - "Modal credentials not found. Code execution will fail unless modal setup is run." - ) - - def execute(self, code: str, timeout: int = 60, allow_network: bool = False) -> dict[str, Any]: - """Execute Python code in a Modal sandbox. - - Args: - code: Python code to execute - timeout: Maximum execution time in seconds (default: 60) - allow_network: Whether to allow network access (default: False for security) - - Returns: - Dictionary containing: - - stdout: Standard output from code execution - - stderr: Standard error from code execution - - success: Boolean indicating if execution succeeded - - error: Error message if execution failed - - Raises: - CodeExecutionError: If execution fails or times out - """ - try: - import modal - except ImportError as e: - raise CodeExecutionError( - "Modal SDK not installed. Run: uv sync or pip install modal>=0.63.0" - ) from e - - logger.info("executing_code", code_length=len(code), timeout=timeout) - - try: - # Create or lookup Modal app - app = modal.App.lookup("deepcritical-code-execution", create_if_missing=True) - - # Define scientific computing image with common libraries - scientific_image = modal.Image.debian_slim(python_version="3.11").uv_pip_install( - *get_sandbox_library_list() - ) - - # Create sandbox with security restrictions - sandbox = modal.Sandbox.create( - app=app, - image=scientific_image, - timeout=timeout, - block_network=not allow_network, # Wire the network control - ) - - try: - # Execute the code - # Wrap code to capture result - wrapped_code = f""" -import sys -import io -from contextlib import redirect_stdout, redirect_stderr - -stdout_io = io.StringIO() -stderr_io = io.StringIO() - -try: - with redirect_stdout(stdout_io), redirect_stderr(stderr_io): - {self._indent_code(code, 8)} - print("__EXECUTION_SUCCESS__") -except Exception as e: - print(f"__EXECUTION_ERROR__: {{type(e).__name__}}: {{e}}", file=sys.stderr) - -print("__STDOUT_START__") -print(stdout_io.getvalue()) -print("__STDOUT_END__") -print("__STDERR_START__") -print(stderr_io.getvalue(), file=sys.stderr) -print("__STDERR_END__", file=sys.stderr) -""" - - # Run the wrapped code - process = sandbox.exec("python", "-c", wrapped_code, timeout=timeout) - - # Read output - stdout_raw = process.stdout.read() - stderr_raw = process.stderr.read() - finally: - # Always clean up sandbox to prevent resource leaks - sandbox.terminate() - - # Parse output - success = "__EXECUTION_SUCCESS__" in stdout_raw - - # Extract actual stdout/stderr - stdout = self._extract_output(stdout_raw, "__STDOUT_START__", "__STDOUT_END__") - stderr = self._extract_output(stderr_raw, "__STDERR_START__", "__STDERR_END__") - - result = { - "stdout": stdout, - "stderr": stderr, - "success": success, - "error": stderr if not success else None, - } - - logger.info( - "code_execution_completed", - success=success, - stdout_length=len(stdout), - stderr_length=len(stderr), - ) - - return result - - except Exception as e: - logger.error("code_execution_failed", error=str(e), error_type=type(e).__name__) - raise CodeExecutionError(f"Code execution failed: {e}") from e - - def execute_with_return(self, code: str, timeout: int = 60) -> Any: - """Execute code and return the value of the 'result' variable. - - Convenience method that executes code and extracts a return value. - The code should assign its final result to a variable named 'result'. - - Args: - code: Python code to execute (must set 'result' variable) - timeout: Maximum execution time in seconds - - Returns: - The value of the 'result' variable from the executed code - - Example: - >>> executor.execute_with_return("result = 2 + 2") - 4 - """ - # Modify code to print result as JSON - wrapped = f""" -import json -{code} -print(json.dumps({{"__RESULT__": result}})) -""" - - execution_result = self.execute(wrapped, timeout=timeout) - - if not execution_result["success"]: - raise CodeExecutionError(f"Execution failed: {execution_result['error']}") - - # Parse result from stdout - import json - - try: - output = execution_result["stdout"].strip() - if "__RESULT__" in output: - # Extract JSON line - for line in output.split("\n"): - if "__RESULT__" in line: - data = json.loads(line) - return data["__RESULT__"] - raise ValueError("Result not found in output") - except (json.JSONDecodeError, ValueError) as e: - logger.warning( - "failed_to_parse_result", error=str(e), stdout=execution_result["stdout"] - ) - return execution_result["stdout"] - - def _indent_code(self, code: str, spaces: int) -> str: - """Indent code by specified number of spaces.""" - indent = " " * spaces - return "\n".join(indent + line if line.strip() else line for line in code.split("\n")) - - def _extract_output(self, text: str, start_marker: str, end_marker: str) -> str: - """Extract content between markers.""" - try: - start_idx = text.index(start_marker) + len(start_marker) - end_idx = text.index(end_marker) - return text[start_idx:end_idx].strip() - except ValueError: - # Markers not found, return original text - return text.strip() - - -@lru_cache(maxsize=1) -def get_code_executor() -> ModalCodeExecutor: - """Get or create singleton code executor instance (thread-safe via lru_cache).""" - return ModalCodeExecutor() diff --git a/src/tools/crawl_adapter.py b/src/tools/crawl_adapter.py deleted file mode 100644 index 332569c52698e39a46085f81989e0c386d402347..0000000000000000000000000000000000000000 --- a/src/tools/crawl_adapter.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Website crawl tool adapter for Pydantic AI agents. - -Uses the vendored crawl_website implementation from src/tools/vendored/crawl_website.py. -""" - -import structlog - -logger = structlog.get_logger() - - -async def crawl_website(starting_url: str) -> str: - """ - Crawl a website starting from the given URL and return formatted results. - - Use this tool to crawl a website for information relevant to the query. - Provide a starting URL as input. - - Args: - starting_url: The starting URL to crawl (e.g., "https://example.com") - - Returns: - Formatted string with crawled content including titles, descriptions, and URLs - """ - try: - # Import vendored crawl tool - from src.tools.vendored.crawl_website import crawl_website as crawl_tool - - # Call the tool function - # The tool returns List[ScrapeResult] or str - results = await crawl_tool(starting_url) - - if isinstance(results, str): - # Error message returned - logger.warning("Crawl returned error", error=results) - return results - - if not results: - return f"No content found when crawling: {starting_url}" - - # Format results for agent consumption - formatted = [f"Found {len(results)} pages from {starting_url}:\n"] - for i, result in enumerate(results[:10], 1): # Limit to 10 pages - formatted.append(f"{i}. **{result.title or 'Untitled'}**") - if result.description: - formatted.append(f" {result.description[:200]}...") - formatted.append(f" URL: {result.url}") - if result.text: - formatted.append(f" Content: {result.text[:500]}...") - formatted.append("") - - return "\n".join(formatted) - - except ImportError as e: - logger.error("Crawl tool not available", error=str(e)) - return f"Crawl tool not available: {e!s}" - except Exception as e: - logger.error("Crawl failed", error=str(e), url=starting_url) - return f"Error crawling website: {e!s}" diff --git a/src/tools/europepmc.py b/src/tools/europepmc.py deleted file mode 100644 index 4dee1f58bc9e1e674578b9e4e3bcd7261f9895c4..0000000000000000000000000000000000000000 --- a/src/tools/europepmc.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Europe PMC search tool - replaces BioRxiv.""" - -from typing import Any - -import httpx -from tenacity import retry, stop_after_attempt, wait_exponential - -from src.utils.exceptions import SearchError -from src.utils.models import Citation, Evidence - - -class EuropePMCTool: - """ - Search Europe PMC for papers and preprints. - - Europe PMC indexes: - - PubMed/MEDLINE articles - - PMC full-text articles - - Preprints from bioRxiv, medRxiv, ChemRxiv, etc. - - Patents and clinical guidelines - - API Docs: https://europepmc.org/RestfulWebService - """ - - BASE_URL = "https://www.ebi.ac.uk/europepmc/webservices/rest/search" - - @property - def name(self) -> str: - return "europepmc" - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=1, max=10), - reraise=True, - ) - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """ - Search Europe PMC for papers matching query. - - Args: - query: Search keywords - max_results: Maximum results to return - - Returns: - List of Evidence objects - """ - params: dict[str, str | int] = { - "query": query, - "resultType": "core", - "pageSize": min(max_results, 100), - "format": "json", - } - - async with httpx.AsyncClient(timeout=30.0) as client: - try: - response = await client.get(self.BASE_URL, params=params) - response.raise_for_status() - - data = response.json() - results = data.get("resultList", {}).get("result", []) - - return [self._to_evidence(r) for r in results[:max_results]] - - except httpx.HTTPStatusError as e: - raise SearchError(f"Europe PMC API error: {e}") from e - except httpx.RequestError as e: - raise SearchError(f"Europe PMC connection failed: {e}") from e - - def _to_evidence(self, result: dict[str, Any]) -> Evidence: - """Convert Europe PMC result to Evidence.""" - title = result.get("title", "Untitled") - abstract = result.get("abstractText", "No abstract available.") - doi = result.get("doi", "") - pub_year = result.get("pubYear", "Unknown") - - # Get authors - author_list = result.get("authorList", {}).get("author", []) - authors = [a.get("fullName", "") for a in author_list[:5] if a.get("fullName")] - - # Check if preprint - pub_types = result.get("pubTypeList", {}).get("pubType", []) - is_preprint = "Preprint" in pub_types - source_db = result.get("source", "europepmc") - - # Build content - preprint_marker = "[PREPRINT - Not peer-reviewed] " if is_preprint else "" - content = f"{preprint_marker}{abstract[:1800]}" - - # Build URL - if doi: - url = f"https://doi.org/{doi}" - elif result.get("pmid"): - url = f"https://pubmed.ncbi.nlm.nih.gov/{result['pmid']}/" - else: - url = f"https://europepmc.org/article/{source_db}/{result.get('id', '')}" - - return Evidence( - content=content[:2000], - citation=Citation( - source="preprint" if is_preprint else "europepmc", - title=title[:500], - url=url, - date=str(pub_year), - authors=authors, - ), - relevance=0.75 if is_preprint else 0.9, - ) diff --git a/src/tools/fallback_web_search.py b/src/tools/fallback_web_search.py deleted file mode 100644 index 62d2dfd0e7daf98c6a71f7f92825eb1a96d92503..0000000000000000000000000000000000000000 --- a/src/tools/fallback_web_search.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Fallback web search tool that tries Serper first, then DuckDuckGo on errors.""" - -import structlog - -from src.tools.serper_web_search import SerperWebSearchTool -from src.tools.web_search import WebSearchTool -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError, RateLimitError, SearchError -from src.utils.models import Evidence - -logger = structlog.get_logger() - - -class FallbackWebSearchTool: - """Web search tool that tries Serper first, falls back to DuckDuckGo on any error. - - This ensures search always works even if Serper fails due to: - - Credit exhaustion (403 Forbidden) - - Rate limiting (429) - - Network errors - - Invalid API key - - Any other errors - """ - - def __init__(self) -> None: - """Initialize fallback web search tool.""" - self._serper_tool: SerperWebSearchTool | None = None - self._duckduckgo_tool: WebSearchTool | None = None - self._serper_available = False - - # Try to initialize Serper if API key is available - if settings.serper_api_key: - try: - self._serper_tool = SerperWebSearchTool() - self._serper_available = True - logger.info("Serper web search initialized for fallback tool") - except Exception as e: - logger.warning( - "Failed to initialize Serper, will use DuckDuckGo only", - error=str(e), - ) - self._serper_available = False - - # DuckDuckGo is always available as fallback - self._duckduckgo_tool = WebSearchTool() - logger.info("DuckDuckGo web search initialized as fallback") - - @property - def name(self) -> str: - """Return the name of this search tool.""" - return "serper" if self._serper_available else "duckduckgo" - - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """Execute web search with automatic fallback. - - Args: - query: The search query string - max_results: Maximum number of results to return - - Returns: - List of Evidence objects from Serper (if successful) or DuckDuckGo (if fallback) - """ - # Try Serper first if available - if self._serper_available and self._serper_tool: - try: - logger.debug("Attempting Serper search", query=query) - results = await self._serper_tool.search(query, max_results=max_results) - logger.info( - "Serper search successful", - query=query, - results_count=len(results), - ) - return results - except (ConfigurationError, RateLimitError, SearchError) as e: - # Serper failed - log and fall back to DuckDuckGo - logger.warning( - "Serper search failed, falling back to DuckDuckGo", - error=str(e), - error_type=type(e).__name__, - query=query, - ) - # Mark Serper as unavailable for future requests (optional optimization) - # self._serper_available = False - except Exception as e: - # Unexpected error from Serper - fall back - logger.error( - "Unexpected error in Serper search, falling back to DuckDuckGo", - error=str(e), - error_type=type(e).__name__, - query=query, - ) - - # Fall back to DuckDuckGo - if self._duckduckgo_tool: - try: - logger.info("Using DuckDuckGo search", query=query) - results = await self._duckduckgo_tool.search(query, max_results=max_results) - logger.info( - "DuckDuckGo search successful", - query=query, - results_count=len(results), - ) - return results - except Exception as e: - logger.error( - "DuckDuckGo search also failed", - error=str(e), - query=query, - ) - # If even DuckDuckGo fails, return empty list - return [] - - # Should never reach here, but just in case - logger.error("No web search tools available") - return [] - - - diff --git a/src/tools/neo4j_search.py b/src/tools/neo4j_search.py deleted file mode 100644 index 0198388961f9f0dba02c813944560b898ce584b7..0000000000000000000000000000000000000000 --- a/src/tools/neo4j_search.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Neo4j knowledge graph search tool.""" - -import structlog - -from src.services.neo4j_service import get_neo4j_service -from src.utils.models import Citation, Evidence - -logger = structlog.get_logger() - - -class Neo4jSearchTool: - """Search Neo4j knowledge graph for papers.""" - - def __init__(self) -> None: - self.name = "neo4j" # ✅ Definir explícitamente - - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """Search Neo4j for papers about diseases in the query.""" - try: - service = get_neo4j_service() - if not service: - logger.warning("Neo4j service not available") - return [] - - # Extract disease name from query - disease = query - if "for" in query.lower(): - disease = query.split("for")[-1].strip().rstrip("?") - - # Query Neo4j - if not service.driver: - logger.warning("Neo4j driver not available") - return [] - with service.driver.session(database=service.database) as session: - result = session.run( - """ - MATCH (p:Paper)-[:ABOUT]->(d:Disease) - WHERE d.name CONTAINS $disease - RETURN p.title as title, p.abstract as abstract, - p.url as url, p.source as source - ORDER BY p.updated_at DESC - LIMIT $max_results - """, - disease=disease, - max_results=max_results, - ) - - records = list(result) - - results = [] - for record in records: - citation = Citation( - source="neo4j", - title=record["title"] or "Untitled", - url=record["url"] or "", - date="", - authors=[], - ) - - evidence = Evidence( - content=record["abstract"] or record["title"] or "", - citation=citation, - relevance=1.0, - metadata={"from_kb": True, "original_source": record["source"]}, - ) - results.append(evidence) - - logger.info(f"📊 Neo4j returned {len(results)} results") - return results - except Exception as e: - logger.error(f"Neo4j search failed: {e}") - return [] diff --git a/src/tools/pubmed.py b/src/tools/pubmed.py deleted file mode 100644 index fe6ed6fa422cea1d80998a52cfe790a2922ad2c0..0000000000000000000000000000000000000000 --- a/src/tools/pubmed.py +++ /dev/null @@ -1,207 +0,0 @@ -"""PubMed search tool using NCBI E-utilities.""" - -from typing import Any - -import httpx -import xmltodict -from tenacity import retry, stop_after_attempt, wait_exponential - -from src.tools.query_utils import preprocess_query -from src.tools.rate_limiter import get_pubmed_limiter -from src.utils.config import settings -from src.utils.exceptions import RateLimitError, SearchError -from src.utils.models import Citation, Evidence - - -class PubMedTool: - """Search tool for PubMed/NCBI.""" - - BASE_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" - HTTP_TOO_MANY_REQUESTS = 429 - - def __init__(self, api_key: str | None = None) -> None: - self.api_key = api_key or settings.ncbi_api_key - # Ignore placeholder values from .env.example - if self.api_key == "your-ncbi-key-here": - self.api_key = None - - # Use shared rate limiter - self._limiter = get_pubmed_limiter(self.api_key) - - @property - def name(self) -> str: - return "pubmed" - - async def _rate_limit(self) -> None: - """Enforce NCBI rate limiting.""" - await self._limiter.acquire() - - def _build_params(self, **kwargs: Any) -> dict[str, Any]: - """Build request params with optional API key.""" - params = {**kwargs, "retmode": "json"} - if self.api_key: - params["api_key"] = self.api_key - return params - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=1, max=10), - reraise=True, - ) - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """ - Search PubMed and return evidence. - - 1. ESearch: Get PMIDs matching query - 2. EFetch: Get abstracts for those PMIDs - 3. Parse and return Evidence objects - """ - await self._rate_limit() - - # Preprocess query to remove noise and expand synonyms - clean_query = preprocess_query(query) - final_query = clean_query if clean_query else query - - async with httpx.AsyncClient(timeout=30.0) as client: - # Step 1: Search for PMIDs - search_params = self._build_params( - db="pubmed", - term=final_query, - retmax=max_results, - sort="relevance", - ) - - try: - search_resp = await client.get( - f"{self.BASE_URL}/esearch.fcgi", - params=search_params, - ) - search_resp.raise_for_status() - except httpx.TimeoutException as e: - raise SearchError(f"PubMed search timeout: {e}") from e - except httpx.HTTPStatusError as e: - if e.response.status_code == self.HTTP_TOO_MANY_REQUESTS: - raise RateLimitError("PubMed rate limit exceeded") from e - raise SearchError(f"PubMed search failed: {e}") from e - - search_data = search_resp.json() - pmids = search_data.get("esearchresult", {}).get("idlist", []) - - if not pmids: - return [] - - # Step 2: Fetch abstracts - await self._rate_limit() - fetch_params = self._build_params( - db="pubmed", - id=",".join(pmids), - rettype="abstract", - ) - # Use XML for fetch (more reliable parsing) - fetch_params["retmode"] = "xml" - - try: - fetch_resp = await client.get( - f"{self.BASE_URL}/efetch.fcgi", - params=fetch_params, - ) - fetch_resp.raise_for_status() - except httpx.TimeoutException as e: - raise SearchError(f"PubMed fetch timeout: {e}") from e - - # Step 3: Parse XML to Evidence - return self._parse_pubmed_xml(fetch_resp.text) - - def _parse_pubmed_xml(self, xml_text: str) -> list[Evidence]: - """Parse PubMed XML into Evidence objects.""" - try: - data = xmltodict.parse(xml_text) - except Exception as e: - raise SearchError(f"Failed to parse PubMed XML: {e}") from e - - if data is None: - return [] - - # Handle case where PubmedArticleSet might not exist or be empty - pubmed_set = data.get("PubmedArticleSet") - if not pubmed_set: - return [] - - articles = pubmed_set.get("PubmedArticle", []) - - # Handle single article (xmltodict returns dict instead of list) - if isinstance(articles, dict): - articles = [articles] - - evidence_list = [] - for article in articles: - try: - evidence = self._article_to_evidence(article) - if evidence: - evidence_list.append(evidence) - except Exception: - continue # Skip malformed articles - - return evidence_list - - def _article_to_evidence(self, article: dict[str, Any]) -> Evidence | None: - """Convert a single PubMed article to Evidence.""" - medline = article.get("MedlineCitation", {}) - article_data = medline.get("Article", {}) - - # Extract PMID - pmid = medline.get("PMID", {}) - if isinstance(pmid, dict): - pmid = pmid.get("#text", "") - - # Extract title - title = article_data.get("ArticleTitle", "") - if isinstance(title, dict): - title = title.get("#text", str(title)) - - # Extract abstract - abstract_data = article_data.get("Abstract", {}).get("AbstractText", "") - if isinstance(abstract_data, list): - abstract = " ".join( - item.get("#text", str(item)) if isinstance(item, dict) else str(item) - for item in abstract_data - ) - elif isinstance(abstract_data, dict): - abstract = abstract_data.get("#text", str(abstract_data)) - else: - abstract = str(abstract_data) - - if not abstract or not title: - return None - - # Extract date - pub_date = article_data.get("Journal", {}).get("JournalIssue", {}).get("PubDate", {}) - year = pub_date.get("Year", "Unknown") - month = pub_date.get("Month", "01") - day = pub_date.get("Day", "01") - date_str = f"{year}-{month}-{day}" if year != "Unknown" else "Unknown" - - # Extract authors - author_list = article_data.get("AuthorList", {}).get("Author", []) - if isinstance(author_list, dict): - author_list = [author_list] - authors = [] - for author in author_list[:5]: # Limit to 5 authors - last = author.get("LastName", "") - first = author.get("ForeName", "") - if last: - authors.append(f"{last} {first}".strip()) - - # Truncation rationale: LLM context limits + cost optimization - # - Abstract: 2000 chars (~500 tokens) captures key findings - # - Title: 500 chars covers even verbose journal titles - return Evidence( - content=abstract[:2000], - citation=Citation( - source="pubmed", - title=title[:500], - url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", - date=date_str, - authors=authors, - ), - ) diff --git a/src/tools/query_utils.py b/src/tools/query_utils.py deleted file mode 100644 index c52507f226400cfe8adef3b4dce90d3f70a8f5a5..0000000000000000000000000000000000000000 --- a/src/tools/query_utils.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Query preprocessing utilities for biomedical search.""" - -import re - -# Question words and filler words to remove -QUESTION_WORDS: set[str] = { - # Question starters - "what", - "which", - "how", - "why", - "when", - "where", - "who", - "whom", - # Auxiliary verbs in questions - "is", - "are", - "was", - "were", - "do", - "does", - "did", - "can", - "could", - "would", - "should", - "will", - "shall", - "may", - "might", - # Filler words in natural questions - "show", - "promise", - "help", - "believe", - "think", - "suggest", - "possible", - "potential", - "effective", - "useful", - "good", - # Articles (remove but less aggressively) - "the", - "a", - "an", -} - -# Medical synonym expansions -SYNONYMS: dict[str, list[str]] = { - "long covid": [ - "long COVID", - "PASC", - "post-acute sequelae of SARS-CoV-2", - "post-COVID syndrome", - "post-COVID-19 condition", - ], - "alzheimer": [ - "Alzheimer's disease", - "Alzheimer disease", - "AD", - "Alzheimer dementia", - ], - "parkinson": [ - "Parkinson's disease", - "Parkinson disease", - "PD", - ], - "diabetes": [ - "diabetes mellitus", - "type 2 diabetes", - "T2DM", - "diabetic", - ], - "cancer": [ - "cancer", - "neoplasm", - "tumor", - "malignancy", - "carcinoma", - ], - "heart disease": [ - "cardiovascular disease", - "CVD", - "coronary artery disease", - "heart failure", - ], -} - - -def strip_question_words(query: str) -> str: - """ - Remove question words and filler terms from query. - - Args: - query: Raw query string - - Returns: - Query with question words removed - """ - words = query.lower().split() - filtered = [w for w in words if w not in QUESTION_WORDS] - return " ".join(filtered) - - -def expand_synonyms(query: str) -> str: - """ - Expand medical terms to include synonyms. - - Args: - query: Query string - - Returns: - Query with synonym expansions in OR groups - """ - result = query.lower() - - for term, expansions in SYNONYMS.items(): - if term in result: - # Create OR group: ("term1" OR "term2" OR "term3") - or_group = " OR ".join([f'"{exp}"' for exp in expansions]) - # Case insensitive replacement is tricky with simple replace - # But we lowercased result already. - # However, this replaces ALL instances. - # Also, result is lowercased, so we lose original casing if any. - # But search engines are usually case-insensitive. - result = result.replace(term, f"({or_group})") - - return result - - -def preprocess_query(raw_query: str) -> str: - """ - Full preprocessing pipeline for PubMed queries. - - Pipeline: - 1. Strip whitespace and punctuation - 2. Remove question words - 3. Expand medical synonyms - - Args: - raw_query: Natural language query from user - - Returns: - Optimized query for PubMed - """ - if not raw_query or not raw_query.strip(): - return "" - - # Remove question marks and extra whitespace - query = raw_query.replace("?", "").strip() - query = re.sub(r"\s+", " ", query) - - # Strip question words - query = strip_question_words(query) - - # Expand synonyms - query = expand_synonyms(query) - - return query.strip() - - -def preprocess_web_query(raw_query: str) -> str: - """ - Simplified preprocessing pipeline for web search engines (Serper, DuckDuckGo, etc.). - - Web search engines work better with natural language queries rather than - complex boolean syntax. This function: - 1. Strips whitespace and punctuation - 2. Removes question words (less aggressively) - 3. Removes complex boolean syntax (OR groups, parentheses) - 4. Uses primary synonym terms instead of expanding to OR groups - - Args: - raw_query: Natural language query from user - - Returns: - Simplified query optimized for web search engines - """ - if not raw_query or not raw_query.strip(): - return "" - - # Remove question marks and extra whitespace - query = raw_query.replace("?", "").strip() - query = re.sub(r"\s+", " ", query) - - # Remove complex boolean syntax that might have been added - # Remove OR groups like: ("term1" OR "term2" OR "term3") - query = re.sub(r'\([^)]*OR[^)]*\)', '', query, flags=re.IGNORECASE) - # Remove standalone OR statements - query = re.sub(r'\s+OR\s+', ' ', query, flags=re.IGNORECASE) - # Remove extra parentheses - query = re.sub(r'[()]', '', query) - # Remove extra quotes that might be left - query = re.sub(r'"([^"]*)"', r'\1', query) - # Clean up multiple spaces - query = re.sub(r'\s+', ' ', query) - - # Strip question words (less aggressively - keep important context words) - # Only remove very common question starters - minimal_question_words = {"what", "which", "how", "why", "when", "where", "who"} - words = query.split() - filtered = [w for w in words if w.lower() not in minimal_question_words] - query = " ".join(filtered) - - # Replace known medical terms with their primary/common form - # Use the first synonym (most common) instead of OR groups - query_lower = query.lower() - for term, expansions in SYNONYMS.items(): - if term in query_lower: - # Use the first expansion (usually the most common term) - primary_term = expansions[0] if expansions else term - # Replace case-insensitively, preserving original case where possible - pattern = re.compile(re.escape(term), re.IGNORECASE) - query = pattern.sub(primary_term, query) - - return query.strip() \ No newline at end of file diff --git a/src/tools/rag_tool.py b/src/tools/rag_tool.py deleted file mode 100644 index 230405db4134ab99b557c6a2f839b9384b628bdf..0000000000000000000000000000000000000000 --- a/src/tools/rag_tool.py +++ /dev/null @@ -1,200 +0,0 @@ -"""RAG tool for semantic search within collected evidence. - -Implements SearchTool protocol to enable RAG as a search option in the research workflow. -""" - -from typing import TYPE_CHECKING, Any - -import structlog - -from src.utils.exceptions import ConfigurationError -from src.utils.models import Citation, Evidence, SourceName - -if TYPE_CHECKING: - from src.services.llamaindex_rag import LlamaIndexRAGService - -logger = structlog.get_logger() - - -class RAGTool: - """Search tool that uses LlamaIndex RAG for semantic search within collected evidence. - - Wraps LlamaIndexRAGService to implement the SearchTool protocol. - Returns Evidence objects from RAG retrieval results. - """ - - def __init__( - self, - rag_service: "LlamaIndexRAGService | None" = None, - oauth_token: str | None = None, - ) -> None: - """ - Initialize RAG tool. - - Args: - rag_service: Optional RAG service instance. If None, will be lazy-initialized. - oauth_token: Optional OAuth token from HuggingFace login (for RAG LLM) - """ - self._rag_service = rag_service - self.oauth_token = oauth_token - self.logger = logger - - @property - def name(self) -> str: - """Return the tool name.""" - return "rag" - - def _get_rag_service(self) -> "LlamaIndexRAGService": - """ - Get or create RAG service instance. - - Returns: - LlamaIndexRAGService instance - - Raises: - ConfigurationError: If RAG service cannot be initialized - """ - if self._rag_service is None: - try: - from src.services.llamaindex_rag import get_rag_service - - # Use local embeddings by default (no API key required) - # Use in-memory ChromaDB to avoid file system issues - # Pass OAuth token for LLM query synthesis - self._rag_service = get_rag_service( - use_openai_embeddings=False, - use_in_memory=True, # Use in-memory for better reliability - oauth_token=self.oauth_token, - ) - self.logger.info("RAG service initialized with local embeddings") - except (ConfigurationError, ImportError) as e: - self.logger.error("Failed to initialize RAG service", error=str(e)) - raise ConfigurationError( - "RAG service unavailable. Check LlamaIndex dependencies are installed." - ) from e - - return self._rag_service - - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """ - Search RAG system and return evidence. - - Args: - query: The search query string - max_results: Maximum number of results to return - - Returns: - List of Evidence objects from RAG retrieval - - Note: - Returns empty list on error (does not raise exceptions). - """ - try: - rag_service = self._get_rag_service() - except ConfigurationError: - self.logger.warning("RAG service unavailable, returning empty results") - return [] - - try: - # Retrieve documents from RAG - retrieved_docs = rag_service.retrieve(query, top_k=max_results) - - if not retrieved_docs: - self.logger.info("No RAG results found", query=query[:50]) - return [] - - # Convert retrieved documents to Evidence objects - evidence_list: list[Evidence] = [] - for doc in retrieved_docs: - try: - evidence = self._doc_to_evidence(doc) - evidence_list.append(evidence) - except Exception as e: - self.logger.warning( - "Failed to convert document to evidence", - error=str(e), - doc_text=doc.get("text", "")[:50], - ) - continue - - self.logger.info( - "RAG search completed", - query=query[:50], - results=len(evidence_list), - ) - return evidence_list - - except Exception as e: - self.logger.error("RAG search failed", error=str(e), query=query[:50]) - # Return empty list on error (graceful degradation) - return [] - - def _doc_to_evidence(self, doc: dict[str, Any]) -> Evidence: - """ - Convert RAG document to Evidence object. - - Args: - doc: Document dict with keys: text, score, metadata - - Returns: - Evidence object - - Raises: - ValueError: If document is missing required fields - """ - text = doc.get("text", "") - if not text: - raise ValueError("Document missing text content") - - metadata = doc.get("metadata", {}) - score = doc.get("score", 0.0) - - # Extract citation information from metadata - source: SourceName = "rag" # RAG is the source - title = metadata.get("title", "Untitled") - url = metadata.get("url", "") - date = metadata.get("date", "Unknown") - authors_str = metadata.get("authors", "") - authors = [a.strip() for a in authors_str.split(",") if a.strip()] if authors_str else [] - - # Create citation - citation = Citation( - source=source, - title=title[:500], # Enforce max length - url=url, - date=date, - authors=authors, - ) - - # Create evidence with relevance score (normalize score to 0-1 if needed) - relevance = min(max(float(score), 0.0), 1.0) if score else 0.0 - - return Evidence( - content=text, - citation=citation, - relevance=relevance, - ) - - -def create_rag_tool( - rag_service: "LlamaIndexRAGService | None" = None, - oauth_token: str | None = None, -) -> RAGTool: - """ - Factory function to create a RAG tool. - - Args: - rag_service: Optional RAG service instance. If None, will be lazy-initialized. - oauth_token: Optional OAuth token from HuggingFace login (for RAG LLM) - - Returns: - Configured RAGTool instance - - Raises: - ConfigurationError: If RAG service cannot be initialized and rag_service is None - """ - try: - return RAGTool(rag_service=rag_service, oauth_token=oauth_token) - except Exception as e: - logger.error("Failed to create RAG tool", error=str(e)) - raise ConfigurationError(f"Failed to create RAG tool: {e}") from e diff --git a/src/tools/rate_limiter.py b/src/tools/rate_limiter.py deleted file mode 100644 index 48cf5385cc04dcfc3d7d7e938813aa24887254ef..0000000000000000000000000000000000000000 --- a/src/tools/rate_limiter.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Rate limiting utilities using the limits library.""" - -import asyncio -import random -from typing import ClassVar - -from limits import RateLimitItem, parse -from limits.storage import MemoryStorage -from limits.strategies import MovingWindowRateLimiter - - -class RateLimiter: - """ - Async-compatible rate limiter using limits library. - - Uses moving window algorithm for smooth rate limiting. - """ - - def __init__(self, rate: str) -> None: - """ - Initialize rate limiter. - - Args: - rate: Rate string like "3/second" or "10/second" - """ - self.rate = rate - self._storage = MemoryStorage() - self._limiter = MovingWindowRateLimiter(self._storage) - self._rate_limit: RateLimitItem = parse(rate) - self._identity = "default" # Single identity for shared limiting - - async def acquire(self, wait: bool = True, jitter: bool = False) -> bool: - """ - Acquire permission to make a request. - - ASYNC-SAFE: Uses asyncio.sleep(), never time.sleep(). - The polling pattern allows other coroutines to run while waiting. - - Args: - wait: If True, wait until allowed. If False, return immediately. - jitter: If True, add random jitter (0-20% of wait time) to avoid thundering herd - - Returns: - True if allowed, False if not (only when wait=False) - """ - while True: - # Check if we can proceed (synchronous, fast - ~microseconds) - if self._limiter.hit(self._rate_limit, self._identity): - # Add jitter after acquiring to spread out requests - if jitter: - # Add 0-1 second jitter to spread requests slightly - # This prevents thundering herd without long delays - jitter_seconds = random.uniform(0, 1.0) - await asyncio.sleep(jitter_seconds) - return True - - if not wait: - return False - - # CRITICAL: Use asyncio.sleep(), NOT time.sleep() - # This yields control to the event loop, allowing other - # coroutines (UI, parallel searches) to run. - # Using 0.01s for fine-grained responsiveness. - await asyncio.sleep(0.01) - - def reset(self) -> None: - """Reset the rate limiter (for testing).""" - self._storage.reset() - - -# Singleton limiter for PubMed/NCBI -_pubmed_limiter: RateLimiter | None = None - - -def get_pubmed_limiter(api_key: str | None = None) -> RateLimiter: - """ - Get the shared PubMed rate limiter. - - Rate depends on whether API key is provided: - - Without key: 3 requests/second - - With key: 10 requests/second - - Args: - api_key: NCBI API key (optional) - - Returns: - Shared RateLimiter instance - """ - global _pubmed_limiter - - if _pubmed_limiter is None: - rate = "10/second" if api_key else "3/second" - _pubmed_limiter = RateLimiter(rate) - - return _pubmed_limiter - - -def reset_pubmed_limiter() -> None: - """Reset the PubMed limiter (for testing).""" - global _pubmed_limiter - _pubmed_limiter = None - - -def get_serper_limiter(api_key: str | None = None) -> RateLimiter: - """ - Get the shared Serper API rate limiter. - - Rate: 100 requests/second (Serper free tier limit) - - Serper free tier provides: - - 2,500 credits (one-time, expire after 6 months) - - 100 requests/second rate limit - - Credits only deduct for successful responses - - We use a slightly conservative rate (90/second) to stay safely under the limit - while allowing high throughput when needed. - - Args: - api_key: Serper API key (optional, for consistency with other limiters) - - Returns: - Shared RateLimiter instance - """ - # Use 90/second to stay safely under 100/second limit - return RateLimiterFactory.get("serper", "90/second") - - -def get_searchxng_limiter() -> RateLimiter: - """ - Get the shared SearchXNG API rate limiter. - - Rate: 5 requests/second (conservative limit) - - Returns: - Shared RateLimiter instance - """ - return RateLimiterFactory.get("searchxng", "5/second") - - -# Factory for other APIs -class RateLimiterFactory: - """Factory for creating/getting rate limiters for different APIs.""" - - _limiters: ClassVar[dict[str, RateLimiter]] = {} - - @classmethod - def get(cls, api_name: str, rate: str) -> RateLimiter: - """ - Get or create a rate limiter for an API. - - Args: - api_name: Unique identifier for the API - rate: Rate limit string (e.g., "10/second") - - Returns: - RateLimiter instance (shared for same api_name) - """ - if api_name not in cls._limiters: - cls._limiters[api_name] = RateLimiter(rate) - return cls._limiters[api_name] - - @classmethod - def reset_all(cls) -> None: - """Reset all limiters (for testing).""" - cls._limiters.clear() diff --git a/src/tools/search_handler.py b/src/tools/search_handler.py deleted file mode 100644 index 9d2b4adf7d00001af74dd2c4614ae92459c74b1d..0000000000000000000000000000000000000000 --- a/src/tools/search_handler.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Search handler - orchestrates multiple search tools.""" - -import asyncio -from typing import TYPE_CHECKING, cast - -import structlog - -from src.services.neo4j_service import get_neo4j_service -from src.tools.base import SearchTool -from src.tools.rag_tool import create_rag_tool -from src.utils.exceptions import ConfigurationError, SearchError -from src.utils.models import Evidence, SearchResult, SourceName - -if TYPE_CHECKING: - from src.services.llamaindex_rag import LlamaIndexRAGService -else: - LlamaIndexRAGService = object - -logger = structlog.get_logger() - - -class SearchHandler: - """Orchestrates parallel searches across multiple tools.""" - - def __init__( - self, - tools: list[SearchTool], - timeout: float = 30.0, - include_rag: bool = False, - auto_ingest_to_rag: bool = True, - oauth_token: str | None = None, - ) -> None: - """ - Initialize the search handler. - - Args: - tools: List of search tools to use - timeout: Timeout for each search in seconds - include_rag: Whether to include RAG tool in searches - auto_ingest_to_rag: Whether to automatically ingest results into RAG - oauth_token: Optional OAuth token from HuggingFace login (for RAG LLM) - """ - self.tools = list(tools) # Make a copy - self.timeout = timeout - self.auto_ingest_to_rag = auto_ingest_to_rag - self.oauth_token = oauth_token - self._rag_service: LlamaIndexRAGService | None = None - - if include_rag: - self.add_rag_tool() - - def add_rag_tool(self) -> None: - """Add RAG tool to the tools list if available.""" - try: - rag_tool = create_rag_tool(oauth_token=self.oauth_token) - self.tools.append(rag_tool) - logger.info("RAG tool added to search handler") - except ConfigurationError: - logger.warning( - "RAG tool unavailable, not adding to search handler", - hint="LlamaIndex dependencies required", - ) - except Exception as e: - logger.error("Failed to add RAG tool", error=str(e)) - - def _get_rag_service(self) -> "LlamaIndexRAGService | None": - """Get or create RAG service for ingestion.""" - if self._rag_service is None and self.auto_ingest_to_rag: - try: - from src.services.llamaindex_rag import get_rag_service - - # Use local embeddings by default (no API key required) - # Use in-memory ChromaDB to avoid file system issues - # Pass OAuth token for LLM query synthesis - self._rag_service = get_rag_service( - use_openai_embeddings=False, - use_in_memory=True, # Use in-memory for better reliability - oauth_token=self.oauth_token, - ) - logger.info("RAG service initialized for ingestion with local embeddings") - except (ConfigurationError, ImportError): - logger.warning("RAG service unavailable for ingestion") - return None - return self._rag_service - - async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult: - """ - Execute search across all tools in parallel. - - Args: - query: The search query - max_results_per_tool: Max results from each tool - - Returns: - SearchResult containing all evidence and metadata - """ - logger.info("Starting search", query=query, tools=[t.name for t in self.tools]) - - # Create tasks for parallel execution - tasks = [ - self._search_with_timeout(tool, query, max_results_per_tool) for tool in self.tools - ] - - # Gather results (don't fail if one tool fails) - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Process results - all_evidence: list[Evidence] = [] - sources_searched: list[SourceName] = [] - errors: list[str] = [] - - # Map tool names to SourceName values - # Some tools have internal names that differ from SourceName literals - tool_name_to_source: dict[str, SourceName] = { - "duckduckgo": "web", - "serper": "web", # Serper uses Google search but maps to "web" source - "searchxng": "web", # SearchXNG also maps to "web" source - "pubmed": "pubmed", - "clinicaltrials": "clinicaltrials", - "europepmc": "europepmc", - "neo4j": "neo4j", - "rag": "rag", - "web": "web", # In case tool already uses "web" - } - - for tool, result in zip(self.tools, results, strict=True): - if isinstance(result, Exception): - errors.append(f"{tool.name}: {result!s}") - logger.warning("Search tool failed", tool=tool.name, error=str(result)) - else: - # Cast result to list[Evidence] as we know it succeeded - success_result = cast(list[Evidence], result) - all_evidence.extend(success_result) - - # Map tool.name to SourceName (handle tool names that don't match SourceName literals) - tool_name = tool_name_to_source.get(tool.name, cast(SourceName, tool.name)) - if tool_name not in [ - "pubmed", - "clinicaltrials", - "biorxiv", - "europepmc", - "preprint", - "rag", - "web", - "neo4j", - ]: - logger.warning( - "Tool name not in SourceName literals, defaulting to 'web'", - tool_name=tool.name, - ) - tool_name = "web" - sources_searched.append(tool_name) - logger.info("Search tool succeeded", tool=tool.name, count=len(success_result)) - - search_result = SearchResult( - query=query, - evidence=all_evidence, - sources_searched=sources_searched, - total_found=len(all_evidence), - errors=errors, - ) - - # Ingest evidence into RAG if enabled and available - if self.auto_ingest_to_rag and all_evidence: - rag_service = self._get_rag_service() - if rag_service: - try: - # Filter out RAG-sourced evidence (avoid circular ingestion) - evidence_to_ingest = [e for e in all_evidence if e.citation.source != "rag"] - if evidence_to_ingest: - rag_service.ingest_evidence(evidence_to_ingest) - logger.info( - "Ingested evidence into RAG", - count=len(evidence_to_ingest), - ) - except Exception as e: - logger.warning("Failed to ingest evidence into RAG", error=str(e)) - - # 🔥 INGEST INTO NEO4J KNOWLEDGE GRAPH 🔥 - if all_evidence: - try: - neo4j_service = get_neo4j_service() - if neo4j_service: - # Extract disease from query - disease = query - if "for" in query.lower(): - disease = query.split("for")[-1].strip().rstrip("?") - - # Convert Evidence objects to dicts for Neo4j - papers = [] - for ev in all_evidence: - papers.append( - { - "id": ev.citation.url or "", - "title": ev.citation.title or "", - "abstract": ev.content, - "url": ev.citation.url or "", - "source": ev.citation.source, - } - ) - - stats = neo4j_service.ingest_search_results(disease, papers) - logger.info("💾 Saved to Neo4j", stats=stats) - except Exception as e: - logger.warning("Neo4j ingestion failed", error=str(e)) - - return search_result - - async def _search_with_timeout( - self, - tool: SearchTool, - query: str, - max_results: int, - ) -> list[Evidence]: - """Execute a single tool search with timeout.""" - try: - return await asyncio.wait_for( - tool.search(query, max_results), - timeout=self.timeout, - ) - except TimeoutError as e: - raise SearchError(f"{tool.name} search timed out after {self.timeout}s") from e diff --git a/src/tools/searchxng_web_search.py b/src/tools/searchxng_web_search.py deleted file mode 100644 index e120b61f72b4e28e780ad43a5673f49f5d27ef58..0000000000000000000000000000000000000000 --- a/src/tools/searchxng_web_search.py +++ /dev/null @@ -1,120 +0,0 @@ -"""SearchXNG web search tool using SearchXNG API for Google searches.""" - -import structlog -from tenacity import retry, stop_after_attempt, wait_exponential - -from src.tools.query_utils import preprocess_query -from src.tools.rate_limiter import get_searchxng_limiter -from src.tools.vendored.searchxng_client import SearchXNGClient -from src.tools.vendored.web_search_core import scrape_urls -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError, RateLimitError, SearchError -from src.utils.models import Citation, Evidence - -logger = structlog.get_logger() - - -class SearchXNGWebSearchTool: - """Tool for searching the web using SearchXNG API (Google search).""" - - def __init__(self, host: str | None = None) -> None: - """Initialize SearchXNG web search tool. - - Args: - host: SearchXNG host URL. If None, reads from settings. - - Raises: - ConfigurationError: If no host is available. - """ - self.host = host or settings.searchxng_host - if not self.host: - raise ConfigurationError( - "SearchXNG host required. Set SEARCHXNG_HOST environment variable or searchxng_host in settings." - ) - - self._client = SearchXNGClient(host=self.host) - self._limiter = get_searchxng_limiter() - - @property - def name(self) -> str: - """Return the name of this search tool.""" - return "searchxng" - - async def _rate_limit(self) -> None: - """Enforce SearchXNG API rate limiting.""" - await self._limiter.acquire() - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=1, max=10), - reraise=True, - ) - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """Execute a web search using SearchXNG API. - - Args: - query: The search query string - max_results: Maximum number of results to return - - Returns: - List of Evidence objects - - Raises: - SearchError: If the search fails - RateLimitError: If rate limit is exceeded - """ - await self._rate_limit() - - # Preprocess query to remove noise - clean_query = preprocess_query(query) - final_query = clean_query if clean_query else query - - try: - # Get search results (snippets) - search_results = await self._client.search( - final_query, filter_for_relevance=False, max_results=max_results - ) - - if not search_results: - logger.info("No search results found", query=final_query) - return [] - - # Scrape URLs to get full content - scraped = await scrape_urls(search_results) - - # Convert ScrapeResult to Evidence objects - evidence = [] - for result in scraped: - # Truncate title to max 500 characters to match Citation model validation - title = result.title - if len(title) > 500: - title = title[:497] + "..." - - ev = Evidence( - content=result.text, - citation=Citation( - title=title, - url=result.url, - source="web", # Use "web" to match SourceName literal, not "searchxng" - date="Unknown", - authors=[], - ), - relevance=0.0, - ) - evidence.append(ev) - - logger.info( - "SearchXNG search complete", - query=final_query, - results_found=len(evidence), - ) - - return evidence - - except RateLimitError: - raise - except SearchError: - raise - except Exception as e: - logger.error("Unexpected error in SearchXNG search", error=str(e), query=final_query) - raise SearchError(f"SearchXNG search failed: {e}") from e diff --git a/src/tools/serper_web_search.py b/src/tools/serper_web_search.py deleted file mode 100644 index d5b6bdb350ca7d6cae834b77b0cb57f043833689..0000000000000000000000000000000000000000 --- a/src/tools/serper_web_search.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Serper web search tool using Serper API for Google searches.""" - -import structlog -from tenacity import ( - retry, - retry_if_not_exception_type, - stop_after_attempt, - wait_random_exponential, -) - -from src.tools.query_utils import preprocess_web_query -from src.tools.rate_limiter import get_serper_limiter -from src.tools.vendored.serper_client import SerperClient -from src.tools.vendored.web_search_core import scrape_urls -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError, RateLimitError, SearchError -from src.utils.models import Citation, Evidence - -logger = structlog.get_logger() - - -class SerperWebSearchTool: - """Tool for searching the web using Serper API (Google search).""" - - def __init__(self, api_key: str | None = None) -> None: - """Initialize Serper web search tool. - - Args: - api_key: Serper API key. If None, reads from settings. - - Raises: - ConfigurationError: If no API key is available. - """ - self.api_key = api_key or settings.serper_api_key - if not self.api_key: - raise ConfigurationError( - "Serper API key required. Set SERPER_API_KEY environment variable or serper_api_key in settings." - ) - - self._client = SerperClient(api_key=self.api_key) - self._limiter = get_serper_limiter(self.api_key) - - # Validate API key format (basic check) - if self.api_key and len(self.api_key.strip()) < 10: - logger.warning( - "Serper API key appears to be too short", - key_length=len(self.api_key), - hint="Verify SERPER_API_KEY is correct", - ) - - @property - def name(self) -> str: - """Return the name of this search tool.""" - return "serper" - - async def _rate_limit(self) -> None: - """Enforce Serper API rate limiting with jitter. - - Uses jitter to spread out requests and avoid thundering herd problems. - Rate limit is 100 requests/second for free tier, we use 90/second to stay safe. - """ - await self._limiter.acquire(jitter=True) - - @retry( - stop=stop_after_attempt(3), # Reduced retries for faster fallback - wait=wait_random_exponential( - multiplier=1, min=2, max=10, exp_base=2 - ), # 2s to 10s backoff with jitter (faster for fallback) - reraise=True, - retry=retry_if_not_exception_type(ConfigurationError), # Don't retry on config errors - ) - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """Execute a web search using Serper API. - - Args: - query: The search query string - max_results: Maximum number of results to return - - Returns: - List of Evidence objects - - Raises: - SearchError: If the search fails - RateLimitError: If rate limit is exceeded - ConfigurationError: If API key is invalid (403 Forbidden) - """ - await self._rate_limit() - - # Preprocess query for web search (simplified, no boolean syntax) - clean_query = preprocess_web_query(query) - final_query = clean_query if clean_query else query - - try: - # Get search results (snippets) - search_results = await self._client.search( - final_query, filter_for_relevance=False, max_results=max_results - ) - - if not search_results: - logger.info("No search results found", query=final_query) - return [] - - # Scrape URLs to get full content - scraped = await scrape_urls(search_results) - - # Convert ScrapeResult to Evidence objects - evidence = [] - for result in scraped: - # Truncate title to max 500 characters to match Citation model validation - title = result.title - if len(title) > 500: - title = title[:497] + "..." - - ev = Evidence( - content=result.text, - citation=Citation( - title=title, - url=result.url, - source="web", # Use "web" to match SourceName literal, not "serper" - date="Unknown", - authors=[], - ), - relevance=0.0, - ) - evidence.append(ev) - - logger.info( - "Serper search complete", - query=final_query, - results_found=len(evidence), - ) - - return evidence - - except ConfigurationError: - # Don't retry configuration errors (e.g., 403 Forbidden = invalid API key) - raise - except RateLimitError: - raise - except SearchError: - raise - except Exception as e: - logger.error("Unexpected error in Serper search", error=str(e), query=final_query) - raise SearchError(f"Serper search failed: {e}") from e diff --git a/src/tools/tool_executor.py b/src/tools/tool_executor.py deleted file mode 100644 index e499960726ae4f0044e863fbc0f30a4a4afd0e9e..0000000000000000000000000000000000000000 --- a/src/tools/tool_executor.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Tool executor for running AgentTask objects. - -Executes tool tasks selected by the tool selector agent and returns ToolAgentOutput. -""" - -import structlog - -from src.tools.crawl_adapter import crawl_website -from src.tools.rag_tool import RAGTool, create_rag_tool -from src.tools.web_search_adapter import web_search -from src.utils.exceptions import ConfigurationError -from src.utils.models import AgentTask, Evidence, ToolAgentOutput - -logger = structlog.get_logger() - -# Module-level RAG tool instance (lazy initialization) -_rag_tool: RAGTool | None = None - - -def _get_rag_tool() -> RAGTool | None: - """ - Get or create RAG tool instance. - - Returns: - RAGTool instance, or None if unavailable - """ - global _rag_tool - if _rag_tool is None: - try: - _rag_tool = create_rag_tool() - logger.info("RAG tool initialized") - except ConfigurationError: - logger.warning("RAG tool unavailable (OPENAI_API_KEY required)") - return None - except Exception as e: - logger.error("Failed to initialize RAG tool", error=str(e)) - return None - return _rag_tool - - -def _evidence_to_text(evidence_list: list[Evidence]) -> str: - """ - Convert Evidence objects to formatted text. - - Args: - evidence_list: List of Evidence objects - - Returns: - Formatted text string with citations and content - """ - if not evidence_list: - return "No evidence found." - - formatted_parts = [] - for i, evidence in enumerate(evidence_list, 1): - citation = evidence.citation - citation_str = f"{citation.formatted}" - if citation.url: - citation_str += f" [{citation.url}]" - - formatted_parts.append(f"[{i}] {citation_str}\n\n{evidence.content}\n\n---\n") - - return "\n".join(formatted_parts) - - -async def execute_agent_task(task: AgentTask) -> ToolAgentOutput: - """ - Execute a single agent task and return ToolAgentOutput. - - Args: - task: AgentTask specifying which tool to use and what query to run - - Returns: - ToolAgentOutput with results and source URLs - """ - logger.info( - "Executing agent task", - agent=task.agent, - query=task.query[:100] if task.query else "", - gap=task.gap[:100] if task.gap else "", - ) - - try: - if task.agent == "WebSearchAgent": - # Use web search adapter - result_text = await web_search(task.query) - # Extract URLs from result (simple heuristic - look for http/https) - import re - - urls = re.findall(r"https?://[^\s\)]+", result_text) - sources = list(set(urls)) # Deduplicate - - return ToolAgentOutput(output=result_text, sources=sources) - - elif task.agent == "SiteCrawlerAgent": - # Use crawl adapter - if task.entity_website: - starting_url = task.entity_website - elif task.query.startswith(("http://", "https://")): - starting_url = task.query - else: - # Try to construct URL from query - starting_url = f"https://{task.query}" - - result_text = await crawl_website(starting_url) - # Extract URLs from result - import re - - urls = re.findall(r"https?://[^\s\)]+", result_text) - sources = list(set(urls)) # Deduplicate - - return ToolAgentOutput(output=result_text, sources=sources) - - elif task.agent == "RAGAgent": - # Use RAG tool for semantic search - rag_tool = _get_rag_tool() - if rag_tool is None: - return ToolAgentOutput( - output="RAG service unavailable. OPENAI_API_KEY required.", - sources=[], - ) - - # Search RAG and get Evidence objects - evidence_list = await rag_tool.search(task.query, max_results=10) - - if not evidence_list: - return ToolAgentOutput( - output="No relevant evidence found in collected research.", - sources=[], - ) - - # Convert Evidence to formatted text - result_text = _evidence_to_text(evidence_list) - - # Extract URLs from evidence citations - sources = [evidence.citation.url for evidence in evidence_list if evidence.citation.url] - - return ToolAgentOutput(output=result_text, sources=sources) - - else: - logger.warning("Unknown agent type", agent=task.agent) - return ToolAgentOutput( - output=f"Unknown agent type: {task.agent}. Available: WebSearchAgent, SiteCrawlerAgent, RAGAgent", - sources=[], - ) - - except Exception as e: - logger.error("Tool execution failed", error=str(e), agent=task.agent) - return ToolAgentOutput( - output=f"Error executing {task.agent} for gap '{task.gap}': {e!s}", - sources=[], - ) - - -async def execute_tool_tasks( - tasks: list[AgentTask], -) -> dict[str, ToolAgentOutput]: - """ - Execute multiple agent tasks concurrently. - - Args: - tasks: List of AgentTask objects to execute - - Returns: - Dictionary mapping task keys to ToolAgentOutput results - """ - import asyncio - - logger.info("Executing tool tasks", count=len(tasks)) - - # Create async tasks - async_tasks = [execute_agent_task(task) for task in tasks] - - # Run concurrently - results_list = await asyncio.gather(*async_tasks, return_exceptions=True) - - # Build results dictionary - results: dict[str, ToolAgentOutput] = {} - for i, (task, result) in enumerate(zip(tasks, results_list, strict=False)): - if isinstance(result, Exception): - logger.error("Task execution failed", error=str(result), task_index=i) - results[f"{task.agent}_{i}"] = ToolAgentOutput(output=f"Error: {result!s}", sources=[]) - else: - # Type narrowing: result is ToolAgentOutput after Exception check - assert isinstance(result, ToolAgentOutput), ( - "Expected ToolAgentOutput after Exception check" - ) - key = f"{task.agent}_{task.gap or i}" if task.gap else f"{task.agent}_{i}" - results[key] = result - - logger.info("Tool tasks completed", completed=len(results)) - - return results diff --git a/src/tools/vendored/__init__.py b/src/tools/vendored/__init__.py deleted file mode 100644 index e96e825253d1d80cb65b5f33c22878fd4ddc225a..0000000000000000000000000000000000000000 --- a/src/tools/vendored/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Vendored web search components from folder/tools/web_search.py.""" - -from src.tools.vendored.crawl_website import crawl_website -from src.tools.vendored.searchxng_client import SearchXNGClient -from src.tools.vendored.serper_client import SerperClient -from src.tools.vendored.web_search_core import ( - CONTENT_LENGTH_LIMIT, - ScrapeResult, - WebpageSnippet, - fetch_and_process_url, - html_to_text, - is_valid_url, - scrape_urls, -) - -__all__ = [ - "CONTENT_LENGTH_LIMIT", - "ScrapeResult", - "SearchXNGClient", - "SerperClient", - "WebpageSnippet", - "crawl_website", - "fetch_and_process_url", - "html_to_text", - "is_valid_url", - "scrape_urls", -] diff --git a/src/tools/vendored/crawl_website.py b/src/tools/vendored/crawl_website.py deleted file mode 100644 index cd75bb509fde305a5ba26d51877b4c06d87351ba..0000000000000000000000000000000000000000 --- a/src/tools/vendored/crawl_website.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Website crawl tool vendored from folder/tools/crawl_website.py. - -This module provides website crawling functionality that starts from a given URL -and crawls linked pages in a breadth-first manner, prioritizing navigation links. -""" - -from urllib.parse import urljoin, urlparse - -import aiohttp -import structlog -from bs4 import BeautifulSoup - -from src.tools.vendored.web_search_core import ( - ScrapeResult, - WebpageSnippet, - scrape_urls, - ssl_context, -) - -logger = structlog.get_logger() - - -async def _extract_links( - html: str, current_url: str, base_domain: str -) -> tuple[list[str], list[str]]: - """Extract prioritized links from HTML content.""" - soup = BeautifulSoup(html, "html.parser") - nav_links = set() - body_links = set() - - # Find navigation/header links - for nav_element in soup.find_all(["nav", "header"]): - for a in nav_element.find_all("a", href=True): - href = str(a["href"]) - link = urljoin(current_url, href) - if urlparse(link).netloc == base_domain: - nav_links.add(link) - - # Find remaining body links - for a in soup.find_all("a", href=True): - href = str(a["href"]) - link = urljoin(current_url, href) - if urlparse(link).netloc == base_domain and link not in nav_links: - body_links.add(link) - - return list(nav_links), list(body_links) - - -async def _fetch_page(url: str) -> str: - """Fetch HTML content from a URL.""" - connector = aiohttp.TCPConnector(ssl=ssl_context) - async with aiohttp.ClientSession(connector=connector) as session: - try: - timeout = aiohttp.ClientTimeout(total=30) - async with session.get(url, timeout=timeout) as response: - if response.status == 200: - return await response.text() - return "" - except Exception as e: - logger.warning("Error fetching URL", url=url, error=str(e)) - return "" - - -def _add_links_to_queue( - links: list[str], - queue: list[str], - all_pages_to_scrape: set[str], - remaining_slots: int, -) -> int: - """Add normalized links to queue if not already visited.""" - for link in links: - normalized_link = link.rstrip("/") - if normalized_link not in all_pages_to_scrape and remaining_slots > 0: - queue.append(normalized_link) - all_pages_to_scrape.add(normalized_link) - remaining_slots -= 1 - return remaining_slots - - -async def crawl_website(starting_url: str) -> list[ScrapeResult] | str: - """Crawl the pages of a website starting with the starting_url and then descending into the pages linked from there. - - Prioritizes links found in headers/navigation, then body links, then subsequent pages. - - Args: - starting_url: Starting URL to scrape - - Returns: - List of ScrapeResult objects which have the following fields: - - url: The URL of the web page - - title: The title of the web page - - description: The description of the web page - - text: The text content of the web page - """ - if not starting_url: - return "Empty URL provided" - - # Ensure URL has a protocol - if not starting_url.startswith(("http://", "https://")): - starting_url = "http://" + starting_url - - max_pages = 10 - base_domain = urlparse(starting_url).netloc - - # Initialize with starting URL - queue: list[str] = [starting_url] - next_level_queue: list[str] = [] - all_pages_to_scrape: set[str] = set([starting_url]) - - # Breadth-first crawl - while queue and len(all_pages_to_scrape) < max_pages: - current_url = queue.pop(0) - - # Fetch and process the page - html_content = await _fetch_page(current_url) - if html_content: - nav_links, body_links = await _extract_links(html_content, current_url, base_domain) - - # Add unvisited nav links to current queue (higher priority) - remaining_slots = max_pages - len(all_pages_to_scrape) - remaining_slots = _add_links_to_queue( - nav_links, queue, all_pages_to_scrape, remaining_slots - ) - - # Add unvisited body links to next level queue (lower priority) - remaining_slots = _add_links_to_queue( - body_links, next_level_queue, all_pages_to_scrape, remaining_slots - ) - - # If current queue is empty, add next level links - if not queue: - queue = next_level_queue - next_level_queue = [] - - # Convert set to list for final processing - pages_to_scrape = list(all_pages_to_scrape)[:max_pages] - pages_to_scrape_snippets: list[WebpageSnippet] = [ - WebpageSnippet(url=page, title="", description="") for page in pages_to_scrape - ] - - # Use scrape_urls to get the content for all discovered pages - result = await scrape_urls(pages_to_scrape_snippets) - return result - - - - - - - - - - diff --git a/src/tools/vendored/searchxng_client.py b/src/tools/vendored/searchxng_client.py deleted file mode 100644 index 2cf4ce83b2f58e47bfb0f543790a6bd7d540bab6..0000000000000000000000000000000000000000 --- a/src/tools/vendored/searchxng_client.py +++ /dev/null @@ -1,106 +0,0 @@ -"""SearchXNG API client for Google searches. - -Vendored and adapted from folder/tools/web_search.py. -""" - -import os - -import aiohttp -import structlog - -from src.tools.vendored.web_search_core import WebpageSnippet, ssl_context -from src.utils.exceptions import RateLimitError, SearchError - -logger = structlog.get_logger() - - -class SearchXNGClient: - """A client for the SearchXNG API to perform Google searches.""" - - def __init__(self, host: str | None = None) -> None: - """Initialize SearchXNG client. - - Args: - host: SearchXNG host URL. If None, reads from SEARCHXNG_HOST env var. - - Raises: - ConfigurationError: If no host is provided. - """ - host = host or os.getenv("SEARCHXNG_HOST") - if not host: - from src.utils.exceptions import ConfigurationError - - raise ConfigurationError("SEARCHXNG_HOST environment variable is not set") - - # Ensure host ends with /search - if not host.endswith("/search"): - host = f"{host}/search" if not host.endswith("/") else f"{host}search" - - self.host: str = host - - async def search( - self, query: str, filter_for_relevance: bool = False, max_results: int = 5 - ) -> list[WebpageSnippet]: - """Perform a search using SearchXNG API. - - Args: - query: The search query - filter_for_relevance: Whether to filter results (currently not implemented) - max_results: Maximum number of results to return - - Returns: - List of WebpageSnippet objects with search results - - Raises: - SearchError: If the search fails - RateLimitError: If rate limit is exceeded - """ - connector = aiohttp.TCPConnector(ssl=ssl_context) - try: - async with aiohttp.ClientSession(connector=connector) as session: - params = { - "q": query, - "format": "json", - } - - async with session.get(self.host, params=params) as response: - if response.status == 429: - raise RateLimitError("SearchXNG API rate limit exceeded") - - response.raise_for_status() - results = await response.json() - - results_list = [ - WebpageSnippet( - url=result.get("url", ""), - title=result.get("title", ""), - description=result.get("content", ""), - ) - for result in results.get("results", []) - ] - - if not results_list: - logger.info("No search results found", query=query) - return [] - - # Return results up to max_results - return results_list[:max_results] - - except aiohttp.ClientError as e: - logger.error("SearchXNG API request failed", error=str(e), query=query) - raise SearchError(f"SearchXNG API request failed: {e}") from e - except RateLimitError: - raise - except Exception as e: - logger.error("Unexpected error in SearchXNG search", error=str(e), query=query) - raise SearchError(f"SearchXNG search failed: {e}") from e - - - - - - - - - - diff --git a/src/tools/vendored/serper_client.py b/src/tools/vendored/serper_client.py deleted file mode 100644 index 17d30b6a8b86f0ec9d36e4107e571a397bf9c050..0000000000000000000000000000000000000000 --- a/src/tools/vendored/serper_client.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Serper API client for Google searches. - -Vendored and adapted from folder/tools/web_search.py. -""" - -import os - -import aiohttp -import structlog - -from src.tools.vendored.web_search_core import WebpageSnippet, ssl_context -from src.utils.exceptions import ConfigurationError, RateLimitError, SearchError - -logger = structlog.get_logger() - - -class SerperClient: - """A client for the Serper API to perform Google searches.""" - - def __init__(self, api_key: str | None = None) -> None: - """Initialize Serper client. - - Args: - api_key: Serper API key. If None, reads from SERPER_API_KEY env var. - - Raises: - ConfigurationError: If no API key is provided. - """ - self.api_key = api_key or os.getenv("SERPER_API_KEY") - if not self.api_key: - from src.utils.exceptions import ConfigurationError - - raise ConfigurationError( - "No API key provided. Set SERPER_API_KEY environment variable." - ) - - # Serper API endpoint and headers - # Documentation: https://serper.dev/api - # Format: POST https://google.serper.dev/search - # Headers: X-API-KEY (required), Content-Type: application/json - # Body: {"q": "search query", "autocorrect": false} - self.url = "https://google.serper.dev/search" - self.headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} - - async def search( - self, query: str, filter_for_relevance: bool = False, max_results: int = 5 - ) -> list[WebpageSnippet]: - """Perform a Google search using Serper API. - - Args: - query: The search query - filter_for_relevance: Whether to filter results (currently not implemented) - max_results: Maximum number of results to return - - Returns: - List of WebpageSnippet objects with search results - - Raises: - SearchError: If the search fails - RateLimitError: If rate limit is exceeded - """ - connector = aiohttp.TCPConnector(ssl=ssl_context) - try: - async with aiohttp.ClientSession(connector=connector) as session: - # Verify API call format matches Serper API documentation: - # POST https://google.serper.dev/search - # Headers: X-API-KEY, Content-Type: application/json - # Body: {"q": query, "autocorrect": false} - async with session.post( - self.url, - headers=self.headers, - json={"q": query, "autocorrect": False}, - timeout=aiohttp.ClientTimeout(total=30), # 30 second timeout - ) as response: - if response.status == 429: - raise RateLimitError("Serper API rate limit exceeded") - - if response.status == 403: - # 403 can mean either invalid key OR credits exhausted - # For free tier (2,500 credits), it's often credit exhaustion - # Read response body to get more details - try: - error_body = await response.text() - logger.warning( - "Serper API returned 403 Forbidden", - status=403, - body=error_body[:200], # Truncate for logging - hint="May be credit exhaustion (free tier: 2,500 credits) or invalid key", - ) - except Exception: - pass - - # Raise RateLimitError instead of ConfigurationError - # This allows retry logic to handle credit exhaustion - # The retry decorator will use exponential backoff with jitter - raise RateLimitError( - "Serper API credits may be exhausted (403 Forbidden). " - "Free tier provides 2,500 credits (one-time, expire after 6 months). " - "Check your dashboard at https://serper.dev/dashboard. " - "Retrying with backoff..." - ) - - response.raise_for_status() - results = await response.json() - - results_list = [ - WebpageSnippet( - url=result.get("link", ""), - title=result.get("title", ""), - description=result.get("snippet", ""), - ) - for result in results.get("organic", []) - ] - - if not results_list: - logger.info("No search results found", query=query) - return [] - - # Return results up to max_results - return results_list[:max_results] - - except aiohttp.ClientError as e: - logger.error("Serper API request failed", error=str(e), query=query) - raise SearchError(f"Serper API request failed: {e}") from e - except RateLimitError: - raise - except Exception as e: - logger.error("Unexpected error in Serper search", error=str(e), query=query) - raise SearchError(f"Serper search failed: {e}") from e - - - - - - - - - - diff --git a/src/tools/vendored/web_search_core.py b/src/tools/vendored/web_search_core.py deleted file mode 100644 index a465f3982df18728c2f4d4eebb21477bae3dbe65..0000000000000000000000000000000000000000 --- a/src/tools/vendored/web_search_core.py +++ /dev/null @@ -1,201 +0,0 @@ -"""Core web search utilities vendored from folder/tools/web_search.py. - -This module contains shared utilities for web scraping, URL processing, -and HTML text extraction used by web search tools. -""" - -import asyncio -import ssl - -import aiohttp -import structlog -from bs4 import BeautifulSoup -from pydantic import BaseModel, Field - -logger = structlog.get_logger() - -# Content length limit to avoid exceeding token limits -CONTENT_LENGTH_LIMIT = 10000 - -# Create a shared SSL context for web requests -ssl_context = ssl.create_default_context() -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE -ssl_context.set_ciphers("DEFAULT:@SECLEVEL=1") # Allow older cipher suites - - -class ScrapeResult(BaseModel): - """Result of scraping a single webpage.""" - - url: str = Field(description="The URL of the webpage") - text: str = Field(description="The full text content of the webpage") - title: str = Field(description="The title of the webpage") - description: str = Field(description="A short description of the webpage") - - -class WebpageSnippet(BaseModel): - """Snippet information for a webpage (before scraping).""" - - url: str = Field(description="The URL of the webpage") - title: str = Field(description="The title of the webpage") - description: str | None = Field(default=None, description="A short description of the webpage") - - -async def scrape_urls(items: list[WebpageSnippet]) -> list[ScrapeResult]: - """Fetch text content from provided URLs. - - Args: - items: List of WebpageSnippet items to extract content from - - Returns: - List of ScrapeResult objects with scraped content - """ - connector = aiohttp.TCPConnector(ssl=ssl_context) - async with aiohttp.ClientSession(connector=connector) as session: - # Create list of tasks for concurrent execution - tasks = [] - for item in items: - if item.url: # Skip empty URLs - tasks.append(fetch_and_process_url(session, item)) - - # Execute all tasks concurrently and gather results - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Filter out errors and return successful results - successful_results: list[ScrapeResult] = [] - for result in results: - if isinstance(result, ScrapeResult): - successful_results.append(result) - elif isinstance(result, Exception): - logger.warning("Failed to scrape URL", error=str(result)) - - return successful_results - - -async def fetch_and_process_url( - session: aiohttp.ClientSession, item: WebpageSnippet -) -> ScrapeResult: - """Helper function to fetch and process a single URL. - - Args: - session: aiohttp ClientSession - item: WebpageSnippet with URL to fetch - - Returns: - ScrapeResult with fetched content - """ - if not is_valid_url(item.url): - return ScrapeResult( - url=item.url, - title=item.title, - description=item.description or "", - text="Error fetching content: URL contains restricted file extension", - ) - - try: - timeout = aiohttp.ClientTimeout(total=8) - async with session.get(item.url, timeout=timeout) as response: - if response.status == 200: - content = await response.text() - # Run html_to_text in a thread pool to avoid blocking - loop = asyncio.get_event_loop() - text_content = await loop.run_in_executor(None, html_to_text, content) - text_content = text_content[ - :CONTENT_LENGTH_LIMIT - ] # Trim content to avoid exceeding token limit - return ScrapeResult( - url=item.url, - title=item.title, - description=item.description or "", - text=text_content, - ) - else: - # Return a ScrapeResult with an error message - return ScrapeResult( - url=item.url, - title=item.title, - description=item.description or "", - text=f"Error fetching content: HTTP {response.status}", - ) - except Exception as e: - logger.warning("Error fetching URL", url=item.url, error=str(e)) - # Return a ScrapeResult with an error message - return ScrapeResult( - url=item.url, - title=item.title, - description=item.description or "", - text=f"Error fetching content: {e!s}", - ) - - -def html_to_text(html_content: str) -> str: - """Strip out unnecessary elements from HTML to prepare for text extraction. - - Args: - html_content: Raw HTML content - - Returns: - Extracted text from relevant HTML tags - """ - # Parse the HTML using lxml for speed - soup = BeautifulSoup(html_content, "lxml") - - # Extract text from relevant tags - tags_to_extract = ("h1", "h2", "h3", "h4", "h5", "h6", "p", "li", "blockquote") - - # Use a generator expression for efficiency - extracted_text = "\n".join( - element.get_text(strip=True) - for element in soup.find_all(tags_to_extract) - if element.get_text(strip=True) - ) - - return extracted_text - - -def is_valid_url(url: str) -> bool: - """Check that a URL does not contain restricted file extensions. - - Args: - url: URL to validate - - Returns: - True if URL is valid, False if it contains restricted extensions - """ - restricted_extensions = [ - ".pdf", - ".doc", - ".xls", - ".ppt", - ".zip", - ".rar", - ".7z", - ".txt", - ".js", - ".xml", - ".css", - ".png", - ".jpg", - ".jpeg", - ".gif", - ".ico", - ".svg", - ".webp", - ".mp3", - ".mp4", - ".avi", - ".mov", - ".wmv", - ".flv", - ".wma", - ".wav", - ".m4a", - ".m4v", - ".m4b", - ".m4p", - ".m4u", - ] - - if any(ext in url for ext in restricted_extensions): - return False - return True diff --git a/src/tools/web_search.py b/src/tools/web_search.py deleted file mode 100644 index f6350e820f2edffa346fe2392828a38518ce3a00..0000000000000000000000000000000000000000 --- a/src/tools/web_search.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Web search tool using DuckDuckGo.""" - -import asyncio - -import structlog - -try: - from ddgs import DDGS # New package name -except ImportError: - # Fallback to old package name for backward compatibility - from duckduckgo_search import DDGS # type: ignore[no-redef] - -from src.tools.query_utils import preprocess_query -from src.utils.exceptions import SearchError -from src.utils.models import Citation, Evidence - -logger = structlog.get_logger() - - -class WebSearchTool: - """Tool for searching the web using DuckDuckGo.""" - - def __init__(self) -> None: - self._ddgs = DDGS() - - @property - def name(self) -> str: - """Return the name of this search tool.""" - return "duckduckgo" - - async def search(self, query: str, max_results: int = 10) -> list[Evidence]: - """Execute a web search and return evidence. - - Args: - query: The search query string - max_results: Maximum number of results to return - - Returns: - List of Evidence objects - - Raises: - SearchError: If the search fails - """ - try: - # Preprocess query to remove noise - clean_query = preprocess_query(query) - final_query = clean_query if clean_query else query - - loop = asyncio.get_running_loop() - - def _do_search() -> list[dict[str, str]]: - # text() returns an iterator, need to list() it or iterate - return list(self._ddgs.text(final_query, max_results=max_results)) - - raw_results = await loop.run_in_executor(None, _do_search) - - evidence = [] - for r in raw_results: - # Truncate title to max 500 characters to match Citation model validation - title = r.get("title", "No Title") - if len(title) > 500: - title = title[:497] + "..." - - ev = Evidence( - content=r.get("body", ""), - citation=Citation( - title=title, - url=r.get("href", ""), - source="web", - date="Unknown", - authors=[], - ), - relevance=0.0, - ) - evidence.append(ev) - - return evidence - - except Exception as e: - logger.error("Web search failed", error=str(e), query=query) - raise SearchError(f"DuckDuckGo search failed: {e}") from e diff --git a/src/tools/web_search_adapter.py b/src/tools/web_search_adapter.py deleted file mode 100644 index 6ee833f2fdda510bd99d3311b923a248b55fa963..0000000000000000000000000000000000000000 --- a/src/tools/web_search_adapter.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Web search tool adapter for Pydantic AI agents. - -Uses the new web search factory to provide web search functionality. -""" - -import structlog - -from src.tools.web_search_factory import create_web_search_tool - -logger = structlog.get_logger() - - -async def web_search(query: str) -> str: - """ - Perform a web search for a given query and return formatted results. - - Use this tool to search the web for information relevant to the query. - Provide a query with 3-6 words as input. - - Args: - query: The search query (3-6 words recommended) - - Returns: - Formatted string with search results including titles, descriptions, and URLs - """ - try: - # Get web search tool from factory - tool = create_web_search_tool() - - if tool is None: - logger.warning("Web search tool not available", hint="Check configuration") - return "Web search tool not available. Please configure a web search provider." - - # Call the tool - it returns list[Evidence] - evidence = await tool.search(query, max_results=5) - - if not evidence: - return f"No web search results found for: {query}" - - # Format results for agent consumption - formatted = [f"Found {len(evidence)} web search results:\n"] - for i, ev in enumerate(evidence, 1): - citation = ev.citation - formatted.append(f"{i}. **{citation.title}**") - if citation.url: - formatted.append(f" URL: {citation.url}") - if ev.content: - formatted.append(f" Content: {ev.content[:300]}...") - formatted.append("") - - return "\n".join(formatted) - - except Exception as e: - logger.error("Web search failed", error=str(e), query=query) - return f"Error performing web search: {e!s}" diff --git a/src/tools/web_search_factory.py b/src/tools/web_search_factory.py deleted file mode 100644 index ae79811639a141e49053aff27f37336dbf146755..0000000000000000000000000000000000000000 --- a/src/tools/web_search_factory.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Factory for creating web search tools based on configuration.""" - -import structlog - -from src.tools.base import SearchTool -from src.tools.fallback_web_search import FallbackWebSearchTool -from src.tools.searchxng_web_search import SearchXNGWebSearchTool -from src.tools.serper_web_search import SerperWebSearchTool -from src.tools.web_search import WebSearchTool -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger() - - -def create_web_search_tool(provider: str | None = None) -> SearchTool | None: - """Create a web search tool based on configuration. - - Args: - provider: Override provider selection. If None, uses settings.web_search_provider. - - Returns: - SearchTool instance, or None if not available/configured - - The tool is selected based on provider (or settings.web_search_provider if None): - - "serper": SerperWebSearchTool (requires SERPER_API_KEY) - - "searchxng": SearchXNGWebSearchTool (requires SEARCHXNG_HOST) - - "duckduckgo": WebSearchTool (always available, no API key) - - "brave" or "tavily": Not yet implemented, returns None - - "auto": Auto-detect best available provider (prefers Serper > SearchXNG > DuckDuckGo) - - Auto-detection logic (when provider is "auto" or not explicitly set): - 1. Try Serper if SERPER_API_KEY is available (best quality - Google search + full content scraping) - 2. Try SearchXNG if SEARCHXNG_HOST is available - 3. Fall back to DuckDuckGo (always available, but lower quality - snippets only) - """ - provider = provider or settings.web_search_provider - - # Auto-detect best available provider if "auto" or if provider is duckduckgo but better options exist - if provider == "auto" or (provider == "duckduckgo" and settings.serper_api_key): - # Use fallback tool if Serper API key is available - # This automatically falls back to DuckDuckGo on any Serper error - if settings.serper_api_key: - try: - logger.info( - "Auto-detected Serper with DuckDuckGo fallback (SERPER_API_KEY found)", - provider="serper+duckduckgo", - ) - return FallbackWebSearchTool() - except Exception as e: - logger.warning( - "Failed to initialize fallback web search, trying alternatives", - error=str(e), - ) - - # Try SearchXNG as second choice - if settings.searchxng_host: - try: - logger.info( - "Auto-detected SearchXNG web search (SEARCHXNG_HOST found)", - provider="searchxng", - ) - return SearchXNGWebSearchTool() - except Exception as e: - logger.warning( - "Failed to initialize SearchXNG, falling back", - error=str(e), - ) - - # Fall back to DuckDuckGo only - if provider == "auto": - logger.info( - "Auto-detected DuckDuckGo web search (no API keys found)", - provider="duckduckgo", - ) - return WebSearchTool() - - try: - if provider == "serper": - if not settings.serper_api_key: - logger.warning( - "Serper provider selected but no API key found", - hint="Set SERPER_API_KEY environment variable", - ) - return None - return SerperWebSearchTool() - - elif provider == "searchxng": - if not settings.searchxng_host: - logger.warning( - "SearchXNG provider selected but no host found", - hint="Set SEARCHXNG_HOST environment variable", - ) - return None - return SearchXNGWebSearchTool() - - elif provider == "duckduckgo": - # DuckDuckGo is always available (no API key required) - return WebSearchTool() - - elif provider in ("brave", "tavily"): - logger.warning( - f"Web search provider '{provider}' not yet implemented", - hint="Use 'serper', 'searchxng', or 'duckduckgo'", - ) - return None - - else: - logger.warning(f"Unknown web search provider '{provider}', falling back to DuckDuckGo") - return WebSearchTool() - - except ConfigurationError as e: - logger.error("Failed to create web search tool", error=str(e), provider=provider) - return None - except Exception as e: - logger.error("Unexpected error creating web search tool", error=str(e), provider=provider) - return None diff --git a/src/agent_factory/__init__.py b/src/tools/websearch.py similarity index 100% rename from src/agent_factory/__init__.py rename to src/tools/websearch.py diff --git a/src/utils/__init__.py b/src/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/utils/citation_validator.py b/src/utils/citation_validator.py deleted file mode 100644 index c124c1a529ad46449c5182576a0e58a54e22381c..0000000000000000000000000000000000000000 --- a/src/utils/citation_validator.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Citation validation to prevent LLM hallucination. - -CRITICAL: Medical research requires accurate citations. -This module validates that all references exist in collected evidence. -""" - -import logging -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from src.utils.models import Evidence, ResearchReport - -logger = logging.getLogger(__name__) - -# Max characters to display for URLs in log messages -_MAX_URL_DISPLAY_LENGTH = 80 - - -def validate_references(report: "ResearchReport", evidence: list["Evidence"]) -> "ResearchReport": - """Ensure all references actually exist in collected evidence. - - CRITICAL: Prevents LLM hallucination of citations. - - Note: - This function MUTATES report.references in-place and returns the same - report object. This is intentional for efficiency. - - Args: - report: The generated research report (will be mutated) - evidence: All evidence collected during research - - Returns: - The same report object with references updated in-place - """ - # Build set of valid URLs from evidence - valid_urls = {e.citation.url for e in evidence} - # Also check titles (case-insensitive, exact match) as fallback - valid_titles = {e.citation.title.lower() for e in evidence} - - validated_refs = [] - removed_count = 0 - - for ref in report.references: - ref_url = ref.get("url", "") - ref_title = ref.get("title", "").lower() - - # Check if URL matches collected evidence - if ref_url in valid_urls: - validated_refs.append(ref) - # Fallback: exact title match (case-insensitive) - elif ref_title and ref_title in valid_titles: - validated_refs.append(ref) - else: - removed_count += 1 - # Truncate URL for display - if len(ref_url) > _MAX_URL_DISPLAY_LENGTH: - url_display = ref_url[:_MAX_URL_DISPLAY_LENGTH] + "..." - else: - url_display = ref_url - logger.warning( - f"Removed hallucinated reference: '{ref.get('title', 'Unknown')}' " - f"(URL: {url_display})" - ) - - if removed_count > 0: - logger.info( - f"Citation validation removed {removed_count} hallucinated references. " - f"{len(validated_refs)} valid references remain." - ) - - # Update report with validated references - report.references = validated_refs - return report - - -def build_reference_from_evidence(evidence: "Evidence") -> dict[str, str]: - """Build a properly formatted reference from evidence. - - Use this to ensure references match the original evidence exactly. - """ - return { - "title": evidence.citation.title, - "authors": ", ".join(evidence.citation.authors or ["Unknown"]), - "source": evidence.citation.source, - "date": evidence.citation.date or "n.d.", - "url": evidence.citation.url, - } - - -def validate_markdown_citations( - markdown_report: str, evidence: list["Evidence"] -) -> tuple[str, int]: - """Validate citations in a markdown report against collected evidence. - - This function validates citations in markdown format (e.g., [1], [2]) by: - 1. Extracting URLs from the references section - 2. Matching them against Evidence objects - 3. Removing invalid citations from the report - - Note: - This is a basic validation. For full validation, use ResearchReport - objects with validate_references(). - - Args: - markdown_report: The markdown report string with citations - evidence: List of Evidence objects collected during research - - Returns: - Tuple of (validated_markdown, removed_count) - """ - import re - - # Build set of valid URLs from evidence - valid_urls = {e.citation.url for e in evidence} - valid_urls_lower = {url.lower() for url in valid_urls} - - # Extract references section (everything after "## References" or "References:") - ref_section_pattern = r"(?i)(?:##\s*)?References:?\s*\n(.*?)(?=\n##|\Z)" - ref_match = re.search(ref_section_pattern, markdown_report, re.DOTALL) - - if not ref_match: - # No references section found, return as-is - return markdown_report, 0 - - ref_section = ref_match.group(1) - ref_lines = ref_section.strip().split("\n") - - # Parse references: [1] https://example.com or [1] https://example.com Title - valid_refs = [] - removed_count = 0 - - for ref_line in ref_lines: - stripped_line = ref_line.strip() - if not stripped_line: - continue - - # Extract URL from reference line - # Pattern: [N] URL or [N] URL Title - url_match = re.search(r"https?://[^\s\)]+", stripped_line) - if url_match: - url = url_match.group(0).rstrip(".,;") - url_lower = url.lower() - - # Check if URL is valid - if url in valid_urls or url_lower in valid_urls_lower: - valid_refs.append(stripped_line) - else: - removed_count += 1 - logger.warning( - f"Removed invalid citation from markdown: {url[:80]}" - + ("..." if len(url) > 80 else "") - ) - else: - # No URL found, keep the line (might be formatted differently) - valid_refs.append(stripped_line) - - # Rebuild references section - if valid_refs: - new_ref_section = "\n".join(valid_refs) - # Replace the old references section - validated_markdown = ( - markdown_report[: ref_match.start(1)] - + new_ref_section - + markdown_report[ref_match.end(1) :] - ) - else: - # No valid references, remove the entire section - validated_markdown = ( - markdown_report[: ref_match.start()] + markdown_report[ref_match.end() :] - ) - - if removed_count > 0: - logger.info( - f"Citation validation removed {removed_count} invalid citations from markdown report. " - f"{len(valid_refs)} valid citations remain." - ) - - return validated_markdown, removed_count diff --git a/src/utils/config.py b/src/utils/config.py deleted file mode 100644 index 9d4356d7b5afd35c6e782c870708cf12bdaae657..0000000000000000000000000000000000000000 --- a/src/utils/config.py +++ /dev/null @@ -1,329 +0,0 @@ -"""Application configuration using Pydantic Settings.""" - -import logging -from typing import Literal - -import structlog -from pydantic import Field -from pydantic_settings import BaseSettings, SettingsConfigDict - -from src.utils.exceptions import ConfigurationError - - -class Settings(BaseSettings): - """Strongly-typed application settings.""" - - model_config = SettingsConfigDict( - env_file=".env", - env_file_encoding="utf-8", - case_sensitive=False, - extra="ignore", - ) - - # LLM Configuration - openai_api_key: str | None = Field(default=None, description="OpenAI API key") - anthropic_api_key: str | None = Field(default=None, description="Anthropic API key") - llm_provider: Literal["openai", "anthropic", "huggingface"] = Field( - default="huggingface", description="Which LLM provider to use" - ) - openai_model: str = Field(default="gpt-5.1", description="OpenAI model name") - anthropic_model: str = Field( - default="claude-sonnet-4-5-20250929", description="Anthropic model" - ) - hf_token: str | None = Field( - default=None, alias="HF_TOKEN", description="HuggingFace API token" - ) - - # Embedding Configuration - # Note: OpenAI embeddings require OPENAI_API_KEY (Anthropic has no embeddings API) - openai_embedding_model: str = Field( - default="text-embedding-3-small", - description="OpenAI embedding model (used by LlamaIndex RAG)", - ) - local_embedding_model: str = Field( - default="all-MiniLM-L6-v2", - description="Local sentence-transformers model (used by EmbeddingService)", - ) - embedding_provider: Literal["openai", "local", "huggingface"] = Field( - default="local", - description="Embedding provider to use", - ) - huggingface_embedding_model: str = Field( - default="sentence-transformers/all-MiniLM-L6-v2", - description="HuggingFace embedding model ID", - ) - - # HuggingFace Configuration - huggingface_api_key: str | None = Field( - default=None, description="HuggingFace API token (HF_TOKEN or HUGGINGFACE_API_KEY)" - ) - huggingface_model: str = Field( - default="meta-llama/Llama-3.1-8B-Instruct", - description="Default HuggingFace model ID for inference", - ) - hf_fallback_models: str = Field( - default="Qwen/Qwen3-Next-80B-A3B-Thinking,Qwen/Qwen3-Next-80B-A3B-Instruct,meta-llama/Llama-3.3-70B-Instruct,meta-llama/Llama-3.1-8B-Instruct,HuggingFaceH4/zephyr-7b-beta,Qwen/Qwen2-7B-Instruct", - alias="HF_FALLBACK_MODELS", - description=( - "Comma-separated list of fallback models for provider discovery and error recovery. " - "Reads from HF_FALLBACK_MODELS environment variable. " - "Default value is used only if the environment variable is not set." - ), - ) - - # PubMed Configuration - ncbi_api_key: str | None = Field( - default=None, description="NCBI API key for higher rate limits" - ) - - # Web Search Configuration - web_search_provider: Literal["serper", "searchxng", "brave", "tavily", "duckduckgo", "auto"] = ( - Field( - default="auto", - description="Web search provider to use. 'auto' will auto-detect best available (prefers Serper > SearchXNG > DuckDuckGo)", - ) - ) - serper_api_key: str | None = Field(default=None, description="Serper API key for Google search") - searchxng_host: str | None = Field(default=None, description="SearchXNG host URL") - brave_api_key: str | None = Field(default=None, description="Brave Search API key") - tavily_api_key: str | None = Field(default=None, description="Tavily API key") - - # Agent Configuration - max_iterations: int = Field(default=10, ge=1, le=50) - search_timeout: int = Field(default=30, description="Seconds to wait for search") - use_graph_execution: bool = Field( - default=False, description="Use graph-based execution for research flows" - ) - - # Budget & Rate Limiting Configuration - default_token_limit: int = Field( - default=100000, - ge=1000, - le=1000000, - description="Default token budget per research loop", - ) - default_time_limit_minutes: int = Field( - default=10, - ge=1, - le=120, - description="Default time limit per research loop (minutes)", - ) - default_iterations_limit: int = Field( - default=10, - ge=1, - le=50, - description="Default iterations limit per research loop", - ) - - # Logging - log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO" - - # External Services - modal_token_id: str | None = Field(default=None, description="Modal token ID") - modal_token_secret: str | None = Field(default=None, description="Modal token secret") - chroma_db_path: str = Field(default="./chroma_db", description="ChromaDB storage path") - chroma_db_persist: bool = Field( - default=True, - description="Whether to persist ChromaDB to disk", - ) - chroma_db_host: str | None = Field( - default=None, - description="ChromaDB server host (for remote ChromaDB)", - ) - chroma_db_port: int | None = Field( - default=None, - description="ChromaDB server port (for remote ChromaDB)", - ) - - # RAG Service Configuration - rag_collection_name: str = Field( - default="deepcritical_evidence", - description="ChromaDB collection name for RAG", - ) - rag_similarity_top_k: int = Field( - default=5, - ge=1, - le=50, - description="Number of top results to retrieve from RAG", - ) - rag_auto_ingest: bool = Field( - default=True, - description="Automatically ingest evidence into RAG", - ) - - # Audio/TTS Configuration - enable_audio_input: bool = Field( - default=True, - description="Enable audio input (speech-to-text) in multimodal interface", - ) - enable_audio_output: bool = Field( - default=True, - description="Enable audio output (text-to-speech) for responses", - ) - enable_image_input: bool = Field( - default=True, - description="Enable image input (OCR) in multimodal interface", - ) - tts_voice: str = Field( - default="af_heart", - description="TTS voice ID for Kokoro TTS (e.g., af_heart, am_michael)", - ) - tts_speed: float = Field( - default=1.0, - ge=0.5, - le=2.0, - description="TTS speech speed multiplier (0.5x to 2.0x)", - ) - tts_use_llm_polish: bool = Field( - default=False, - description="Use LLM for final text polish before TTS (optional, costs API calls)", - ) - tts_gpu: str | None = Field( - default=None, - description="Modal GPU type for TTS (T4, A10, A100, L4, L40S). None uses default T4.", - ) - - # STT (Speech-to-Text) Configuration - stt_api_url: str | None = Field( - default="https://nvidia-canary-1b-v2.hf.space", - description="Gradio Space URL for STT service (default: nvidia/canary-1b-v2)", - ) - stt_source_lang: str = Field( - default="English", - description="Source language for STT (full name like 'English', 'Spanish', etc.)", - ) - stt_target_lang: str = Field( - default="English", - description="Target language for STT (full name like 'English', 'Spanish', etc.)", - ) - - # Image OCR Configuration - ocr_api_url: str | None = Field( - default="https://prithivmlmods-multimodal-ocr3.hf.space", - description="Gradio Space URL for OCR service (default: prithivMLmods/Multimodal-OCR3)", - ) - - # Report File Output Configuration - save_reports_to_file: bool = Field( - default=True, - description="Save generated reports to files (enables file downloads in Gradio)", - ) - report_output_directory: str | None = Field( - default=None, - description="Directory to save report files. If None, uses system temp directory.", - ) - report_file_format: Literal["md", "md_html", "md_pdf"] = Field( - default="md", - description="File format(s) to save reports in. 'md' saves only markdown, others save multiple formats.", - ) - report_filename_template: str = Field( - default="report_{timestamp}_{query_hash}.md", - description="Template for report filenames. Supports {timestamp}, {query_hash}, {date} placeholders.", - ) - - @property - def modal_available(self) -> bool: - """Check if Modal credentials are configured.""" - return bool(self.modal_token_id and self.modal_token_secret) - - def get_api_key(self) -> str: - """Get the API key for the configured provider.""" - if self.llm_provider == "openai": - if not self.openai_api_key: - raise ConfigurationError("OPENAI_API_KEY not set") - return self.openai_api_key - - if self.llm_provider == "anthropic": - if not self.anthropic_api_key: - raise ConfigurationError("ANTHROPIC_API_KEY not set") - return self.anthropic_api_key - - raise ConfigurationError(f"Unknown LLM provider: {self.llm_provider}") - - def get_openai_api_key(self) -> str: - """Get OpenAI API key (required for Magentic function calling).""" - if not self.openai_api_key: - raise ConfigurationError( - "OPENAI_API_KEY not set. Magentic mode requires OpenAI for function calling. " - "Use mode='simple' for other providers." - ) - return self.openai_api_key - - @property - def has_openai_key(self) -> bool: - """Check if OpenAI API key is available.""" - return bool(self.openai_api_key) - - @property - def has_anthropic_key(self) -> bool: - """Check if Anthropic API key is available.""" - return bool(self.anthropic_api_key) - - @property - def has_huggingface_key(self) -> bool: - """Check if HuggingFace API key is available.""" - return bool(self.huggingface_api_key or self.hf_token) - - @property - def has_any_llm_key(self) -> bool: - """Check if any LLM API key is available.""" - return self.has_openai_key or self.has_anthropic_key or self.has_huggingface_key - - @property - def web_search_available(self) -> bool: - """Check if web search is available (either no-key provider or API key present).""" - if self.web_search_provider == "duckduckgo": - return True # No API key required - if self.web_search_provider == "serper": - return bool(self.serper_api_key) - if self.web_search_provider == "searchxng": - return bool(self.searchxng_host) - if self.web_search_provider == "brave": - return bool(self.brave_api_key) - if self.web_search_provider == "tavily": - return bool(self.tavily_api_key) - return False - - def get_hf_fallback_models_list(self) -> list[str]: - """Get the list of fallback models as a list. - - Parses the comma-separated HF_FALLBACK_MODELS string into a list, - stripping whitespace from each model ID. - - Returns: - List of model IDs - """ - if not self.hf_fallback_models: - return [] - return [model.strip() for model in self.hf_fallback_models.split(",") if model.strip()] - - -def get_settings() -> Settings: - """Factory function to get settings (allows mocking in tests).""" - return Settings() - - -def configure_logging(settings: Settings) -> None: - """Configure structured logging with the configured log level.""" - # Set stdlib logging level from settings - logging.basicConfig( - level=getattr(logging, settings.log_level), - format="%(message)s", - ) - - structlog.configure( - processors=[ - structlog.stdlib.filter_by_level, - structlog.stdlib.add_logger_name, - structlog.stdlib.add_log_level, - structlog.processors.TimeStamper(fmt="iso"), - structlog.processors.JSONRenderer(), - ], - wrapper_class=structlog.stdlib.BoundLogger, - context_class=dict, - logger_factory=structlog.stdlib.LoggerFactory(), - ) - - -# Singleton for easy import -settings = get_settings() diff --git a/src/utils/exceptions.py b/src/utils/exceptions.py deleted file mode 100644 index abedf1baf49f3f88f01b2451a81c71c221b0caac..0000000000000000000000000000000000000000 --- a/src/utils/exceptions.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Custom exceptions for The DETERMINATOR.""" - - -class DeepCriticalError(Exception): - """Base exception for all DETERMINATOR errors. - - Note: Class name kept for backward compatibility. - """ - - pass - - -class SearchError(DeepCriticalError): - """Raised when a search operation fails.""" - - pass - - -class JudgeError(DeepCriticalError): - """Raised when the judge fails to assess evidence.""" - - pass - - -class ConfigurationError(DeepCriticalError): - """Raised when configuration is invalid.""" - - pass - - -class RateLimitError(SearchError): - """Raised when we hit API rate limits.""" - - pass diff --git a/src/utils/hf_error_handler.py b/src/utils/hf_error_handler.py deleted file mode 100644 index becc6862275260ef63249fee69fc7225468dec30..0000000000000000000000000000000000000000 --- a/src/utils/hf_error_handler.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Utility functions for handling HuggingFace API errors and token validation.""" - -import re -from typing import Any - -import structlog - -logger = structlog.get_logger() - - -def extract_error_details(error: Exception) -> dict[str, Any]: - """Extract error details from HuggingFace API errors. - - Pydantic AI and HuggingFace Inference API errors often contain - information in the error message string like: - "status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden" - - Args: - error: The exception object - - Returns: - Dictionary with extracted error details: - - status_code: HTTP status code (if found) - - model_name: Model name (if found) - - body: Error body/message (if found) - - error_type: Type of error (403, 422, etc.) - - is_auth_error: Whether this is an authentication/authorization error - - is_model_error: Whether this is a model-specific error - """ - error_str = str(error) - details: dict[str, Any] = { - "status_code": None, - "model_name": None, - "body": None, - "error_type": "unknown", - "is_auth_error": False, - "is_model_error": False, - } - - # Try to extract status_code - status_match = re.search(r"status_code:\s*(\d+)", error_str) - if status_match: - details["status_code"] = int(status_match.group(1)) - details["error_type"] = f"http_{details['status_code']}" - - # Determine error category - if details["status_code"] == 403: - details["is_auth_error"] = True - elif details["status_code"] == 422: - details["is_model_error"] = True - - # Try to extract model_name - model_match = re.search(r"model_name:\s*([^\s,]+)", error_str) - if model_match: - details["model_name"] = model_match.group(1) - - # Try to extract body - body_match = re.search(r"body:\s*(.+)", error_str) - if body_match: - details["body"] = body_match.group(1).strip() - - return details - - -def get_user_friendly_error_message(error: Exception, model_name: str | None = None) -> str: - """Generate a user-friendly error message from an exception. - - Args: - error: The exception object - model_name: Optional model name for context - - Returns: - User-friendly error message - """ - details = extract_error_details(error) - - if details["is_auth_error"]: - return ( - "🔐 **Authentication Error**\n\n" - "Your HuggingFace token doesn't have permission to access this model or API.\n\n" - "**Possible solutions:**\n" - "1. **Re-authenticate**: Log out and log back in to ensure your token has the `inference-api` scope\n" - "2. **Check model access**: Visit the model page on HuggingFace and request access if it's gated\n" - "3. **Use alternative model**: Try a different model that's publicly available\n\n" - f"**Model attempted**: {details['model_name'] or model_name or 'Unknown'}\n" - f"**Error**: {details['body'] or str(error)}" - ) - - if details["is_model_error"]: - return ( - "⚠️ **Model Compatibility Error**\n\n" - "The selected model is not compatible with the current provider or has specific requirements.\n\n" - "**Possible solutions:**\n" - "1. **Try a different model**: Use a model that's compatible with the current provider\n" - "2. **Check provider status**: The provider may be in staging mode or unavailable\n" - "3. **Wait and retry**: If the model is in staging, it may become available later\n\n" - f"**Model attempted**: {details['model_name'] or model_name or 'Unknown'}\n" - f"**Error**: {details['body'] or str(error)}" - ) - - # Generic error - return ( - "❌ **API Error**\n\n" - f"An error occurred while calling the HuggingFace API:\n\n" - f"**Error**: {error!s}\n\n" - "Please try again or contact support if the issue persists." - ) - - -def validate_hf_token(token: str | None) -> tuple[bool, str | None]: - """Validate HuggingFace token format. - - Args: - token: The token to validate - - Returns: - Tuple of (is_valid, error_message) - - is_valid: True if token appears valid - - error_message: Error message if invalid, None if valid - """ - if not token: - return False, "Token is None or empty" - - if not isinstance(token, str): - return False, f"Token is not a string (type: {type(token).__name__})" - - if len(token) < 10: - return False, "Token appears too short (minimum 10 characters expected)" - - # HuggingFace tokens typically start with "hf_" for user tokens - # OAuth tokens may have different formats, so we're lenient - # Just check it's not obviously invalid - - return True, None - - -def log_token_info(token: str | None, context: str = "") -> None: - """Log token information for debugging (without exposing the actual token). - - Args: - token: The token to log info about - context: Additional context for the log message - """ - if token: - is_valid, error_msg = validate_hf_token(token) - logger.debug( - "Token validation", - context=context, - has_token=True, - is_valid=is_valid, - token_length=len(token), - token_prefix=token[:4] + "..." if len(token) > 4 else "***", - validation_error=error_msg, - ) - else: - logger.debug("Token validation", context=context, has_token=False) - - -def should_retry_with_fallback(error: Exception) -> bool: - """Determine if an error should trigger a fallback to alternative models. - - Args: - error: The exception object - - Returns: - True if the error suggests we should try a fallback model - """ - details = extract_error_details(error) - - # Retry with fallback for: - # - 403 errors (authentication/permission issues - might work with different model) - # - 422 errors (model/provider compatibility - definitely try different model) - # - Model-specific errors - return ( - details["is_auth_error"] or details["is_model_error"] or details["model_name"] is not None - ) - - -def get_fallback_models(original_model: str | None = None) -> list[str]: - """Get a list of fallback models to try. - - Args: - original_model: The original model that failed - - Returns: - List of fallback model names to try in order - """ - # Publicly available models that should work with most tokens - fallbacks = [ - "meta-llama/Llama-3.1-8B-Instruct", # Common, often available - "mistralai/Mistral-7B-Instruct-v0.3", # Alternative - "HuggingFaceH4/zephyr-7b-beta", # Ungated fallback - ] - - # If original model is in the list, remove it - if original_model and original_model in fallbacks: - fallbacks.remove(original_model) - - return fallbacks diff --git a/src/utils/hf_model_validator.py b/src/utils/hf_model_validator.py deleted file mode 100644 index e7dd75eea282b66fc655d4fb33fe1a884af397f7..0000000000000000000000000000000000000000 --- a/src/utils/hf_model_validator.py +++ /dev/null @@ -1,477 +0,0 @@ -"""Validator for querying available HuggingFace models and providers using OAuth token. - -This module provides functions to: -1. Query available models from HuggingFace Hub -2. Query available inference providers (with dynamic discovery) -3. Validate model/provider combinations -4. Return formatted lists for Gradio dropdowns - -Uses Hugging Face Hub API to discover providers dynamically by querying model -information. Falls back to known providers list if discovery fails. -""" - -import asyncio -from time import time -from typing import Any - -import structlog -from huggingface_hub import HfApi - -from src.utils.config import settings - -logger = structlog.get_logger() - - -def extract_oauth_token(oauth_token: Any) -> str | None: - """Extract OAuth token value from Gradio OAuthToken object. - - Handles both gr.OAuthToken objects (with .token attribute) and plain strings. - This is a convenience function for Gradio apps that use OAuth authentication. - - Args: - oauth_token: Gradio OAuthToken object or string token - - Returns: - Token string if available, None otherwise - """ - if oauth_token is None: - return None - - if hasattr(oauth_token, "token"): - return oauth_token.token # type: ignore[no-any-return] - elif isinstance(oauth_token, str): - return oauth_token - - logger.warning( - "Could not extract token from OAuthToken object", - oauth_token_type=type(oauth_token).__name__, - ) - return None - - -# Known providers as fallback (updated from Hugging Face documentation) -# These are used when dynamic discovery fails or times out -KNOWN_PROVIDERS = [ - "auto", # Auto-select (always available) - "hf-inference", # HuggingFace's own Inference API - "nebius", - "together", - "scaleway", - "hyperbolic", - "novita", - "nscale", - "sambanova", - "ovh", - "fireworks-ai", # Note: API uses "fireworks-ai", not "fireworks" - "cerebras", - "fal-ai", - "cohere", -] - - -def get_provider_discovery_models() -> list[str]: - """Get list of models to use for provider discovery. - - Reads from HF_FALLBACK_MODELS environment variable via settings. - The environment variable should be a comma-separated list of model IDs. - - Returns: - List of model IDs to query for provider discovery - """ - # Get models from HF_FALLBACK_MODELS environment variable - # This is automatically read by Pydantic Settings from the env var - fallback_models = settings.get_hf_fallback_models_list() - - logger.debug( - "Using HF_FALLBACK_MODELS for provider discovery", - count=len(fallback_models), - models=fallback_models, - ) - - return fallback_models - - -# Simple in-memory cache for provider lists (TTL: 1 hour) -_provider_cache: dict[str, tuple[list[str], float]] = {} -PROVIDER_CACHE_TTL = 3600 # 1 hour in seconds - - -async def get_available_providers(token: str | None = None) -> list[str]: - """Get list of available inference providers. - - Discovers providers dynamically by querying model information from HuggingFace Hub. - Uses caching to avoid repeated API calls. Falls back to known providers if discovery fails. - - Strategy: - 1. Check cache (if valid, return cached list) - 2. Query popular models to extract unique providers from their inferenceProviderMapping - 3. Fall back to known providers list if discovery fails - 4. Cache results for future use - - Args: - token: Optional HuggingFace API token for authenticated requests - Can be extracted from gr.OAuthToken.token in Gradio apps - - Returns: - List of provider names sorted alphabetically, with "auto" first - (e.g., ["auto", "fireworks-ai", "hf-inference", "nebius", ...]) - """ - # Check cache first - cache_key = "providers" + (f"_{token[:8]}" if token else "_no_token") - if cache_key in _provider_cache: - cached_providers, cache_time = _provider_cache[cache_key] - if time() - cache_time < PROVIDER_CACHE_TTL: - logger.debug("Returning cached providers", count=len(cached_providers)) - return cached_providers - - try: - providers = set(["auto"]) # Always include "auto" - - # Try dynamic discovery by querying popular models - loop = asyncio.get_running_loop() - api = HfApi(token=token) - - # Get models to query from HF_FALLBACK_MODELS environment variable via settings - discovery_models = get_provider_discovery_models() - - # Query a sample of popular models to discover providers - # This is more efficient than querying all models - discovery_count = 0 - for model_id in discovery_models: - try: - - def _get_model_info(m: str) -> Any: - """Get model info synchronously.""" - return api.model_info(m, expand=["inferenceProviderMapping"]) # type: ignore[arg-type] - - info = await loop.run_in_executor(None, _get_model_info, model_id) - - # Extract providers from inference_provider_mapping - if hasattr(info, "inference_provider_mapping") and info.inference_provider_mapping: - mapping = info.inference_provider_mapping - # mapping is a dict like {'hf-inference': InferenceProviderMapping(...), ...} - providers.update(mapping.keys()) - discovery_count += 1 - logger.debug( - "Discovered providers from model", - model=model_id, - providers=list(mapping.keys()), - ) - except Exception as e: - logger.debug( - "Could not get provider info for model", - model=model_id, - error=str(e), - ) - continue - - # If we discovered providers, use them; otherwise fall back to known providers - if len(providers) > 1: # More than just "auto" - provider_list = sorted(list(providers)) - logger.info( - "Discovered providers dynamically", - count=len(provider_list), - models_queried=discovery_count, - has_token=bool(token), - ) - else: - # Fallback to known providers - provider_list = KNOWN_PROVIDERS.copy() - logger.info( - "Using known providers list (discovery failed or incomplete)", - count=len(provider_list), - models_queried=discovery_count, - ) - - # Cache the results - _provider_cache[cache_key] = (provider_list, time()) - - return provider_list - - except Exception as e: - logger.warning("Failed to get providers", error=str(e)) - # Return known providers as fallback - return KNOWN_PROVIDERS.copy() - - -async def get_available_models( - token: str | None = None, - task: str = "text-generation", - limit: int = 100, - inference_provider: str | None = None, -) -> list[str]: - """Get list of available models for text generation. - - Queries HuggingFace Hub API to get models that support text generation. - Optionally filters by inference provider to show only models available via that provider. - - Args: - token: Optional HuggingFace API token for authenticated requests - Can be extracted from gr.OAuthToken.token in Gradio apps - task: Task type to filter models (default: "text-generation") - limit: Maximum number of models to return - inference_provider: Optional provider name to filter models (e.g., "fireworks-ai", "nebius") - If None, returns all models for the task - - Returns: - List of model IDs (e.g., ["meta-llama/Llama-3.1-8B-Instruct", ...]) - """ - try: - loop = asyncio.get_running_loop() - - def _fetch_models() -> list[str]: - """Fetch models synchronously in executor.""" - api = HfApi(token=token) - - # Build query parameters - query_params: dict[str, Any] = { - "task": task, - "sort": "downloads", - "direction": -1, - "limit": limit, - } - - # Filter by inference provider if specified - if inference_provider and inference_provider != "auto": - query_params["inference_provider"] = inference_provider - - # Search for models - models = api.list_models(**query_params) - - # Extract model IDs - model_ids = [model.id for model in models] - return model_ids - - model_ids = await loop.run_in_executor(None, _fetch_models) - - logger.info( - "Fetched available models", - count=len(model_ids), - task=task, - provider=inference_provider or "all", - has_token=bool(token), - ) - - return model_ids - - except Exception as e: - logger.warning("Failed to get models from Hub API", error=str(e)) - # Return popular fallback models - return [ - "meta-llama/Llama-3.1-8B-Instruct", - "mistralai/Mistral-7B-Instruct-v0.3", - "HuggingFaceH4/zephyr-7b-beta", - "google/gemma-2-9b-it", - ] - - -async def validate_model_provider_combination( - model_id: str, - provider: str | None, - token: str | None = None, -) -> tuple[bool, str | None]: - """Validate that a model is available with a specific provider. - - Uses HuggingFace Hub API to check if the provider is listed in the model's - inferenceProviderMapping. This is faster and more reliable than making test API calls. - - Args: - model_id: HuggingFace model ID - provider: Provider name (or None/empty for auto) - token: Optional HuggingFace API token (from gr.OAuthToken.token) - - Returns: - Tuple of (is_valid, error_message) - - is_valid: True if combination is valid or provider is "auto" - - error_message: Error message if invalid, None if valid - """ - # "auto" is always valid - let HuggingFace select the provider - if not provider or provider == "auto": - return True, None - - try: - loop = asyncio.get_running_loop() - api = HfApi(token=token) - - def _get_model_info() -> Any: - """Get model info with provider mapping synchronously.""" - return api.model_info(model_id, expand=["inferenceProviderMapping"]) # type: ignore[arg-type] - - info = await loop.run_in_executor(None, _get_model_info) - - # Check if provider is in the model's inference provider mapping - if hasattr(info, "inference_provider_mapping") and info.inference_provider_mapping: - mapping = info.inference_provider_mapping - available_providers = set(mapping.keys()) - - # Normalize provider name (some APIs use "fireworks-ai", others use "fireworks") - normalized_provider = provider.lower() - provider_variants = {normalized_provider} - - # Handle common provider name variations - if normalized_provider == "fireworks": - provider_variants.add("fireworks-ai") - elif normalized_provider == "fireworks-ai": - provider_variants.add("fireworks") - - # Check if any variant matches - if any(p in available_providers for p in provider_variants): - logger.debug( - "Model/provider combination validated via API", - model=model_id, - provider=provider, - available_providers=list(available_providers), - ) - return True, None - else: - error_msg = ( - f"Model {model_id} is not available with provider '{provider}'. " - f"Available providers: {', '.join(sorted(available_providers))}" - ) - logger.debug( - "Model/provider combination invalid", - model=model_id, - provider=provider, - available_providers=list(available_providers), - ) - return False, error_msg - else: - # Model doesn't have provider mapping - assume valid and let actual usage determine - logger.debug( - "Model has no provider mapping, assuming valid", - model=model_id, - provider=provider, - ) - return True, None - - except Exception as e: - logger.warning( - "Model/provider validation failed", - model=model_id, - provider=provider, - error=str(e), - ) - # Don't fail validation on error - let the actual request fail - # This is more user-friendly than blocking on validation errors - return True, None - - -async def get_models_for_provider( - provider: str, - token: str | None = None, - limit: int = 50, -) -> list[str]: - """Get models available for a specific provider. - - This is a convenience wrapper around get_available_models() with provider filtering. - - Args: - provider: Provider name (e.g., "nebius", "together", "fireworks-ai") - Note: Use "fireworks-ai" not "fireworks" for the API - token: Optional HuggingFace API token (from gr.OAuthToken.token) - limit: Maximum number of models to return - - Returns: - List of model IDs available for the provider - """ - # Normalize provider name for API - normalized_provider = provider - if provider.lower() == "fireworks": - normalized_provider = "fireworks-ai" - logger.debug("Normalized provider name", original=provider, normalized=normalized_provider) - - return await get_available_models( - token=token, - task="text-generation", - limit=limit, - inference_provider=normalized_provider, - ) - - -async def validate_oauth_token(token: str | None) -> dict[str, Any]: - """Validate OAuth token and return available resources. - - Args: - token: OAuth token to validate - - Returns: - Dictionary with: - - is_valid: Whether token is valid - - has_inference_api_scope: Whether token has inference-api scope - - available_models: List of available model IDs - - available_providers: List of available provider names - - username: HuggingFace username (if available) - - error: Error message if validation failed - """ - result: dict[str, Any] = { - "is_valid": False, - "has_inference_api_scope": False, - "available_models": [], - "available_providers": [], - "username": None, - "error": None, - } - - if not token: - result["error"] = "No token provided" - return result - - try: - # Validate token format - from src.utils.hf_error_handler import validate_hf_token - - is_valid_format, format_error = validate_hf_token(token) - if not is_valid_format: - result["error"] = f"Invalid token format: {format_error}" - return result - - # Try to get user info to validate token - loop = asyncio.get_running_loop() - - def _get_user_info() -> dict[str, Any] | None: - """Get user info from HuggingFace API.""" - try: - api = HfApi(token=token) - user_info = api.whoami() - return user_info - except Exception: - return None - - user_info = await loop.run_in_executor(None, _get_user_info) - - if user_info: - result["is_valid"] = True - result["username"] = user_info.get("name") or user_info.get("fullname") - logger.info("Token validated", username=result["username"]) - else: - result["error"] = "Token validation failed - could not authenticate" - return result - - # Try to query models to check inference-api scope - try: - models = await get_available_models(token=token, limit=10) - if models: - result["has_inference_api_scope"] = True - result["available_models"] = models - logger.info("Inference API scope confirmed", model_count=len(models)) - except Exception as e: - logger.warning("Could not verify inference-api scope", error=str(e)) - # Token might be valid but without inference-api scope - result["has_inference_api_scope"] = False - result["error"] = f"Token may not have inference-api scope: {e}" - - # Get available providers - try: - providers = await get_available_providers(token=token) - result["available_providers"] = providers - except Exception as e: - logger.warning("Could not get providers", error=str(e)) - # Use fallback providers - result["available_providers"] = ["auto"] - - return result - - except Exception as e: - logger.error("Token validation failed", error=str(e)) - result["error"] = str(e) - return result diff --git a/src/utils/huggingface_chat_client.py b/src/utils/huggingface_chat_client.py deleted file mode 100644 index 5b0d1f67a96e95b971f110d571486c816d9fb7c3..0000000000000000000000000000000000000000 --- a/src/utils/huggingface_chat_client.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Custom ChatClient implementation using HuggingFace InferenceClient. - -Uses HuggingFace InferenceClient which natively supports function calling, -making this a thin async wrapper rather than a complex implementation. - -Reference: https://huggingface.co/docs/huggingface_hub/package_reference/inference_client -""" - -import asyncio -from typing import Any - -import structlog -from huggingface_hub import InferenceClient - -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger() - - -class HuggingFaceChatClient: - """ChatClient implementation using HuggingFace InferenceClient. - - HuggingFace InferenceClient natively supports function calling via - the 'tools' parameter, making this a simple async wrapper. - - This client is compatible with agent-framework's ChatAgent interface. - """ - - def __init__( - self, - model_name: str = "meta-llama/Llama-3.1-8B-Instruct", - api_key: str | None = None, - provider: str = "auto", - ) -> None: - """Initialize HuggingFace chat client. - - Args: - model_name: HuggingFace model identifier (e.g., "meta-llama/Llama-3.1-8B-Instruct") - api_key: Optional HF_TOKEN for gated models. If None, uses environment token. - provider: Provider name or "auto" for automatic selection. - Options: "auto", "cerebras", "together", "sambanova", etc. - - Raises: - ConfigurationError: If initialization fails - """ - try: - # Type ignore: provider can be str but InferenceClient expects Literal - # We validate it's a valid provider at runtime - self.client = InferenceClient( - model=model_name, - api_key=api_key, - provider=provider, # type: ignore[arg-type] - ) - self.model_name = model_name - self.provider = provider - logger.info( - "Initialized HuggingFace chat client", - model=model_name, - provider=provider, - ) - except Exception as e: - raise ConfigurationError( - f"Failed to initialize HuggingFace InferenceClient: {e}" - ) from e - - async def chat_completion( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - tool_choice: str | dict[str, Any] | None = None, - temperature: float | None = None, - max_tokens: int | None = None, - ) -> Any: - """Send chat completion with optional tools. - - HuggingFace InferenceClient natively supports tools parameter! - This is just an async wrapper around the synchronous API. - - Args: - messages: List of message dicts with 'role' and 'content' keys. - Format: [{"role": "user", "content": "Hello"}] - tools: Optional list of tool definitions in OpenAI format. - Format: [{"type": "function", "function": {...}}] - tool_choice: Tool selection strategy. - Options: "auto", "none", or {"type": "function", "function": {"name": "tool_name"}} - temperature: Sampling temperature (0.0 to 2.0). Defaults to 1.0. - max_tokens: Maximum tokens in response. Defaults to 100. - - Returns: - ChatCompletionOutput compatible with agent-framework. - Has .choices attribute with message and tool_calls. - - Raises: - ConfigurationError: If chat completion fails - """ - try: - loop = asyncio.get_running_loop() - response = await loop.run_in_executor( - None, - lambda: self.client.chat_completion( - messages=messages, - tools=tools, # type: ignore[arg-type] # ✅ Native support! - tool_choice=tool_choice, # type: ignore[arg-type] # ✅ Native support! - temperature=temperature, - max_tokens=max_tokens, - ), - ) - - logger.debug( - "Chat completion successful", - model=self.model_name, - has_tools=bool(tools), - has_tool_calls=bool( - response.choices[0].message.tool_calls - if response.choices and response.choices[0].message.tool_calls - else None - ), - ) - - return response - - except Exception as e: - logger.error( - "Chat completion failed", - model=self.model_name, - error=str(e), - error_type=type(e).__name__, - ) - raise ConfigurationError(f"HuggingFace chat completion failed: {e}") from e diff --git a/src/utils/llm_factory.py b/src/utils/llm_factory.py deleted file mode 100644 index 2009568298056aca40e6fdddb93f7e7b48fc9f54..0000000000000000000000000000000000000000 --- a/src/utils/llm_factory.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Centralized LLM client factory. - -This module provides factory functions for creating LLM clients, -ensuring consistent configuration and clear error messages. - -Agent-Framework Chat Clients: -- HuggingFace InferenceClient: Native function calling support via 'tools' parameter -- OpenAI ChatClient: Native function calling support (original implementation) -- Both can be used with agent-framework's ChatAgent - -Pydantic AI Models: -- Default provider is HuggingFace (free tier, no API key required for public models) -- OpenAI and Anthropic are available as fallback options -- All providers use Pydantic AI's unified interface -""" - -from typing import TYPE_CHECKING, Any - -import structlog - -from src.utils.config import settings -from src.utils.exceptions import ConfigurationError - -logger = structlog.get_logger() - -if TYPE_CHECKING: - from agent_framework.openai import OpenAIChatClient - - from src.utils.huggingface_chat_client import HuggingFaceChatClient - - -def get_magentic_client() -> "OpenAIChatClient": - """ - Get the OpenAI client for Magentic agents (legacy function). - - Note: This function is kept for backward compatibility. - For new code, use get_chat_client_for_agent() which supports - both OpenAI and HuggingFace. - - Raises: - ConfigurationError: If OPENAI_API_KEY is not set - - Returns: - Configured OpenAIChatClient for Magentic agents - """ - # Import here to avoid requiring agent-framework for simple mode - from agent_framework.openai import OpenAIChatClient - - api_key = settings.get_openai_api_key() - - return OpenAIChatClient( - model_id=settings.openai_model, - api_key=api_key, - ) - - -def get_huggingface_chat_client(oauth_token: str | None = None) -> "HuggingFaceChatClient": - """ - Get HuggingFace chat client for agent-framework. - - HuggingFace InferenceClient natively supports function calling, - making it compatible with agent-framework's ChatAgent. - - Args: - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured HuggingFaceChatClient - - Raises: - ConfigurationError: If initialization fails - """ - from src.utils.huggingface_chat_client import HuggingFaceChatClient - - model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct" - # Priority: oauth_token > env vars - api_key = oauth_token or settings.hf_token or settings.huggingface_api_key - - return HuggingFaceChatClient( - model_name=model_name, - api_key=api_key, - provider="auto", # Auto-select best provider - ) - - -def get_chat_client_for_agent(oauth_token: str | None = None) -> Any: - """ - Get appropriate chat client for agent-framework based on configuration. - - Supports: - - HuggingFace InferenceClient (if HF_TOKEN available, preferred for free tier) - - OpenAI ChatClient (if OPENAI_API_KEY available, fallback) - - Args: - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - ChatClient compatible with agent-framework (HuggingFaceChatClient or OpenAIChatClient) - - Raises: - ConfigurationError: If no suitable client can be created - """ - # Check if we have OAuth token or env vars - has_hf_key = bool(oauth_token or settings.has_huggingface_key) - - # Prefer HuggingFace if available (free tier) - if has_hf_key: - return get_huggingface_chat_client(oauth_token=oauth_token) - - # Fallback to OpenAI if available - if settings.has_openai_key: - return get_magentic_client() - - # If neither available, try HuggingFace without key (public models) - try: - return get_huggingface_chat_client(oauth_token=oauth_token) - except Exception: - pass - - raise ConfigurationError( - "No chat client available. Set HF_TOKEN or OPENAI_API_KEY for agent-framework mode." - ) - - -def get_pydantic_ai_model(oauth_token: str | None = None) -> Any: - """ - Get the appropriate model for pydantic-ai based on configuration. - - Uses the configured LLM_PROVIDER to select between HuggingFace, OpenAI, and Anthropic. - Defaults to HuggingFace if provider is not specified or unknown. - This is used by simple mode components (JudgeHandler, etc.) - - Args: - oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) - - Returns: - Configured pydantic-ai model - """ - from pydantic_ai.models.huggingface import HuggingFaceModel - from pydantic_ai.providers.huggingface import HuggingFaceProvider - - # Priority: oauth_token > settings.hf_token > settings.huggingface_api_key - effective_hf_token = oauth_token or settings.hf_token or settings.huggingface_api_key - - # HuggingFaceProvider requires a token - cannot use None - if not effective_hf_token: - raise ConfigurationError( - "HuggingFace token required. Please either:\n" - "1. Log in via HuggingFace OAuth (recommended for Spaces)\n" - "2. Set HF_TOKEN environment variable\n" - "3. Set huggingface_api_key in settings" - ) - - # Validate and log token information - from src.utils.hf_error_handler import log_token_info, validate_hf_token - - log_token_info(effective_hf_token, context="get_pydantic_ai_model") - is_valid, error_msg = validate_hf_token(effective_hf_token) - if not is_valid: - logger.warning( - "Token validation failed in get_pydantic_ai_model", - error=error_msg, - has_oauth=bool(oauth_token), - ) - # Continue anyway - let the API call fail with a clear error - - # Always use HuggingFace with available token - model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct" - hf_provider = HuggingFaceProvider(api_key=effective_hf_token) - return HuggingFaceModel(model_name, provider=hf_provider) - - -def check_magentic_requirements() -> None: - """ - Check if Magentic/agent-framework mode requirements are met. - - Note: HuggingFace InferenceClient now supports function calling natively, - so this check is relaxed. We prefer HuggingFace if available, fallback to OpenAI. - - Raises: - ConfigurationError: If no suitable client can be created - """ - # Try to get a chat client - will raise if none available - try: - get_chat_client_for_agent() - except ConfigurationError as e: - raise ConfigurationError( - "Agent-framework mode requires HF_TOKEN or OPENAI_API_KEY. " - "HuggingFace is preferred (free tier with function calling support). " - "Use mode='simple' for other LLM providers." - ) from e - - -def check_simple_mode_requirements() -> None: - """ - Check if simple mode requirements are met. - - Simple mode supports HuggingFace (default), OpenAI, and Anthropic. - HuggingFace can work without an API key for public models. - - Raises: - ConfigurationError: If no LLM is available (only if explicitly required) - """ - # HuggingFace can work without API key for public models, so we don't require it - # This allows simple mode to work out of the box - pass diff --git a/src/utils/markdown.css b/src/utils/markdown.css deleted file mode 100644 index 1854a4bb7d4cb4329e7c3b720f62d01a1b521473..0000000000000000000000000000000000000000 --- a/src/utils/markdown.css +++ /dev/null @@ -1,7 +0,0 @@ -body { - font-family: Arial, sans-serif; - font-size: 14px; - line-height: 1.8; - color: #000; -} - diff --git a/src/utils/md_to_pdf.py b/src/utils/md_to_pdf.py deleted file mode 100644 index 08e7f71d003032dc202f897563f99cf1b600f07c..0000000000000000000000000000000000000000 --- a/src/utils/md_to_pdf.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Utility for converting markdown to PDF.""" - -from pathlib import Path -from typing import TYPE_CHECKING - -import structlog - -if TYPE_CHECKING: - pass - -logger = structlog.get_logger() - -# Try to import md2pdf -try: - from md2pdf import md2pdf - - _MD2PDF_AVAILABLE = True -except ImportError: - md2pdf = None # type: ignore[assignment, misc] - _MD2PDF_AVAILABLE = False - logger.warning("md2pdf not available - PDF generation will be disabled") - - -def get_css_path() -> Path: - """Get the path to the markdown.css file.""" - curdir = Path(__file__).parent - css_path = curdir / "markdown.css" - return css_path - - -def md_to_pdf(md_text: str, pdf_file_path: str) -> None: - """ - Convert markdown text to PDF. - - Args: - md_text: Markdown text content - pdf_file_path: Path where PDF should be saved - - Raises: - ImportError: If md2pdf is not installed - ValueError: If markdown text is empty - OSError: If PDF file cannot be written - """ - if not _MD2PDF_AVAILABLE: - raise ImportError("md2pdf is not installed. Install it with: pip install md2pdf") - - if not md_text or not md_text.strip(): - raise ValueError("Markdown text cannot be empty") - - css_path = get_css_path() - - if not css_path.exists(): - logger.warning( - "CSS file not found, PDF will be generated without custom styling", - css_path=str(css_path), - ) - # Generate PDF without CSS - md2pdf(pdf_file_path, md_text) - else: - # Generate PDF with CSS - md2pdf(pdf_file_path, md_text, css_file_path=str(css_path)) - - logger.debug("PDF generated successfully", pdf_path=pdf_file_path) diff --git a/src/utils/message_history.py b/src/utils/message_history.py deleted file mode 100644 index 5ddbbe30af20de6ab9163494267a0b5626a480be..0000000000000000000000000000000000000000 --- a/src/utils/message_history.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Message history utilities for Pydantic AI integration.""" - -from typing import Any - -import structlog - -try: - from pydantic_ai import ModelMessage, ModelRequest, ModelResponse - from pydantic_ai.messages import TextPart, UserPromptPart - - _PYDANTIC_AI_AVAILABLE = True -except ImportError: - # Fallback for older pydantic-ai versions - ModelMessage = Any # type: ignore[assignment, misc] - ModelRequest = Any # type: ignore[assignment, misc] - ModelResponse = Any # type: ignore[assignment, misc] - TextPart = Any # type: ignore[assignment, misc] - UserPromptPart = Any # type: ignore[assignment, misc] - _PYDANTIC_AI_AVAILABLE = False - -logger = structlog.get_logger() - - -def convert_gradio_to_message_history( - history: list[dict[str, Any]], - max_messages: int = 20, -) -> list[ModelMessage]: - """ - Convert Gradio chat history to Pydantic AI message history. - - Args: - history: Gradio chat history format [{"role": "user", "content": "..."}, ...] - max_messages: Maximum messages to include (most recent) - - Returns: - List of ModelMessage objects for Pydantic AI - """ - if not history: - return [] - - if not _PYDANTIC_AI_AVAILABLE: - logger.warning( - "Pydantic AI message history not available, returning empty list", - ) - return [] - - messages: list[ModelMessage] = [] - - # Take most recent messages - recent = history[-max_messages:] if len(history) > max_messages else history - - for msg in recent: - role = msg.get("role", "") - content = msg.get("content", "") - - if not content or role not in ("user", "assistant"): - continue - - # Convert content to string if needed - content_str = str(content) - - if role == "user": - messages.append( - ModelRequest(parts=[UserPromptPart(content=content_str)]), - ) - elif role == "assistant": - messages.append( - ModelResponse(parts=[TextPart(content=content_str)]), - ) - - logger.debug( - "Converted Gradio history to message history", - input_turns=len(history), - output_messages=len(messages), - ) - - return messages - - -def message_history_to_string( - messages: list[ModelMessage], - max_messages: int = 5, - include_metadata: bool = False, -) -> str: - """ - Convert message history to string format for backward compatibility. - - Used during transition period when some agents still expect strings. - - Args: - messages: List of ModelMessage objects - max_messages: Maximum messages to include - include_metadata: Whether to include metadata - - Returns: - Formatted string representation - """ - if not messages: - return "" - - recent = messages[-max_messages:] if len(messages) > max_messages else messages - - parts = ["PREVIOUS CONVERSATION:", "---"] - turn_num = 1 - - for msg in recent: - # Extract text content - text = "" - if isinstance(msg, ModelRequest): - for part in msg.parts: - if hasattr(part, "content"): - text += str(part.content) - parts.append(f"[Turn {turn_num}]") - parts.append(f"User: {text}") - turn_num += 1 - elif isinstance(msg, ModelResponse): - for part in msg.parts: # type: ignore[assignment] - if hasattr(part, "content"): - text += str(part.content) - parts.append(f"Assistant: {text}") - - parts.append("---") - return "\n".join(parts) - - -def create_truncation_processor(max_messages: int = 10) -> Any: - """Create a history processor that keeps only the most recent N messages. - - Args: - max_messages: Maximum number of messages to keep - - Returns: - Processor function that takes a list of messages and returns truncated list - """ - - def processor(messages: list[ModelMessage]) -> list[ModelMessage]: - return messages[-max_messages:] if len(messages) > max_messages else messages - - return processor - - -def create_relevance_processor(min_length: int = 10) -> Any: - """Create a history processor that filters out very short messages. - - Args: - min_length: Minimum message length to keep - - Returns: - Processor function that filters messages by length - """ - - def processor(messages: list[ModelMessage]) -> list[ModelMessage]: - filtered = [] - for msg in messages: - text = "" - if isinstance(msg, ModelRequest): - for part in msg.parts: - if hasattr(part, "content"): - text += str(part.content) - elif isinstance(msg, ModelResponse): - for part in msg.parts: # type: ignore[assignment] - if hasattr(part, "content"): - text += str(part.content) - - if len(text.strip()) >= min_length: - filtered.append(msg) - return filtered - - return processor diff --git a/src/utils/models.py b/src/utils/models.py deleted file mode 100644 index e43efac0e0477910c98bd50cee94f75bf321a39d..0000000000000000000000000000000000000000 --- a/src/utils/models.py +++ /dev/null @@ -1,574 +0,0 @@ -"""Data models for the Search feature.""" - -from datetime import UTC, datetime -from typing import Any, ClassVar, Literal - -from pydantic import BaseModel, Field - -# Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11) -SourceName = Literal[ - "pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web", "neo4j" -] - - -class Citation(BaseModel): - """A citation to a source document.""" - - source: SourceName = Field(description="Where this came from") - - title: str = Field(min_length=1, max_length=500) - url: str = Field(description="URL to the source") - date: str = Field(description="Publication date (YYYY-MM-DD or 'Unknown')") - authors: list[str] = Field(default_factory=list) - - MAX_AUTHORS_IN_CITATION: ClassVar[int] = 3 - - @property - def formatted(self) -> str: - """Format as a citation string.""" - author_str = ", ".join(self.authors[: self.MAX_AUTHORS_IN_CITATION]) - if len(self.authors) > self.MAX_AUTHORS_IN_CITATION: - author_str += " et al." - return f"{author_str} ({self.date}). {self.title}. {self.source.upper()}" - - -class Evidence(BaseModel): - """A piece of evidence retrieved from search.""" - - content: str = Field(min_length=1, description="The actual text content") - citation: Citation - relevance: float = Field(default=0.0, ge=0.0, le=1.0, description="Relevance score 0-1") - metadata: dict[str, Any] = Field( - default_factory=dict, - description="Additional metadata (e.g., cited_by_count, concepts, is_open_access)", - ) - - model_config = {"frozen": True} - - -class SearchResult(BaseModel): - """Result of a search operation.""" - - query: str - evidence: list[Evidence] - sources_searched: list[SourceName] - total_found: int - errors: list[str] = Field(default_factory=list) - - -class AssessmentDetails(BaseModel): - """Detailed assessment of evidence quality.""" - - mechanism_score: int = Field( - ..., - ge=0, - le=10, - description="How well does the evidence explain the mechanism? 0-10", - ) - mechanism_reasoning: str = Field( - ..., min_length=10, description="Explanation of mechanism score" - ) - clinical_evidence_score: int = Field( - ..., - ge=0, - le=10, - description="Strength of clinical/preclinical evidence. 0-10", - ) - clinical_reasoning: str = Field( - ..., min_length=10, description="Explanation of clinical evidence score" - ) - drug_candidates: list[str] = Field( - default_factory=list, description="List of specific drug candidates mentioned" - ) - key_findings: list[str] = Field( - default_factory=list, description="Key findings from the evidence" - ) - - -class JudgeAssessment(BaseModel): - """Complete assessment from the Judge.""" - - details: AssessmentDetails - sufficient: bool = Field(..., description="Is evidence sufficient to provide a recommendation?") - confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence in the assessment (0-1)") - recommendation: Literal["continue", "synthesize"] = Field( - ..., - description="continue = need more evidence, synthesize = ready to answer", - ) - next_search_queries: list[str] = Field( - default_factory=list, description="If continue, what queries to search next" - ) - reasoning: str = Field( - ..., min_length=20, description="Overall reasoning for the recommendation" - ) - - -class AgentEvent(BaseModel): - """Event emitted by the orchestrator for UI streaming.""" - - type: Literal[ - "started", - "searching", - "search_complete", - "judging", - "judge_complete", - "looping", - "synthesizing", - "complete", - "error", - "streaming", - "hypothesizing", - "analyzing", # NEW for Phase 13 - "analysis_complete", # NEW for Phase 13 - ] - message: str - data: Any = None - timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC)) - iteration: int = 0 - - def to_markdown(self) -> str: - """Format event as markdown for chat display.""" - icons = { - "started": "🚀", - "searching": "🔍", - "search_complete": "📚", - "judging": "🧠", - "judge_complete": "✅", - "looping": "🔄", - "synthesizing": "📝", - "complete": "🎉", - "error": "❌", - "streaming": "📡", - "hypothesizing": "🔬", # NEW - "analyzing": "📊", # NEW - "analysis_complete": "📈", # NEW - } - icon = icons.get(self.type, "•") - return f"{icon} **{self.type.upper()}**: {self.message}" - - -class MechanismHypothesis(BaseModel): - """A scientific hypothesis about drug mechanism.""" - - drug: str = Field(description="The drug being studied") - target: str = Field(description="Molecular target (e.g., AMPK, mTOR)") - pathway: str = Field(description="Biological pathway affected") - effect: str = Field(description="Downstream effect on disease") - confidence: float = Field(ge=0, le=1, description="Confidence in hypothesis") - supporting_evidence: list[str] = Field( - default_factory=list, description="PMIDs or URLs supporting this hypothesis" - ) - contradicting_evidence: list[str] = Field( - default_factory=list, description="PMIDs or URLs contradicting this hypothesis" - ) - search_suggestions: list[str] = Field( - default_factory=list, description="Suggested searches to test this hypothesis" - ) - - def to_search_queries(self) -> list[str]: - """Generate search queries to test this hypothesis.""" - return [ - f"{self.drug} {self.target}", - f"{self.target} {self.pathway}", - f"{self.pathway} {self.effect}", - *self.search_suggestions, - ] - - -class HypothesisAssessment(BaseModel): - """Assessment of evidence against hypotheses.""" - - hypotheses: list[MechanismHypothesis] - primary_hypothesis: MechanismHypothesis | None = Field( - default=None, description="Most promising hypothesis based on current evidence" - ) - knowledge_gaps: list[str] = Field(description="What we don't know yet") - recommended_searches: list[str] = Field(description="Searches to fill knowledge gaps") - - -class ReportSection(BaseModel): - """A section of the research report.""" - - title: str - content: str - # Reserved for future inline citation tracking within sections - citations: list[str] = Field(default_factory=list) - - -class ResearchReport(BaseModel): - """Structured scientific report.""" - - title: str = Field(description="Report title") - executive_summary: str = Field( - description="One-paragraph summary for quick reading", min_length=100, max_length=1000 - ) - research_question: str = Field(description="Clear statement of what was investigated") - - methodology: ReportSection = Field(description="How the research was conducted") - hypotheses_tested: list[dict[str, Any]] = Field( - description="Hypotheses with supporting/contradicting evidence counts" - ) - - mechanistic_findings: ReportSection = Field(description="Findings about drug mechanisms") - clinical_findings: ReportSection = Field( - description="Findings from clinical/preclinical studies" - ) - - drug_candidates: list[str] = Field(description="Identified drug candidates") - limitations: list[str] = Field(description="Study limitations") - conclusion: str = Field(description="Overall conclusion") - - references: list[dict[str, str]] = Field( - default_factory=list, - description="Formatted references with title, authors, source, URL", - ) - - # Metadata - sources_searched: list[str] = Field(default_factory=list) - total_papers_reviewed: int = 0 - search_iterations: int = 0 - confidence_score: float = Field(ge=0, le=1) - - def to_markdown(self) -> str: - """Render report as markdown.""" - sections = [ - f"# {self.title}\n", - f"## Executive Summary\n{self.executive_summary}\n", - f"## Research Question\n{self.research_question}\n", - f"## Methodology\n{self.methodology.content}\n", - ] - - # Hypotheses - sections.append("## Hypotheses Tested\n") - if not self.hypotheses_tested: - sections.append("*No hypotheses tested yet.*\n") - for h in self.hypotheses_tested: - supported = h.get("supported", 0) - contradicted = h.get("contradicted", 0) - if supported == 0 and contradicted == 0: - status = "❓ Untested" - elif supported > contradicted: - status = "✅ Supported" - else: - status = "⚠️ Mixed" - sections.append( - f"- **{h.get('mechanism', 'Unknown')}** ({status}): " - f"{supported} supporting, {contradicted} contradicting\n" - ) - - # Findings - sections.append(f"## Mechanistic Findings\n{self.mechanistic_findings.content}\n") - sections.append(f"## Clinical Findings\n{self.clinical_findings.content}\n") - - # Drug candidates - sections.append("## Drug Candidates\n") - if self.drug_candidates: - for drug in self.drug_candidates: - sections.append(f"- **{drug}**\n") - else: - sections.append("*No drug candidates identified.*\n") - - # Limitations - sections.append("## Limitations\n") - if self.limitations: - for lim in self.limitations: - sections.append(f"- {lim}\n") - else: - sections.append("*No limitations documented.*\n") - - # Conclusion - sections.append(f"## Conclusion\n{self.conclusion}\n") - - # References - sections.append("## References\n") - if self.references: - for i, ref in enumerate(self.references, 1): - sections.append( - f"{i}. {ref.get('authors', 'Unknown')}. " - f"*{ref.get('title', 'Untitled')}*. " - f"{ref.get('source', '')} ({ref.get('date', '')}). " - f"[Link]({ref.get('url', '#')})\n" - ) - else: - sections.append("*No references available.*\n") - - # Metadata footer - sections.append("\n---\n") - sections.append( - f"*Report generated from {self.total_papers_reviewed} papers " - f"across {self.search_iterations} search iterations. " - f"Confidence: {self.confidence_score:.0%}*" - ) - - return "\n".join(sections) - - -class OrchestratorConfig(BaseModel): - """Configuration for the orchestrator.""" - - max_iterations: int = Field(default=10, ge=1, le=20) - max_results_per_tool: int = Field(default=10, ge=1, le=50) - search_timeout: float = Field(default=30.0, ge=5.0, le=120.0) - - -# Models for iterative/deep research patterns - - -class IterationData(BaseModel): - """Data for a single iteration of the research loop.""" - - gap: str = Field(description="The gap addressed in the iteration", default="") - tool_calls: list[str] = Field(description="The tool calls made", default_factory=list) - findings: list[str] = Field( - description="The findings collected from tool calls", default_factory=list - ) - thought: str = Field( - description="The thinking done to reflect on the success of the iteration and next steps", - default="", - ) - - model_config = {"frozen": True} - - -class Conversation(BaseModel): - """A conversation between the user and the iterative researcher.""" - - history: list[IterationData] = Field( - description="The data for each iteration of the research loop", - default_factory=list, - ) - - def add_iteration(self, iteration_data: IterationData | None = None) -> None: - """Add a new iteration to the conversation history.""" - if iteration_data is None: - iteration_data = IterationData() - self.history.append(iteration_data) - - def set_latest_gap(self, gap: str) -> None: - """Set the gap for the latest iteration.""" - if not self.history: - self.add_iteration() - # Use model_copy() since IterationData is frozen - self.history[-1] = self.history[-1].model_copy(update={"gap": gap}) - - def set_latest_tool_calls(self, tool_calls: list[str]) -> None: - """Set the tool calls for the latest iteration.""" - if not self.history: - self.add_iteration() - # Use model_copy() since IterationData is frozen - self.history[-1] = self.history[-1].model_copy(update={"tool_calls": tool_calls}) - - def set_latest_findings(self, findings: list[str]) -> None: - """Set the findings for the latest iteration.""" - if not self.history: - self.add_iteration() - # Use model_copy() since IterationData is frozen - self.history[-1] = self.history[-1].model_copy(update={"findings": findings}) - - def set_latest_thought(self, thought: str) -> None: - """Set the thought for the latest iteration.""" - if not self.history: - self.add_iteration() - # Use model_copy() since IterationData is frozen - self.history[-1] = self.history[-1].model_copy(update={"thought": thought}) - - def get_latest_gap(self) -> str: - """Get the gap from the latest iteration.""" - if not self.history: - return "" - return self.history[-1].gap - - def get_latest_tool_calls(self) -> list[str]: - """Get the tool calls from the latest iteration.""" - if not self.history: - return [] - return self.history[-1].tool_calls - - def get_latest_findings(self) -> list[str]: - """Get the findings from the latest iteration.""" - if not self.history: - return [] - return self.history[-1].findings - - def get_latest_thought(self) -> str: - """Get the thought from the latest iteration.""" - if not self.history: - return "" - return self.history[-1].thought - - def get_all_findings(self) -> list[str]: - """Get all findings from all iterations.""" - return [finding for iteration_data in self.history for finding in iteration_data.findings] - - def compile_conversation_history(self) -> str: - """Compile the conversation history into a string.""" - conversation = "" - for iteration_num, iteration_data in enumerate(self.history): - conversation += f"[ITERATION {iteration_num + 1}]\n\n" - if iteration_data.thought: - conversation += f"{self.get_thought_string(iteration_num)}\n\n" - if iteration_data.gap: - conversation += f"{self.get_task_string(iteration_num)}\n\n" - if iteration_data.tool_calls: - conversation += f"{self.get_action_string(iteration_num)}\n\n" - if iteration_data.findings: - conversation += f"{self.get_findings_string(iteration_num)}\n\n" - - return conversation - - def get_task_string(self, iteration_num: int) -> str: - """Get the task for the specified iteration.""" - if iteration_num < len(self.history) and self.history[iteration_num].gap: - return f"\nAddress this knowledge gap: {self.history[iteration_num].gap}\n" - return "" - - def get_action_string(self, iteration_num: int) -> str: - """Get the action for the specified iteration.""" - if iteration_num < len(self.history) and self.history[iteration_num].tool_calls: - joined_calls = "\n".join(self.history[iteration_num].tool_calls) - return ( - "\nCalling the following tools to address the knowledge gap:\n" - f"{joined_calls}\n" - ) - return "" - - def get_findings_string(self, iteration_num: int) -> str: - """Get the findings for the specified iteration.""" - if iteration_num < len(self.history) and self.history[iteration_num].findings: - joined_findings = "\n\n".join(self.history[iteration_num].findings) - return f"\n{joined_findings}\n" - return "" - - def get_thought_string(self, iteration_num: int) -> str: - """Get the thought for the specified iteration.""" - if iteration_num < len(self.history) and self.history[iteration_num].thought: - return f"\n{self.history[iteration_num].thought}\n" - return "" - - def latest_task_string(self) -> str: - """Get the latest task.""" - if not self.history: - return "" - return self.get_task_string(len(self.history) - 1) - - def latest_action_string(self) -> str: - """Get the latest action.""" - if not self.history: - return "" - return self.get_action_string(len(self.history) - 1) - - def latest_findings_string(self) -> str: - """Get the latest findings.""" - if not self.history: - return "" - return self.get_findings_string(len(self.history) - 1) - - def latest_thought_string(self) -> str: - """Get the latest thought.""" - if not self.history: - return "" - return self.get_thought_string(len(self.history) - 1) - - -class ReportPlanSection(BaseModel): - """A section of the report that needs to be written.""" - - title: str = Field(description="The title of the section") - key_question: str = Field(description="The key question to be addressed in the section") - - model_config = {"frozen": True} - - -class ReportPlan(BaseModel): - """Output from the Report Planner Agent.""" - - background_context: str = Field( - description="A summary of supporting context that can be passed onto the research agents" - ) - report_outline: list[ReportPlanSection] = Field( - description="List of sections that need to be written in the report" - ) - report_title: str = Field(description="The title of the report") - - model_config = {"frozen": True} - - -class KnowledgeGapOutput(BaseModel): - """Output from the Knowledge Gap Agent.""" - - research_complete: bool = Field( - description="Whether the research and findings are complete enough to end the research loop" - ) - outstanding_gaps: list[str] = Field( - description="List of knowledge gaps that still need to be addressed" - ) - - model_config = {"frozen": True} - - -class AgentTask(BaseModel): - """A task for a specific agent to address knowledge gaps.""" - - gap: str | None = Field(description="The knowledge gap being addressed", default=None) - agent: str = Field(description="The name of the agent to use") - query: str = Field(description="The specific query for the agent") - entity_website: str | None = Field( - description="The website of the entity being researched, if known", - default=None, - ) - - model_config = {"frozen": True} - - -class AgentSelectionPlan(BaseModel): - """Plan for which agents to use for knowledge gaps.""" - - tasks: list[AgentTask] = Field(description="List of agent tasks to address knowledge gaps") - - model_config = {"frozen": True} - - -class ReportDraftSection(BaseModel): - """A section of the report that needs to be written.""" - - section_title: str = Field(description="The title of the section") - section_content: str = Field(description="The content of the section") - - model_config = {"frozen": True} - - -class ReportDraft(BaseModel): - """Output from the Report Planner Agent.""" - - sections: list[ReportDraftSection] = Field( - description="List of sections that are in the report" - ) - - model_config = {"frozen": True} - - -class ToolAgentOutput(BaseModel): - """Standard output for all tool agents.""" - - output: str = Field(description="The output from the tool agent") - sources: list[str] = Field(description="List of source URLs", default_factory=list) - - model_config = {"frozen": True} - - -class ParsedQuery(BaseModel): - """Parsed and improved user query with research mode detection.""" - - original_query: str = Field(description="The original user query") - improved_query: str = Field(description="Improved/refined query") - research_mode: Literal["iterative", "deep"] = Field(description="Detected research mode") - key_entities: list[str] = Field( - default_factory=list, - description="Key entities extracted from query", - ) - research_questions: list[str] = Field( - default_factory=list, - description="Specific research questions extracted", - ) - - model_config = {"frozen": True} diff --git a/src/utils/report_generator.py b/src/utils/report_generator.py deleted file mode 100644 index 330f2ad281867acc08f21b9a04fc1617a0ad7029..0000000000000000000000000000000000000000 --- a/src/utils/report_generator.py +++ /dev/null @@ -1,179 +0,0 @@ -"""Utility functions for generating reports from evidence when LLM fails.""" - -from typing import TYPE_CHECKING - -import structlog - -if TYPE_CHECKING: - from src.utils.models import Citation, Evidence - -logger = structlog.get_logger() - - -def _format_authors(citation: "Citation") -> str: - """Format authors string from citation.""" - authors = ", ".join(citation.authors[:3]) - if len(citation.authors) > 3: - authors += " et al." - elif not authors: - authors = "Unknown" - return authors - - -def _add_evidence_section(report_parts: list[str], evidence: list["Evidence"]) -> None: - """Add evidence summary section to report.""" - from src.utils.models import SourceName - - report_parts.append("## Evidence Summary\n") - report_parts.append(f"**Total Sources Found:** {len(evidence)}\n\n") - - # Group evidence by source - by_source: dict[SourceName, list[Evidence]] = {} - for ev in evidence: - source = ev.citation.source - if source not in by_source: - by_source[source] = [] - by_source[source].append(ev) - - # Organize by source - for source in sorted(by_source.keys()): # type: ignore[assignment] - source_evidence = by_source[source] - report_parts.append(f"### {source.upper()} Sources ({len(source_evidence)})\n\n") - - for i, ev in enumerate(source_evidence, 1): - authors = _format_authors(ev.citation) - report_parts.append(f"#### {i}. {ev.citation.title}\n") - if authors and authors != "Unknown": - report_parts.append(f"**Authors:** {authors} \n") - report_parts.append(f"**Date:** {ev.citation.date} \n") - report_parts.append(f"**Source:** {ev.citation.source.upper()} \n") - report_parts.append(f"**URL:** {ev.citation.url} \n\n") - - # Content (truncated if too long) - content = ev.content - if len(content) > 500: - content = content[:500] + "... [truncated]" - report_parts.append(f"{content}\n\n") - - -def _add_key_findings(report_parts: list[str], evidence: list["Evidence"]) -> None: - """Add key findings section to report.""" - report_parts.append("## Key Findings\n\n") - report_parts.append( - "Based on the evidence collected, the following key points were identified:\n\n" - ) - - # Extract key points from evidence (first sentence or summary) - key_points: list[str] = [] - for ev in evidence[:10]: # Limit to top 10 - # Try to extract first meaningful sentence - content = ev.content.strip() - if content: - # Find first sentence - first_period = content.find(".") - if first_period > 0 and first_period < 200: - key_point = content[: first_period + 1].strip() - else: - # Fallback: first 150 chars - key_point = content[:150].strip() - if len(content) > 150: - key_point += "..." - key_points.append(f"- {key_point} [[{len(key_points) + 1}]](#references)") - - if key_points: - report_parts.append("\n".join(key_points)) - report_parts.append("\n\n") - else: - report_parts.append("*No specific key findings could be extracted from the evidence.*\n\n") - - -def _add_references(report_parts: list[str], evidence: list["Evidence"]) -> None: - """Add references section to report.""" - report_parts.append("## References\n\n") - for i, ev in enumerate(evidence, 1): - authors = _format_authors(ev.citation) - report_parts.append( - f"[{i}] {authors} ({ev.citation.date}). " - f"*{ev.citation.title}*. " - f"{ev.citation.source.upper()}. " - f"Available at: {ev.citation.url}\n\n" - ) - - -def generate_report_from_evidence( - query: str, - evidence: list["Evidence"] | None = None, - findings: str | None = None, -) -> str: - """ - Generate a structured markdown report from evidence or findings when LLM fails. - - This function creates a proper report structure even without LLM assistance, - formatting the collected evidence into a readable, well-structured document. - - Args: - query: The original research query - evidence: List of Evidence objects (preferred if available) - findings: Pre-formatted findings string (fallback if evidence not available) - - Returns: - Markdown formatted report string - """ - report_parts: list[str] = [] - - # Title - report_parts.append(f"# Research Report: {query}\n") - - # Introduction - report_parts.append("## Introduction\n") - report_parts.append(f"This report addresses the following research query: **{query}**\n") - report_parts.append( - "*Note: This report was generated from collected evidence. " - "LLM-based synthesis was unavailable due to API limitations.*\n\n" - ) - - # Evidence Summary - if evidence and len(evidence) > 0: - _add_evidence_section(report_parts, evidence) - _add_key_findings(report_parts, evidence) - - elif findings: - # Fallback: use findings string if evidence not available - report_parts.append("## Research Findings\n\n") - # Truncate if too long - if len(findings) > 10000: - findings = findings[:10000] + "\n\n[Content truncated due to length]" - report_parts.append(f"{findings}\n\n") - else: - report_parts.append("## Research Findings\n\n") - report_parts.append( - "*No evidence or findings were collected during the research process.*\n\n" - ) - - # References Section - if evidence and len(evidence) > 0: - _add_references(report_parts, evidence) - - # Conclusion - report_parts.append("## Conclusion\n\n") - if evidence and len(evidence) > 0: - report_parts.append( - f"This report synthesized information from {len(evidence)} sources " - f"to address the research query: **{query}**\n\n" - ) - report_parts.append( - "*Note: Due to API limitations, this report was generated directly from " - "collected evidence without LLM-based synthesis. For a more comprehensive " - "analysis, please retry when API access is available.*\n" - ) - else: - report_parts.append( - "This report could not be fully generated due to limited evidence collection " - "and API access issues.\n\n" - ) - report_parts.append( - "*Please retry your query when API access is available for a more " - "comprehensive research report.*\n" - ) - - return "".join(report_parts) diff --git a/src/utils/text_utils.py b/src/utils/text_utils.py deleted file mode 100644 index da37ce8eeece0575b2009e483f75a3cf4c2ab9fc..0000000000000000000000000000000000000000 --- a/src/utils/text_utils.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Text processing utilities for evidence handling.""" - -from typing import TYPE_CHECKING - -import numpy as np - -if TYPE_CHECKING: - from src.services.embeddings import EmbeddingService - from src.utils.models import Evidence - - -def truncate_at_sentence(text: str, max_chars: int = 300) -> str: - """Truncate text at sentence boundary, preserving meaning. - - Args: - text: The text to truncate - max_chars: Maximum characters (default 300) - - Returns: - Text truncated at last complete sentence within limit - """ - if len(text) <= max_chars: - return text - - # Find truncation point - truncated = text[:max_chars] - - # Look for sentence endings: . ! ? followed by space or end - # We check for sep at the END of the truncated string - for sep in [". ", "! ", "? ", ".\n", "!\n", "?\n"]: - last_sep = truncated.rfind(sep) - if last_sep > max_chars // 2: # Don't truncate too aggressively (less than half) - return text[: last_sep + 1].strip() - - # Fallback: find last period (even if not followed by space, e.g. end of string) - last_period = truncated.rfind(".") - if last_period > max_chars // 2: - return text[: last_period + 1].strip() - - # Last resort: truncate at word boundary - last_space = truncated.rfind(" ") - if last_space > 0: - return text[:last_space].strip() + "..." - - return truncated + "..." - - -async def select_diverse_evidence( - evidence: list["Evidence"], n: int, query: str, embeddings: "EmbeddingService | None" = None -) -> list["Evidence"]: - """Select n most diverse and relevant evidence items. - - Uses Maximal Marginal Relevance (MMR) when embeddings available, - falls back to relevance_score sorting otherwise. - - Args: - evidence: All available evidence - n: Number of items to select - query: Original query for relevance scoring - embeddings: Optional EmbeddingService for semantic diversity - - Returns: - Selected evidence items, diverse and relevant - """ - if not evidence: - return [] - - if n >= len(evidence): - return evidence - - # Fallback: sort by relevance score if no embeddings - if embeddings is None: - return sorted( - evidence, - key=lambda e: e.relevance, # Use .relevance (from Pydantic model) - reverse=True, - )[:n] - - # MMR: Maximal Marginal Relevance for diverse selection - # Score = λ * relevance - (1-λ) * max_similarity_to_selected - lambda_param = 0.7 # Balance relevance vs diversity - - # Get query embedding - query_emb = await embeddings.embed(query) - - # Get all evidence embeddings - evidence_embs = await embeddings.embed_batch([e.content for e in evidence]) - - # Cosine similarity helper - def cosine(a: list[float], b: list[float]) -> float: - arr_a, arr_b = np.array(a), np.array(b) - denominator = float(np.linalg.norm(arr_a) * np.linalg.norm(arr_b)) - if denominator == 0: - return 0.0 - return float(np.dot(arr_a, arr_b) / denominator) - - # Compute relevance scores (cosine similarity to query) - # Note: We use semantic relevance to query, not the keyword search 'relevance' score - relevance_scores = [cosine(query_emb, emb) for emb in evidence_embs] - - # Greedy MMR selection - selected_indices: list[int] = [] - remaining = set(range(len(evidence))) - - for _ in range(n): - best_score = float("-inf") - best_idx = -1 - - for idx in remaining: - # Relevance component - relevance = relevance_scores[idx] - - # Diversity component: max similarity to already selected - if selected_indices: - max_sim = max( - cosine(evidence_embs[idx], evidence_embs[sel]) for sel in selected_indices - ) - else: - max_sim = 0 - - # MMR score - mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim - - if mmr_score > best_score: - best_score = mmr_score - best_idx = idx - - if best_idx >= 0: - selected_indices.append(best_idx) - remaining.remove(best_idx) - - return [evidence[i] for i in selected_indices]