agent-cost-optimizer / aco /optimizer.py
narcolepticchicken's picture
Upload aco/optimizer.py with huggingface_hub
2cbe770 verified
"""ACO Optimizer: Main orchestrator that coordinates all modules."""
import json, time, uuid
from typing import Dict, List, Optional, Any
from .config import ACOConfig, RoutingPolicy
from .trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall
from .classifier import TaskCostClassifier
from .router import ModelCascadeRouter, RoutingDecision
from .context_budgeter import ContextBudgeter, ContextBudget
from .cache_layout import CacheAwareLayout, PromptLayout
from .tool_gate import ToolCostGate, ToolDecision
from .verifier_budgeter import VerifierBudgeter, VerifierDecision
from .retry_optimizer import RetryOptimizer, RecoveryAction
from .meta_tool_miner import MetaToolMiner, MacroTool
from .doom_detector import DoomDetector, DoomAssessment
from .execution_feedback import ExecutionFeedbackRouter, CascadeResult, FeedbackSignal
class ACOOptimizer:
def __init__(self, config: ACOConfig = None):
self.config = config or ACOConfig()
self.classifier = TaskCostClassifier()
self.router = ModelCascadeRouter(
model_path=self.config.router_model_path,
safety_threshold=self.config.routing_policy.safety_threshold,
downgrade_threshold=self.config.routing_policy.downgrade_threshold,
task_floor=self.config.task_floors,
tier_costs=self.config.tier_costs,
)
self.context_budgeter = ContextBudgeter()
self.cache_layout = CacheAwareLayout()
self.tool_gate = ToolCostGate()
self.verifier_budgeter = VerifierBudgeter()
self.retry_optimizer = RetryOptimizer(
max_retries=self.config.routing_policy.max_retries,
)
self.meta_tool_miner = MetaToolMiner()
self.doom_detector = DoomDetector()
self.execution_feedback = ExecutionFeedbackRouter(
tier_costs=self.config.tier_costs,
task_floors=self.config.task_floors,
)
self._current_trace: Optional[AgentTrace] = None
self._step_num = 0
self._traces: List[AgentTrace] = []
def start_run(self, request: str) -> Dict:
prediction = self.classifier.classify(request)
routing = self.router.route(request, prediction["task_type"], prediction["difficulty"], prediction)
context_budget = self.context_budgeter.budget(
prediction["task_type"], prediction["difficulty"],
prediction["needs_retrieval"], prediction["needs_tools"],
)
# Check for meta-tool match
macro = self.meta_tool_miner.match_macro(request, prediction["task_type"]) if self.config.enable_meta_tools else None
self._current_trace = AgentTrace(
request=request,
task_type=prediction["task_type"],
difficulty=prediction["difficulty"],
predicted_tier=routing.tier,
)
self._step_num = 0
self.retry_optimizer.reset_run()
self.verifier_budgeter.reset_run()
return {
"trace_id": self._current_trace.trace_id,
"prediction": prediction,
"routing": {
"model_id": routing.model_id,
"tier": routing.tier,
"confidence": routing.confidence,
"cost_estimate": routing.cost_estimate,
"dynamic_difficulty": routing.dynamic_difficulty,
"escalated": routing.escalated,
"downgraded": routing.downgraded,
"reasoning": routing.reasoning,
},
"context_budget": {
"total_tokens": context_budget.total_tokens,
"keep_exact": context_budget.keep_exact,
"summarize": context_budget.summarize,
"omit": context_budget.omit,
"retrieve_on_demand": context_budget.retrieve_on_demand,
"cache_prefix": context_budget.cache_prefix,
},
"macro_tool": macro.name if macro else None,
}
def record_step(self, model_call: Dict = None, tool_calls: List[Dict] = None,
context_size: int = 0, verifier_called: bool = False,
verifier_result: str = None, retry_num: int = 0,
recovery_action: str = None) -> None:
self._step_num += 1
mc = None
if model_call:
mc = ModelCall(**model_call)
tcs = [ToolCall(**tc) for tc in (tool_calls or [])]
step = TraceStep(
step_num=self._step_num,
model_call=mc,
tool_calls=tcs,
context_size=context_size,
verifier_called=verifier_called,
verifier_result=verifier_result,
retry_num=retry_num,
recovery_action=recovery_action,
)
if self._current_trace:
self._current_trace.steps.append(step)
def check_doom(self, current_cost: float = 0.0) -> DoomAssessment:
if not self._current_trace:
return DoomAssessment(False, 0.0, [], "continue", "no active trace")
return self.doom_detector.assess(
[s.__dict__ for s in self._current_trace.steps],
current_cost, self.config.routing_policy.max_cost_per_task, 4)
def should_verify(self, is_irreversible: bool = False,
has_prior_failures: bool = False) -> VerifierDecision:
if not self._current_trace:
return VerifierDecision(False, "skip", 0.0, "no active trace", 0.0)
return self.verifier_budgeter.should_verify(
self._current_trace.task_type, "medium", 0.8,
is_irreversible, has_prior_failures,
self._current_trace.predicted_tier)
def gate_tool(self, tool_name: str, args: Dict) -> ToolDecision:
if not self._current_trace:
return ToolDecision("skip", tool_name, 0.0, "no active trace", 0.0, 0.0)
return self.tool_gate.gate(tool_name, args, self._current_trace.task_type,
self._step_num, self._step_num + 1, 0.5)
def cascade_step(self, request: str, initial_tier: int,
cheap_logprobs: List[float],
cheap_response: str,
strong_response: str = "",
task_type: str = None) -> CascadeResult:
"""Execution-feedback cascade: use cheap model output to decide escalation."""
if not self._current_trace:
task_type = task_type or "unknown_ambiguous"
else:
task_type = task_type or self._current_trace.task_type
floor = self.config.task_floors.get(task_type, 1)
return self.execution_feedback.cascade(
request, initial_tier, cheap_logprobs,
cheap_response, strong_response,
task_type=task_type, task_floor=floor,
)
def analyze_output_confidence(self, token_logprobs: List[float],
task_type: str = "unknown",
current_tier: int = 2) -> FeedbackSignal:
"""Analyze model output confidence for routing decisions."""
return self.execution_feedback.analyze_output(
token_logprobs, task_type=task_type, current_tier=current_tier)
def get_recovery(self, failure_tag: str, current_tier: int,
retry_num: int, previous_actions: List[str] = None,
run_cost: float = 0.0) -> RecoveryAction:
return self.retry_optimizer.get_recovery(
failure_tag, current_tier, retry_num,
previous_actions, run_cost,
self.config.routing_policy.max_cost_per_task)
def end_run(self, success: bool, outcome: str = "completed",
artifacts: List[str] = None, failure_tags: List[str] = None,
user_correction: bool = False) -> AgentTrace:
if self._current_trace:
self._current_trace.task_success = success
self._current_trace.final_outcome = outcome
self._current_trace.artifacts_created = artifacts or []
self._current_trace.failure_tags = failure_tags or []
self._current_trace.user_correction = user_correction
summary = self._current_trace.compute_summary()
self._current_trace.total_cost = summary["total_cost"]
self._current_trace.total_tokens = summary["total_tokens"]
self._current_trace.total_tool_calls = summary["total_tool_calls"]
self._current_trace.total_retries = summary["total_retries"]
self._current_trace.total_verifier_calls = summary["total_verifier_calls"]
self._current_trace.cache_hit_rate = summary["cache_hit_rate"]
self._traces.append(self._current_trace)
trace = self._current_trace
self._current_trace = None
return trace
def layout_prompt(self, sources: Dict[str, str]) -> PromptLayout:
if not self._current_trace:
budget = self.context_budgeter.budget("unknown_ambiguous", 3, False, False)
else:
budget = self.context_budgeter.budget(
self._current_trace.task_type,
self._current_trace.difficulty,
False, False)
return self.cache_layout.layout(sources, budget)
def get_stats(self) -> Dict:
return {
"total_runs": len(self._traces),
"successful_runs": sum(1 for t in self._traces if t.task_success),
"avg_cost": sum(t.total_cost for t in self._traces) / max(len(self._traces),1),
"cache_stats": self.cache_layout.stats(),
"tool_stats": self.tool_gate.call_stats,
"verifier_stats": self.verifier_budgeter.stats,
"retry_stats": self.retry_optimizer.recovery_stats,
}