Rithwik Ravi commited on
Commit ·
1314b5a
1
Parent(s): a2db70a
Fixed submission errors
Browse files- inference.py +2 -1
- server/__pycache__/app.cpython-314.pyc +0 -0
- server/__pycache__/tasks.cpython-314.pyc +0 -0
- server/app.py +2 -2
- server/tasks.py +8 -8
inference.py
CHANGED
|
@@ -137,7 +137,8 @@ def run_task(task_id: int):
|
|
| 137 |
if done:
|
| 138 |
break
|
| 139 |
|
| 140 |
-
|
|
|
|
| 141 |
rewards_str = ",".join(rewards)
|
| 142 |
print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}")
|
| 143 |
|
|
|
|
| 137 |
if done:
|
| 138 |
break
|
| 139 |
|
| 140 |
+
final_score = max(0.01, min(0.99, final_score))
|
| 141 |
+
success_str = "true" if final_score >= 0.99 else "false"
|
| 142 |
rewards_str = ",".join(rewards)
|
| 143 |
print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}")
|
| 144 |
|
server/__pycache__/app.cpython-314.pyc
CHANGED
|
Binary files a/server/__pycache__/app.cpython-314.pyc and b/server/__pycache__/app.cpython-314.pyc differ
|
|
|
server/__pycache__/tasks.cpython-314.pyc
CHANGED
|
Binary files a/server/__pycache__/tasks.cpython-314.pyc and b/server/__pycache__/tasks.cpython-314.pyc differ
|
|
|
server/app.py
CHANGED
|
@@ -215,8 +215,8 @@ class SQLEnvironment:
|
|
| 215 |
|
| 216 |
self.current_score = new_score
|
| 217 |
|
| 218 |
-
# Episode terminates when score is
|
| 219 |
-
done = (self.current_score >=
|
| 220 |
|
| 221 |
obs = Observation(
|
| 222 |
goal=task.get_goal(),
|
|
|
|
| 215 |
|
| 216 |
self.current_score = new_score
|
| 217 |
|
| 218 |
+
# Episode terminates when score is 0.99
|
| 219 |
+
done = (self.current_score >= 0.99) or (self.step_count > 30)
|
| 220 |
|
| 221 |
obs = Observation(
|
| 222 |
goal=task.get_goal(),
|
server/tasks.py
CHANGED
|
@@ -39,7 +39,7 @@ class EasyTask(Task):
|
|
| 39 |
# Check if view exists
|
| 40 |
c.execute("SELECT name FROM sqlite_master WHERE type='view' AND name='high_value_customers'")
|
| 41 |
if not c.fetchone():
|
| 42 |
-
return 0.
|
| 43 |
|
| 44 |
# Check rows
|
| 45 |
c.execute("SELECT name, total_spent FROM high_value_customers ORDER BY name")
|
|
@@ -50,10 +50,10 @@ class EasyTask(Task):
|
|
| 50 |
|
| 51 |
expected = [("Bob", 1200.0), ("Diana", 3000.0), ("Eve", 1000.01)]
|
| 52 |
if rows == expected:
|
| 53 |
-
return
|
| 54 |
return 0.5
|
| 55 |
except Exception:
|
| 56 |
-
return 0.
|
| 57 |
|
| 58 |
|
| 59 |
class MediumTask(Task):
|
|
@@ -107,9 +107,9 @@ class MediumTask(Task):
|
|
| 107 |
correct_cats = sum(1 for c, e in zip(categories, expected_cats) if c == e)
|
| 108 |
score += (correct_cats / 5.0) * 0.3 # up to 0.3 for correct categories
|
| 109 |
|
| 110 |
-
return min(
|
| 111 |
except Exception:
|
| 112 |
-
return score
|
| 113 |
|
| 114 |
|
| 115 |
class HardTask(Task):
|
|
@@ -160,7 +160,7 @@ class HardTask(Task):
|
|
| 160 |
if 'appointments' in tables: score += 0.2
|
| 161 |
|
| 162 |
if score < 0.4:
|
| 163 |
-
return score
|
| 164 |
|
| 165 |
# Check data counts (3 unique patients, 2 unique doctors, 4 appointments)
|
| 166 |
c.execute("SELECT COUNT(*) FROM patients")
|
|
@@ -185,9 +185,9 @@ class HardTask(Task):
|
|
| 185 |
if len(reconstructed) == 4:
|
| 186 |
score += 0.3
|
| 187 |
|
| 188 |
-
return min(
|
| 189 |
except Exception:
|
| 190 |
-
return score
|
| 191 |
|
| 192 |
TASKS = {
|
| 193 |
1: EasyTask(),
|
|
|
|
| 39 |
# Check if view exists
|
| 40 |
c.execute("SELECT name FROM sqlite_master WHERE type='view' AND name='high_value_customers'")
|
| 41 |
if not c.fetchone():
|
| 42 |
+
return 0.01
|
| 43 |
|
| 44 |
# Check rows
|
| 45 |
c.execute("SELECT name, total_spent FROM high_value_customers ORDER BY name")
|
|
|
|
| 50 |
|
| 51 |
expected = [("Bob", 1200.0), ("Diana", 3000.0), ("Eve", 1000.01)]
|
| 52 |
if rows == expected:
|
| 53 |
+
return 0.99
|
| 54 |
return 0.5
|
| 55 |
except Exception:
|
| 56 |
+
return 0.01
|
| 57 |
|
| 58 |
|
| 59 |
class MediumTask(Task):
|
|
|
|
| 107 |
correct_cats = sum(1 for c, e in zip(categories, expected_cats) if c == e)
|
| 108 |
score += (correct_cats / 5.0) * 0.3 # up to 0.3 for correct categories
|
| 109 |
|
| 110 |
+
return min(max(score, 0.01), 0.99)
|
| 111 |
except Exception:
|
| 112 |
+
return min(max(score, 0.01), 0.99)
|
| 113 |
|
| 114 |
|
| 115 |
class HardTask(Task):
|
|
|
|
| 160 |
if 'appointments' in tables: score += 0.2
|
| 161 |
|
| 162 |
if score < 0.4:
|
| 163 |
+
return max(0.01, score)
|
| 164 |
|
| 165 |
# Check data counts (3 unique patients, 2 unique doctors, 4 appointments)
|
| 166 |
c.execute("SELECT COUNT(*) FROM patients")
|
|
|
|
| 185 |
if len(reconstructed) == 4:
|
| 186 |
score += 0.3
|
| 187 |
|
| 188 |
+
return min(max(score, 0.01), 0.99)
|
| 189 |
except Exception:
|
| 190 |
+
return min(max(score, 0.01), 0.99)
|
| 191 |
|
| 192 |
TASKS = {
|
| 193 |
1: EasyTask(),
|