|
""" |
|
MedGenesis β dual-LLM orchestrator |
|
---------------------------------- |
|
β’ Accepts `llm` arg ("openai" | "gemini") |
|
β’ Defaults to "openai" if arg omitted |
|
""" |
|
|
|
import asyncio, httpx |
|
from typing import Dict, Any, List |
|
|
|
from mcp.arxiv import fetch_arxiv |
|
from mcp.pubmed import fetch_pubmed |
|
from mcp.nlp import extract_keywords |
|
from mcp.umls import lookup_umls |
|
from mcp.openfda import fetch_drug_safety |
|
from mcp.ncbi import search_gene, get_mesh_definition |
|
from mcp.disgenet import disease_to_genes |
|
from mcp.clinicaltrials import search_trials |
|
from mcp.openai_utils import ai_summarize, ai_qa |
|
from mcp.gemini import gemini_summarize, gemini_qa |
|
|
|
|
|
def _get_llm(llm: str): |
|
if llm.lower() == "gemini": |
|
return gemini_summarize, gemini_qa |
|
return ai_summarize, ai_qa |
|
|
|
|
|
async def _enrich_genes_mesh_disg(keys: List[str]) -> Dict[str, Any]: |
|
jobs = [] |
|
for k in keys: |
|
jobs += [search_gene(k), get_mesh_definition(k), disease_to_genes(k)] |
|
res = await asyncio.gather(*jobs, return_exceptions=True) |
|
|
|
genes, meshes, disg = [], [], [] |
|
for i, r in enumerate(res): |
|
if isinstance(r, Exception): |
|
continue |
|
if i % 3 == 0: genes.extend(r) |
|
elif i % 3 == 1: meshes.append(r) |
|
else: disg.extend(r) |
|
return {"genes": genes, "meshes": meshes, "disgenet": disg} |
|
|
|
|
|
|
|
async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]: |
|
""" |
|
Main orchestrator β returns dict for UI. |
|
""" |
|
|
|
arxiv_f = asyncio.create_task(fetch_arxiv(query)) |
|
pubmed_f = asyncio.create_task(fetch_pubmed(query)) |
|
papers = sum(await asyncio.gather(arxiv_f, pubmed_f), []) |
|
|
|
|
|
blob = " ".join(p["summary"] for p in papers) |
|
keys = extract_keywords(blob)[:8] |
|
|
|
|
|
umls_f = [lookup_umls(k) for k in keys] |
|
fda_f = [fetch_drug_safety(k) for k in keys] |
|
genes_f = asyncio.create_task(_enrich_genes_mesh_disg(keys)) |
|
trials_f = asyncio.create_task(search_trials(query, max_studies=10)) |
|
|
|
umls, fda, genes, trials = await asyncio.gather( |
|
asyncio.gather(*umls_f, return_exceptions=True), |
|
asyncio.gather(*fda_f, return_exceptions=True), |
|
genes_f, |
|
trials_f, |
|
) |
|
|
|
|
|
summarize, _ = _get_llm(llm) |
|
summary = await summarize(blob) |
|
|
|
return { |
|
"papers" : papers, |
|
"umls" : umls, |
|
"drug_safety" : fda, |
|
"ai_summary" : summary, |
|
"llm_used" : llm.lower(), |
|
"genes" : genes["genes"], |
|
"mesh_defs" : genes["meshes"], |
|
"gene_disease" : genes["disgenet"], |
|
"clinical_trials" : trials, |
|
} |
|
|
|
|
|
async def answer_ai_question(question: str, context: str, llm: str = "openai") -> Dict[str, str]: |
|
"""One-shot follow-up Q-A via chosen engine.""" |
|
_, qa = _get_llm(llm) |
|
return {"answer": await qa(question, context)} |
|
|