imurra commited on
Commit
4dcbdd1
Β·
verified Β·
1 Parent(s): 1d7f6cb

Implement Solution A - Modify your HuggingFace app.py to add the /batch_search_medqa endpoint. This gives you:

βœ… True batch processing (1 API call for 60 objectives)
βœ… Automatic deduplication on the server
βœ… Pre-organized results
βœ… No user interaction needed
βœ… Scales to hundreds of objectives

Files changed (1) hide show
  1. app.py +150 -2
app.py CHANGED
@@ -7,7 +7,9 @@ from sentence_transformers import SentenceTransformer
7
  import gradio as gr
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
 
10
  import re
 
11
 
12
  # Extract and load database
13
  DB_PATH = "./medqa_db"
@@ -17,7 +19,7 @@ if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
17
  z.extractall(".")
18
  print("βœ… Database extracted")
19
 
20
- print("πŸ”Œ Loading ChromaDB...")
21
  client = chromadb.PersistentClient(path=DB_PATH)
22
  collection = client.get_collection("medqa")
23
  print(f"βœ… Loaded {collection.count()} questions")
@@ -191,7 +193,7 @@ def ui_search(query, num_results=3, source_filter="all"):
191
  # Create Gradio interface
192
  with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo:
193
  gr.Markdown("""
194
- # πŸ₯ MedQA Semantic Search
195
 
196
  Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases.
197
  Uses medical-specific embeddings (MedCPT) for accurate retrieval.
@@ -271,6 +273,11 @@ class SearchRequest(BaseModel):
271
  num_results: int = 3
272
  source_filter: str = None
273
 
 
 
 
 
 
274
  @app.post("/search_medqa")
275
  def api_search(req: SearchRequest):
276
  """
@@ -311,6 +318,147 @@ def api_search(req: SearchRequest):
311
 
312
  return {"results": results}
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  app = gr.mount_gradio_app(app, demo, path="/")
315
 
316
  if __name__ == "__main__":
 
7
  import gradio as gr
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
10
+ from typing import List, Optional
11
  import re
12
+ import time
13
 
14
  # Extract and load database
15
  DB_PATH = "./medqa_db"
 
19
  z.extractall(".")
20
  print("βœ… Database extracted")
21
 
22
+ print("πŸ“Œ Loading ChromaDB...")
23
  client = chromadb.PersistentClient(path=DB_PATH)
24
  collection = client.get_collection("medqa")
25
  print(f"βœ… Loaded {collection.count()} questions")
 
193
  # Create Gradio interface
194
  with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo:
195
  gr.Markdown("""
196
+ # πŸ₯Ό MedQA Semantic Search
197
 
198
  Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases.
199
  Uses medical-specific embeddings (MedCPT) for accurate retrieval.
 
273
  num_results: int = 3
274
  source_filter: str = None
275
 
276
+ class BatchSearchRequest(BaseModel):
277
+ queries: List[str]
278
+ num_results_per_query: int = 10
279
+ source_filter: Optional[str] = None
280
+
281
  @app.post("/search_medqa")
282
  def api_search(req: SearchRequest):
283
  """
 
318
 
319
  return {"results": results}
320
 
321
+ @app.post("/batch_search_medqa")
322
+ def batch_api_search(req: BatchSearchRequest):
323
+ """
324
+ NEW: Batch search for multiple learning objectives.
325
+ Processes all queries, tracks duplicates, and returns organized results.
326
+
327
+ Returns:
328
+ - results_by_objective: List of results organized by each objective
329
+ - unique_questions: Deduplicated list of all questions
330
+ - statistics: Coverage and quality metrics
331
+ """
332
+
333
+ start_time = time.time()
334
+
335
+ # Track all questions and their objective mappings
336
+ all_questions = {} # key: question_text, value: question data + objectives list
337
+ results_by_objective = []
338
+
339
+ for obj_idx, query in enumerate(req.queries):
340
+ objective_id = obj_idx + 1
341
+
342
+ # Search for this objective
343
+ r = search(query, req.num_results_per_query, req.source_filter)
344
+
345
+ objective_results = []
346
+ similarities = []
347
+
348
+ if r['documents'][0]:
349
+ for i in range(len(r['documents'][0])):
350
+ doc_text = r['documents'][0][i]
351
+ metadata = r['metadatas'][0][i]
352
+ similarity = round(1 - r['distances'][0][i], 3)
353
+ similarities.append(similarity)
354
+
355
+ # Parse the document
356
+ parsed = parse_question_document(doc_text, metadata)
357
+
358
+ # Create unique key for deduplication
359
+ question_key = parsed['question'][:200] # Use first 200 chars as key
360
+
361
+ # Build result object
362
+ result = {
363
+ "question": parsed['question'],
364
+ "choices": parsed['choices'],
365
+ "correct_answer": parsed['correct_answer_letter'],
366
+ "correct_answer_text": parsed['correct_answer_text'],
367
+ "explanation": metadata.get('explanation', ''),
368
+ "has_explanation": bool(metadata.get('explanation', '').strip()),
369
+ "source": metadata.get('source', 'unknown'),
370
+ "similarity": similarity
371
+ }
372
+
373
+ # Track for global deduplication
374
+ if question_key in all_questions:
375
+ # This question already exists - add this objective to its list
376
+ all_questions[question_key]['matches_objectives'].append(objective_id)
377
+ # Update similarity if higher
378
+ if similarity > all_questions[question_key]['max_similarity']:
379
+ all_questions[question_key]['max_similarity'] = similarity
380
+ else:
381
+ # First time seeing this question
382
+ all_questions[question_key] = {
383
+ **result,
384
+ 'matches_objectives': [objective_id],
385
+ 'max_similarity': similarity,
386
+ 'first_seen_at': objective_id
387
+ }
388
+
389
+ objective_results.append(result)
390
+
391
+ # Store results for this objective
392
+ results_by_objective.append({
393
+ "objective_id": objective_id,
394
+ "objective_text": query,
395
+ "num_results": len(objective_results),
396
+ "avg_similarity": round(sum(similarities) / len(similarities), 3) if similarities else 0,
397
+ "results": objective_results
398
+ })
399
+
400
+ # Prepare unique questions list
401
+ unique_questions = list(all_questions.values())
402
+
403
+ # Calculate statistics
404
+ execution_time = round(time.time() - start_time, 2)
405
+ total_retrieved = sum(obj['num_results'] for obj in results_by_objective)
406
+
407
+ # Coverage analysis
408
+ coverage = {
409
+ "excellent": [obj for obj in results_by_objective if obj['num_results'] >= 5],
410
+ "moderate": [obj for obj in results_by_objective if 2 <= obj['num_results'] < 5],
411
+ "limited": [obj for obj in results_by_objective if obj['num_results'] == 1],
412
+ "none": [obj for obj in results_by_objective if obj['num_results'] == 0]
413
+ }
414
+
415
+ # Multi-objective questions
416
+ multi_objective_questions = [q for q in unique_questions if len(q['matches_objectives']) > 1]
417
+
418
+ # Source distribution
419
+ sources = {}
420
+ for q in unique_questions:
421
+ source = q['source']
422
+ sources[source] = sources.get(source, 0) + 1
423
+
424
+ # Similarity distribution
425
+ all_similarities = [q['max_similarity'] for q in unique_questions]
426
+ high_sim = len([s for s in all_similarities if s > 0.8])
427
+ med_sim = len([s for s in all_similarities if 0.7 <= s <= 0.8])
428
+ low_sim = len([s for s in all_similarities if s < 0.7])
429
+
430
+ statistics = {
431
+ "total_objectives": len(req.queries),
432
+ "total_retrieved": total_retrieved,
433
+ "unique_questions": len(unique_questions),
434
+ "deduplication_rate": round((total_retrieved - len(unique_questions)) / total_retrieved * 100, 1) if total_retrieved > 0 else 0,
435
+ "execution_time_seconds": execution_time,
436
+ "coverage": {
437
+ "excellent_coverage_count": len(coverage["excellent"]),
438
+ "moderate_coverage_count": len(coverage["moderate"]),
439
+ "limited_coverage_count": len(coverage["limited"]),
440
+ "no_coverage_count": len(coverage["none"]),
441
+ "no_coverage_objectives": [obj['objective_id'] for obj in coverage["none"]]
442
+ },
443
+ "cross_objective": {
444
+ "multi_objective_questions": len(multi_objective_questions),
445
+ "multi_objective_percentage": round(len(multi_objective_questions) / len(unique_questions) * 100, 1) if unique_questions else 0
446
+ },
447
+ "sources": sources,
448
+ "similarity_distribution": {
449
+ "high_similarity_count": high_sim,
450
+ "medium_similarity_count": med_sim,
451
+ "low_similarity_count": low_sim,
452
+ "average_similarity": round(sum(all_similarities) / len(all_similarities), 3) if all_similarities else 0
453
+ }
454
+ }
455
+
456
+ return {
457
+ "results_by_objective": results_by_objective,
458
+ "unique_questions": unique_questions,
459
+ "statistics": statistics
460
+ }
461
+
462
  app = gr.mount_gradio_app(app, demo, path="/")
463
 
464
  if __name__ == "__main__":