| """Task Cost Classifier: Predicts task type, difficulty, and cost requirements.""" |
| from typing import Dict, Tuple, Optional |
| import re |
|
|
| CODE_PATTERNS = [r'\b(code|function|bug|debug|refactor|implement|compile|runtime|segfault|thread|async|class|module|python|javascript|typescript|go|rust|java)\b'] |
| LEGAL_PATTERNS = [r'\b(contract|legal|compliance|gdpr|privacy|policy|regulatory|liability|indemnif|clause)\b'] |
| RESEARCH_PATTERNS = [r'\b(research|sources?|literature|investigate|compare|analy[sz]e|survey|paper|arxiv|find)\b'] |
| TOOL_PATTERNS = [r'\b(search|fetch|retrieve|query|api|database|scrape|aggregate|list|download)\b'] |
| LONG_PATTERNS = [r'\b(plan|roadmap|orchestrat|migrate|pipeline|deploy|architecture|multi-step|end.to.end|entire)\b'] |
| MATH_PATTERNS = [r'\b(calculat|comput|solve|equation|formula|optim[iy]|probability|integral|derivative)\b'] |
| SIMPLE_PATTERNS = [r'\b(typo|simple|quick|brief|just|minor|small|easy|trivial|clarif|only)\b'] |
| CRITICAL_PATTERNS = [r'\b(critical|production|urgent|now|emergency|live|deployed|safety|security|important)\b'] |
| DOC_PATTERNS = [r'\b(draft|write|compose|email|proposal|report|memo|letter|document|create)\b'] |
| RETRIEVAL_PATTERNS = [r'\b(find all|search.*for|look up|based on|according to|in the document|in the file)\b'] |
|
|
| TASK_TYPES = [ |
| "quick_answer", "coding", "research", "document_drafting", |
| "legal_regulated", "tool_heavy", "retrieval_heavy", "long_horizon", "unknown_ambiguous" |
| ] |
|
|
| TASK_DIFFICULTY_BASE = { |
| "quick_answer": 1, "document_drafting": 2, "tool_heavy": 2, "retrieval_heavy": 2, |
| "research": 3, "coding": 3, "unknown_ambiguous": 3, "long_horizon": 4, "legal_regulated": 5, |
| } |
|
|
| TASK_RISK = { |
| "quick_answer": "low", "document_drafting": "low", "tool_heavy": "medium", |
| "retrieval_heavy": "medium", "research": "medium", "coding": "medium", |
| "unknown_ambiguous": "medium", "long_horizon": "high", "legal_regulated": "critical", |
| } |
|
|
| class TaskCostClassifier: |
| def __init__(self): |
| self.task_types = TASK_TYPES |
|
|
| def classify(self, request: str) -> Dict: |
| task_type = self._classify_type(request) |
| difficulty = self._estimate_difficulty(request, task_type) |
| risk = TASK_RISK.get(task_type, "medium") |
| needs_tools = self._needs_tools(request, task_type) |
| needs_retrieval = self._needs_retrieval(request, task_type) |
| needs_verifier = self._needs_verifier(request, task_type, risk) |
| expected_cost = self._estimate_cost(difficulty, needs_tools, needs_retrieval, needs_verifier) |
| return { |
| "task_type": task_type, |
| "difficulty": difficulty, |
| "risk": risk, |
| "needs_tools": needs_tools, |
| "needs_retrieval": needs_retrieval, |
| "needs_verifier": needs_verifier, |
| "expected_cost": expected_cost, |
| "expected_tier": min(difficulty + 1, 5), |
| } |
|
|
| def _classify_type(self, request: str) -> str: |
| r = request.lower() |
| scores = {} |
| scores["legal_regulated"] = sum(len(re.findall(p, r)) for p in LEGAL_PATTERNS) |
| scores["coding"] = sum(len(re.findall(p, r)) for p in CODE_PATTERNS) |
| scores["research"] = sum(len(re.findall(p, r)) for p in RESEARCH_PATTERNS) |
| scores["tool_heavy"] = sum(len(re.findall(p, r)) for p in TOOL_PATTERNS) |
| scores["long_horizon"] = sum(len(re.findall(p, r)) for p in LONG_PATTERNS) |
| scores["retrieval_heavy"] = sum(len(re.findall(p, r)) for p in RETRIEVAL_PATTERNS) |
| scores["document_drafting"] = sum(len(re.findall(p, r)) for p in DOC_PATTERNS) |
| scores["quick_answer"] = 0.5 if len(r.split()) < 10 else 0 |
| |
| max_score = max(scores.values()) if scores else 0 |
| if max_score == 0: |
| return "unknown_ambiguous" |
| return max(scores, key=scores.get) |
|
|
| def _estimate_difficulty(self, request: str, task_type: str) -> int: |
| r = request.lower() |
| base = TASK_DIFFICULTY_BASE.get(task_type, 3) |
| if any(re.findall(p, r) for p in CRITICAL_PATTERNS): |
| base = min(base + 1, 5) |
| if any(re.findall(p, r) for p in SIMPLE_PATTERNS): |
| base = max(base - 1, 1) |
| return base |
|
|
| def _needs_tools(self, request: str, task_type: str) -> bool: |
| return task_type in ("tool_heavy", "retrieval_heavy", "coding", "research") |
|
|
| def _needs_retrieval(self, request: str, task_type: str) -> bool: |
| return task_type in ("retrieval_heavy", "research") |
|
|
| def _needs_verifier(self, request: str, task_type: str, risk: str) -> bool: |
| return risk in ("high", "critical") |
|
|
| def _estimate_cost(self, difficulty: int, tools: bool, retrieval: bool, verifier: bool) -> float: |
| base_cost = {1: 0.05, 2: 0.15, 3: 0.75, 4: 1.0, 5: 1.5}.get(difficulty, 1.0) |
| if tools: base_cost *= 1.3 |
| if retrieval: base_cost *= 1.2 |
| if verifier: base_cost *= 1.1 |
| return base_cost |
|
|