Mmanikandan commited on
Commit
c74d5fa
·
1 Parent(s): 4e9153c

phase 2 fix

Browse files
Files changed (4) hide show
  1. inference.py +6 -4
  2. openenv.yaml +9 -18
  3. server/environment.py +49 -9
  4. server/grader.py +53 -0
inference.py CHANGED
@@ -648,7 +648,8 @@ def run_inference(config: Optional[Dict[str, str]] = None) -> None:
648
  reset_data = reset_response.json()
649
 
650
  observation = reset_data.get("observation", {})
651
- task_name = observation.get("email_id", "email_workflow")
 
652
  email_subject = observation.get("subject", "")
653
  email_body = observation.get("body", "")
654
  customer_history = observation.get("customer_history", "")
@@ -761,9 +762,10 @@ def run_inference(config: Optional[Dict[str, str]] = None) -> None:
761
  # Log step
762
  log_step(step_num, action_str, reward, done, None)
763
 
764
- # Prepare final metrics
765
- # CRITICAL FIX: Use the environment's official cumulative reward instead of manual summation
766
- normalized_score = step_data.get("info", {}).get("total_reward", sum(rewards))
 
767
 
768
  # Clamp just in case, though the environment already does this
769
  normalized_score = min(max(normalized_score, 0.0), 1.0)
 
648
  reset_data = reset_response.json()
649
 
650
  observation = reset_data.get("observation", {})
651
+ info = reset_data.get("info", {})
652
+ task_name = info.get("task_id", observation.get("email_id", "email_workflow"))
653
  email_subject = observation.get("subject", "")
654
  email_body = observation.get("body", "")
655
  customer_history = observation.get("customer_history", "")
 
762
  # Log step
763
  log_step(step_num, action_str, reward, done, None)
764
 
765
+ # PHASE 2 REQUIREMENT: Use the programmatic grader's score if available
766
+ # Fallback to total_reward or manual sum for robust reporting
767
+ final_info = step_data.get("info", {})
768
+ normalized_score = final_info.get("score", final_info.get("total_reward", sum(rewards)))
769
 
770
  # Clamp just in case, though the environment already does this
771
  normalized_score = min(max(normalized_score, 0.0), 1.0)
openenv.yaml CHANGED
@@ -157,35 +157,26 @@ reward:
157
  description: "Escalation bonus or penalty for appropriate decision"
158
 
159
  tasks:
160
- - id: email_001
161
- name: Easy Email
162
  difficulty: easy
163
- description: >
164
- Clear billing issue. Straightforward double-charge complaint
165
- from good customer. Requires correct classification and
166
- appropriate urgency response.
167
  ground_truth:
168
  category: billing
169
  priority: high
170
 
171
- - id: email_002
172
- name: Medium Email
173
  difficulty: medium
174
- description: >
175
- Technical issue with app. Requires interpretation of problem
176
- and prioritization judgment. Customer history is important context.
177
  ground_truth:
178
  category: tech
179
  priority: medium
180
 
181
- - id: email_003
182
- name: Hard Email
183
  difficulty: hard
184
- description: >
185
- Emotional complaint from enterprise customer. Requires nuanced
186
- understanding of tone, prior history, and business impact.
187
- Response must show empathy and urgency. Failure to prioritize
188
- properly could lead to business loss.
189
  ground_truth:
190
  category: complaint
191
  priority: high
 
157
  description: "Escalation bonus or penalty for appropriate decision"
158
 
159
  tasks:
160
+ - id: easy_refund
161
+ name: Easy Refund Task
162
  difficulty: easy
163
+ description: Handle a straightforward billing refund request for a duplicate charge.
 
 
 
164
  ground_truth:
165
  category: billing
166
  priority: high
167
 
168
+ - id: medium_tech
169
+ name: Medium Tech Task
170
  difficulty: medium
171
+ description: Resolve a technical issue regarding app crashes and provide instructions.
 
 
172
  ground_truth:
173
  category: tech
174
  priority: medium
175
 
176
+ - id: hard_escalation
177
+ name: Hard Escalation Task
178
  difficulty: hard
179
+ description: Handle a high-value enterprise complaint that requires escalation or financial compensation.
 
 
 
 
180
  ground_truth:
181
  category: complaint
182
  priority: high
server/environment.py CHANGED
@@ -18,9 +18,16 @@ from models import (
18
  from .grader import (
19
  calculate_step_reward, grade_workflow_completion,
20
  analyze_customer_sentiment, extract_urgency_indicators,
21
- check_escalation_requirement
22
  )
23
 
 
 
 
 
 
 
 
24
  def search_knowledge_base(query: str):
25
  if "refund" in query.lower():
26
  return {
@@ -320,7 +327,19 @@ class CustomerSupportEnv:
320
  if not self.task_queue:
321
  self.task_queue = self._load_tasks()
322
 
323
- self.current_task = self._prepare_task_data(self.task_queue.pop(0))
 
 
 
 
 
 
 
 
 
 
 
 
324
  self.episode_count += 1
325
 
326
  # Initialize workflow state
@@ -362,6 +381,7 @@ class CustomerSupportEnv:
362
  "episode_id": self.current_state.episode_id,
363
  "difficulty": self.current_task.get("difficulty", "unknown"),
364
  "email_id": self.current_task["id"],
 
365
  "workflow_step": 0,
366
  "max_steps": 5
367
  }
@@ -502,19 +522,39 @@ class CustomerSupportEnv:
502
  reward_breakdown["escalation_bonus"] = escalation_bonus
503
  reward_breakdown.update(completion_breakdown)
504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  return {
506
  "observation": observation,
507
  "reward": step_reward,
508
  "done": done,
509
- "info": {
510
- "workflow_state": self.workflow_state.copy(),
511
- "total_reward": self.current_state.total_reward,
512
- "reward_breakdown": reward_breakdown,
513
- "step_count": self.current_state.step_count,
514
- "episode_complete": done
515
- }
516
  }
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  def _is_episode_complete(self) -> bool:
519
  """
520
  Check if the current episode is complete.
 
18
  from .grader import (
19
  calculate_step_reward, grade_workflow_completion,
20
  analyze_customer_sentiment, extract_urgency_indicators,
21
+ check_escalation_requirement, refund_grader, tech_grader, escalation_grader
22
  )
23
 
24
+ # Mandatory Task definitions for OpenEnv validation
25
+ TASKS = [
26
+ {"task_id": "easy_refund", "email_id": "email_001", "difficulty": "easy"},
27
+ {"task_id": "medium_tech", "email_id": "email_002", "difficulty": "medium"},
28
+ {"task_id": "hard_escalation", "email_id": "email_003", "difficulty": "hard"},
29
+ ]
30
+
31
  def search_knowledge_base(query: str):
32
  if "refund" in query.lower():
33
  return {
 
327
  if not self.task_queue:
328
  self.task_queue = self._load_tasks()
329
 
330
+ # Phase 2 Fix: Wrap selection in structured TASKS context
331
+ # We cycle through the 3 mandatory tasks for consistent evaluation
332
+ task_idx = self.episode_count % len(TASKS)
333
+ selected_task_metadata = TASKS[task_idx]
334
+ task_id = selected_task_metadata["task_id"]
335
+
336
+ # Find corresponding email data in current queue or database
337
+ # For simplicity in this fix, we search task_queue for matching email_id
338
+ email_data = next((t for t in self.task_queue if t["id"] == selected_task_metadata["email_id"]), self.task_queue[0])
339
+
340
+ self.current_task = self._prepare_task_data(email_data)
341
+ self.current_task["task_id"] = task_id # Attach mandatory task_id
342
+
343
  self.episode_count += 1
344
 
345
  # Initialize workflow state
 
381
  "episode_id": self.current_state.episode_id,
382
  "difficulty": self.current_task.get("difficulty", "unknown"),
383
  "email_id": self.current_task["id"],
384
+ "task_id": self.current_task.get("task_id"),
385
  "workflow_step": 0,
386
  "max_steps": 5
387
  }
 
522
  reward_breakdown["escalation_bonus"] = escalation_bonus
523
  reward_breakdown.update(completion_breakdown)
524
 
525
+ info = {
526
+ "workflow_state": self.workflow_state.copy(),
527
+ "total_reward": self.current_state.total_reward,
528
+ "reward_breakdown": reward_breakdown,
529
+ "step_count": self.current_state.step_count,
530
+ "episode_complete": done,
531
+ "task_id": self.current_task.get("task_id")
532
+ }
533
+
534
+ # PHASE 2 REQUIREMENT: Return score only at end
535
+ if done:
536
+ info["score"] = self.compute_score()
537
+
538
  return {
539
  "observation": observation,
540
  "reward": step_reward,
541
  "done": done,
542
+ "info": info
 
 
 
 
 
 
543
  }
544
 
545
+ def compute_score(self) -> float:
546
+ """ Programmatic score computation for OpenEnv validation. """
547
+ task_id = self.current_task.get("task_id")
548
+
549
+ if task_id == "easy_refund":
550
+ return refund_grader(self.workflow_state)
551
+ elif task_id == "medium_tech":
552
+ return tech_grader(self.workflow_state)
553
+ elif task_id == "hard_escalation":
554
+ return escalation_grader(self.workflow_state)
555
+
556
+ return 0.0
557
+
558
  def _is_episode_complete(self) -> bool:
559
  """
560
  Check if the current episode is complete.
server/grader.py CHANGED
@@ -702,3 +702,56 @@ def check_escalation_requirement(email_task: Dict[str, Any], state: Dict[str, An
702
  bonus = 0.1 # Bonus for correct escalation
703
 
704
  return penalty, bonus
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  bonus = 0.1 # Bonus for correct escalation
703
 
704
  return penalty, bonus
705
+
706
+ def refund_grader(state: Dict[str, Any]) -> float:
707
+ """ Programmatic grader for easy_refund task. """
708
+ score = 0.0
709
+
710
+ if state.get("classification") == "billing":
711
+ score += 0.3
712
+ if state.get("priority") == "high":
713
+ score += 0.2
714
+ if state.get("strategy") == "offer_refund":
715
+ score += 0.3
716
+
717
+ response = state.get("response")
718
+ if response and "refund" in response.lower():
719
+ score += 0.2
720
+
721
+ return min(score, 1.0)
722
+
723
+
724
+ def tech_grader(state: Dict[str, Any]) -> float:
725
+ """ Programmatic grader for medium_tech task. """
726
+ score = 0.0
727
+
728
+ if state.get("classification") == "tech":
729
+ score += 0.3
730
+ if state.get("priority") in ["medium", "high"]:
731
+ score += 0.2
732
+ if state.get("strategy") in ["auto_resolve", "request_more_info"]:
733
+ score += 0.3
734
+
735
+ response = state.get("response")
736
+ if response and len(response) > 20:
737
+ score += 0.2
738
+
739
+ return min(score, 1.0)
740
+
741
+
742
+ def escalation_grader(state: Dict[str, Any]) -> float:
743
+ """ Programmatic grader for hard_escalation task. """
744
+ score = 0.0
745
+
746
+ if state.get("classification") == "complaint":
747
+ score += 0.2
748
+ if state.get("priority") == "high":
749
+ score += 0.2
750
+ if state.get("strategy") in ["escalate_to_human", "offer_refund"]:
751
+ score += 0.3
752
+
753
+ # Check if escalation payload exists
754
+ if state.get("escalation"):
755
+ score += 0.3
756
+
757
+ return min(score, 1.0)