eeshwar143 commited on
Commit
4c21555
·
1 Parent(s): dd3c25c

Clamp task scores to open interval

Browse files
Files changed (2) hide show
  1. inference.py +12 -4
  2. support_queue_env/grading.py +8 -1
inference.py CHANGED
@@ -49,6 +49,11 @@ ALLOW_DIRECT_OPENAI = os.getenv("ALLOW_DIRECT_OPENAI") == "1"
49
  BENCHMARK = "support_queue_env"
50
  SUCCESS_SCORE_THRESHOLD = 0.80
51
  MAX_TOKENS = 250
 
 
 
 
 
52
 
53
 
54
  def log_start(task: str, env: str, model: str) -> None:
@@ -303,7 +308,7 @@ async def run_task(client: Any, env: SupportQueueEnv, task: TaskCard) -> dict[st
303
  history: List[str] = []
304
  rewards: List[float] = []
305
  steps_taken = 0
306
- score = 0.0
307
  success = False
308
 
309
  log_start(task=task.task_id, env=BENCHMARK, model=MODEL_NAME)
@@ -344,7 +349,7 @@ async def run_task(client: Any, env: SupportQueueEnv, task: TaskCard) -> dict[st
344
  break
345
 
346
  score = sum(rewards) / len(rewards) if rewards else 0.0
347
- score = min(max(score, 0.0), 1.0)
348
  success = score >= SUCCESS_SCORE_THRESHOLD
349
 
350
  except Exception as exc:
@@ -377,11 +382,11 @@ async def main() -> None:
377
  print(f"[DEBUG] Environment bootstrap failed: {exc}", flush=True)
378
  for task in tasks:
379
  log_start(task=task.task_id, env=BENCHMARK, model=MODEL_NAME)
380
- log_end(success=False, steps=0, score=0.0, rewards=[])
381
  results.append(
382
  {
383
  "task_id": task.task_id,
384
- "score": 0.0,
385
  "steps": 0,
386
  "rewards": [],
387
  "success": False,
@@ -409,3 +414,6 @@ if __name__ == "__main__":
409
  asyncio.run(main())
410
  except Exception as exc:
411
  print(f"[DEBUG] Fatal inference error: {exc}", flush=True)
 
 
 
 
49
  BENCHMARK = "support_queue_env"
50
  SUCCESS_SCORE_THRESHOLD = 0.80
51
  MAX_TOKENS = 250
52
+ SCORE_EPSILON = 0.0001
53
+
54
+
55
+ def clamp_task_score(score: float) -> float:
56
+ return min(max(score, SCORE_EPSILON), 1.0 - SCORE_EPSILON)
57
 
58
 
59
  def log_start(task: str, env: str, model: str) -> None:
 
308
  history: List[str] = []
309
  rewards: List[float] = []
310
  steps_taken = 0
311
+ score = clamp_task_score(0.0)
312
  success = False
313
 
314
  log_start(task=task.task_id, env=BENCHMARK, model=MODEL_NAME)
 
349
  break
350
 
351
  score = sum(rewards) / len(rewards) if rewards else 0.0
352
+ score = clamp_task_score(score)
353
  success = score >= SUCCESS_SCORE_THRESHOLD
354
 
355
  except Exception as exc:
 
382
  print(f"[DEBUG] Environment bootstrap failed: {exc}", flush=True)
383
  for task in tasks:
384
  log_start(task=task.task_id, env=BENCHMARK, model=MODEL_NAME)
385
+ log_end(success=False, steps=0, score=clamp_task_score(0.0), rewards=[])
386
  results.append(
387
  {
388
  "task_id": task.task_id,
389
+ "score": clamp_task_score(0.0),
390
  "steps": 0,
391
  "rewards": [],
392
  "success": False,
 
414
  asyncio.run(main())
415
  except Exception as exc:
416
  print(f"[DEBUG] Fatal inference error: {exc}", flush=True)
417
+
418
+
419
+
support_queue_env/grading.py CHANGED
@@ -8,6 +8,11 @@ from support_queue_env.models import GradingBreakdown, SupportQueueAction, Ticke
8
  from support_queue_env.tasks import TicketSpec
9
 
10
  PRIORITY_ORDER = ["P1", "P2", "P3", "P4"]
 
 
 
 
 
11
 
12
 
13
  def _normalize(text: str) -> str:
@@ -72,7 +77,7 @@ def grade_ticket(ticket: TicketSpec, action: SupportQueueAction) -> TicketFeedba
72
  + breakdown.response_score
73
  + breakdown.penalty
74
  )
75
- breakdown.total = round(max(0.0, min(1.0, total)), 4)
76
 
77
  matched_summary = summary_hits if ticket.summary_keywords else 0
78
  matched_response = response_hits if ticket.response_keywords else 0
@@ -92,3 +97,5 @@ def grade_ticket(ticket: TicketSpec, action: SupportQueueAction) -> TicketFeedba
92
  breakdown=breakdown,
93
  feedback=feedback,
94
  )
 
 
 
8
  from support_queue_env.tasks import TicketSpec
9
 
10
  PRIORITY_ORDER = ["P1", "P2", "P3", "P4"]
11
+ SCORE_EPSILON = 0.001
12
+
13
+
14
+ def _open_unit_interval(score: float) -> float:
15
+ return round(min(max(score, SCORE_EPSILON), 1.0 - SCORE_EPSILON), 4)
16
 
17
 
18
  def _normalize(text: str) -> str:
 
77
  + breakdown.response_score
78
  + breakdown.penalty
79
  )
80
+ breakdown.total = _open_unit_interval(total)
81
 
82
  matched_summary = summary_hits if ticket.summary_keywords else 0
83
  matched_response = response_hits if ticket.response_keywords else 0
 
97
  breakdown=breakdown,
98
  feedback=feedback,
99
  )
100
+
101
+