Spaces:
Running
Running
Commit ·
c74d5fa
1
Parent(s): 4e9153c
phase 2 fix
Browse files- inference.py +6 -4
- openenv.yaml +9 -18
- server/environment.py +49 -9
- server/grader.py +53 -0
inference.py
CHANGED
|
@@ -648,7 +648,8 @@ def run_inference(config: Optional[Dict[str, str]] = None) -> None:
|
|
| 648 |
reset_data = reset_response.json()
|
| 649 |
|
| 650 |
observation = reset_data.get("observation", {})
|
| 651 |
-
|
|
|
|
| 652 |
email_subject = observation.get("subject", "")
|
| 653 |
email_body = observation.get("body", "")
|
| 654 |
customer_history = observation.get("customer_history", "")
|
|
@@ -761,9 +762,10 @@ def run_inference(config: Optional[Dict[str, str]] = None) -> None:
|
|
| 761 |
# Log step
|
| 762 |
log_step(step_num, action_str, reward, done, None)
|
| 763 |
|
| 764 |
-
#
|
| 765 |
-
#
|
| 766 |
-
|
|
|
|
| 767 |
|
| 768 |
# Clamp just in case, though the environment already does this
|
| 769 |
normalized_score = min(max(normalized_score, 0.0), 1.0)
|
|
|
|
| 648 |
reset_data = reset_response.json()
|
| 649 |
|
| 650 |
observation = reset_data.get("observation", {})
|
| 651 |
+
info = reset_data.get("info", {})
|
| 652 |
+
task_name = info.get("task_id", observation.get("email_id", "email_workflow"))
|
| 653 |
email_subject = observation.get("subject", "")
|
| 654 |
email_body = observation.get("body", "")
|
| 655 |
customer_history = observation.get("customer_history", "")
|
|
|
|
| 762 |
# Log step
|
| 763 |
log_step(step_num, action_str, reward, done, None)
|
| 764 |
|
| 765 |
+
# PHASE 2 REQUIREMENT: Use the programmatic grader's score if available
|
| 766 |
+
# Fallback to total_reward or manual sum for robust reporting
|
| 767 |
+
final_info = step_data.get("info", {})
|
| 768 |
+
normalized_score = final_info.get("score", final_info.get("total_reward", sum(rewards)))
|
| 769 |
|
| 770 |
# Clamp just in case, though the environment already does this
|
| 771 |
normalized_score = min(max(normalized_score, 0.0), 1.0)
|
openenv.yaml
CHANGED
|
@@ -157,35 +157,26 @@ reward:
|
|
| 157 |
description: "Escalation bonus or penalty for appropriate decision"
|
| 158 |
|
| 159 |
tasks:
|
| 160 |
-
- id:
|
| 161 |
-
name: Easy
|
| 162 |
difficulty: easy
|
| 163 |
-
description:
|
| 164 |
-
Clear billing issue. Straightforward double-charge complaint
|
| 165 |
-
from good customer. Requires correct classification and
|
| 166 |
-
appropriate urgency response.
|
| 167 |
ground_truth:
|
| 168 |
category: billing
|
| 169 |
priority: high
|
| 170 |
|
| 171 |
-
- id:
|
| 172 |
-
name: Medium
|
| 173 |
difficulty: medium
|
| 174 |
-
description:
|
| 175 |
-
Technical issue with app. Requires interpretation of problem
|
| 176 |
-
and prioritization judgment. Customer history is important context.
|
| 177 |
ground_truth:
|
| 178 |
category: tech
|
| 179 |
priority: medium
|
| 180 |
|
| 181 |
-
- id:
|
| 182 |
-
name: Hard
|
| 183 |
difficulty: hard
|
| 184 |
-
description:
|
| 185 |
-
Emotional complaint from enterprise customer. Requires nuanced
|
| 186 |
-
understanding of tone, prior history, and business impact.
|
| 187 |
-
Response must show empathy and urgency. Failure to prioritize
|
| 188 |
-
properly could lead to business loss.
|
| 189 |
ground_truth:
|
| 190 |
category: complaint
|
| 191 |
priority: high
|
|
|
|
| 157 |
description: "Escalation bonus or penalty for appropriate decision"
|
| 158 |
|
| 159 |
tasks:
|
| 160 |
+
- id: easy_refund
|
| 161 |
+
name: Easy Refund Task
|
| 162 |
difficulty: easy
|
| 163 |
+
description: Handle a straightforward billing refund request for a duplicate charge.
|
|
|
|
|
|
|
|
|
|
| 164 |
ground_truth:
|
| 165 |
category: billing
|
| 166 |
priority: high
|
| 167 |
|
| 168 |
+
- id: medium_tech
|
| 169 |
+
name: Medium Tech Task
|
| 170 |
difficulty: medium
|
| 171 |
+
description: Resolve a technical issue regarding app crashes and provide instructions.
|
|
|
|
|
|
|
| 172 |
ground_truth:
|
| 173 |
category: tech
|
| 174 |
priority: medium
|
| 175 |
|
| 176 |
+
- id: hard_escalation
|
| 177 |
+
name: Hard Escalation Task
|
| 178 |
difficulty: hard
|
| 179 |
+
description: Handle a high-value enterprise complaint that requires escalation or financial compensation.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
ground_truth:
|
| 181 |
category: complaint
|
| 182 |
priority: high
|
server/environment.py
CHANGED
|
@@ -18,9 +18,16 @@ from models import (
|
|
| 18 |
from .grader import (
|
| 19 |
calculate_step_reward, grade_workflow_completion,
|
| 20 |
analyze_customer_sentiment, extract_urgency_indicators,
|
| 21 |
-
check_escalation_requirement
|
| 22 |
)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def search_knowledge_base(query: str):
|
| 25 |
if "refund" in query.lower():
|
| 26 |
return {
|
|
@@ -320,7 +327,19 @@ class CustomerSupportEnv:
|
|
| 320 |
if not self.task_queue:
|
| 321 |
self.task_queue = self._load_tasks()
|
| 322 |
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
self.episode_count += 1
|
| 325 |
|
| 326 |
# Initialize workflow state
|
|
@@ -362,6 +381,7 @@ class CustomerSupportEnv:
|
|
| 362 |
"episode_id": self.current_state.episode_id,
|
| 363 |
"difficulty": self.current_task.get("difficulty", "unknown"),
|
| 364 |
"email_id": self.current_task["id"],
|
|
|
|
| 365 |
"workflow_step": 0,
|
| 366 |
"max_steps": 5
|
| 367 |
}
|
|
@@ -502,19 +522,39 @@ class CustomerSupportEnv:
|
|
| 502 |
reward_breakdown["escalation_bonus"] = escalation_bonus
|
| 503 |
reward_breakdown.update(completion_breakdown)
|
| 504 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
return {
|
| 506 |
"observation": observation,
|
| 507 |
"reward": step_reward,
|
| 508 |
"done": done,
|
| 509 |
-
"info":
|
| 510 |
-
"workflow_state": self.workflow_state.copy(),
|
| 511 |
-
"total_reward": self.current_state.total_reward,
|
| 512 |
-
"reward_breakdown": reward_breakdown,
|
| 513 |
-
"step_count": self.current_state.step_count,
|
| 514 |
-
"episode_complete": done
|
| 515 |
-
}
|
| 516 |
}
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
def _is_episode_complete(self) -> bool:
|
| 519 |
"""
|
| 520 |
Check if the current episode is complete.
|
|
|
|
| 18 |
from .grader import (
|
| 19 |
calculate_step_reward, grade_workflow_completion,
|
| 20 |
analyze_customer_sentiment, extract_urgency_indicators,
|
| 21 |
+
check_escalation_requirement, refund_grader, tech_grader, escalation_grader
|
| 22 |
)
|
| 23 |
|
| 24 |
+
# Mandatory Task definitions for OpenEnv validation
|
| 25 |
+
TASKS = [
|
| 26 |
+
{"task_id": "easy_refund", "email_id": "email_001", "difficulty": "easy"},
|
| 27 |
+
{"task_id": "medium_tech", "email_id": "email_002", "difficulty": "medium"},
|
| 28 |
+
{"task_id": "hard_escalation", "email_id": "email_003", "difficulty": "hard"},
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
def search_knowledge_base(query: str):
|
| 32 |
if "refund" in query.lower():
|
| 33 |
return {
|
|
|
|
| 327 |
if not self.task_queue:
|
| 328 |
self.task_queue = self._load_tasks()
|
| 329 |
|
| 330 |
+
# Phase 2 Fix: Wrap selection in structured TASKS context
|
| 331 |
+
# We cycle through the 3 mandatory tasks for consistent evaluation
|
| 332 |
+
task_idx = self.episode_count % len(TASKS)
|
| 333 |
+
selected_task_metadata = TASKS[task_idx]
|
| 334 |
+
task_id = selected_task_metadata["task_id"]
|
| 335 |
+
|
| 336 |
+
# Find corresponding email data in current queue or database
|
| 337 |
+
# For simplicity in this fix, we search task_queue for matching email_id
|
| 338 |
+
email_data = next((t for t in self.task_queue if t["id"] == selected_task_metadata["email_id"]), self.task_queue[0])
|
| 339 |
+
|
| 340 |
+
self.current_task = self._prepare_task_data(email_data)
|
| 341 |
+
self.current_task["task_id"] = task_id # Attach mandatory task_id
|
| 342 |
+
|
| 343 |
self.episode_count += 1
|
| 344 |
|
| 345 |
# Initialize workflow state
|
|
|
|
| 381 |
"episode_id": self.current_state.episode_id,
|
| 382 |
"difficulty": self.current_task.get("difficulty", "unknown"),
|
| 383 |
"email_id": self.current_task["id"],
|
| 384 |
+
"task_id": self.current_task.get("task_id"),
|
| 385 |
"workflow_step": 0,
|
| 386 |
"max_steps": 5
|
| 387 |
}
|
|
|
|
| 522 |
reward_breakdown["escalation_bonus"] = escalation_bonus
|
| 523 |
reward_breakdown.update(completion_breakdown)
|
| 524 |
|
| 525 |
+
info = {
|
| 526 |
+
"workflow_state": self.workflow_state.copy(),
|
| 527 |
+
"total_reward": self.current_state.total_reward,
|
| 528 |
+
"reward_breakdown": reward_breakdown,
|
| 529 |
+
"step_count": self.current_state.step_count,
|
| 530 |
+
"episode_complete": done,
|
| 531 |
+
"task_id": self.current_task.get("task_id")
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
# PHASE 2 REQUIREMENT: Return score only at end
|
| 535 |
+
if done:
|
| 536 |
+
info["score"] = self.compute_score()
|
| 537 |
+
|
| 538 |
return {
|
| 539 |
"observation": observation,
|
| 540 |
"reward": step_reward,
|
| 541 |
"done": done,
|
| 542 |
+
"info": info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
}
|
| 544 |
|
| 545 |
+
def compute_score(self) -> float:
|
| 546 |
+
""" Programmatic score computation for OpenEnv validation. """
|
| 547 |
+
task_id = self.current_task.get("task_id")
|
| 548 |
+
|
| 549 |
+
if task_id == "easy_refund":
|
| 550 |
+
return refund_grader(self.workflow_state)
|
| 551 |
+
elif task_id == "medium_tech":
|
| 552 |
+
return tech_grader(self.workflow_state)
|
| 553 |
+
elif task_id == "hard_escalation":
|
| 554 |
+
return escalation_grader(self.workflow_state)
|
| 555 |
+
|
| 556 |
+
return 0.0
|
| 557 |
+
|
| 558 |
def _is_episode_complete(self) -> bool:
|
| 559 |
"""
|
| 560 |
Check if the current episode is complete.
|
server/grader.py
CHANGED
|
@@ -702,3 +702,56 @@ def check_escalation_requirement(email_task: Dict[str, Any], state: Dict[str, An
|
|
| 702 |
bonus = 0.1 # Bonus for correct escalation
|
| 703 |
|
| 704 |
return penalty, bonus
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
bonus = 0.1 # Bonus for correct escalation
|
| 703 |
|
| 704 |
return penalty, bonus
|
| 705 |
+
|
| 706 |
+
def refund_grader(state: Dict[str, Any]) -> float:
|
| 707 |
+
""" Programmatic grader for easy_refund task. """
|
| 708 |
+
score = 0.0
|
| 709 |
+
|
| 710 |
+
if state.get("classification") == "billing":
|
| 711 |
+
score += 0.3
|
| 712 |
+
if state.get("priority") == "high":
|
| 713 |
+
score += 0.2
|
| 714 |
+
if state.get("strategy") == "offer_refund":
|
| 715 |
+
score += 0.3
|
| 716 |
+
|
| 717 |
+
response = state.get("response")
|
| 718 |
+
if response and "refund" in response.lower():
|
| 719 |
+
score += 0.2
|
| 720 |
+
|
| 721 |
+
return min(score, 1.0)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def tech_grader(state: Dict[str, Any]) -> float:
|
| 725 |
+
""" Programmatic grader for medium_tech task. """
|
| 726 |
+
score = 0.0
|
| 727 |
+
|
| 728 |
+
if state.get("classification") == "tech":
|
| 729 |
+
score += 0.3
|
| 730 |
+
if state.get("priority") in ["medium", "high"]:
|
| 731 |
+
score += 0.2
|
| 732 |
+
if state.get("strategy") in ["auto_resolve", "request_more_info"]:
|
| 733 |
+
score += 0.3
|
| 734 |
+
|
| 735 |
+
response = state.get("response")
|
| 736 |
+
if response and len(response) > 20:
|
| 737 |
+
score += 0.2
|
| 738 |
+
|
| 739 |
+
return min(score, 1.0)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def escalation_grader(state: Dict[str, Any]) -> float:
|
| 743 |
+
""" Programmatic grader for hard_escalation task. """
|
| 744 |
+
score = 0.0
|
| 745 |
+
|
| 746 |
+
if state.get("classification") == "complaint":
|
| 747 |
+
score += 0.2
|
| 748 |
+
if state.get("priority") == "high":
|
| 749 |
+
score += 0.2
|
| 750 |
+
if state.get("strategy") in ["escalate_to_human", "offer_refund"]:
|
| 751 |
+
score += 0.3
|
| 752 |
+
|
| 753 |
+
# Check if escalation payload exists
|
| 754 |
+
if state.get("escalation"):
|
| 755 |
+
score += 0.3
|
| 756 |
+
|
| 757 |
+
return min(score, 1.0)
|