agent-cost-optimizer / aco /classifier.py
narcolepticchicken's picture
Upload aco/classifier.py with huggingface_hub
1b0e9a1 verified
"""Task Cost Classifier: Predicts task type, difficulty, and cost requirements."""
from typing import Dict, Tuple, Optional
import re
CODE_PATTERNS = [r'\b(code|function|bug|debug|refactor|implement|compile|runtime|segfault|thread|async|class|module|python|javascript|typescript|go|rust|java)\b']
LEGAL_PATTERNS = [r'\b(contract|legal|compliance|gdpr|privacy|policy|regulatory|liability|indemnif|clause)\b']
RESEARCH_PATTERNS = [r'\b(research|sources?|literature|investigate|compare|analy[sz]e|survey|paper|arxiv|find)\b']
TOOL_PATTERNS = [r'\b(search|fetch|retrieve|query|api|database|scrape|aggregate|list|download)\b']
LONG_PATTERNS = [r'\b(plan|roadmap|orchestrat|migrate|pipeline|deploy|architecture|multi-step|end.to.end|entire)\b']
MATH_PATTERNS = [r'\b(calculat|comput|solve|equation|formula|optim[iy]|probability|integral|derivative)\b']
SIMPLE_PATTERNS = [r'\b(typo|simple|quick|brief|just|minor|small|easy|trivial|clarif|only)\b']
CRITICAL_PATTERNS = [r'\b(critical|production|urgent|now|emergency|live|deployed|safety|security|important)\b']
DOC_PATTERNS = [r'\b(draft|write|compose|email|proposal|report|memo|letter|document|create)\b']
RETRIEVAL_PATTERNS = [r'\b(find all|search.*for|look up|based on|according to|in the document|in the file)\b']
TASK_TYPES = [
"quick_answer", "coding", "research", "document_drafting",
"legal_regulated", "tool_heavy", "retrieval_heavy", "long_horizon", "unknown_ambiguous"
]
TASK_DIFFICULTY_BASE = {
"quick_answer": 1, "document_drafting": 2, "tool_heavy": 2, "retrieval_heavy": 2,
"research": 3, "coding": 3, "unknown_ambiguous": 3, "long_horizon": 4, "legal_regulated": 5,
}
TASK_RISK = {
"quick_answer": "low", "document_drafting": "low", "tool_heavy": "medium",
"retrieval_heavy": "medium", "research": "medium", "coding": "medium",
"unknown_ambiguous": "medium", "long_horizon": "high", "legal_regulated": "critical",
}
class TaskCostClassifier:
def __init__(self):
self.task_types = TASK_TYPES
def classify(self, request: str) -> Dict:
task_type = self._classify_type(request)
difficulty = self._estimate_difficulty(request, task_type)
risk = TASK_RISK.get(task_type, "medium")
needs_tools = self._needs_tools(request, task_type)
needs_retrieval = self._needs_retrieval(request, task_type)
needs_verifier = self._needs_verifier(request, task_type, risk)
expected_cost = self._estimate_cost(difficulty, needs_tools, needs_retrieval, needs_verifier)
return {
"task_type": task_type,
"difficulty": difficulty,
"risk": risk,
"needs_tools": needs_tools,
"needs_retrieval": needs_retrieval,
"needs_verifier": needs_verifier,
"expected_cost": expected_cost,
"expected_tier": min(difficulty + 1, 5),
}
def _classify_type(self, request: str) -> str:
r = request.lower()
scores = {}
scores["legal_regulated"] = sum(len(re.findall(p, r)) for p in LEGAL_PATTERNS)
scores["coding"] = sum(len(re.findall(p, r)) for p in CODE_PATTERNS)
scores["research"] = sum(len(re.findall(p, r)) for p in RESEARCH_PATTERNS)
scores["tool_heavy"] = sum(len(re.findall(p, r)) for p in TOOL_PATTERNS)
scores["long_horizon"] = sum(len(re.findall(p, r)) for p in LONG_PATTERNS)
scores["retrieval_heavy"] = sum(len(re.findall(p, r)) for p in RETRIEVAL_PATTERNS)
scores["document_drafting"] = sum(len(re.findall(p, r)) for p in DOC_PATTERNS)
scores["quick_answer"] = 0.5 if len(r.split()) < 10 else 0
# Check if no strong signal
max_score = max(scores.values()) if scores else 0
if max_score == 0:
return "unknown_ambiguous"
return max(scores, key=scores.get)
def _estimate_difficulty(self, request: str, task_type: str) -> int:
r = request.lower()
base = TASK_DIFFICULTY_BASE.get(task_type, 3)
if any(re.findall(p, r) for p in CRITICAL_PATTERNS):
base = min(base + 1, 5)
if any(re.findall(p, r) for p in SIMPLE_PATTERNS):
base = max(base - 1, 1)
return base
def _needs_tools(self, request: str, task_type: str) -> bool:
return task_type in ("tool_heavy", "retrieval_heavy", "coding", "research")
def _needs_retrieval(self, request: str, task_type: str) -> bool:
return task_type in ("retrieval_heavy", "research")
def _needs_verifier(self, request: str, task_type: str, risk: str) -> bool:
return risk in ("high", "critical")
def _estimate_cost(self, difficulty: int, tools: bool, retrieval: bool, verifier: bool) -> float:
base_cost = {1: 0.05, 2: 0.15, 3: 0.75, 4: 1.0, 5: 1.5}.get(difficulty, 1.0)
if tools: base_cost *= 1.3
if retrieval: base_cost *= 1.2
if verifier: base_cost *= 1.1
return base_cost