VersionRAG / evaluation.py
shahbazdev0's picture
Update evaluation.py
ee91b7e verified
# evaluation.py - Evaluation System (WITH SAFETY CAPS)
from typing import List, Dict, Tuple
import time
import numpy as np
from dataclasses import dataclass
import json
@dataclass
class Question:
"""Represents a single evaluation question"""
query: str
query_type: str # content_retrieval, version_inquiry, change_retrieval
expected_answer: str
expected_version: str
domain: str
topic: str
expected_keywords: List[str] = None
class VersionQADataset:
"""Dataset for evaluating version-aware QA"""
def __init__(self, questions: List[Question]):
self.questions = questions
@classmethod
def create_mini_versionqa(cls) -> 'VersionQADataset':
"""Create the Mini-VersionQA dataset as specified"""
questions = [
# Software - Node.js Assert
Question(
query="What is the assert module in Node.js v20.0?",
query_type="content_retrieval",
expected_answer="assert module provides testing functions",
expected_version="v20.0",
domain="Software",
topic="Node.js Assert",
expected_keywords=["assert", "testing", "module"]
),
Question(
query="List all versions of the assert module",
query_type="version_inquiry",
expected_answer="v20.0, v21.0, v23.0",
expected_version="all",
domain="Software",
topic="Node.js Assert",
expected_keywords=["v20.0", "v21.0", "v23.0"]
),
Question(
query="When was the strict mode added to assert?",
query_type="change_retrieval",
expected_answer="v21.0",
expected_version="v21.0",
domain="Software",
topic="Node.js Assert",
expected_keywords=["strict", "mode", "v21.0"]
),
# Software - Bootstrap
Question(
query="What are the grid classes in Bootstrap v5.2?",
query_type="content_retrieval",
expected_answer="col-*, row classes for responsive grid",
expected_version="v5.2",
domain="Software",
topic="Bootstrap",
expected_keywords=["grid", "col", "row"]
),
Question(
query="What changed in Bootstrap from v5.2 to v5.3?",
query_type="change_retrieval",
expected_answer="new utility classes and improvements",
expected_version="v5.3",
domain="Software",
topic="Bootstrap",
expected_keywords=["utility", "classes", "v5.3"]
),
# Software - Spark
Question(
query="How does DataFrame work in Spark v3.0?",
query_type="content_retrieval",
expected_answer="distributed collection of data organized into named columns",
expected_version="v3.0",
domain="Software",
topic="Spark",
expected_keywords=["dataframe", "distributed", "columns"]
),
Question(
query="What was removed in Spark v3.5?",
query_type="change_retrieval",
expected_answer="deprecated APIs and legacy features",
expected_version="v3.5",
domain="Software",
topic="Spark",
expected_keywords=["removed", "deprecated", "v3.5"]
),
# Healthcare
Question(
query="What are the treatment guidelines in v1.0?",
query_type="content_retrieval",
expected_answer="standard treatment protocols for patient care",
expected_version="v1.0",
domain="Healthcare",
topic="Clinical Guidelines",
expected_keywords=["treatment", "protocols", "guidelines"]
),
Question(
query="What changed in clinical guidelines from v1.0 to v2.0?",
query_type="change_retrieval",
expected_answer="updated treatment protocols and new recommendations",
expected_version="v2.0",
domain="Healthcare",
topic="Clinical Guidelines",
expected_keywords=["updated", "protocols", "v2.0"]
),
# Finance
Question(
query="What are the compliance requirements in FY2023?",
query_type="content_retrieval",
expected_answer="regulatory compliance requirements for financial reporting",
expected_version="FY2023",
domain="Finance",
topic="Compliance Reports",
expected_keywords=["compliance", "requirements", "regulatory"]
),
Question(
query="What regulations changed from FY2023 to FY2024?",
query_type="change_retrieval",
expected_answer="new regulatory requirements and updated compliance standards",
expected_version="FY2024",
domain="Finance",
topic="Compliance Reports",
expected_keywords=["regulations", "changed", "FY2024"]
),
# Industrial
Question(
query="What is the startup procedure in Rev. 1.0?",
query_type="content_retrieval",
expected_answer="machine startup steps and initialization procedures",
expected_version="Rev. 1.0",
domain="Industrial",
topic="Machine Operation",
expected_keywords=["startup", "procedure", "machine"]
),
Question(
query="What safety features were added in Rev. 2.0?",
query_type="change_retrieval",
expected_answer="enhanced safety features and emergency protocols",
expected_version="Rev. 2.0",
domain="Industrial",
topic="Machine Operation",
expected_keywords=["safety", "features", "Rev. 2.0"]
),
]
return cls(questions)
@classmethod
def from_dict(cls, data: List[Dict]) -> 'VersionQADataset':
"""Load dataset from dictionary"""
questions = []
for q in data:
questions.append(Question(
query=q['query'],
query_type=q['query_type'],
expected_answer=q['expected_answer'],
expected_version=q['expected_version'],
domain=q['domain'],
topic=q['topic'],
expected_keywords=q.get('expected_keywords', [])
))
return cls(questions)
def to_dict(self) -> List[Dict]:
"""Convert dataset to dictionary"""
return [
{
'query': q.query,
'query_type': q.query_type,
'expected_answer': q.expected_answer,
'expected_version': q.expected_version,
'domain': q.domain,
'topic': q.topic,
'expected_keywords': q.expected_keywords
}
for q in self.questions
]
class Evaluator:
"""Evaluates VersionRAG and Baseline systems"""
def __init__(self, version_rag, baseline_rag):
self.version_rag = version_rag
self.baseline_rag = baseline_rag
def evaluate(self, dataset: VersionQADataset) -> Dict:
"""Run full evaluation on dataset"""
versionrag_results = []
baseline_results = []
for question in dataset.questions:
# Evaluate VersionRAG
start_time = time.time()
try:
if question.query_type == "content_retrieval":
vrag_answer = self.version_rag.query(
query=question.query,
version_filter=question.expected_version if question.expected_version != "all" else None
)
elif question.query_type == "version_inquiry":
vrag_answer = self.version_rag.version_inquiry(question.query)
else: # change_retrieval
vrag_answer = self.version_rag.change_retrieval(question.query)
vrag_latency = time.time() - start_time
except Exception as e:
print(f"VersionRAG error on '{question.query}': {e}")
vrag_answer = {'answer': '', 'sources': []}
vrag_latency = 0
# Evaluate Baseline
start_time = time.time()
try:
baseline_answer = self.baseline_rag.query(question.query)
baseline_latency = time.time() - start_time
except Exception as e:
print(f"Baseline error on '{question.query}': {e}")
baseline_answer = {'answer': '', 'sources': []}
baseline_latency = 0
# Score answers
vrag_score = self._score_answer(
vrag_answer.get('answer', ''),
question.expected_answer,
vrag_answer.get('sources', []),
question.expected_version,
question.expected_keywords
)
baseline_score = self._score_answer(
baseline_answer.get('answer', ''),
question.expected_answer,
baseline_answer.get('sources', []),
question.expected_version,
question.expected_keywords
)
versionrag_results.append({
'question': question,
'score': vrag_score,
'latency': vrag_latency,
'answer': vrag_answer.get('answer', '')
})
baseline_results.append({
'question': question,
'score': baseline_score,
'latency': baseline_latency,
'answer': baseline_answer.get('answer', '')
})
# Compute metrics
versionrag_metrics = self._compute_metrics(versionrag_results)
baseline_metrics = self._compute_metrics(baseline_results)
return {
'versionrag': versionrag_metrics,
'baseline': baseline_metrics,
'questions': len(dataset.questions),
'improvement': {
'accuracy': versionrag_metrics['accuracy'] - baseline_metrics['accuracy'],
'vsa': versionrag_metrics['vsa'] - baseline_metrics['vsa'],
'hit_at_5': versionrag_metrics['hit_at_5'] - baseline_metrics['hit_at_5']
}
}
def _score_answer(self, answer: str, expected: str, sources: List[Dict],
expected_version: str, expected_keywords: List[str] = None) -> Dict:
"""Score an answer based on correctness and version awareness"""
if not answer:
return {
'content_score': 0.0,
'version_score': 0.0,
'keyword_score': 0.0,
'total_score': 0.0
}
# Keyword-based content scoring
expected_keywords_set = set(expected.lower().split())
if expected_keywords:
expected_keywords_set.update([k.lower() for k in expected_keywords])
answer_keywords = set(answer.lower().split())
# Compute overlap
overlap = len(expected_keywords_set & answer_keywords)
keyword_score = min(overlap / max(len(expected_keywords_set), 1), 1.0)
# Semantic similarity (simple word overlap as proxy)
answer_words = answer.lower().split()
expected_words = expected.lower().split()
common_words = set(answer_words) & set(expected_words)
if len(expected_words) > 0:
content_score = len(common_words) / len(expected_words)
else:
content_score = 0.0
# Boost score if answer is longer and contains key terms
if len(answer) > 20 and keyword_score > 0.3:
content_score = min(content_score * 1.2, 1.0)
# Check version awareness
version_score = self._compute_version_score(sources, expected_version)
# Combined score with SAFETY CAP ✅
total_score = min((content_score * 0.4 + version_score * 0.4 + keyword_score * 0.2), 1.0)
return {
'content_score': min(content_score, 1.0),
'version_score': min(version_score, 1.0),
'keyword_score': min(keyword_score, 1.0),
'total_score': total_score
}
def _compute_version_score(self, sources: List[Dict], expected_version: str) -> float:
"""Compute version-awareness score"""
if expected_version == "all":
# For version inquiry, check if multiple versions are present
versions_in_sources = set()
for source in sources:
if isinstance(source, dict):
version = source.get('version', 'N/A')
if version != 'N/A':
versions_in_sources.add(version)
# Score based on number of versions found (more is better)
return min(len(versions_in_sources) / 3.0, 1.0)
else:
# For specific version, check if expected version is in sources
for source in sources:
if isinstance(source, dict):
version = source.get('version', '')
if expected_version in str(version):
return 1.0
return 0.0
def _compute_metrics(self, results: List[Dict]) -> Dict:
"""Compute evaluation metrics with SAFETY CAPS ✅"""
if not results:
return {
'accuracy': 0.0,
'hit_at_5': 0.0,
'mrr': 0.0,
'vsa': 0.0,
'avg_latency': 0.0,
'by_type': {
'content_retrieval': 0.0,
'version_inquiry': 0.0,
'change_retrieval': 0.0
}
}
# Overall metrics
total_scores = [r['score']['total_score'] for r in results]
content_scores = [r['score']['content_score'] for r in results]
version_scores = [r['score']['version_score'] for r in results]
latencies = [r['latency'] for r in results]
# Hit@k (consider hit if score > 0.5)
hits = [1 if score > 0.5 else 0 for score in total_scores]
# MRR (Mean Reciprocal Rank)
# Assume rank 1 if score > 0.7, rank 2 if > 0.5, rank 3 if > 0.3, else rank 5
reciprocal_ranks = []
for score in total_scores:
if score > 0.7:
reciprocal_ranks.append(1.0)
elif score > 0.5:
reciprocal_ranks.append(1/2)
elif score > 0.3:
reciprocal_ranks.append(1/3)
else:
reciprocal_ranks.append(1/5)
# By query type
by_type = {
'content_retrieval': [],
'version_inquiry': [],
'change_retrieval': []
}
for result in results:
qtype = result['question'].query_type
by_type[qtype].append(result['score']['total_score'])
# Return metrics with SAFETY CAPS ✅
return {
'accuracy': min(np.mean(total_scores) * 100, 100.0),
'hit_at_5': min(np.mean(hits) * 100, 100.0),
'mrr': min(np.mean(reciprocal_ranks), 1.0),
'vsa': min(np.mean(version_scores) * 100, 100.0), # Version-Sensitive Accuracy
'avg_latency': np.mean(latencies) if latencies else 0,
'by_type': {
'content_retrieval': min(np.mean(by_type['content_retrieval']) * 100, 100.0) if by_type['content_retrieval'] else 0,
'version_inquiry': min(np.mean(by_type['version_inquiry']) * 100, 100.0) if by_type['version_inquiry'] else 0,
'change_retrieval': min(np.mean(by_type['change_retrieval']) * 100, 100.0) if by_type['change_retrieval'] else 0
}
}