csa01 / app /env.py
prashantmatlani's picture
implemented agents' self-learning, self-correcting without explicit training
0894e25
# app/env.py
from typing import Tuple, Dict, Any
from app.models import Observation, Action, Reward
from app.dataset import TICKETS
import random
from graders import grade_easy, grade_medium, grade_hard
#from tasks import TASKS
import sys
# =========================
# PURPOSE: Controls difficulty-driven stochasticity
# - noise_prob → message distortion
# - missing_info_prob → partial observability
# =========================
DIFFICULTY_CONFIG = {
"easy": {
"max_steps": 8,
"noise_prob": 0.0,
"missing_info_prob": 0.1
},
"medium": {
"max_steps": 10,
"noise_prob": 0.2,
"missing_info_prob": 0.3
},
"hard": {
"max_steps": 12,
"noise_prob": 0.4,
"missing_info_prob": 0.5
}
}
# =========================
# PURPOSE: Defines tasks exposed to validator
# =========================
AVAILABLE_TASKS = [
{
"id": "easy-info-collection",
"difficulty": "easy",
"grader": grade_easy
},
{
"id": "medium-complete-info",
"difficulty": "medium",
"grader": grade_medium
},
{
"id": "hard-efficient-resolution",
"difficulty": "hard",
"grader": grade_hard
}
]
def get_tasks():
return AVAILABLE_TASKS
class CustomerSupportEnv:
# OBTAIN TASKS FROM GRADERS.PY
def get_tasks(self):
return [
{
"id": "easy-info-collection",
"difficulty": "easy",
"grader": grade_easy,
},
{
"id": "medium-complete-info",
"difficulty": "medium",
"grader": grade_medium,
},
{
"id": "hard-efficient-resolution",
"difficulty": "hard",
"grader": grade_hard,
},
]
# =========================
# PURPOSE: Build observation exposed to agent
# =========================
def _get_observation(self):
required = self.state_data["required_info"]
collected = self.state_data["collected_info"]
total = len(required)
collected_count = sum(1 for f in required if f in collected)
return {
"ticket_id": self.ticket["ticket_id"],
"customer_message": self.state_data["customer_message"],
"known_info": collected,
"required": required,
"missing_required": [f for f in required if f not in collected],
"info_progress": collected_count / max(1, total),
"status": self.state_data["status"],
"step_count": self.state_data["steps_taken"],
"remaining_steps": self.max_steps - self.state_data["steps_taken"],
"difficulty": self.difficulty # difficulty awareness
}
# =========================
# PURPOSE: Initialize environment with difficulty & randomness
# =========================
def __init__(self, difficulty="medium", seed=None):
self.difficulty = difficulty
self.config = DIFFICULTY_CONFIG[difficulty]
if seed is not None:
random.seed(seed)
self.state_data = None
self.max_steps = self.config["max_steps"]
self.last_action = None
# self-correction tracking
self.classification_history = []
# METRICS TRACKING
self.episode_stats = []
def list_tasks(self):
return self.tasks
def reset(self):
self.last_action = None
#self.current_episode_reward = 0.0
self.current_steps = 0
self.success = False
self.ticket = random.choice(TICKETS)
gt = self.ticket["ground_truth"]
msg = random.choice(self.ticket["variants"])
msg = self._inject_noise(msg)
masked_required = self._mask_required_info(gt["required_info"])
self.state_data = {
"ticket_id": self.ticket["ticket_id"],
"customer_message": msg,
"status": "open",
"category": None,
"priority": None,
"required_info": masked_required,
"collected_info": {},
"steps_taken": 0,
"ground_truth": gt
}
return self._get_observation()
# =========================
# PURPOSE: Core transition function with self-correction logic
# =========================
def step(self, action: dict):
if self.state_data is None:
self.reset()
reward = -0.05
done = False
info = {}
collected = self.state_data["collected_info"]
gt = self.ticket["ground_truth"]
action_type = action.get("type") if isinstance(action, dict) else None
# -----------------------
# CLASSIFY (SELF-CORRECTION ENABLED)
# -----------------------
if action_type == "classify":
new_cat = action.get("category")
prev_cat = collected.get("category")
collected["category"] = new_cat
collected["priority"] = action.get("priority")
self.classification_history.append(new_cat)
# correct classification
if new_cat == gt["category"]:
reward += 0.3
# self-correction bonus
if prev_cat and prev_cat != gt["category"] and new_cat == gt["category"]:
reward += 0.5 # major reward
# flip-flop penalty
if len(self.classification_history) >= 3:
if len(set(self.classification_history[-3:])) > 2:
reward -= 0.3
# -----------------------
# ASK INFO
# -----------------------
elif action_type == "ask_info":
field = action.get("field")
if field not in collected:
collected[field] = "value"
reward += 0.25
else:
reward -= 0.2
# -----------------------
# RESOLVE
# -----------------------
elif action_type == "resolve":
done = True
required = gt["required_info"]
all_info = all(f in collected for f in required)
correct_cat = collected.get("category") == gt["category"]
# 🔥 premature penalty
if not all_info:
reward -= 0.7
# scoring
if correct_cat:
reward += 0.3
if all_info:
reward += 0.3
self.success = True
reward += 0.2 # completion bonus
else:
reward -= 0.3
# -----------------------
# STEP UPDATE
# -----------------------
self.state_data["steps_taken"] += 1
if self.state_data["steps_taken"] >= self.max_steps:
done = True
reward -= 1.5
return self._get_observation(), reward, done, {
"task_success": self.success
}
def state(self) -> Dict:
return self.state_data
def get_metrics(self):
if not self.episode_stats:
return {}
total = len(self.episode_stats)
success_rate = sum(e["success"] for e in self.episode_stats) / total
avg_steps = sum(e["steps"] for e in self.episode_stats) / total
avg_reward = sum(e["reward"] for e in self.episode_stats) / total
info_eff = sum(e["info_efficiency"] for e in self.episode_stats) / total
return {
"success_rate": round(success_rate, 3),
"avg_steps": round(avg_steps, 3),
"avg_reward": round(avg_reward, 3),
"info_efficiency": round(info_eff, 3)
}
# =========================
# PURPOSE: Apply noise to simulate real-world messy input
# =========================
def _inject_noise(self, message):
if random.random() < self.config["noise_prob"]:
noise = random.choice([
"pls help asap",
"not sure what's wrong",
"this is urgent",
"been days"
])
return message + " " + noise
return message
# =========================
# PURPOSE: Mask required fields → partial observability
# =========================
def _mask_required_info(self, required_fields):
masked = [
f for f in required_fields
if random.random() > self.config["missing_info_prob"]
]
return masked if masked else required_fields
"""
def _mask_required_info(self, required_fields):
masked = []
for field in required_fields:
if random.random() > self.config["missing_info_prob"]:
masked.append(field)
# ensure at least 1 required field remains
return masked if masked else required_fields
"""