|
from typing import Dict, List, Tuple, Any, Optional |
|
import numpy as np |
|
import random |
|
from logger_config import config_logger |
|
from cross_encoder_reranker import CrossEncoderReranker |
|
|
|
logger = config_logger(__name__) |
|
|
|
|
|
class ChatbotValidator: |
|
""" |
|
Handles automated validation and performance analysis for the chatbot. |
|
This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker. |
|
""" |
|
|
|
def __init__(self, chatbot, quality_checker, cross_encoder_model='cross-encoder/ms-marco-MiniLM-L-12-v2'): |
|
""" |
|
Initialize the validator. |
|
Args: |
|
chatbot: RetrievalChatbot for inference |
|
quality_checker: ResponseQualityChecker |
|
""" |
|
self.chatbot = chatbot |
|
self.quality_checker = quality_checker |
|
self.reranker = CrossEncoderReranker(model_name=cross_encoder_model) |
|
|
|
|
|
self.domain_queries = { |
|
'restaurant': [ |
|
"Hi, I have a question about your restaurant. Do they take reservations?", |
|
"I'd like to make a reservation for dinner tonight after 6pm. Is that time available?", |
|
"Can you recommend an Italian restaurant with wood-fired pizza?", |
|
], |
|
'movie': [ |
|
"How much are movie tickets for two people?", |
|
"I'm looking for showings after 6pm?", |
|
"Is this at the new theater with reclining seats?", |
|
], |
|
'ride_share': [ |
|
"I need a ride from the airport to downtown.", |
|
"What is the cost for Lyft? How about Uber XL?", |
|
"Can you book a car for tomorrow morning?", |
|
], |
|
'coffee': [ |
|
"Can I customize my coffee?", |
|
"Can I order a mocha from you?", |
|
"Can I get my usual venti vanilla latte?", |
|
], |
|
'pizza': [ |
|
"Do you have any pizza specials or deals available?", |
|
"How long is the wait until the pizza is ready and delivered to me?", |
|
"Please repeat my pizza order for two medium pizzas with thick crust.", |
|
], |
|
'auto': [ |
|
"The car is making a funny noise when I turn, and I'm due for an oil change.", |
|
"Is my buddy John available to work on my car?", |
|
"My Jeep needs a repair. Can you help me with that?", |
|
], |
|
} |
|
|
|
def run_validation( |
|
self, |
|
num_examples: int = 3, |
|
top_k: int = 10, |
|
domains: Optional[List[str]] = None, |
|
randomize: bool = False, |
|
seed: int = 42 |
|
) -> Dict[str, Any]: |
|
""" |
|
Run validation across testable domains. |
|
Args: |
|
num_examples: Number of test queries per domain |
|
top_k: Number of responses to retrieve for each query |
|
domains: Optional list of domain keys to test. If None, test all. |
|
randomize: If True, randomly select queries from the domain lists |
|
seed: Random seed for consistent sampling if randomize=True |
|
Returns: |
|
Dict with validation metrics |
|
""" |
|
logger.info("\n=== Running Automatic Validation ===") |
|
|
|
|
|
test_domains = domains if domains else list(self.domain_queries.keys()) |
|
|
|
|
|
metrics_history = [] |
|
domain_metrics = {} |
|
|
|
|
|
rng = random.Random(seed) |
|
|
|
|
|
for domain in test_domains: |
|
|
|
if domain not in self.domain_queries: |
|
logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.") |
|
continue |
|
|
|
all_queries = self.domain_queries[domain] |
|
if randomize: |
|
queries = rng.sample(all_queries, min(num_examples, len(all_queries))) |
|
else: |
|
queries = all_queries[:num_examples] |
|
|
|
|
|
domain_metrics[domain] = [] |
|
|
|
logger.info(f"\n=== Testing {domain.title()} Domain ===\n") |
|
|
|
for i, query in enumerate(queries, 1): |
|
logger.info(f"TEST CASE {i}: QUERY: {query}") |
|
|
|
|
|
responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=self.reranker) |
|
quality_metrics = self.quality_checker.check_response_quality(query, responses) |
|
|
|
|
|
quality_metrics['domain'] = domain |
|
metrics_history.append(quality_metrics) |
|
domain_metrics[domain].append(quality_metrics) |
|
self._log_validation_results(query, responses, quality_metrics) |
|
logger.info(f"Quality metrics: {quality_metrics}\n") |
|
|
|
|
|
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history) |
|
domain_analysis = self._analyze_domain_performance(domain_metrics) |
|
confidence_analysis = self._analyze_confidence_distribution(metrics_history) |
|
|
|
aggregate_metrics.update({ |
|
'domain_performance': domain_analysis, |
|
'confidence_analysis': confidence_analysis |
|
}) |
|
|
|
self._log_validation_summary(aggregate_metrics) |
|
return aggregate_metrics |
|
|
|
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]: |
|
""" |
|
Calculate aggregate metrics over tested queries. |
|
""" |
|
if not metrics_history: |
|
logger.warning("No metrics to aggregate. Returning empty summary.") |
|
return {} |
|
|
|
top_scores = [m.get('top_score', 0.0) for m in metrics_history] |
|
|
|
metrics = { |
|
'num_queries_tested': len(metrics_history), |
|
'avg_top_response_score': np.mean(top_scores), |
|
'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_history]), |
|
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]), |
|
'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]), |
|
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]), |
|
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_history]), |
|
'median_top_score': np.median(top_scores), |
|
'score_std': np.std(top_scores), |
|
'min_score': np.min(top_scores), |
|
'max_score': np.max(top_scores) |
|
} |
|
return metrics |
|
|
|
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]: |
|
""" |
|
Analyze performance by domain, returning a nested dict. |
|
""" |
|
analysis = {} |
|
|
|
for domain, metrics_list in domain_metrics.items(): |
|
if not metrics_list: |
|
analysis[domain] = {} |
|
continue |
|
|
|
top_scores = [m.get('top_score', 0.0) for m in metrics_list] |
|
|
|
analysis[domain] = { |
|
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_list]), |
|
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_list]), |
|
'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_list]), |
|
'avg_top_score': np.mean(top_scores), |
|
'num_samples': len(metrics_list) |
|
} |
|
|
|
return analysis |
|
|
|
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]: |
|
""" |
|
Analyze the distribution of top scores to gauge system confidence levels. |
|
""" |
|
if not metrics_history: |
|
return {'percentile_25': 0.0, 'percentile_50': 0.0, |
|
'percentile_75': 0.0, 'percentile_90': 0.0} |
|
|
|
scores = [m.get('top_score', 0.0) for m in metrics_history] |
|
return { |
|
'percentile_25': float(np.percentile(scores, 25)), |
|
'percentile_50': float(np.percentile(scores, 50)), |
|
'percentile_75': float(np.percentile(scores, 75)), |
|
'percentile_90': float(np.percentile(scores, 90)) |
|
} |
|
|
|
def _log_validation_results( |
|
self, |
|
query: str, |
|
responses: List[Tuple[str, float]], |
|
metrics: Dict[str, Any], |
|
): |
|
""" |
|
Log detailed validation results for each test case. |
|
""" |
|
domain = metrics.get('domain', 'Unknown') |
|
is_confident = metrics.get('is_confident', False) |
|
|
|
logger.info(f"DOMAIN: {domain} | CONFIDENCE: {'Yes' if is_confident else 'No'}") |
|
|
|
if is_confident or responses[0][1] >= 0.5: |
|
logger.info(f"SELECTED RESPONSE: '{responses[0][0]}'") |
|
else: |
|
logger.info("SELECTED RESPONSE: NONE (Low Confidence)") |
|
|
|
logger.info(" Top 3 Responses:") |
|
for i, (resp_text, score) in enumerate(responses[:3], 1): |
|
logger.info(f" {i}) Score: {score:.4f} | {resp_text}") |
|
|
|
def _log_validation_summary(self, metrics: Dict[str, Any]): |
|
""" |
|
Log a summary of all validation metrics and domain performance. |
|
""" |
|
if not metrics: |
|
logger.info("No metrics to summarize.") |
|
return |
|
|
|
logger.info("\n=== Validation Summary ===") |
|
|
|
|
|
logger.info("\nOverall Metrics:") |
|
for metric, value in metrics.items(): |
|
|
|
if isinstance(value, (int, float)): |
|
logger.info(f"{metric}: {value:.4f}") |
|
|
|
|
|
domain_perf = metrics.get('domain_performance', {}) |
|
logger.info("\nDomain Performance:") |
|
for domain, domain_stats in domain_perf.items(): |
|
logger.info(f"\n{domain.title()}:") |
|
for metric, value in domain_stats.items(): |
|
logger.info(f" {metric}: {value:.4f}") |
|
|
|
|
|
conf_analysis = metrics.get('confidence_analysis', {}) |
|
logger.info("\nConfidence Distribution:") |
|
for pct, val in conf_analysis.items(): |
|
logger.info(f" {pct}: {val:.4f}") |
|
|
|
|