Bman21 commited on
Commit
f9e7c04
·
verified ·
1 Parent(s): 94d9c74

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, pipeline
3
+ import torch
4
+ import faiss
5
+ import numpy as np
6
+
7
+ # -------------------
8
+ # CONFIG
9
+ # -------------------
10
+ DEVICE = "cpu"
11
+ MAX_TOKENS = 256
12
+ SEARCH_RESULTS = 3
13
+
14
+ SYSTEM_MESSAGE = (
15
+ "You are a helpful medical tutor. Provide clear, accurate explanations "
16
+ "based strictly on the provided notes. If the answer isn't in the notes, "
17
+ "say you don't have enough information."
18
+ )
19
+
20
+ # -------------------
21
+ # MODELS
22
+ # -------------------
23
+ # PubMedBERT for embeddings
24
+ EMBED_MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased"
25
+ embed_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)
26
+ embed_model = AutoModel.from_pretrained(EMBED_MODEL_NAME)
27
+
28
+ # Flan-T5 for generation
29
+ GEN_MODEL_NAME = "google/flan-t5-small"
30
+ gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
31
+ gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
32
+ generator = pipeline("text2text-generation", model=gen_model, tokenizer=gen_tokenizer, device=-1)
33
+
34
+ # -------------------
35
+ # SAMPLE NOTES
36
+ # -------------------
37
+ notes_dict = {
38
+ "Anatomy": """
39
+ The heart pumps blood through pulmonary and systemic circuits.
40
+ Oxygenated blood leaves the left ventricle and travels through the aorta.
41
+ """,
42
+ "Physiology": """
43
+ The respiratory system involves the exchange of oxygen and carbon dioxide
44
+ in the alveoli of the lungs.
45
+ """,
46
+ "Pharmacology": """
47
+ Paracetamol is used to relieve pain and reduce fever.
48
+ It acts centrally on the hypothalamic heat-regulating center.
49
+ """,
50
+ }
51
+
52
+ # -------------------
53
+ # BUILD EMBEDDING INDEX
54
+ # -------------------
55
+ def embed_texts(texts):
56
+ inputs = embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
57
+ with torch.no_grad():
58
+ outputs = embed_model(**inputs)
59
+ embeddings = outputs.last_hidden_state.mean(dim=1).numpy().astype("float32")
60
+ return embeddings
61
+
62
+ doc_map = []
63
+ chunks = []
64
+ for subject, content in notes_dict.items():
65
+ parts = [content[i:i + 400] for i in range(0, len(content), 400)]
66
+ for p in parts:
67
+ doc_map.append((subject, p))
68
+ chunks.append(p)
69
+
70
+ embeddings = embed_texts(chunks)
71
+ index = faiss.IndexFlatL2(embeddings.shape[1])
72
+ index.add(embeddings)
73
+
74
+ # -------------------
75
+ # SEMANTIC SEARCH
76
+ # -------------------
77
+ def semantic_search(query, max_results=SEARCH_RESULTS):
78
+ query_emb = embed_texts([query])
79
+ distances, ids = index.search(query_emb, max_results)
80
+ results = []
81
+ for idx in ids[0]:
82
+ subject, chunk = doc_map[idx]
83
+ results.append((subject, chunk))
84
+ return results
85
+
86
+ # -------------------
87
+ # RESPONSE FUNCTION
88
+ # -------------------
89
+ def respond(message, history):
90
+ if not message.strip():
91
+ return "Please enter a question."
92
+
93
+ search_results = semantic_search(message)
94
+ context_parts = [f"From {s}:\n{t}" for s, t in search_results]
95
+ context_text = "\n\n".join(context_parts) if context_parts else "No relevant notes found."
96
+
97
+ prompt = f"{SYSTEM_MESSAGE}\n\nNotes:\n{context_text}\n\nQuestion: {message}\nAnswer:"
98
+
99
+ response = generator(prompt, max_new_tokens=MAX_TOKENS)[0]["generated_text"]
100
+ return response
101
+
102
+ # -------------------
103
+ # GRADIO APP
104
+ # -------------------
105
+ with gr.Blocks(title="PubMed Medical Tutor") as demo:
106
+ gr.Markdown("# 🧬 PubMed Medical Tutor (CPU Friendly)")
107
+ chatbot = gr.Chatbot(label="Tutor Chat")
108
+ msg = gr.Textbox(label="Ask a medical question...")
109
+ clear = gr.Button("Clear Chat")
110
+
111
+ def user_input(message, history):
112
+ bot_reply = respond(message, history)
113
+ history.append((message, bot_reply))
114
+ return "", history
115
+
116
+ msg.submit(user_input, [msg, chatbot], [msg, chatbot])
117
+ clear.click(lambda: None, None, chatbot, queue=False)
118
+
119
+ demo.launch()