imurra commited on
Commit
f33d22c
Β·
verified Β·
1 Parent(s): 1d5bf01

updated to give entire exemplar question and answer choices

Browse files

. Generate All Explanations
For each exemplar, we need to AI-generate:

βœ… Explanation for WHY each choice (A-E) is correct/wrong
βœ… New similar question
βœ… Explanations for all choices in the new question

This means 2 AI calls per exemplar:

Call 1: "Explain why each answer choice is right/wrong"
Call 2: "Generate a similar question with explanations"

For 3 exemplars = 6 AI calls = ~30 seconds total

My Plan:

Fix API to return FULL untruncated exemplar text
Add function to call AI (GPT-4/Claude) to generate:

Choice explanations for original
New question + choice explanations


Format everything as plain text (not JSON)
Return the complete formatted output

Files changed (1) hide show
  1. app.py +238 -135
app.py CHANGED
@@ -7,6 +7,9 @@ from sentence_transformers import SentenceTransformer
7
  import gradio as gr
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
 
 
 
10
 
11
  # Extract and load database
12
  DB_PATH = "./medqa_db"
@@ -25,15 +28,17 @@ print("🧠 Loading MedCPT model...")
25
  model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
26
  print("βœ… Model ready")
27
 
 
 
 
 
 
 
 
28
  # ============================================================================
29
- # NEW: Deduplication function
30
  # ============================================================================
31
  def deduplicate_results(results, target_count):
32
- """
33
- Remove duplicate questions based on:
34
- 1. High text similarity (>0.92) - catches near-exact duplicates
35
- 2. Same answer + moderate similarity (>0.85) - catches conceptual duplicates
36
- """
37
  if not results['documents'][0]:
38
  return results
39
 
@@ -47,32 +52,24 @@ def deduplicate_results(results, target_count):
47
  is_duplicate = False
48
  current_answer = metadatas[i].get('answer', '')
49
 
50
- # Compare to already-selected results
51
  for j in selected_indices:
52
  selected_answer = metadatas[j].get('answer', '')
53
-
54
- # Calculate similarity between questions
55
- # Lower distance = higher similarity
56
  dist_diff = abs(distances[i] - distances[j])
57
 
58
- # Rule 1: Very similar questions (likely exact/near-exact duplicates)
59
- if dist_diff < 0.08: # Roughly equivalent to >0.92 similarity
60
  is_duplicate = True
61
  break
62
 
63
- # Rule 2: Same answer + similar question (conceptual duplicates)
64
- if current_answer == selected_answer and dist_diff < 0.15: # ~0.85 similarity
65
  is_duplicate = True
66
  break
67
 
68
  if not is_duplicate:
69
  selected_indices.append(i)
70
 
71
- # Stop when we have enough unique results
72
  if len(selected_indices) >= target_count:
73
  break
74
 
75
- # Return filtered results in same format
76
  return {
77
  'documents': [[documents[i] for i in selected_indices]],
78
  'metadatas': [[metadatas[i] for i in selected_indices]],
@@ -81,18 +78,16 @@ def deduplicate_results(results, target_count):
81
  }
82
 
83
  # ============================================================================
84
- # MODIFIED: Search function with deduplication
85
  # ============================================================================
86
  def search(query, num_results=3, source_filter=None):
87
  emb = model.encode(query).tolist()
88
 
89
- # Apply source filter if specified
90
  where_clause = None
91
  if source_filter and source_filter != "all":
92
  where_clause = {"source": source_filter}
93
 
94
- # Over-fetch to ensure we get enough unique results
95
- fetch_count = min(num_results * 4, 50) # Fetch 4x but cap at 50
96
 
97
  results = collection.query(
98
  query_embeddings=[emb],
@@ -100,155 +95,263 @@ def search(query, num_results=3, source_filter=None):
100
  where=where_clause
101
  )
102
 
103
- # Deduplicate and return only requested number
104
  return deduplicate_results(results, num_results)
105
 
106
- # Enhanced Gradio UI
107
- def ui_search(query, num_results=3, source_filter="all"):
108
- if not query.strip():
109
- return "πŸ’‘ Enter a medical query to search"
 
110
 
111
- try:
112
- r = search(query, num_results, source_filter if source_filter != "all" else None)
113
-
114
- if not r['documents'][0]:
115
- return "❌ No results found"
116
-
117
- out = f"πŸ” Found {len(r['documents'][0])} unique results\n\n"
118
-
119
- for i in range(len(r['documents'][0])):
120
- source = r['metadatas'][0][i].get('source', 'unknown')
121
- distance = r['distances'][0][i]
122
- similarity = 1 - distance
123
-
124
- # Source emoji
125
- if source == 'medgemini':
126
- source_icon = "πŸ”¬"
127
- source_name = "Med-Gemini"
128
- elif source.startswith('medqa_'):
129
- source_icon = "πŸ“š"
130
- split = source.replace('medqa_', '').upper()
131
- source_name = f"MedQA {split}"
132
- else:
133
- source_icon = "πŸ“„"
134
- source_name = source.upper()
135
-
136
- out += f"\n{'='*70}\n"
137
- out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n"
138
- out += f"{'='*70}\n\n"
139
- out += r['documents'][0][i]
140
-
141
- # Show answer
142
- answer = r['metadatas'][0][i].get('answer', 'N/A')
143
- out += f"\n\nβœ… CORRECT ANSWER: {answer}\n"
144
-
145
- # Show explanation if available (Med-Gemini)
146
- explanation = r['metadatas'][0][i].get('explanation', '')
147
- if explanation and explanation.strip():
148
- out += f"\nπŸ’‘ EXPLANATION:\n{explanation}\n"
149
 
150
- out += "\n"
151
 
152
- return out
 
 
 
 
 
 
 
 
 
 
153
 
154
- except Exception as e:
155
- return f"❌ Error: {e}"
 
 
 
156
 
157
- # Create Gradio interface
158
- with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo:
159
- gr.Markdown("""
160
- # πŸ₯ MedQA Semantic Search
 
161
 
162
- Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases.
163
- Uses medical-specific embeddings (MedCPT) for accurate retrieval.
164
 
165
- ✨ **New**: Automatic deduplication removes similar/duplicate questions
166
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- with gr.Row():
169
- with gr.Column(scale=3):
170
- query_input = gr.Textbox(
171
- label="Medical Query",
172
- placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...",
173
- lines=2
174
- )
175
- with gr.Column(scale=1):
176
- num_results = gr.Slider(
177
- minimum=1,
178
- maximum=10,
179
- value=3,
180
- step=1,
181
- label="Number of Results"
182
- )
183
 
184
- with gr.Row():
185
- source_filter = gr.Radio(
186
- choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"],
187
- value="all",
188
- label="Filter by Source"
189
- )
 
 
 
 
190
 
191
- search_btn = gr.Button("πŸ” Search", variant="primary", size="lg")
192
 
193
- output = gr.Textbox(
194
- label="Search Results",
195
- lines=25,
196
- max_lines=50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  )
198
 
199
- search_btn.click(
200
- fn=ui_search,
201
- inputs=[query_input, num_results, source_filter],
202
- outputs=output
203
- )
204
 
205
- query_input.submit(
206
- fn=ui_search,
207
- inputs=[query_input, num_results, source_filter],
208
- outputs=output
209
- )
 
 
 
 
 
 
 
 
210
 
211
- gr.Markdown("""
212
- ### πŸ“Š Database Info
213
 
214
- **Med-Gemini**: Expert-relabeled questions with detailed explanations
215
- **MedQA**: USMLE-style questions (Train/Dev/Test splits)
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- **Total Questions**: Use the database you built with `build_combined_db.py`
218
- """)
219
 
220
- gr.Examples(
221
- examples=[
222
- ["hyponatremia", 3, "all"],
223
- ["myocardial infarction treatment", 2, "medgemini"],
224
- ["diabetes complications", 3, "all"],
225
- ["antibiotics for pneumonia", 2, "medqa_train"]
226
- ],
227
- inputs=[query_input, num_results, source_filter]
228
- )
229
 
230
- # FastAPI
 
 
 
 
 
 
231
  app = FastAPI()
232
 
233
  class SearchRequest(BaseModel):
234
  query: str
235
  num_results: int = 3
236
  source_filter: str = None
 
237
 
238
  @app.post("/search_medqa")
239
  def api_search(req: SearchRequest):
 
 
 
240
  r = search(req.query, req.num_results, req.source_filter)
241
- return {"results": [{
242
- "result_number": i+1,
243
- "question": r['documents'][0][i],
244
- "answer": r['metadatas'][0][i].get('answer', 'N/A'),
245
- "source": r['metadatas'][0][i].get('source', 'unknown'),
246
- "similarity": 1 - r['distances'][0][i]
247
- } for i in range(len(r['documents'][0]))]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
 
 
 
 
 
 
249
  app = gr.mount_gradio_app(app, demo, path="/")
250
 
251
- # Launch
252
  if __name__ == "__main__":
253
  import uvicorn
254
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
7
  import gradio as gr
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
10
+ import re
11
+ import anthropic # You'll need: pip install anthropic
12
+ # OR if using OpenAI: import openai
13
 
14
  # Extract and load database
15
  DB_PATH = "./medqa_db"
 
28
  model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
29
  print("βœ… Model ready")
30
 
31
+ # Initialize AI client (choose one)
32
+ # Option 1: Claude
33
+ claude_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
34
+
35
+ # Option 2: OpenAI (uncomment if using)
36
+ # openai.api_key = os.environ.get("OPENAI_API_KEY")
37
+
38
  # ============================================================================
39
+ # Deduplication function (same as before)
40
  # ============================================================================
41
  def deduplicate_results(results, target_count):
 
 
 
 
 
42
  if not results['documents'][0]:
43
  return results
44
 
 
52
  is_duplicate = False
53
  current_answer = metadatas[i].get('answer', '')
54
 
 
55
  for j in selected_indices:
56
  selected_answer = metadatas[j].get('answer', '')
 
 
 
57
  dist_diff = abs(distances[i] - distances[j])
58
 
59
+ if dist_diff < 0.08:
 
60
  is_duplicate = True
61
  break
62
 
63
+ if current_answer == selected_answer and dist_diff < 0.15:
 
64
  is_duplicate = True
65
  break
66
 
67
  if not is_duplicate:
68
  selected_indices.append(i)
69
 
 
70
  if len(selected_indices) >= target_count:
71
  break
72
 
 
73
  return {
74
  'documents': [[documents[i] for i in selected_indices]],
75
  'metadatas': [[metadatas[i] for i in selected_indices]],
 
78
  }
79
 
80
  # ============================================================================
81
+ # Search function (same as before)
82
  # ============================================================================
83
  def search(query, num_results=3, source_filter=None):
84
  emb = model.encode(query).tolist()
85
 
 
86
  where_clause = None
87
  if source_filter and source_filter != "all":
88
  where_clause = {"source": source_filter}
89
 
90
+ fetch_count = min(num_results * 4, 50)
 
91
 
92
  results = collection.query(
93
  query_embeddings=[emb],
 
95
  where=where_clause
96
  )
97
 
 
98
  return deduplicate_results(results, num_results)
99
 
100
+ # ============================================================================
101
+ # NEW: Parser to extract question structure
102
+ # ============================================================================
103
+ def parse_question_document(doc_text, metadata):
104
+ """Extract question and choices from document text."""
105
 
106
+ lines = doc_text.split('\n')
107
+ question_lines = []
108
+ options_started = False
109
+ options = {}
110
+
111
+ for line in lines:
112
+ line = line.strip()
113
+ if not line:
114
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ option_match = re.match(r'^([A-E])[\.\)]\s*(.+)$', line)
117
 
118
+ if option_match:
119
+ options_started = True
120
+ letter = option_match.group(1)
121
+ text = option_match.group(2).strip()
122
+ options[letter] = text
123
+ elif not options_started:
124
+ question_lines.append(line)
125
+
126
+ question_text = ' '.join(question_lines).strip()
127
+
128
+ answer_idx = metadata.get('answer_idx', 'N/A')
129
 
130
+ return {
131
+ 'question': question_text,
132
+ 'choices': options,
133
+ 'correct_answer': answer_idx
134
+ }
135
 
136
+ # ============================================================================
137
+ # NEW: AI generation functions
138
+ # ============================================================================
139
+ def generate_choice_explanations(question, choices, correct_answer):
140
+ """Generate explanations for why each choice is correct/wrong."""
141
 
142
+ choices_text = '\n'.join([f"{k}. {v}" for k, v in choices.items()])
 
143
 
144
+ prompt = f"""You are a medical educator. For this USMLE-style question, explain why EACH answer choice is correct or incorrect.
145
+
146
+ QUESTION:
147
+ {question}
148
+
149
+ ANSWER CHOICES:
150
+ {choices_text}
151
+
152
+ CORRECT ANSWER: {correct_answer}
153
+
154
+ Provide a 1-2 sentence explanation for EACH choice (A through E) explaining why it is correct or incorrect. Format as:
155
+
156
+ A. [Choice text] - [Explanation]
157
+ B. [Choice text] - [Explanation]
158
+ C. [Choice text] - [Explanation]
159
+ D. [Choice text] - [Explanation]
160
+ E. [Choice text] - [Explanation]"""
161
+
162
+ # Using Claude
163
+ message = claude_client.messages.create(
164
+ model="claude-sonnet-4-20250514",
165
+ max_tokens=1000,
166
+ messages=[{"role": "user", "content": prompt}]
167
+ )
168
 
169
+ return message.content[0].text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ # OR using OpenAI (uncomment if using):
172
+ # response = openai.ChatCompletion.create(
173
+ # model="gpt-4",
174
+ # messages=[{"role": "user", "content": prompt}],
175
+ # max_tokens=1000
176
+ # )
177
+ # return response.choices[0].message.content
178
+
179
+ def generate_similar_question(original_question, choices, correct_answer):
180
+ """Generate a new question based on the exemplar."""
181
 
182
+ choices_text = '\n'.join([f"{k}. {v}" for k, v in choices.items()])
183
 
184
+ prompt = f"""You are a medical educator. Based on this USMLE-style question, create a NEW similar question that tests the SAME medical concept but with a different clinical scenario.
185
+
186
+ ORIGINAL QUESTION:
187
+ {question}
188
+
189
+ ANSWER CHOICES:
190
+ {choices_text}
191
+
192
+ CORRECT ANSWER: {correct_answer}
193
+
194
+ Create a NEW question that:
195
+ 1. Tests the same medical concept
196
+ 2. Uses a different patient scenario
197
+ 3. Has 5 answer choices (A-E)
198
+ 4. Includes explanations for why each choice is correct/incorrect
199
+
200
+ Format your response EXACTLY as:
201
+
202
+ NEW QUESTION:
203
+ [Your new question text]
204
+
205
+ ANSWER CHOICES:
206
+ A. [Choice A]
207
+ B. [Choice B]
208
+ C. [Choice C]
209
+ D. [Choice D]
210
+ E. [Choice E]
211
+
212
+ CORRECT ANSWER: [Letter]
213
+
214
+ EXPLANATIONS:
215
+ A. [Choice A text] - [Explanation]
216
+ B. [Choice B text] - [Explanation]
217
+ C. [Choice C text] - [Explanation]
218
+ D. [Choice D text] - [Explanation]
219
+ E. [Choice E text] - [Explanation]"""
220
+
221
+ # Using Claude
222
+ message = claude_client.messages.create(
223
+ model="claude-sonnet-4-20250514",
224
+ max_tokens=2000,
225
+ messages=[{"role": "user", "content": prompt}]
226
  )
227
 
228
+ return message.content[0].text
 
 
 
 
229
 
230
+ # OR using OpenAI:
231
+ # response = openai.ChatCompletion.create(
232
+ # model="gpt-4",
233
+ # messages=[{"role": "user", "content": prompt}],
234
+ # max_tokens=2000
235
+ # )
236
+ # return response.choices[0].message.content
237
+
238
+ # ============================================================================
239
+ # NEW: Format complete output
240
+ # ============================================================================
241
+ def format_complete_output(exemplar_num, parsed, original_explanation, choice_explanations, new_question_text):
242
+ """Format everything into readable plain text."""
243
 
244
+ choices_text = '\n'.join([f"{k}. {v}" for k, v in parsed['choices'].items()])
 
245
 
246
+ output = f"""{'='*80}
247
+ EXEMPLAR {exemplar_num}
248
+ {'='*80}
249
+
250
+ ORIGINAL QUESTION:
251
+ {parsed['question']}
252
+
253
+ ANSWER CHOICES:
254
+ {choices_text}
255
+
256
+ CORRECT ANSWER: {parsed['correct_answer']}
257
+
258
+ EXPLANATION FOR EACH CHOICE:
259
+ {choice_explanations}
260
+ """
261
 
262
+ if original_explanation:
263
+ output += f"\nORIGINAL EXPLANATION FROM DATABASE:\n{original_explanation}\n"
264
 
265
+ output += f"""
266
+ {'-'*80}
267
+ AI-GENERATED SIMILAR QUESTION:
268
+ {'-'*80}
269
+
270
+ {new_question_text}
271
+
272
+ {'='*80}
 
273
 
274
+ """
275
+
276
+ return output
277
+
278
+ # ============================================================================
279
+ # MODIFIED: API endpoint with full generation
280
+ # ============================================================================
281
  app = FastAPI()
282
 
283
  class SearchRequest(BaseModel):
284
  query: str
285
  num_results: int = 3
286
  source_filter: str = None
287
+ generate_ai: bool = True # Option to skip AI generation for faster response
288
 
289
  @app.post("/search_medqa")
290
  def api_search(req: SearchRequest):
291
+ """Search and return complete formatted exemplars with AI-generated content."""
292
+
293
+ print(f"πŸ” Searching for: {req.query}")
294
  r = search(req.query, req.num_results, req.source_filter)
295
+
296
+ if not r['documents'][0]:
297
+ return {"output": "No results found."}
298
+
299
+ complete_output = f"SEARCH QUERY: {req.query}\n"
300
+ complete_output += f"FOUND {len(r['documents'][0])} EXEMPLARS\n\n"
301
+
302
+ for i in range(len(r['documents'][0])):
303
+ print(f"Processing exemplar {i+1}...")
304
+
305
+ doc_text = r['documents'][0][i]
306
+ metadata = r['metadatas'][0][i]
307
+
308
+ # Parse the exemplar
309
+ parsed = parse_question_document(doc_text, metadata)
310
+ original_explanation = metadata.get('explanation', '')
311
+
312
+ if req.generate_ai:
313
+ # Generate AI content
314
+ print(f" Generating choice explanations...")
315
+ choice_explanations = generate_choice_explanations(
316
+ parsed['question'],
317
+ parsed['choices'],
318
+ parsed['correct_answer']
319
+ )
320
+
321
+ print(f" Generating similar question...")
322
+ new_question = generate_similar_question(
323
+ parsed['question'],
324
+ parsed['choices'],
325
+ parsed['correct_answer']
326
+ )
327
+ else:
328
+ choice_explanations = "(AI generation skipped)"
329
+ new_question = "(AI generation skipped)"
330
+
331
+ # Format complete output
332
+ formatted = format_complete_output(
333
+ i + 1,
334
+ parsed,
335
+ original_explanation,
336
+ choice_explanations,
337
+ new_question
338
+ )
339
+
340
+ complete_output += formatted
341
+
342
+ return {
343
+ "output": complete_output,
344
+ "content_type": "text/plain"
345
+ }
346
 
347
+ # Gradio UI (simplified - just shows we have it)
348
+ with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo:
349
+ gr.Markdown("# πŸ₯ MedQA Search with AI Generation")
350
+ query_input = gr.Textbox(label="Query")
351
+ output = gr.Textbox(label="Results", lines=50)
352
+
353
  app = gr.mount_gradio_app(app, demo, path="/")
354
 
 
355
  if __name__ == "__main__":
356
  import uvicorn
357
  uvicorn.run(app, host="0.0.0.0", port=7860)