Update env.py
Browse files
env.py
CHANGED
|
@@ -2,43 +2,42 @@ import numpy as np
|
|
| 2 |
import gymnasium as gym
|
| 3 |
from gymnasium import spaces
|
| 4 |
|
| 5 |
-
# --- UI LABELS
|
| 6 |
URGENCY_LABELS = ["General", "Billing", "Security Breach"]
|
| 7 |
ROUTING_LABELS = ["AI Auto-Reply", "Tech Support", "Legal"]
|
| 8 |
RESOLUTION_LABELS = ["Archive", "Draft Reply", "Escalate to Human"]
|
| 9 |
|
| 10 |
-
# --- Vocabulary &
|
| 11 |
KEYWORD_VOCAB = [
|
| 12 |
-
"invoice", "payment", "overdue", "refund",
|
| 13 |
-
"hacked", "breach", "unauthorized", "password",
|
| 14 |
-
"crash", "error", "bug", "slow",
|
| 15 |
-
"lawsuit", "legal", "attorney", "sue",
|
| 16 |
-
"spam", "offer", "win", "free",
|
| 17 |
-
"urgent", "critical", "angry", "threat",
|
| 18 |
]
|
| 19 |
|
| 20 |
SENTIMENT_MAP = {"positive": 0, "neutral": 1, "negative": 2}
|
| 21 |
CONTEXT_MAP = {"spam": 0, "billing": 1, "tech": 2, "security": 3, "legal": 4}
|
| 22 |
OBS_DIM = len(KEYWORD_VOCAB) + len(SENTIMENT_MAP) + len(CONTEXT_MAP)
|
| 23 |
|
| 24 |
-
# --- Environment Class ---
|
| 25 |
class EmailTriageEnv(gym.Env):
|
| 26 |
def __init__(self, task="all", batch=None, shuffle=True):
|
| 27 |
super().__init__()
|
| 28 |
|
| 29 |
-
# Dataset ko import karna (app.py se load hoga)
|
| 30 |
try:
|
| 31 |
from app import EMAIL_DATASET
|
| 32 |
dataset_to_use = EMAIL_DATASET
|
| 33 |
except ImportError:
|
| 34 |
-
dataset_to_use = []
|
| 35 |
|
|
|
|
| 36 |
if batch is not None:
|
| 37 |
-
self.
|
| 38 |
elif task != "all":
|
| 39 |
-
self.
|
| 40 |
else:
|
| 41 |
-
self.
|
| 42 |
|
| 43 |
self.shuffle = shuffle
|
| 44 |
self.action_space = spaces.MultiDiscrete([3, 3, 3])
|
|
@@ -49,48 +48,37 @@ class EmailTriageEnv(gym.Env):
|
|
| 49 |
kw_flags = np.array([1.0 if kw in email.get("keywords", []) else 0.0 for kw in KEYWORD_VOCAB])
|
| 50 |
sent_idx = SENTIMENT_MAP.get(email.get("sentiment", "neutral"), 1)
|
| 51 |
sentiment_vec = np.zeros(len(SENTIMENT_MAP)); sentiment_vec[sent_idx] = 1.0
|
| 52 |
-
|
| 53 |
ctx_idx = CONTEXT_MAP.get(email.get("context", "spam"), 0)
|
| 54 |
context_vec = np.zeros(len(CONTEXT_MAP)); context_vec[ctx_idx] = 1.0
|
| 55 |
-
|
| 56 |
return np.concatenate([kw_flags, sentiment_vec, context_vec]).astype(np.float32)
|
| 57 |
|
| 58 |
def reset(self, seed=None, options=None):
|
| 59 |
super().reset(seed=seed)
|
| 60 |
self._step_idx = 0
|
| 61 |
-
if not self.
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
obs = self._encode(self.email_batch[0])
|
| 65 |
-
return obs, {"description": self.email_batch[0].get("description", "")}
|
| 66 |
|
| 67 |
def step(self, action):
|
| 68 |
-
email = self.
|
| 69 |
correct = email["correct_actions"]
|
| 70 |
|
| 71 |
-
# Reward Logic (Score sudharne ke liye)
|
| 72 |
reward = 0.0
|
|
|
|
| 73 |
if correct[0] == 2 and action[0] != 2:
|
| 74 |
-
reward = -2.0
|
| 75 |
elif tuple(action) == correct:
|
| 76 |
reward = 1.0
|
| 77 |
elif action[0] == correct[0]:
|
| 78 |
reward = 0.2
|
| 79 |
|
| 80 |
self._step_idx += 1
|
| 81 |
-
terminated = self._step_idx >= len(self.
|
| 82 |
-
|
| 83 |
-
# Next observation
|
| 84 |
-
if not terminated:
|
| 85 |
-
next_email = self.email_batch[self._step_idx]
|
| 86 |
-
obs = self._encode(next_email)
|
| 87 |
-
else:
|
| 88 |
-
obs = self._encode(email)
|
| 89 |
|
| 90 |
info = {
|
| 91 |
"description": email.get("description", ""),
|
| 92 |
"correct_actions": correct,
|
| 93 |
"raw_reward": reward
|
| 94 |
}
|
| 95 |
-
|
| 96 |
return obs, float(reward), terminated, False, info
|
|
|
|
| 2 |
import gymnasium as gym
|
| 3 |
from gymnasium import spaces
|
| 4 |
|
| 5 |
+
# --- UI LABELS ---
|
| 6 |
URGENCY_LABELS = ["General", "Billing", "Security Breach"]
|
| 7 |
ROUTING_LABELS = ["AI Auto-Reply", "Tech Support", "Legal"]
|
| 8 |
RESOLUTION_LABELS = ["Archive", "Draft Reply", "Escalate to Human"]
|
| 9 |
|
| 10 |
+
# --- Vocabulary & Config ---
|
| 11 |
KEYWORD_VOCAB = [
|
| 12 |
+
"invoice", "payment", "overdue", "refund",
|
| 13 |
+
"hacked", "breach", "unauthorized", "password",
|
| 14 |
+
"crash", "error", "bug", "slow",
|
| 15 |
+
"lawsuit", "legal", "attorney", "sue",
|
| 16 |
+
"spam", "offer", "win", "free",
|
| 17 |
+
"urgent", "critical", "angry", "threat",
|
| 18 |
]
|
| 19 |
|
| 20 |
SENTIMENT_MAP = {"positive": 0, "neutral": 1, "negative": 2}
|
| 21 |
CONTEXT_MAP = {"spam": 0, "billing": 1, "tech": 2, "security": 3, "legal": 4}
|
| 22 |
OBS_DIM = len(KEYWORD_VOCAB) + len(SENTIMENT_MAP) + len(CONTEXT_MAP)
|
| 23 |
|
|
|
|
| 24 |
class EmailTriageEnv(gym.Env):
|
| 25 |
def __init__(self, task="all", batch=None, shuffle=True):
|
| 26 |
super().__init__()
|
| 27 |
|
|
|
|
| 28 |
try:
|
| 29 |
from app import EMAIL_DATASET
|
| 30 |
dataset_to_use = EMAIL_DATASET
|
| 31 |
except ImportError:
|
| 32 |
+
dataset_to_use = []
|
| 33 |
|
| 34 |
+
# Fix: App.py needs '_queue' for the interface to work
|
| 35 |
if batch is not None:
|
| 36 |
+
self._queue = batch
|
| 37 |
elif task != "all":
|
| 38 |
+
self._queue = [e for e in dataset_to_use if e.get("difficulty") == task]
|
| 39 |
else:
|
| 40 |
+
self._queue = dataset_to_use
|
| 41 |
|
| 42 |
self.shuffle = shuffle
|
| 43 |
self.action_space = spaces.MultiDiscrete([3, 3, 3])
|
|
|
|
| 48 |
kw_flags = np.array([1.0 if kw in email.get("keywords", []) else 0.0 for kw in KEYWORD_VOCAB])
|
| 49 |
sent_idx = SENTIMENT_MAP.get(email.get("sentiment", "neutral"), 1)
|
| 50 |
sentiment_vec = np.zeros(len(SENTIMENT_MAP)); sentiment_vec[sent_idx] = 1.0
|
|
|
|
| 51 |
ctx_idx = CONTEXT_MAP.get(email.get("context", "spam"), 0)
|
| 52 |
context_vec = np.zeros(len(CONTEXT_MAP)); context_vec[ctx_idx] = 1.0
|
|
|
|
| 53 |
return np.concatenate([kw_flags, sentiment_vec, context_vec]).astype(np.float32)
|
| 54 |
|
| 55 |
def reset(self, seed=None, options=None):
|
| 56 |
super().reset(seed=seed)
|
| 57 |
self._step_idx = 0
|
| 58 |
+
if not self._queue: return np.zeros(OBS_DIM, dtype=np.float32), {}
|
| 59 |
+
obs = self._encode(self._queue[0])
|
| 60 |
+
return obs, {"description": self._queue[0].get("description", "")}
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def step(self, action):
|
| 63 |
+
email = self._queue[self._step_idx]
|
| 64 |
correct = email["correct_actions"]
|
| 65 |
|
|
|
|
| 66 |
reward = 0.0
|
| 67 |
+
# Critical Security Check
|
| 68 |
if correct[0] == 2 and action[0] != 2:
|
| 69 |
+
reward = -2.0
|
| 70 |
elif tuple(action) == correct:
|
| 71 |
reward = 1.0
|
| 72 |
elif action[0] == correct[0]:
|
| 73 |
reward = 0.2
|
| 74 |
|
| 75 |
self._step_idx += 1
|
| 76 |
+
terminated = self._step_idx >= len(self._queue)
|
| 77 |
+
obs = self._encode(email)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
info = {
|
| 80 |
"description": email.get("description", ""),
|
| 81 |
"correct_actions": correct,
|
| 82 |
"raw_reward": reward
|
| 83 |
}
|
|
|
|
| 84 |
return obs, float(reward), terminated, False, info
|