GraphGen / graphgen /models /searcher /db /rnacentral_searcher.py
github-actions[bot]
Auto-sync from demo at Mon Dec 1 10:51:45 UTC 2025
06c3276
import asyncio
import os
import re
import subprocess
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
import tempfile
from typing import Dict, Optional, List, Any, Set
import hashlib
import requests
import aiohttp
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from graphgen.bases import BaseSearcher
from graphgen.utils import logger
@lru_cache(maxsize=None)
def _get_pool():
return ThreadPoolExecutor(max_workers=10)
class RNACentralSearch(BaseSearcher):
"""
RNAcentral Search client to search RNA databases.
1) Get RNA by RNAcentral ID.
2) Search with keywords or RNA names (fuzzy search).
3) Search with RNA sequence.
API Documentation: https://rnacentral.org/api/v1
"""
def __init__(self, use_local_blast: bool = False, local_blast_db: str = "rna_db"):
super().__init__()
self.base_url = "https://rnacentral.org/api/v1"
self.headers = {"Accept": "application/json"}
self.use_local_blast = use_local_blast
self.local_blast_db = local_blast_db
if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
logger.error("Local BLAST database files not found. Please check the path.")
self.use_local_blast = False
@staticmethod
def _rna_data_to_dict(
rna_id: str,
rna_data: Dict[str, Any],
xrefs_data: Optional[List[Dict[str, Any]]] = None
) -> Dict[str, Any]:
organisms, gene_names, so_terms = set(), set(), set()
modifications: List[Any] = []
for xref in xrefs_data or []:
acc = xref.get("accession", {})
if s := acc.get("species"):
organisms.add(s)
if g := acc.get("gene", "").strip():
gene_names.add(g)
if m := xref.get("modifications"):
modifications.extend(m)
if b := acc.get("biotype"):
so_terms.add(b)
def format_unique_values(values: Set[str]) -> Optional[str]:
if not values:
return None
if len(values) == 1:
return next(iter(values))
return ", ".join(sorted(values))
xrefs_info = {
"organism": format_unique_values(organisms),
"gene_name": format_unique_values(gene_names),
"related_genes": list(gene_names) if gene_names else None,
"modifications": modifications or None,
"so_term": format_unique_values(so_terms),
}
fallback_rules = {
"organism": ["organism", "species"],
"related_genes": ["related_genes", "genes"],
"gene_name": ["gene_name", "gene"],
"so_term": ["so_term"],
"modifications": ["modifications"],
}
def resolve_field(field_name: str) -> Any:
if (value := xrefs_info.get(field_name)) is not None:
return value
for key in fallback_rules[field_name]:
if (value := rna_data.get(key)) is not None:
return value
return None
organism = resolve_field("organism")
gene_name = resolve_field("gene_name")
so_term = resolve_field("so_term")
modifications = resolve_field("modifications")
related_genes = resolve_field("related_genes")
if not related_genes and (single_gene := rna_data.get("gene_name")):
related_genes = [single_gene]
sequence = rna_data.get("sequence", "")
return {
"molecule_type": "RNA",
"database": "RNAcentral",
"id": rna_id,
"rnacentral_id": rna_data.get("rnacentral_id", rna_id),
"sequence": sequence,
"sequence_length": rna_data.get("length", len(sequence)),
"rna_type": rna_data.get("rna_type", "N/A"),
"description": rna_data.get("description", "N/A"),
"url": f"https://rnacentral.org/rna/{rna_id}",
"organism": organism,
"related_genes": related_genes or None,
"gene_name": gene_name,
"so_term": so_term,
"modifications": modifications,
}
@staticmethod
def _calculate_md5(sequence: str) -> str:
"""
Calculate MD5 hash for RNA sequence as per RNAcentral spec.
- Replace U with T
- Convert to uppercase
- Encode as ASCII
"""
# Normalize sequence
normalized_seq = sequence.replace("U", "T").replace("u", "t").upper()
if not re.fullmatch(r"[ATCGN]+", normalized_seq):
raise ValueError(f"Invalid sequence characters after normalization: {normalized_seq[:50]}...")
return hashlib.md5(normalized_seq.encode("ascii")).hexdigest()
def get_by_rna_id(self, rna_id: str) -> Optional[dict]:
"""
Get RNA information by RNAcentral ID.
:param rna_id: RNAcentral ID (e.g., URS0000000001).
:return: A dictionary containing RNA information or None if not found.
"""
try:
url = f"{self.base_url}/rna/{rna_id}"
url += "?flat=true"
resp = requests.get(url, headers=self.headers, timeout=30)
resp.raise_for_status()
rna_data = resp.json()
xrefs_data = rna_data.get("xrefs", [])
return self._rna_data_to_dict(rna_id, rna_data, xrefs_data)
except requests.RequestException as e:
logger.error("Network error getting RNA ID %s: %s", rna_id, e)
return None
except Exception as e: # pylint: disable=broad-except
logger.error("Unexpected error getting RNA ID %s: %s", rna_id, e)
return None
def get_best_hit(self, keyword: str) -> Optional[dict]:
"""
Search RNAcentral with a keyword and return the best hit.
:param keyword: The search keyword (e.g., miRNA name, RNA name).
:return: Dictionary with RNA information or None.
"""
keyword = keyword.strip()
if not keyword:
logger.warning("Empty keyword provided to get_best_hit")
return None
try:
url = f"{self.base_url}/rna"
params = {"search": keyword, "format": "json"}
resp = requests.get(url, params=params, headers=self.headers, timeout=30)
resp.raise_for_status()
data = resp.json()
results = data.get("results", [])
if not results:
logger.info("No search results for keyword: %s", keyword)
return None
first_result = results[0]
rna_id = first_result.get("rnacentral_id")
if rna_id:
detailed = self.get_by_rna_id(rna_id)
if detailed:
return detailed
logger.debug("Using search result data for %s", rna_id or "unknown")
return self._rna_data_to_dict(rna_id or "", first_result)
except requests.RequestException as e:
logger.error("Network error searching keyword '%s': %s", keyword, e)
return None
except Exception as e:
logger.error("Unexpected error searching keyword '%s': %s", keyword, e)
return None
def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
"""Perform local BLAST search using local BLAST database."""
try:
with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp:
tmp.write(f">query\n{seq}\n")
tmp_name = tmp.name
cmd = [
"blastn", "-db", self.local_blast_db, "-query", tmp_name,
"-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc"
]
logger.debug("Running local blastn for RNA: %s", " ".join(cmd))
out = subprocess.check_output(cmd, text=True).strip()
os.remove(tmp_name)
return out.split("\n", maxsplit=1)[0] if out else None
except Exception as exc:
logger.error("Local blastn failed: %s", exc)
return None
def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
"""
Search RNAcentral with an RNA sequence.
Tries local BLAST first if enabled, falls back to RNAcentral API.
Unified approach: Find RNA ID from sequence search, then call get_by_rna_id() for complete information.
:param sequence: RNA sequence (FASTA format or raw sequence).
:param threshold: E-value threshold for BLAST search.
:return: A dictionary containing complete RNA information or None if not found.
"""
def _extract_sequence(sequence: str) -> Optional[str]:
"""Extract and normalize RNA sequence from input."""
if sequence.startswith(">"):
seq_lines = sequence.strip().split("\n")
seq = "".join(seq_lines[1:])
else:
seq = sequence.strip().replace(" ", "").replace("\n", "")
return seq if seq and re.fullmatch(r"[AUCGN\s]+", seq, re.I) else None
try:
seq = _extract_sequence(sequence)
if not seq:
logger.error("Empty or invalid RNA sequence provided.")
return None
# Try local BLAST first if enabled
if self.use_local_blast:
accession = self._local_blast(seq, threshold)
if accession:
logger.debug("Local BLAST found accession: %s", accession)
return self.get_by_rna_id(accession)
# Fall back to RNAcentral API if local BLAST didn't find result
logger.debug("Falling back to RNAcentral API.")
md5_hash = self._calculate_md5(seq)
search_url = f"{self.base_url}/rna"
params = {"md5": md5_hash, "format": "json"}
resp = requests.get(search_url, params=params, headers=self.headers, timeout=60)
resp.raise_for_status()
search_results = resp.json()
results = search_results.get("results", [])
if not results:
logger.info("No exact match found in RNAcentral for sequence")
return None
rna_id = results[0].get("rnacentral_id")
if not rna_id:
logger.error("No RNAcentral ID found in search results.")
return None
return self.get_by_rna_id(rna_id)
except Exception as e:
logger.error("Sequence search failed: %s", e)
return None
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
reraise=True,
)
async def search(self, query: str, threshold: float = 0.1, **kwargs) -> Optional[Dict]:
"""Search RNAcentral with either an RNAcentral ID, keyword, or RNA sequence."""
if not query or not isinstance(query, str):
logger.error("Empty or non-string input.")
return None
query = query.strip()
logger.debug("RNAcentral search query: %s", query)
loop = asyncio.get_running_loop()
# check if RNA sequence (AUCG characters, contains U)
if query.startswith(">") or (
re.fullmatch(r"[AUCGN\s]+", query, re.I) and "U" in query.upper()
):
result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
# check if RNAcentral ID (typically starts with URS)
elif re.fullmatch(r"URS\d+", query, re.I):
result = await loop.run_in_executor(_get_pool(), self.get_by_rna_id, query)
else:
# otherwise treat as keyword
result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
if result:
result["_search_query"] = query
return result