sniro23 commited on
Commit
f0cde84
Β·
1 Parent(s): 9b4f6f0

Feat: Implement performance and citation fixes

Browse files
Files changed (2) hide show
  1. app.py +23 -13
  2. src/enhanced_groq_medical_rag.py +12 -31
app.py CHANGED
@@ -74,15 +74,31 @@ def process_enhanced_medical_query(message: str, history: List[List[str]]) -> st
74
 
75
  def format_enhanced_medical_response(response: EnhancedMedicalResponse) -> str:
76
  """
77
- Format the enhanced medical response for display
78
  """
79
  formatted_parts = []
80
 
81
- # Main response
82
- formatted_parts.append(response.answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Enhanced information section
85
- formatted_parts.append("\n---\n")
86
  formatted_parts.append("### πŸ“Š **Enhanced Medical Analysis**")
87
 
88
  # Safety and verification info
@@ -96,17 +112,11 @@ def format_enhanced_medical_response(response: EnhancedMedicalResponse) -> str:
96
  formatted_parts.append(f"**🧠 Medical Entities Extracted**: {response.medical_entities_count}")
97
  formatted_parts.append(f"**🎯 Context Adherence**: {response.context_adherence_score:.1%}")
98
  formatted_parts.append(f"**πŸ“š Sources Used**: {len(response.sources)}")
99
- if hasattr(response, 'processing_time'):
100
- formatted_parts.append(f"**⚑ Processing Time**: {response.processing_time:.2f}s")
101
-
102
- # Sources
103
- if response.sources:
104
- formatted_parts.append("\n### πŸ“‹ **Clinical Sources**")
105
- for i, source in enumerate(response.sources[:5], 1): # Show top 5 sources
106
- formatted_parts.append(f"{i}. {source}")
107
 
108
  # Medical disclaimer
109
- formatted_parts.append("\n---")
110
  formatted_parts.append("*This information is for clinical reference based on Sri Lankan guidelines and does not replace professional medical judgment.*")
111
 
112
  return "\n".join(formatted_parts)
 
74
 
75
  def format_enhanced_medical_response(response: EnhancedMedicalResponse) -> str:
76
  """
77
+ Format the enhanced medical response for display, ensuring citations are always included.
78
  """
79
  formatted_parts = []
80
 
81
+ # Main response from the LLM
82
+ # The new prompt instructs the LLM to include markdown citations like [1], [2]
83
+ # The final response text is now the primary source of the answer.
84
+ final_response_text = response.answer
85
+ formatted_parts.append(final_response_text)
86
+
87
+ # Always add the clinical sources section if sources exist
88
+ if response.sources:
89
+ formatted_parts.append("\n\n---\n")
90
+ formatted_parts.append("### πŸ“‹ **Clinical Sources**")
91
+ # Create a numbered list of sources for clarity
92
+ for i, source in enumerate(response.sources, 1):
93
+ # Ensure we don't list more sources than were used for citations
94
+ if f"[{i}]" in final_response_text:
95
+ formatted_parts.append(f"{i}. {source}")
96
+ else:
97
+ # If the LLM didn't cite this source, we can choose to omit it or list it as an uncited reference
98
+ pass # For now, only show cited sources to keep the output clean.
99
 
100
  # Enhanced information section
101
+ formatted_parts.append("\n\n---\n")
102
  formatted_parts.append("### πŸ“Š **Enhanced Medical Analysis**")
103
 
104
  # Safety and verification info
 
112
  formatted_parts.append(f"**🧠 Medical Entities Extracted**: {response.medical_entities_count}")
113
  formatted_parts.append(f"**🎯 Context Adherence**: {response.context_adherence_score:.1%}")
114
  formatted_parts.append(f"**πŸ“š Sources Used**: {len(response.sources)}")
115
+ if hasattr(response, 'query_time'): # Changed from processing_time to match the object attribute
116
+ formatted_parts.append(f"**⚑ Processing Time**: {response.query_time:.2f}s")
 
 
 
 
 
 
117
 
118
  # Medical disclaimer
119
+ formatted_parts.append("\n---\n")
120
  formatted_parts.append("*This information is for clinical reference based on Sri Lankan guidelines and does not replace professional medical judgment.*")
121
 
122
  return "\n".join(formatted_parts)
src/enhanced_groq_medical_rag.py CHANGED
@@ -363,49 +363,30 @@ class EnhancedGroqMedicalRAG:
363
  query_analysis = self.analyze_medical_query(query)
364
  self._stop_timer("Query Analysis")
365
 
366
- # Step 2: Multi-stage comprehensive retrieval
367
- all_documents = []
368
- seen_content = set()
369
-
370
- # Stage 2a: Original query retrieval (increased from 15 to 25)
371
- stage1_docs = self.vector_store.search(query=query, k=25)
372
- for doc in stage1_docs:
373
- if doc.content not in seen_content:
374
- all_documents.append(doc)
375
- seen_content.add(doc.content)
376
-
377
- # Stage 2b: Expanded query retrieval
378
- for expanded_query in query_analysis['expanded_queries']:
379
- expanded_docs = self.vector_store.search(expanded_query, k=15)
380
- for doc in expanded_docs:
381
- if doc.content not in seen_content and len(all_documents) < 50:
382
- all_documents.append(doc)
383
- seen_content.add(doc.content)
384
-
385
- # Stage 2c: Entity-specific retrieval
386
- for entity in query_analysis['medical_entities']:
387
- entity_docs = self.vector_store.search(entity, k=10)
388
- for doc in entity_docs:
389
- if doc.content not in seen_content and len(all_documents) < 60:
390
- all_documents.append(doc)
391
- seen_content.add(doc.content)
392
-
393
  if not all_documents:
394
  return self._create_no_results_response(query, self._stop_timer("Total Query Time"))
395
 
396
  # Step 3: Advanced multi-criteria re-ranking
 
397
  reranked_docs = self._advanced_medical_reranking(query_analysis, all_documents)
 
398
 
399
- # Step 4: Select an initial set of documents, respecting the user's preference for more context.
400
- initial_doc_count = 10
401
- final_docs = reranked_docs[:initial_doc_count]
402
 
403
  # Step 5: Verify coverage and add missing context if needed, up to a hard limit to avoid API errors.
404
  MAX_FINAL_DOCS = 12
405
  coverage_score = self._verify_query_coverage(query_analysis, final_docs)
406
  if coverage_score < 0.7: # Less than 70% coverage
407
  self.logger.info(f"⚠️ Low coverage score ({coverage_score:.1%}). Retrieving additional context...")
408
- additional_docs = self._retrieve_missing_context(query_analysis, final_docs, seen_content)
409
  remaining_capacity = MAX_FINAL_DOCS - len(final_docs)
410
  if remaining_capacity > 0:
411
  final_docs.extend(additional_docs[:remaining_capacity])
 
363
  query_analysis = self.analyze_medical_query(query)
364
  self._stop_timer("Query Analysis")
365
 
366
+ # Step 2: Simplified single-stage retrieval
367
+ self._start_timer("Single Stage Retrieval")
368
+ NUM_CANDIDATE_DOCS = 40
369
+ all_documents = self.vector_store.search(query=query_analysis['original_query'], k=NUM_CANDIDATE_DOCS)
370
+ self._stop_timer("Single Stage Retrieval")
371
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  if not all_documents:
373
  return self._create_no_results_response(query, self._stop_timer("Total Query Time"))
374
 
375
  # Step 3: Advanced multi-criteria re-ranking
376
+ self._start_timer("Re-ranking")
377
  reranked_docs = self._advanced_medical_reranking(query_analysis, all_documents)
378
+ self._stop_timer("Re-ranking")
379
 
380
+ # Step 4: Select the final documents to be used for context
381
+ FINAL_DOC_COUNT = 10
382
+ final_docs = reranked_docs[:FINAL_DOC_COUNT]
383
 
384
  # Step 5: Verify coverage and add missing context if needed, up to a hard limit to avoid API errors.
385
  MAX_FINAL_DOCS = 12
386
  coverage_score = self._verify_query_coverage(query_analysis, final_docs)
387
  if coverage_score < 0.7: # Less than 70% coverage
388
  self.logger.info(f"⚠️ Low coverage score ({coverage_score:.1%}). Retrieving additional context...")
389
+ additional_docs = self._retrieve_missing_context(query_analysis, final_docs, set()) # Pass an empty set for seen_content
390
  remaining_capacity = MAX_FINAL_DOCS - len(final_docs)
391
  if remaining_capacity > 0:
392
  final_docs.extend(additional_docs[:remaining_capacity])