Spaces:
Runtime error
Runtime error
| # app.py | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import faiss | |
| import numpy as np | |
| import nltk | |
| # ------------------------------ | |
| # Step 0. NLTK setup | |
| # ------------------------------ | |
| nltk.download("punkt", quiet=True) | |
| # ------------------------------ | |
| # Step 1. Load dataset | |
| # ------------------------------ | |
| print("π Loading PubMedQA dataset...") | |
| dataset = load_dataset("pubmed_qa", "pqa_labeled") | |
| def extract_docs(ds): | |
| """Extract clean text documents safely from the PubMedQA dataset.""" | |
| docs = [] | |
| for e in ds: | |
| if isinstance(e, dict): | |
| ctx = e.get("context", "") | |
| if isinstance(ctx, dict): | |
| text = ctx.get("contexts", [""]) | |
| if isinstance(text, list): | |
| docs.append(" ".join(map(str, text))) | |
| else: | |
| docs.append(str(text)) | |
| else: | |
| docs.append(str(ctx)) | |
| elif isinstance(e, str): | |
| docs.append(e) | |
| else: | |
| docs.append(str(e)) | |
| return docs | |
| documents = extract_docs(dataset["train"][:500]) | |
| print(f"β Loaded {len(documents)} biomedical documents.") | |
| # ------------------------------ | |
| # Step 2. Build embeddings (Biomedical) | |
| # ------------------------------ | |
| print("π Building biomedical embeddings...") | |
| embed_model = SentenceTransformer("pritamdeka/S-PubMedBert-MS-MARCO") | |
| embeddings = embed_model.encode(documents, show_progress_bar=True) | |
| embeddings = np.array(embeddings).astype("float32") | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings) | |
| print("β FAISS index built with biomedical embeddings.") | |
| # ------------------------------ | |
| # Step 3. Load biomedical generation model | |
| # ------------------------------ | |
| print("βοΈ Loading biomedical text generation model...") | |
| tokenizer = AutoTokenizer.from_pretrained("allenai/biomed-flan-t5-base") | |
| gen_model = AutoModelForSeq2SeqLM.from_pretrained("allenai/biomed-flan-t5-base") | |
| # ------------------------------ | |
| # Step 4. Define RAG function | |
| # ------------------------------ | |
| def rag_answer(question, k=3, max_new_tokens=256): | |
| """Retrieve top-k relevant biomedical passages and generate an answer.""" | |
| if not question.strip(): | |
| return "Please enter a question.", "" | |
| query_vec = embed_model.encode([question]) | |
| scores, indices = index.search(query_vec.astype("float32"), k) | |
| retrieved = [documents[i] for i in indices[0]] | |
| context = "\n".join(retrieved) | |
| prompt = f"Question: {question}\n\nContext:\n{context}\n\nAnswer:" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| outputs = gen_model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return answer, "\n\n---\n".join(retrieved) | |
| # ------------------------------ | |
| # Step 5. Gradio Interface | |
| # ------------------------------ | |
| def ask(question, k, max_tokens): | |
| answer, sources = rag_answer(question, k, max_tokens) | |
| return answer, sources | |
| with gr.Blocks(title="π₯ MedQuery AI β Biomedical RAG Assistant") as demo: | |
| gr.Markdown( | |
| """ | |
| # π₯ MedQuery AI β Biomedical Knowledge Assistant | |
| This app retrieves relevant PubMed-style passages and generates concise, | |
| **evidence-based biomedical answers** using Retrieval-Augmented Generation (RAG). | |
| """ | |
| ) | |
| with gr.Row(): | |
| question = gr.Textbox( | |
| label="Ask a biomedical or clinical question", | |
| placeholder="e.g. What are the diagnostic criteria for hypertension?" | |
| ) | |
| with gr.Row(): | |
| k = gr.Slider(1, 8, step=1, value=3, label="Top-K passages to retrieve") | |
| max_tokens = gr.Slider(64, 512, step=32, value=256, label="Max tokens for answer") | |
| with gr.Row(): | |
| submit = gr.Button("Get Answer") | |
| answer = gr.Textbox(label="AI Answer", lines=4) | |
| sources = gr.Textbox(label="Retrieved Context", lines=10) | |
| submit.click(ask, inputs=[question, k, max_tokens], outputs=[answer, sources]) | |
| demo.launch() | |