Rithwik Ravi commited on
Commit
1314b5a
·
1 Parent(s): a2db70a

Fixed submission errors

Browse files
inference.py CHANGED
@@ -137,7 +137,8 @@ def run_task(task_id: int):
137
  if done:
138
  break
139
 
140
- success_str = "true" if final_score >= 1.0 else "false"
 
141
  rewards_str = ",".join(rewards)
142
  print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}")
143
 
 
137
  if done:
138
  break
139
 
140
+ final_score = max(0.01, min(0.99, final_score))
141
+ success_str = "true" if final_score >= 0.99 else "false"
142
  rewards_str = ",".join(rewards)
143
  print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}")
144
 
server/__pycache__/app.cpython-314.pyc CHANGED
Binary files a/server/__pycache__/app.cpython-314.pyc and b/server/__pycache__/app.cpython-314.pyc differ
 
server/__pycache__/tasks.cpython-314.pyc CHANGED
Binary files a/server/__pycache__/tasks.cpython-314.pyc and b/server/__pycache__/tasks.cpython-314.pyc differ
 
server/app.py CHANGED
@@ -215,8 +215,8 @@ class SQLEnvironment:
215
 
216
  self.current_score = new_score
217
 
218
- # Episode terminates when score is 1.0
219
- done = (self.current_score >= 1.0) or (self.step_count > 30)
220
 
221
  obs = Observation(
222
  goal=task.get_goal(),
 
215
 
216
  self.current_score = new_score
217
 
218
+ # Episode terminates when score is 0.99
219
+ done = (self.current_score >= 0.99) or (self.step_count > 30)
220
 
221
  obs = Observation(
222
  goal=task.get_goal(),
server/tasks.py CHANGED
@@ -39,7 +39,7 @@ class EasyTask(Task):
39
  # Check if view exists
40
  c.execute("SELECT name FROM sqlite_master WHERE type='view' AND name='high_value_customers'")
41
  if not c.fetchone():
42
- return 0.0
43
 
44
  # Check rows
45
  c.execute("SELECT name, total_spent FROM high_value_customers ORDER BY name")
@@ -50,10 +50,10 @@ class EasyTask(Task):
50
 
51
  expected = [("Bob", 1200.0), ("Diana", 3000.0), ("Eve", 1000.01)]
52
  if rows == expected:
53
- return 1.0
54
  return 0.5
55
  except Exception:
56
- return 0.0
57
 
58
 
59
  class MediumTask(Task):
@@ -107,9 +107,9 @@ class MediumTask(Task):
107
  correct_cats = sum(1 for c, e in zip(categories, expected_cats) if c == e)
108
  score += (correct_cats / 5.0) * 0.3 # up to 0.3 for correct categories
109
 
110
- return min(1.0, score)
111
  except Exception:
112
- return score
113
 
114
 
115
  class HardTask(Task):
@@ -160,7 +160,7 @@ class HardTask(Task):
160
  if 'appointments' in tables: score += 0.2
161
 
162
  if score < 0.4:
163
- return score
164
 
165
  # Check data counts (3 unique patients, 2 unique doctors, 4 appointments)
166
  c.execute("SELECT COUNT(*) FROM patients")
@@ -185,9 +185,9 @@ class HardTask(Task):
185
  if len(reconstructed) == 4:
186
  score += 0.3
187
 
188
- return min(1.0, score)
189
  except Exception:
190
- return score
191
 
192
  TASKS = {
193
  1: EasyTask(),
 
39
  # Check if view exists
40
  c.execute("SELECT name FROM sqlite_master WHERE type='view' AND name='high_value_customers'")
41
  if not c.fetchone():
42
+ return 0.01
43
 
44
  # Check rows
45
  c.execute("SELECT name, total_spent FROM high_value_customers ORDER BY name")
 
50
 
51
  expected = [("Bob", 1200.0), ("Diana", 3000.0), ("Eve", 1000.01)]
52
  if rows == expected:
53
+ return 0.99
54
  return 0.5
55
  except Exception:
56
+ return 0.01
57
 
58
 
59
  class MediumTask(Task):
 
107
  correct_cats = sum(1 for c, e in zip(categories, expected_cats) if c == e)
108
  score += (correct_cats / 5.0) * 0.3 # up to 0.3 for correct categories
109
 
110
+ return min(max(score, 0.01), 0.99)
111
  except Exception:
112
+ return min(max(score, 0.01), 0.99)
113
 
114
 
115
  class HardTask(Task):
 
160
  if 'appointments' in tables: score += 0.2
161
 
162
  if score < 0.4:
163
+ return max(0.01, score)
164
 
165
  # Check data counts (3 unique patients, 2 unique doctors, 4 appointments)
166
  c.execute("SELECT COUNT(*) FROM patients")
 
185
  if len(reconstructed) == 4:
186
  score += 0.3
187
 
188
+ return min(max(score, 0.01), 0.99)
189
  except Exception:
190
+ return min(max(score, 0.01), 0.99)
191
 
192
  TASKS = {
193
  1: EasyTask(),