MedQA / backend /main.py
mgbam's picture
Upload 4 files
3ef03d3 verified
from fastapi import FastAPI
from pydantic import BaseModel
from .llm_utils import simulate_search
from .umls_linker import link_umls
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import functools
ANSWER_MODEL = "sunhaonlp/SearchSimulation_14B"
@functools.lru_cache(maxsize=1)
def _load_answer_pipe():
tokenizer = AutoTokenizer.from_pretrained(ANSWER_MODEL)
model = AutoModelForCausalLM.from_pretrained(
ANSWER_MODEL,
trust_remote_code=True,
device_map="auto"
)
return pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
do_sample=False,
temperature=0.0,
)
class Query(BaseModel):
question: str
app = FastAPI(
title="ZeroSearch Medical Q&A API",
description="Ask clinical questions; get answers with UMLS links, no external search APIs.",
version="0.1.0",
)
@app.post("/ask")
def ask(query: Query):
docs = simulate_search(query.question, k=5)
context = "\n\n".join(docs)
prompt = (
"Answer the medical question strictly based on the provided context.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query.question}\nAnswer:"
)
answer_pipe = _load_answer_pipe()
answer = (
answer_pipe(prompt, num_return_sequences=1)[0]["generated_text"]
.split("Answer:")[-1].strip()
)
umls = link_umls(answer)
return {"answer": answer, "docs": docs, "umls": umls}