Kaanta / test_optimization.py
Eniiyanu's picture
Upload 14 files
c54e3bd verified
#!/usr/bin/env python3
"""
Test script to validate FREE optimization improvements.
Measures before/after quality on sample tax queries.
"""
import sys
import json
from pathlib import Path
from rag_pipeline import RAGPipeline, DocumentStore
# Test questions covering different tax scenarios
TEST_QUESTIONS = [
{
"question": "What are the personal income tax rates in Nigeria?",
"expected_keywords": ["₦800,000", "15%", "18%", "21%", "23%", "25%"],
"category": "rates"
},
{
"question": "What is CRA and how is it calculated?",
"expected_keywords": ["Consolidated Relief Allowance", "₦200,000", "20%", "1%"],
"category": "relief"
},
{
"question": "What are the company income tax rates?",
"expected_keywords": ["30%", "20%", "CIT", "company"],
"category": "corporate"
},
{
"question": "Tell me about PAYE deductions",
"expected_keywords": ["Pay As You Earn", "employer", "monthly", "withholding"],
"category": "paye"
},
{
"question": "What tax reliefs are available for individuals?",
"expected_keywords": ["relief", "allowance", "deduction", "pension"],
"category": "reliefs"
},
]
def test_retrieval_quality(rag: RAGPipeline):
"""Test if retrieval finds expected keywords."""
print("\n" + "=" * 80)
print("RETRIEVAL QUALITY TEST")
print("=" * 80)
results = []
for item in TEST_QUESTIONS:
question = item["question"]
expected = item["expected_keywords"]
# Retrieve docs
docs = rag._retrieve(question)
retrieved_text = " ".join([d.page_content for d in docs[:10]]).lower()
# Check if expected keywords found
found = [kw for kw in expected if kw.lower() in retrieved_text]
precision = len(found) / len(expected) if expected else 0
results.append({
"question": question,
"precision": precision,
"found": len(found),
"total": len(expected),
"found_keywords": found
})
print(f"\n{item['category'].upper()}: {question}")
print(f" Found: {len(found)}/{len(expected)} keywords ({precision*100:.0f}%)")
if len(found) < len(expected):
missing = set(expected) - set([k for k in expected if k.lower() in retrieved_text])
print(f" Missing: {', '.join(missing)}")
avg_precision = sum(r["precision"] for r in results) / len(results)
print(f"\n{'='*80}")
print(f"AVERAGE RETRIEVAL PRECISION: {avg_precision*100:.1f}%")
print(f"{'='*80}\n")
return avg_precision
def test_answer_quality(rag: RAGPipeline):
"""Test if answers have good formatting and content."""
print("\n" + "=" * 80)
print("ANSWER QUALITY TEST")
print("=" * 80)
for idx, item in enumerate(TEST_QUESTIONS[:3], 1): # Test first 3 for speed
question = item["question"]
print(f"\n[{idx}] QUESTION: {question}")
print("-" * 80)
try:
answer = rag.query(question, verbose=False)
# Quality checks
has_bottom_line = "**Bottom line**" in answer
has_numbers = any(char.isdigit() for char in answer)
has_bold_numbers = "**₦" in answer or "**%" in answer
no_fact_ids = "[F1]" not in answer and "[F2]" not in answer
has_structure = "**Here's what you need to know**" in answer
print(f"ANSWER:\n{answer}\n")
print("QUALITY CHECKS:")
print(f" ✓ Has bottom line: {has_bottom_line}")
print(f" ✓ Contains numbers: {has_numbers}")
print(f" ✓ Numbers emphasized (bold): {has_bold_numbers}")
print(f" ✓ No fact IDs ([F1], etc.): {no_fact_ids}")
print(f" ✓ Structured format: {has_structure}")
if not all([has_bottom_line, has_numbers, no_fact_ids, has_structure]):
print(" ⚠️ WARNING: Some quality checks failed!")
except Exception as e:
print(f" ❌ ERROR: {e}")
print(f"\n{'='*80}\n")
def test_hallucination_prevention(rag: RAGPipeline):
"""Test if system avoids hallucinating specific examples."""
print("\n" + "=" * 80)
print("HALLUCINATION PREVENTION TEST")
print("=" * 80)
# Questions designed to tempt hallucination
trick_questions = [
{
"question": "How much tax will I pay if I earn ₦500,000 per month?",
"should_calculate": True, # Should use tax calculator
"forbidden_phrases": [] # Calculator is allowed to show examples
},
{
"question": "What happens if I don't pay my taxes?",
"should_calculate": False,
"forbidden_phrases": ["for example, you could be fined ₦", "typically around ₦"]
},
]
hallucinations = 0
total = 0
for item in trick_questions:
question = item["question"]
print(f"\nQUESTION: {question}")
try:
answer = rag.query(question, verbose=False)
# Check for forbidden phrases
found_forbidden = []
for phrase in item["forbidden_phrases"]:
if phrase.lower() in answer.lower():
found_forbidden.append(phrase)
hallucinations += 1
if found_forbidden:
print(f" ❌ HALLUCINATION DETECTED: {found_forbidden}")
print(f" Answer excerpt: {answer[:200]}...")
else:
print(f" ✓ No hallucinations detected")
total += 1
except Exception as e:
print(f" ⚠️ ERROR: {e}")
if total > 0:
hallucination_rate = (hallucinations / total) * 100
print(f"\n{'='*80}")
print(f"HALLUCINATION RATE: {hallucination_rate:.1f}%")
if hallucination_rate == 0:
print("✓ EXCELLENT: No hallucinations detected!")
elif hallucination_rate < 10:
print("✓ GOOD: Low hallucination rate")
else:
print("⚠️ WARNING: High hallucination rate, review prompts")
print(f"{'='*80}\n")
def main():
print("=" * 80)
print("FREE OPTIMIZATION VALIDATION TEST")
print("Testing: Improved embeddings, prompts, formatting, and retrieval")
print("=" * 80)
# Initialize RAG pipeline
print("\nInitializing RAG pipeline...")
vector_store_path = Path("vector_store")
doc_store = DocumentStore(
persist_dir=vector_store_path,
embedding_model="BAAI/bge-large-en-v1.5" # New embedding model
)
src = Path("data")
pdfs = doc_store.discover_pdfs(src)
doc_store.build_vector_store(pdfs, force_rebuild=False)
rag = RAGPipeline(
doc_store=doc_store,
model="llama-3.3-70b-versatile",
temperature=0.1,
top_k=15, # Increased from 8
use_hybrid=True,
use_mmr=True,
use_reranker=True
)
print("✓ RAG pipeline initialized\n")
# Run tests
try:
retrieval_precision = test_retrieval_quality(rag)
test_answer_quality(rag)
test_hallucination_prevention(rag)
# Summary
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
print(f"Retrieval Precision: {retrieval_precision*100:.1f}%")
print(f" Target: >55% (baseline was ~42%)")
if retrieval_precision > 0.55:
print(f" ✓ EXCELLENT: Retrieval improved!")
elif retrieval_precision > 0.45:
print(f" ✓ GOOD: Retrieval improved")
else:
print(f" ⚠️ Need improvement")
print("\nOPTIMIZATIONS APPLIED:")
print(" ✓ Upgraded embedding: all-MiniLM-L6-v2 → bge-large-en-v1.5")
print(" ✓ Upgraded reranker: MiniLM-L-6 → MiniLM-L-12")
print(" ✓ Anti-hallucination system prompts")
print(" ✓ Enhanced fact schema with number extraction")
print(" ✓ Removed fact IDs from output")
print(" ✓ Bold emphasis on numbers and percentages")
print(" ✓ Tax-aware query expansion")
print(" ✓ Increased retrieval: 8 → 15 docs")
print(" ✓ Context added to thresholds (₦800K → ₦800K (₦66,667/month))")
print("\n" + "=" * 80)
print("TEST COMPLETE")
print("=" * 80)
except Exception as e:
print(f"\n❌ TEST FAILED: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()