NIRAJz's picture
Upload 23 files
6d55fec verified
raw
history blame
5.5 kB
from langgraph.graph import StateGraph, END
from typing import Dict, Any, List
from typing_extensions import TypedDict
import asyncio
from schemas.data_models import EvaluationResult
from chains.evaluation_chains import EvaluationChains
from .tools import evaluate_response
from config import settings
class EvaluationState(TypedDict):
question: str
ground_truth: str
model_response: str
context: str
metrics: List[str]
results: Dict[str, Any]
current_metric: str
final_result: EvaluationResult
class EvaluationGraphBuilder:
def __init__(self, model_name: str = None, api_provider: str = None):
self.model_name = model_name
self.api_provider = api_provider or settings.DEFAULT_API_PROVIDER
self.evaluation_chains = EvaluationChains(model_name, self.api_provider)
def build_graph(self):
"""Build LangGraph workflow for evaluation"""
workflow = StateGraph(EvaluationState)
# Add nodes
workflow.add_node("initialize", self._initialize_node)
workflow.add_node("evaluate_metric", self._evaluate_metric_node)
workflow.add_node("aggregate_results", self._aggregate_results_node)
# Define edges
workflow.set_entry_point("initialize")
workflow.add_edge("initialize", "evaluate_metric")
workflow.add_conditional_edges(
"evaluate_metric",
self._should_continue,
{
"continue": "evaluate_metric",
"done": "aggregate_results"
}
)
workflow.add_edge("aggregate_results", END)
return workflow.compile()
def _initialize_node(self, state: EvaluationState) -> EvaluationState:
"""Initialize evaluation state"""
return {
**state,
"results": {},
"current_metric": state["metrics"][0] if state["metrics"] else ""
}
def _evaluate_metric_node(self, state: EvaluationState) -> EvaluationState:
"""Evaluate a single metric"""
metric = state["current_metric"]
chain = self.evaluation_chains.create_evaluation_chain(metric)
# Prepare tool input
tool_input = {
"question": state["question"],
"ground_truth": state["ground_truth"],
"response": state["model_response"],
"metric": metric,
"chain": chain
}
# Add context for context-based metrics (even if empty)
if metric in ["context_precision", "context_recall"]:
tool_input["context"] = state.get("context", "No context provided.")
# Fix: Use the tool correctly with proper arguments
result = evaluate_response.invoke(tool_input)
# Update results
results = state.get("results", {})
results[metric] = result
# Move to next metric
current_index = state["metrics"].index(metric)
next_index = current_index + 1
return {
**state,
"results": results,
"current_metric": state["metrics"][next_index] if next_index < len(state["metrics"]) else None
}
def _should_continue(self, state: EvaluationState) -> str:
"""Determine if we should continue evaluating metrics"""
if state["current_metric"] is None:
return "done"
return "continue"
def _aggregate_results_node(self, state: EvaluationState) -> EvaluationState:
"""Aggregate results into final format"""
metrics_scores = {}
explanations = {}
total_time = 0
for metric, result in state["results"].items():
metrics_scores[metric] = result.get("score", 0)
explanations[metric] = result.get("explanation", "")
total_time += result.get("processing_time", 0)
# Calculate overall score (weighted average)
overall_score = self._calculate_overall_score(metrics_scores)
final_result = EvaluationResult(
question=state["question"],
ground_truth=state["ground_truth"],
model_response=state["model_response"],
metrics=metrics_scores,
explanations=explanations,
processing_time=total_time,
overall_score=overall_score
)
return {**state, "final_result": final_result}
def _calculate_overall_score(self, metrics_scores: Dict[str, float]) -> float:
"""Calculate overall score with weighted metrics"""
# Define weights for different metrics
weights = {
"accuracy": 0.3,
"faithfulness": 0.25,
"relevance": 0.2,
"toxicity": 0.15,
"context_precision": 0.05,
"context_recall": 0.05
}
# Calculate weighted average
total_weight = 0
weighted_sum = 0
for metric, score in metrics_scores.items():
weight = weights.get(metric, 0.1) # Default weight for unknown metrics
weighted_sum += score * weight
total_weight += weight
# Normalize to 0-100 scale
if total_weight > 0:
return weighted_sum / total_weight
return 0