Omkar1806 commited on
Commit
746be7a
·
verified ·
1 Parent(s): 45730bd

Update env.py

Browse files
Files changed (1) hide show
  1. env.py +21 -33
env.py CHANGED
@@ -2,43 +2,42 @@ import numpy as np
2
  import gymnasium as gym
3
  from gymnasium import spaces
4
 
5
- # --- UI LABELS (Ye labels app.py mang raha hai) ---
6
  URGENCY_LABELS = ["General", "Billing", "Security Breach"]
7
  ROUTING_LABELS = ["AI Auto-Reply", "Tech Support", "Legal"]
8
  RESOLUTION_LABELS = ["Archive", "Draft Reply", "Escalate to Human"]
9
 
10
- # --- Vocabulary & Encoding Configuration ---
11
  KEYWORD_VOCAB = [
12
- "invoice", "payment", "overdue", "refund", # billing
13
- "hacked", "breach", "unauthorized", "password", # security
14
- "crash", "error", "bug", "slow", # tech
15
- "lawsuit", "legal", "attorney", "sue", # legal
16
- "spam", "offer", "win", "free", # spam
17
- "urgent", "critical", "angry", "threat", # sentiment signals
18
  ]
19
 
20
  SENTIMENT_MAP = {"positive": 0, "neutral": 1, "negative": 2}
21
  CONTEXT_MAP = {"spam": 0, "billing": 1, "tech": 2, "security": 3, "legal": 4}
22
  OBS_DIM = len(KEYWORD_VOCAB) + len(SENTIMENT_MAP) + len(CONTEXT_MAP)
23
 
24
- # --- Environment Class ---
25
  class EmailTriageEnv(gym.Env):
26
  def __init__(self, task="all", batch=None, shuffle=True):
27
  super().__init__()
28
 
29
- # Dataset ko import karna (app.py se load hoga)
30
  try:
31
  from app import EMAIL_DATASET
32
  dataset_to_use = EMAIL_DATASET
33
  except ImportError:
34
- dataset_to_use = [] # Fallback agar dataset na mile
35
 
 
36
  if batch is not None:
37
- self.email_batch = batch
38
  elif task != "all":
39
- self.email_batch = [e for e in dataset_to_use if e.get("difficulty") == task]
40
  else:
41
- self.email_batch = dataset_to_use
42
 
43
  self.shuffle = shuffle
44
  self.action_space = spaces.MultiDiscrete([3, 3, 3])
@@ -49,48 +48,37 @@ class EmailTriageEnv(gym.Env):
49
  kw_flags = np.array([1.0 if kw in email.get("keywords", []) else 0.0 for kw in KEYWORD_VOCAB])
50
  sent_idx = SENTIMENT_MAP.get(email.get("sentiment", "neutral"), 1)
51
  sentiment_vec = np.zeros(len(SENTIMENT_MAP)); sentiment_vec[sent_idx] = 1.0
52
-
53
  ctx_idx = CONTEXT_MAP.get(email.get("context", "spam"), 0)
54
  context_vec = np.zeros(len(CONTEXT_MAP)); context_vec[ctx_idx] = 1.0
55
-
56
  return np.concatenate([kw_flags, sentiment_vec, context_vec]).astype(np.float32)
57
 
58
  def reset(self, seed=None, options=None):
59
  super().reset(seed=seed)
60
  self._step_idx = 0
61
- if not self.email_batch:
62
- return np.zeros(OBS_DIM, dtype=np.float32), {}
63
-
64
- obs = self._encode(self.email_batch[0])
65
- return obs, {"description": self.email_batch[0].get("description", "")}
66
 
67
  def step(self, action):
68
- email = self.email_batch[self._step_idx]
69
  correct = email["correct_actions"]
70
 
71
- # Reward Logic (Score sudharne ke liye)
72
  reward = 0.0
 
73
  if correct[0] == 2 and action[0] != 2:
74
- reward = -2.0 # Security missed penalty
75
  elif tuple(action) == correct:
76
  reward = 1.0
77
  elif action[0] == correct[0]:
78
  reward = 0.2
79
 
80
  self._step_idx += 1
81
- terminated = self._step_idx >= len(self.email_batch)
82
-
83
- # Next observation
84
- if not terminated:
85
- next_email = self.email_batch[self._step_idx]
86
- obs = self._encode(next_email)
87
- else:
88
- obs = self._encode(email)
89
 
90
  info = {
91
  "description": email.get("description", ""),
92
  "correct_actions": correct,
93
  "raw_reward": reward
94
  }
95
-
96
  return obs, float(reward), terminated, False, info
 
2
  import gymnasium as gym
3
  from gymnasium import spaces
4
 
5
+ # --- UI LABELS ---
6
  URGENCY_LABELS = ["General", "Billing", "Security Breach"]
7
  ROUTING_LABELS = ["AI Auto-Reply", "Tech Support", "Legal"]
8
  RESOLUTION_LABELS = ["Archive", "Draft Reply", "Escalate to Human"]
9
 
10
+ # --- Vocabulary & Config ---
11
  KEYWORD_VOCAB = [
12
+ "invoice", "payment", "overdue", "refund",
13
+ "hacked", "breach", "unauthorized", "password",
14
+ "crash", "error", "bug", "slow",
15
+ "lawsuit", "legal", "attorney", "sue",
16
+ "spam", "offer", "win", "free",
17
+ "urgent", "critical", "angry", "threat",
18
  ]
19
 
20
  SENTIMENT_MAP = {"positive": 0, "neutral": 1, "negative": 2}
21
  CONTEXT_MAP = {"spam": 0, "billing": 1, "tech": 2, "security": 3, "legal": 4}
22
  OBS_DIM = len(KEYWORD_VOCAB) + len(SENTIMENT_MAP) + len(CONTEXT_MAP)
23
 
 
24
  class EmailTriageEnv(gym.Env):
25
  def __init__(self, task="all", batch=None, shuffle=True):
26
  super().__init__()
27
 
 
28
  try:
29
  from app import EMAIL_DATASET
30
  dataset_to_use = EMAIL_DATASET
31
  except ImportError:
32
+ dataset_to_use = []
33
 
34
+ # Fix: App.py needs '_queue' for the interface to work
35
  if batch is not None:
36
+ self._queue = batch
37
  elif task != "all":
38
+ self._queue = [e for e in dataset_to_use if e.get("difficulty") == task]
39
  else:
40
+ self._queue = dataset_to_use
41
 
42
  self.shuffle = shuffle
43
  self.action_space = spaces.MultiDiscrete([3, 3, 3])
 
48
  kw_flags = np.array([1.0 if kw in email.get("keywords", []) else 0.0 for kw in KEYWORD_VOCAB])
49
  sent_idx = SENTIMENT_MAP.get(email.get("sentiment", "neutral"), 1)
50
  sentiment_vec = np.zeros(len(SENTIMENT_MAP)); sentiment_vec[sent_idx] = 1.0
 
51
  ctx_idx = CONTEXT_MAP.get(email.get("context", "spam"), 0)
52
  context_vec = np.zeros(len(CONTEXT_MAP)); context_vec[ctx_idx] = 1.0
 
53
  return np.concatenate([kw_flags, sentiment_vec, context_vec]).astype(np.float32)
54
 
55
  def reset(self, seed=None, options=None):
56
  super().reset(seed=seed)
57
  self._step_idx = 0
58
+ if not self._queue: return np.zeros(OBS_DIM, dtype=np.float32), {}
59
+ obs = self._encode(self._queue[0])
60
+ return obs, {"description": self._queue[0].get("description", "")}
 
 
61
 
62
  def step(self, action):
63
+ email = self._queue[self._step_idx]
64
  correct = email["correct_actions"]
65
 
 
66
  reward = 0.0
67
+ # Critical Security Check
68
  if correct[0] == 2 and action[0] != 2:
69
+ reward = -2.0
70
  elif tuple(action) == correct:
71
  reward = 1.0
72
  elif action[0] == correct[0]:
73
  reward = 0.2
74
 
75
  self._step_idx += 1
76
+ terminated = self._step_idx >= len(self._queue)
77
+ obs = self._encode(email)
 
 
 
 
 
 
78
 
79
  info = {
80
  "description": email.get("description", ""),
81
  "correct_actions": correct,
82
  "raw_reward": reward
83
  }
 
84
  return obs, float(reward), terminated, False, info