|
|
"""Unit tests for SearchHandler.""" |
|
|
|
|
|
from unittest.mock import AsyncMock, create_autospec |
|
|
|
|
|
import pytest |
|
|
|
|
|
from src.tools.base import SearchTool |
|
|
from src.tools.search_handler import SearchHandler, deduplicate_evidence, extract_paper_id |
|
|
from src.utils.exceptions import SearchError |
|
|
from src.utils.models import Citation, Evidence |
|
|
|
|
|
|
|
|
def _make_evidence(source: str, url: str, metadata: dict | None = None) -> Evidence: |
|
|
"""Helper to create Evidence objects for testing.""" |
|
|
return Evidence( |
|
|
content="Test content", |
|
|
citation=Citation( |
|
|
source=source, |
|
|
title="Test", |
|
|
url=url, |
|
|
date="2024", |
|
|
authors=[], |
|
|
), |
|
|
metadata=metadata or {}, |
|
|
) |
|
|
|
|
|
|
|
|
class TestExtractPaperId: |
|
|
"""Tests for paper ID extraction from Evidence objects.""" |
|
|
|
|
|
def test_extracts_pubmed_id(self) -> None: |
|
|
evidence = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
|
|
assert extract_paper_id(evidence) == "PMID:12345678" |
|
|
|
|
|
def test_extracts_europepmc_med_id(self) -> None: |
|
|
evidence = _make_evidence("europepmc", "https://europepmc.org/article/MED/12345678") |
|
|
assert extract_paper_id(evidence) == "PMID:12345678" |
|
|
|
|
|
def test_extracts_europepmc_pmc_id(self) -> None: |
|
|
"""Europe PMC PMC articles have different ID format.""" |
|
|
evidence = _make_evidence("europepmc", "https://europepmc.org/article/PMC/PMC7654321") |
|
|
assert extract_paper_id(evidence) == "PMCID:PMC7654321" |
|
|
|
|
|
def test_extracts_europepmc_ppr_id(self) -> None: |
|
|
"""Europe PMC preprints have PPR IDs.""" |
|
|
evidence = _make_evidence("europepmc", "https://europepmc.org/article/PPR/PPR123456") |
|
|
assert extract_paper_id(evidence) == "PPRID:PPR123456" |
|
|
|
|
|
def test_extracts_europepmc_pat_id(self) -> None: |
|
|
"""Europe PMC patents have PAT IDs (WIPO format).""" |
|
|
evidence = _make_evidence("europepmc", "https://europepmc.org/article/PAT/WO8601415") |
|
|
assert extract_paper_id(evidence) == "PATID:WO8601415" |
|
|
|
|
|
def test_extracts_europepmc_pat_id_eu_format(self) -> None: |
|
|
"""European patent format should also work.""" |
|
|
evidence = _make_evidence("europepmc", "https://europepmc.org/article/PAT/EP1234567") |
|
|
assert extract_paper_id(evidence) == "PATID:EP1234567" |
|
|
|
|
|
def test_extracts_doi(self) -> None: |
|
|
evidence = _make_evidence("pubmed", "https://doi.org/10.1038/nature12345") |
|
|
assert extract_paper_id(evidence) == "DOI:10.1038/nature12345" |
|
|
|
|
|
def test_extracts_doi_with_trailing_slash(self) -> None: |
|
|
"""DOIs should be normalized (trailing slash removed).""" |
|
|
evidence = _make_evidence("pubmed", "https://doi.org/10.1038/nature12345/") |
|
|
assert extract_paper_id(evidence) == "DOI:10.1038/nature12345" |
|
|
|
|
|
def test_extracts_openalex_id_from_url(self) -> None: |
|
|
"""OpenAlex ID from URL (fallback when no PMID in metadata).""" |
|
|
evidence = _make_evidence("openalex", "https://openalex.org/W1234567890") |
|
|
assert extract_paper_id(evidence) == "OAID:W1234567890" |
|
|
|
|
|
def test_extracts_openalex_pmid_from_metadata(self) -> None: |
|
|
"""OpenAlex PMID from metadata takes priority over URL.""" |
|
|
evidence = _make_evidence( |
|
|
"openalex", |
|
|
"https://openalex.org/W1234567890", |
|
|
metadata={"pmid": "98765432"}, |
|
|
) |
|
|
assert extract_paper_id(evidence) == "PMID:98765432" |
|
|
|
|
|
def test_extracts_nct_id_modern(self) -> None: |
|
|
evidence = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT12345678") |
|
|
assert extract_paper_id(evidence) == "NCT:NCT12345678" |
|
|
|
|
|
def test_extracts_nct_id_legacy(self) -> None: |
|
|
"""Legacy ClinicalTrials.gov URL format should also work.""" |
|
|
evidence = _make_evidence( |
|
|
"clinicaltrials", "https://clinicaltrials.gov/ct2/show/NCT12345678" |
|
|
) |
|
|
assert extract_paper_id(evidence) == "NCT:NCT12345678" |
|
|
|
|
|
def test_returns_none_for_unknown_url(self) -> None: |
|
|
evidence = _make_evidence("web", "https://example.com/unknown") |
|
|
assert extract_paper_id(evidence) is None |
|
|
|
|
|
|
|
|
class TestDeduplicateEvidence: |
|
|
"""Tests for evidence deduplication.""" |
|
|
|
|
|
def test_removes_pubmed_europepmc_duplicate(self) -> None: |
|
|
"""Same paper from PubMed and Europe PMC should dedupe to PubMed.""" |
|
|
pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
|
|
europepmc = _make_evidence("europepmc", "https://europepmc.org/article/MED/12345678") |
|
|
|
|
|
result = deduplicate_evidence([pubmed, europepmc]) |
|
|
|
|
|
assert len(result) == 1 |
|
|
assert result[0].citation.source == "pubmed" |
|
|
|
|
|
def test_removes_pubmed_openalex_duplicate_via_metadata(self) -> None: |
|
|
"""OpenAlex with PMID in metadata should dedupe against PubMed.""" |
|
|
pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
|
|
openalex = _make_evidence( |
|
|
"openalex", |
|
|
"https://openalex.org/W9999999", |
|
|
metadata={"pmid": "12345678", "cited_by_count": 100}, |
|
|
) |
|
|
|
|
|
result = deduplicate_evidence([pubmed, openalex]) |
|
|
|
|
|
assert len(result) == 1 |
|
|
assert result[0].citation.source == "pubmed" |
|
|
|
|
|
def test_preserves_unique_evidence(self) -> None: |
|
|
"""Different papers should not be deduplicated.""" |
|
|
e1 = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/11111111/") |
|
|
e2 = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/22222222/") |
|
|
|
|
|
result = deduplicate_evidence([e1, e2]) |
|
|
|
|
|
assert len(result) == 2 |
|
|
|
|
|
def test_preserves_openalex_without_pmid(self) -> None: |
|
|
"""OpenAlex papers without PMID should NOT be deduplicated against PubMed.""" |
|
|
pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
|
|
openalex_no_pmid = _make_evidence( |
|
|
"openalex", |
|
|
"https://openalex.org/W9999999", |
|
|
metadata={"cited_by_count": 100}, |
|
|
) |
|
|
|
|
|
result = deduplicate_evidence([pubmed, openalex_no_pmid]) |
|
|
|
|
|
assert len(result) == 2 |
|
|
|
|
|
def test_keeps_unidentifiable_evidence(self) -> None: |
|
|
"""Evidence with unrecognized URLs should be preserved.""" |
|
|
unknown = _make_evidence("web", "https://example.com/paper/123") |
|
|
|
|
|
result = deduplicate_evidence([unknown]) |
|
|
|
|
|
assert len(result) == 1 |
|
|
|
|
|
def test_clinicaltrials_unique_per_nct(self) -> None: |
|
|
"""ClinicalTrials entries have unique NCT IDs.""" |
|
|
trial1 = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT11111111") |
|
|
trial2 = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT22222222") |
|
|
|
|
|
result = deduplicate_evidence([trial1, trial2]) |
|
|
|
|
|
assert len(result) == 2 |
|
|
|
|
|
def test_preprints_preserved_separately(self) -> None: |
|
|
"""Preprints (PPR IDs) should not dedupe against peer-reviewed papers.""" |
|
|
peer_reviewed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
|
|
preprint = _make_evidence("europepmc", "https://europepmc.org/article/PPR/PPR999999") |
|
|
|
|
|
result = deduplicate_evidence([peer_reviewed, preprint]) |
|
|
|
|
|
assert len(result) == 2 |
|
|
|
|
|
|
|
|
class TestSearchHandler: |
|
|
"""Tests for SearchHandler.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_execute_aggregates_and_deduplicates(self): |
|
|
"""SearchHandler should aggregate results and deduplicate them.""" |
|
|
|
|
|
mock_tool1 = AsyncMock(spec=SearchTool) |
|
|
mock_tool1.name = "pubmed" |
|
|
mock_tool1.search.return_value = [ |
|
|
_make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
|
|
] |
|
|
|
|
|
mock_tool2 = AsyncMock(spec=SearchTool) |
|
|
mock_tool2.name = "europepmc" |
|
|
|
|
|
mock_tool2.search.return_value = [ |
|
|
_make_evidence("europepmc", "https://europepmc.org/article/MED/12345678") |
|
|
] |
|
|
|
|
|
handler = SearchHandler(tools=[mock_tool1, mock_tool2]) |
|
|
|
|
|
|
|
|
result = await handler.execute("test") |
|
|
|
|
|
|
|
|
assert result.total_found == 1 |
|
|
assert len(result.evidence) == 1 |
|
|
assert result.evidence[0].citation.source == "pubmed" |
|
|
assert "pubmed" in result.sources_searched |
|
|
assert "europepmc" in result.sources_searched |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_execute_handles_tool_failure(self): |
|
|
"""SearchHandler should continue if one tool fails.""" |
|
|
mock_tool_ok = create_autospec(SearchTool, instance=True) |
|
|
mock_tool_ok.name = "pubmed" |
|
|
mock_tool_ok.search = AsyncMock( |
|
|
return_value=[_make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")] |
|
|
) |
|
|
|
|
|
mock_tool_fail = create_autospec(SearchTool, instance=True) |
|
|
mock_tool_fail.name = "clinicaltrials" |
|
|
mock_tool_fail.search = AsyncMock(side_effect=SearchError("API down")) |
|
|
|
|
|
handler = SearchHandler(tools=[mock_tool_ok, mock_tool_fail]) |
|
|
result = await handler.execute("test") |
|
|
|
|
|
assert result.total_found == 1 |
|
|
assert "pubmed" in result.sources_searched |
|
|
assert len(result.errors) == 1 |
|
|
assert "clinicaltrials: API down" in result.errors[0] |
|
|
|