Spaces:
Sleeping
Sleeping
| # app/env.py | |
| from typing import Tuple, Dict, Any | |
| from app.models import Observation, Action, Reward | |
| from app.dataset import TICKETS | |
| import random | |
| from graders import grade_easy, grade_medium, grade_hard | |
| #from tasks import TASKS | |
| import sys | |
| # ========================= | |
| # PURPOSE: Controls difficulty-driven stochasticity | |
| # - noise_prob → message distortion | |
| # - missing_info_prob → partial observability | |
| # ========================= | |
| DIFFICULTY_CONFIG = { | |
| "easy": { | |
| "max_steps": 8, | |
| "noise_prob": 0.0, | |
| "missing_info_prob": 0.1 | |
| }, | |
| "medium": { | |
| "max_steps": 10, | |
| "noise_prob": 0.2, | |
| "missing_info_prob": 0.3 | |
| }, | |
| "hard": { | |
| "max_steps": 12, | |
| "noise_prob": 0.4, | |
| "missing_info_prob": 0.5 | |
| } | |
| } | |
| # ========================= | |
| # PURPOSE: Defines tasks exposed to validator | |
| # ========================= | |
| AVAILABLE_TASKS = [ | |
| { | |
| "id": "easy-info-collection", | |
| "difficulty": "easy", | |
| "grader": grade_easy | |
| }, | |
| { | |
| "id": "medium-complete-info", | |
| "difficulty": "medium", | |
| "grader": grade_medium | |
| }, | |
| { | |
| "id": "hard-efficient-resolution", | |
| "difficulty": "hard", | |
| "grader": grade_hard | |
| } | |
| ] | |
| def get_tasks(): | |
| return AVAILABLE_TASKS | |
| class CustomerSupportEnv: | |
| # OBTAIN TASKS FROM GRADERS.PY | |
| def get_tasks(self): | |
| return [ | |
| { | |
| "id": "easy-info-collection", | |
| "difficulty": "easy", | |
| "grader": grade_easy, | |
| }, | |
| { | |
| "id": "medium-complete-info", | |
| "difficulty": "medium", | |
| "grader": grade_medium, | |
| }, | |
| { | |
| "id": "hard-efficient-resolution", | |
| "difficulty": "hard", | |
| "grader": grade_hard, | |
| }, | |
| ] | |
| # ========================= | |
| # PURPOSE: Build observation exposed to agent | |
| # ========================= | |
| def _get_observation(self): | |
| required = self.state_data["required_info"] | |
| collected = self.state_data["collected_info"] | |
| total = len(required) | |
| collected_count = sum(1 for f in required if f in collected) | |
| return { | |
| "ticket_id": self.ticket["ticket_id"], | |
| "customer_message": self.state_data["customer_message"], | |
| "known_info": collected, | |
| "required": required, | |
| "missing_required": [f for f in required if f not in collected], | |
| "info_progress": collected_count / max(1, total), | |
| "status": self.state_data["status"], | |
| "step_count": self.state_data["steps_taken"], | |
| "remaining_steps": self.max_steps - self.state_data["steps_taken"], | |
| "difficulty": self.difficulty # difficulty awareness | |
| } | |
| # ========================= | |
| # PURPOSE: Initialize environment with difficulty & randomness | |
| # ========================= | |
| def __init__(self, difficulty="medium", seed=None): | |
| self.difficulty = difficulty | |
| self.config = DIFFICULTY_CONFIG[difficulty] | |
| if seed is not None: | |
| random.seed(seed) | |
| self.state_data = None | |
| self.max_steps = self.config["max_steps"] | |
| self.last_action = None | |
| # self-correction tracking | |
| self.classification_history = [] | |
| # METRICS TRACKING | |
| self.episode_stats = [] | |
| def list_tasks(self): | |
| return self.tasks | |
| def reset(self): | |
| self.last_action = None | |
| #self.current_episode_reward = 0.0 | |
| self.current_steps = 0 | |
| self.success = False | |
| self.ticket = random.choice(TICKETS) | |
| gt = self.ticket["ground_truth"] | |
| msg = random.choice(self.ticket["variants"]) | |
| msg = self._inject_noise(msg) | |
| masked_required = self._mask_required_info(gt["required_info"]) | |
| self.state_data = { | |
| "ticket_id": self.ticket["ticket_id"], | |
| "customer_message": msg, | |
| "status": "open", | |
| "category": None, | |
| "priority": None, | |
| "required_info": masked_required, | |
| "collected_info": {}, | |
| "steps_taken": 0, | |
| "ground_truth": gt | |
| } | |
| return self._get_observation() | |
| # ========================= | |
| # PURPOSE: Core transition function with self-correction logic | |
| # ========================= | |
| def step(self, action: dict): | |
| if self.state_data is None: | |
| self.reset() | |
| reward = -0.05 | |
| done = False | |
| info = {} | |
| collected = self.state_data["collected_info"] | |
| gt = self.ticket["ground_truth"] | |
| action_type = action.get("type") if isinstance(action, dict) else None | |
| # ----------------------- | |
| # CLASSIFY (SELF-CORRECTION ENABLED) | |
| # ----------------------- | |
| if action_type == "classify": | |
| new_cat = action.get("category") | |
| prev_cat = collected.get("category") | |
| collected["category"] = new_cat | |
| collected["priority"] = action.get("priority") | |
| self.classification_history.append(new_cat) | |
| # correct classification | |
| if new_cat == gt["category"]: | |
| reward += 0.3 | |
| # self-correction bonus | |
| if prev_cat and prev_cat != gt["category"] and new_cat == gt["category"]: | |
| reward += 0.5 # major reward | |
| # flip-flop penalty | |
| if len(self.classification_history) >= 3: | |
| if len(set(self.classification_history[-3:])) > 2: | |
| reward -= 0.3 | |
| # ----------------------- | |
| # ASK INFO | |
| # ----------------------- | |
| elif action_type == "ask_info": | |
| field = action.get("field") | |
| if field not in collected: | |
| collected[field] = "value" | |
| reward += 0.25 | |
| else: | |
| reward -= 0.2 | |
| # ----------------------- | |
| # RESOLVE | |
| # ----------------------- | |
| elif action_type == "resolve": | |
| done = True | |
| required = gt["required_info"] | |
| all_info = all(f in collected for f in required) | |
| correct_cat = collected.get("category") == gt["category"] | |
| # 🔥 premature penalty | |
| if not all_info: | |
| reward -= 0.7 | |
| # scoring | |
| if correct_cat: | |
| reward += 0.3 | |
| if all_info: | |
| reward += 0.3 | |
| self.success = True | |
| reward += 0.2 # completion bonus | |
| else: | |
| reward -= 0.3 | |
| # ----------------------- | |
| # STEP UPDATE | |
| # ----------------------- | |
| self.state_data["steps_taken"] += 1 | |
| if self.state_data["steps_taken"] >= self.max_steps: | |
| done = True | |
| reward -= 1.5 | |
| return self._get_observation(), reward, done, { | |
| "task_success": self.success | |
| } | |
| def state(self) -> Dict: | |
| return self.state_data | |
| def get_metrics(self): | |
| if not self.episode_stats: | |
| return {} | |
| total = len(self.episode_stats) | |
| success_rate = sum(e["success"] for e in self.episode_stats) / total | |
| avg_steps = sum(e["steps"] for e in self.episode_stats) / total | |
| avg_reward = sum(e["reward"] for e in self.episode_stats) / total | |
| info_eff = sum(e["info_efficiency"] for e in self.episode_stats) / total | |
| return { | |
| "success_rate": round(success_rate, 3), | |
| "avg_steps": round(avg_steps, 3), | |
| "avg_reward": round(avg_reward, 3), | |
| "info_efficiency": round(info_eff, 3) | |
| } | |
| # ========================= | |
| # PURPOSE: Apply noise to simulate real-world messy input | |
| # ========================= | |
| def _inject_noise(self, message): | |
| if random.random() < self.config["noise_prob"]: | |
| noise = random.choice([ | |
| "pls help asap", | |
| "not sure what's wrong", | |
| "this is urgent", | |
| "been days" | |
| ]) | |
| return message + " " + noise | |
| return message | |
| # ========================= | |
| # PURPOSE: Mask required fields → partial observability | |
| # ========================= | |
| def _mask_required_info(self, required_fields): | |
| masked = [ | |
| f for f in required_fields | |
| if random.random() > self.config["missing_info_prob"] | |
| ] | |
| return masked if masked else required_fields | |
| """ | |
| def _mask_required_info(self, required_fields): | |
| masked = [] | |
| for field in required_fields: | |
| if random.random() > self.config["missing_info_prob"]: | |
| masked.append(field) | |
| # ensure at least 1 required field remains | |
| return masked if masked else required_fields | |
| """ |