|
|
import os |
|
|
import logging |
|
|
from typing import Dict, Any, List, Optional |
|
|
from datetime import datetime |
|
|
from src.tools.base_tool import BaseTool |
|
|
|
|
|
logger = logging.getLogger("healthcare-mcp") |
|
|
|
|
|
class PubMedTool(BaseTool): |
|
|
"""Tool for searching medical literature in PubMed database""" |
|
|
|
|
|
def __init__(self, cache_db_path: str = "healthcare_cache.db"): |
|
|
"""Initialize the PubMed tool with API key and base URL""" |
|
|
super().__init__(cache_db_path=cache_db_path) |
|
|
self.api_key = os.getenv("PUBMED_API_KEY", "") |
|
|
self.base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" |
|
|
|
|
|
async def search_literature(self, query: str, max_results: int = 5, date_range: str = "") -> Dict[str, Any]: |
|
|
""" |
|
|
Search for medical literature in PubMed database with caching |
|
|
|
|
|
Args: |
|
|
query: Search query for medical literature |
|
|
max_results: Maximum number of results to return |
|
|
date_range: Limit to articles published within years (e.g. '5' for last 5 years) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing search results or error details |
|
|
""" |
|
|
|
|
|
if not query: |
|
|
return self._format_error_response("Search query is required") |
|
|
|
|
|
|
|
|
try: |
|
|
max_results = int(max_results) |
|
|
if max_results < 1: |
|
|
max_results = 5 |
|
|
elif max_results > 100: |
|
|
max_results = 100 |
|
|
except (ValueError, TypeError): |
|
|
max_results = 5 |
|
|
|
|
|
|
|
|
cache_key = self._get_cache_key("pubmed_search", query, max_results, date_range) |
|
|
|
|
|
|
|
|
cached_result = self.cache.get(cache_key) |
|
|
if cached_result: |
|
|
logger.info(f"Cache hit for PubMed search: {query}") |
|
|
return cached_result |
|
|
|
|
|
try: |
|
|
logger.info(f"Searching PubMed for: {query}, max_results={max_results}, date_range={date_range}") |
|
|
|
|
|
|
|
|
processed_query = query |
|
|
if date_range: |
|
|
try: |
|
|
years_back = int(date_range) |
|
|
current_year = datetime.now().year |
|
|
min_year = current_year - years_back |
|
|
processed_query += f" AND {min_year}:{current_year}[pdat]" |
|
|
logger.debug(f"Added date range filter: {min_year}-{current_year}") |
|
|
except ValueError: |
|
|
|
|
|
logger.warning(f"Invalid date range: {date_range}, ignoring") |
|
|
pass |
|
|
|
|
|
|
|
|
search_params = { |
|
|
"db": "pubmed", |
|
|
"term": processed_query, |
|
|
"retmax": max_results, |
|
|
"format": "json" |
|
|
} |
|
|
|
|
|
|
|
|
if self.api_key: |
|
|
search_params["api_key"] = self.api_key |
|
|
|
|
|
|
|
|
search_endpoint = f"{self.base_url}esearch.fcgi" |
|
|
search_data = await self._make_request(search_endpoint, params=search_params) |
|
|
|
|
|
|
|
|
id_list = search_data.get("esearchresult", {}).get("idlist", []) |
|
|
total_results = int(search_data.get("esearchresult", {}).get("count", 0)) |
|
|
|
|
|
|
|
|
articles = [] |
|
|
if id_list: |
|
|
|
|
|
summary_params = { |
|
|
"db": "pubmed", |
|
|
"id": ",".join(id_list), |
|
|
"retmode": "json" |
|
|
} |
|
|
|
|
|
|
|
|
if self.api_key: |
|
|
summary_params["api_key"] = self.api_key |
|
|
|
|
|
|
|
|
summary_endpoint = f"{self.base_url}esummary.fcgi" |
|
|
summary_data = await self._make_request(summary_endpoint, params=summary_params) |
|
|
|
|
|
|
|
|
articles = await self._process_article_data(id_list, summary_data) |
|
|
|
|
|
|
|
|
result = self._format_success_response( |
|
|
query=query, |
|
|
total_results=total_results, |
|
|
articles=articles |
|
|
) |
|
|
|
|
|
|
|
|
self.cache.set(cache_key, result, ttl=43200) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error searching PubMed: {str(e)}") |
|
|
return self._format_error_response(f"Error searching PubMed: {str(e)}") |
|
|
|
|
|
async def _process_article_data(self, id_list: List[str], summary_data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process article data from PubMed API response |
|
|
|
|
|
Args: |
|
|
id_list: List of article IDs |
|
|
summary_data: Summary data from PubMed API |
|
|
|
|
|
Returns: |
|
|
List of processed article data |
|
|
""" |
|
|
articles = [] |
|
|
|
|
|
|
|
|
result_data = summary_data.get("result", {}) |
|
|
|
|
|
|
|
|
for article_id in id_list: |
|
|
if article_id in result_data: |
|
|
article_data = result_data[article_id] |
|
|
|
|
|
|
|
|
authors = [] |
|
|
if "authors" in article_data: |
|
|
authors = [author.get("name", "") for author in article_data["authors"] if "name" in author] |
|
|
|
|
|
|
|
|
article = { |
|
|
"id": article_id, |
|
|
"title": article_data.get("title", ""), |
|
|
"authors": authors, |
|
|
"journal": article_data.get("fulljournalname", ""), |
|
|
"publication_date": article_data.get("pubdate", ""), |
|
|
"abstract_url": f"https://pubmed.ncbi.nlm.nih.gov/{article_id}/", |
|
|
} |
|
|
|
|
|
|
|
|
if "articleids" in article_data: |
|
|
for id_obj in article_data["articleids"]: |
|
|
if id_obj.get("idtype") == "doi": |
|
|
article["doi"] = id_obj.get("value", "") |
|
|
|
|
|
articles.append(article) |
|
|
|
|
|
return articles |
|
|
|