scrapeRL / backend /app /search /engine.py
NeerajCodz's picture
feat: add MCP tool registry and search engine integration
afefaea
"""Search engine router for aggregating multiple search providers."""
from typing import Any, Optional
from dataclasses import dataclass, field
from app.utils.logging import get_logger
logger = get_logger(__name__)
@dataclass
class SearchResult:
"""Individual search result."""
title: str
url: str
snippet: str
position: int
source: str
score: float = 1.0
metadata: dict[str, Any] = field(default_factory=dict)
class SearchEngineRouter:
"""
Routes search queries to different providers and aggregates results.
Supports multiple search providers and can aggregate/rank results
from multiple sources.
"""
def __init__(self) -> None:
self._providers: dict[str, Any] = {}
self._default_provider: Optional[str] = None
self._initialized: bool = False
async def initialize(self) -> None:
"""Initialize the search engine router and all providers."""
logger.info("Initializing SearchEngineRouter")
# Initialize all registered providers
for name, provider in self._providers.items():
try:
if hasattr(provider, "initialize"):
await provider.initialize()
logger.info(f"Initialized provider: {name}")
except Exception as e:
logger.error(f"Failed to initialize provider {name}: {e}")
self._initialized = True
logger.info("SearchEngineRouter initialized")
async def shutdown(self) -> None:
"""Shutdown the router and all providers."""
logger.info("Shutting down SearchEngineRouter")
for name, provider in self._providers.items():
try:
if hasattr(provider, "shutdown"):
await provider.shutdown()
logger.info(f"Shut down provider: {name}")
except Exception as e:
logger.error(f"Error shutting down provider {name}: {e}")
self._initialized = False
def register_provider(
self,
name: str,
provider: Any,
set_default: bool = False,
) -> None:
"""
Register a search provider.
Args:
name: Provider identifier
provider: Provider instance
set_default: Set as the default provider
"""
self._providers[name] = provider
logger.info(f"Registered search provider: {name}")
if set_default or self._default_provider is None:
self._default_provider = name
logger.info(f"Set default provider: {name}")
def unregister_provider(self, name: str) -> bool:
"""
Unregister a search provider.
Args:
name: Provider identifier
Returns:
True if provider was removed
"""
if name in self._providers:
del self._providers[name]
if self._default_provider == name:
self._default_provider = next(iter(self._providers), None)
logger.info(f"Unregistered provider: {name}")
return True
return False
def get_providers(self) -> list[str]:
"""
Get list of registered provider names.
Returns:
List of provider identifiers
"""
return list(self._providers.keys())
def get_provider(self, name: str) -> Optional[Any]:
"""
Get a specific provider by name.
Args:
name: Provider identifier
Returns:
Provider instance or None
"""
return self._providers.get(name)
async def search(
self,
query: str,
max_results: int = 10,
provider: Optional[str] = None,
) -> list[SearchResult]:
"""
Perform a search using a specific provider.
Args:
query: Search query string
max_results: Maximum results to return
provider: Provider to use (defaults to default provider)
Returns:
List of search results
Raises:
ValueError: If provider not found
"""
provider_name = provider or self._default_provider
if provider_name is None:
raise ValueError("No search provider configured")
if provider_name not in self._providers:
raise ValueError(f"Provider '{provider_name}' not found")
provider_instance = self._providers[provider_name]
logger.info(f"Searching with provider '{provider_name}': {query}")
try:
results = await provider_instance.search(query, max_results)
# Ensure results have proper source attribution
for i, result in enumerate(results):
if isinstance(result, dict):
result["source"] = provider_name
result["position"] = i + 1
elif hasattr(result, "source"):
result.source = provider_name
result.position = i + 1
return results
except Exception as e:
logger.error(f"Search failed with provider '{provider_name}': {e}")
raise
async def search_all(
self,
query: str,
max_results_per_provider: int = 10,
providers: Optional[list[str]] = None,
) -> list[SearchResult]:
"""
Search across multiple providers and aggregate results.
Args:
query: Search query string
max_results_per_provider: Max results from each provider
providers: Specific providers to use (defaults to all)
Returns:
Aggregated and ranked list of results
"""
provider_names = providers or list(self._providers.keys())
all_results: list[SearchResult] = []
for provider_name in provider_names:
try:
results = await self.search(
query=query,
max_results=max_results_per_provider,
provider=provider_name,
)
all_results.extend(results)
except Exception as e:
logger.warning(f"Provider '{provider_name}' failed: {e}")
continue
# Rank and deduplicate results
ranked_results = self._rank_results(all_results)
return ranked_results
def _rank_results(
self,
results: list[SearchResult],
) -> list[SearchResult]:
"""
Rank and deduplicate search results.
Args:
results: Raw results from multiple providers
Returns:
Ranked and deduplicated results
"""
# Deduplicate by URL
seen_urls: set[str] = set()
unique_results: list[SearchResult] = []
for result in results:
url = result.url if hasattr(result, "url") else result.get("url", "")
if url and url not in seen_urls:
seen_urls.add(url)
unique_results.append(result)
# Sort by score (higher is better) then by position (lower is better)
def sort_key(r: Any) -> tuple[float, int]:
score = r.score if hasattr(r, "score") else r.get("score", 1.0)
position = r.position if hasattr(r, "position") else r.get("position", 999)
return (-score, position)
unique_results.sort(key=sort_key)
# Update positions
for i, result in enumerate(unique_results):
if hasattr(result, "position"):
result.position = i + 1
elif isinstance(result, dict):
result["position"] = i + 1
return unique_results
@property
def is_initialized(self) -> bool:
"""Check if the router is initialized."""
return self._initialized
@property
def default_provider(self) -> Optional[str]:
"""Get the default provider name."""
return self._default_provider