Health_Care / app.py
Ahmad-01's picture
Update app.py
a57e772 verified
# app.py
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import faiss
import numpy as np
import nltk
# ------------------------------
# Step 0. NLTK setup
# ------------------------------
nltk.download("punkt", quiet=True)
# ------------------------------
# Step 1. Load dataset
# ------------------------------
print("πŸ“š Loading PubMedQA dataset...")
dataset = load_dataset("pubmed_qa", "pqa_labeled")
def extract_docs(ds):
"""Extract clean text documents safely from the PubMedQA dataset."""
docs = []
for e in ds:
if isinstance(e, dict):
ctx = e.get("context", "")
if isinstance(ctx, dict):
text = ctx.get("contexts", [""])
if isinstance(text, list):
docs.append(" ".join(map(str, text)))
else:
docs.append(str(text))
else:
docs.append(str(ctx))
elif isinstance(e, str):
docs.append(e)
else:
docs.append(str(e))
return docs
documents = extract_docs(dataset["train"][:500])
print(f"βœ… Loaded {len(documents)} biomedical documents.")
# ------------------------------
# Step 2. Build embeddings (Biomedical)
# ------------------------------
print("πŸ” Building biomedical embeddings...")
embed_model = SentenceTransformer("pritamdeka/S-PubMedBert-MS-MARCO")
embeddings = embed_model.encode(documents, show_progress_bar=True)
embeddings = np.array(embeddings).astype("float32")
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
print("βœ… FAISS index built with biomedical embeddings.")
# ------------------------------
# Step 3. Load biomedical generation model
# ------------------------------
print("βš™οΈ Loading biomedical text generation model...")
tokenizer = AutoTokenizer.from_pretrained("allenai/biomed-flan-t5-base")
gen_model = AutoModelForSeq2SeqLM.from_pretrained("allenai/biomed-flan-t5-base")
# ------------------------------
# Step 4. Define RAG function
# ------------------------------
def rag_answer(question, k=3, max_new_tokens=256):
"""Retrieve top-k relevant biomedical passages and generate an answer."""
if not question.strip():
return "Please enter a question.", ""
query_vec = embed_model.encode([question])
scores, indices = index.search(query_vec.astype("float32"), k)
retrieved = [documents[i] for i in indices[0]]
context = "\n".join(retrieved)
prompt = f"Question: {question}\n\nContext:\n{context}\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = gen_model.generate(**inputs, max_new_tokens=max_new_tokens)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer, "\n\n---\n".join(retrieved)
# ------------------------------
# Step 5. Gradio Interface
# ------------------------------
def ask(question, k, max_tokens):
answer, sources = rag_answer(question, k, max_tokens)
return answer, sources
with gr.Blocks(title="πŸ₯ MedQuery AI β€” Biomedical RAG Assistant") as demo:
gr.Markdown(
"""
# πŸ₯ MedQuery AI β€” Biomedical Knowledge Assistant
This app retrieves relevant PubMed-style passages and generates concise,
**evidence-based biomedical answers** using Retrieval-Augmented Generation (RAG).
"""
)
with gr.Row():
question = gr.Textbox(
label="Ask a biomedical or clinical question",
placeholder="e.g. What are the diagnostic criteria for hypertension?"
)
with gr.Row():
k = gr.Slider(1, 8, step=1, value=3, label="Top-K passages to retrieve")
max_tokens = gr.Slider(64, 512, step=32, value=256, label="Max tokens for answer")
with gr.Row():
submit = gr.Button("Get Answer")
answer = gr.Textbox(label="AI Answer", lines=4)
sources = gr.Textbox(label="Retrieved Context", lines=10)
submit.click(ask, inputs=[question, k, max_tokens], outputs=[answer, sources])
demo.launch()