Mr66 commited on
Commit
e2ca55c
·
verified ·
1 Parent(s): 4a42415

Upload server/env.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/env.py +238 -0
server/env.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import uuid
4
+ from pathlib import Path
5
+ from enum import Enum
6
+
7
+ from server.models import (
8
+ Secret,
9
+ MindReadObservation,
10
+ StepResult,
11
+ SubmitResult,
12
+ RewardBreakdown,
13
+ TaskMeta,
14
+ )
15
+ from server.oracle import ask_oracle
16
+ from server.reward import compute_reward
17
+
18
+ SECRETS_PATH = Path(__file__).parent / "data" / "secrets.json"
19
+
20
+ TASK_META: dict[str, TaskMeta] = {
21
+ "factual_easy": TaskMeta(
22
+ id="factual_easy",
23
+ description="Infer a hidden factual workplace secret (easy) — event, decision, or fact the Oracle knows but hasn't announced.",
24
+ max_steps=8,
25
+ reward_range=[0.0, 1.0],
26
+ difficulty="easy",
27
+ category="factual",
28
+ ),
29
+ "factual_hard": TaskMeta(
30
+ id="factual_hard",
31
+ description="Infer a precise numerical or date-bound secret. Requires specific inference, not just general direction.",
32
+ max_steps=6,
33
+ reward_range=[0.0, 1.0],
34
+ difficulty="hard",
35
+ category="factual",
36
+ ),
37
+ "belief_inference": TaskMeta(
38
+ id="belief_inference",
39
+ description="Infer what the Oracle believes about another person's internal state — emotions, plans, or intentions.",
40
+ max_steps=8,
41
+ reward_range=[0.0, 1.0],
42
+ difficulty="medium",
43
+ category="belief",
44
+ ),
45
+ "goal_inference": TaskMeta(
46
+ id="goal_inference",
47
+ description="Infer the Oracle's hidden personal or professional ambition they haven't disclosed to the team.",
48
+ max_steps=8,
49
+ reward_range=[0.0, 1.0],
50
+ difficulty="medium",
51
+ category="goal",
52
+ ),
53
+ "second_order": TaskMeta(
54
+ id="second_order",
55
+ description="Infer a recursive belief: what the Oracle believes someone else believes — second-order Theory of Mind.",
56
+ max_steps=10,
57
+ reward_range=[0.0, 1.0],
58
+ difficulty="hard",
59
+ category="second_order",
60
+ ),
61
+ }
62
+
63
+ TASK_DESCRIPTION = {
64
+ "factual_easy": (
65
+ "Figure out what factual information the Oracle is privately aware of "
66
+ "but has not publicly disclosed. Ask indirect, strategic questions."
67
+ ),
68
+ "factual_hard": (
69
+ "Infer a specific fact (number, date, or precise detail) the Oracle knows privately. "
70
+ "You need precision — vague guesses score low."
71
+ ),
72
+ "belief_inference": (
73
+ "Determine what the Oracle believes about another person's state of mind, "
74
+ "intentions, or emotional situation. The belief may not be stated but can be inferred."
75
+ ),
76
+ "goal_inference": (
77
+ "Infer the Oracle's hidden personal ambition or undisclosed professional goal. "
78
+ "They won't tell you directly but their answers will reveal it."
79
+ ),
80
+ "second_order": (
81
+ "Determine what the Oracle believes that ANOTHER PERSON believes or thinks. "
82
+ "This is second-order Theory of Mind — you must infer a belief about a belief."
83
+ ),
84
+ }
85
+
86
+
87
+ class EpisodeState(str, Enum):
88
+ IDLE = "idle"
89
+ ACTIVE = "active"
90
+ SCORED = "scored"
91
+
92
+
93
+ class Episode:
94
+ def __init__(self, episode_id: str, secret: Secret, task_id: str):
95
+ self.episode_id = episode_id
96
+ self.secret = secret
97
+ self.task_id = task_id
98
+ self.state = EpisodeState.ACTIVE
99
+ self.conversation_history: list[dict] = []
100
+ self.step = 0
101
+ self.max_steps = TASK_META[task_id].max_steps
102
+ self.reward: float | None = None
103
+ self.breakdown: RewardBreakdown | None = None
104
+
105
+ def questions_remaining(self) -> int:
106
+ return max(0, self.max_steps - self.step)
107
+
108
+ def to_observation(self) -> MindReadObservation:
109
+ return MindReadObservation(
110
+ episode_id=self.episode_id,
111
+ task_id=self.task_id,
112
+ step=self.step,
113
+ max_steps=self.max_steps,
114
+ context=self.secret.context,
115
+ oracle_persona=self.secret.persona,
116
+ conversation_history=list(self.conversation_history),
117
+ questions_remaining=self.questions_remaining(),
118
+ task_description=TASK_DESCRIPTION[self.task_id],
119
+ )
120
+
121
+
122
+ class MindReadEnv:
123
+ def __init__(self):
124
+ self._secrets: dict[str, list[Secret]] = {}
125
+ self._episodes: dict[str, Episode] = {}
126
+ self._load_secrets()
127
+
128
+ def _load_secrets(self):
129
+ raw = json.loads(SECRETS_PATH.read_text(encoding="utf-8"))
130
+ for item in raw:
131
+ s = Secret(**item)
132
+ self._secrets.setdefault(s.task_id, []).append(s)
133
+
134
+ def get_tasks(self) -> list[TaskMeta]:
135
+ return list(TASK_META.values())
136
+
137
+ def reset(self, task_id: str, secret_id: str | None = None) -> MindReadObservation:
138
+ if task_id not in TASK_META:
139
+ raise ValueError(f"Unknown task_id: {task_id}")
140
+
141
+ pool = self._secrets.get(task_id, [])
142
+ if not pool:
143
+ raise RuntimeError(f"No secrets available for task: {task_id}")
144
+
145
+ if secret_id:
146
+ candidates = [s for s in pool if s.id == secret_id]
147
+ if not candidates:
148
+ raise ValueError(f"secret_id {secret_id!r} not found in task {task_id!r}")
149
+ secret = candidates[0]
150
+ else:
151
+ secret = random.choice(pool)
152
+
153
+ episode_id = str(uuid.uuid4())
154
+ ep = Episode(episode_id=episode_id, secret=secret, task_id=task_id)
155
+ self._episodes[episode_id] = ep
156
+ return ep.to_observation()
157
+
158
+ def step(self, episode_id: str, question: str) -> StepResult:
159
+ ep = self._get_active(episode_id)
160
+
161
+ if ep.questions_remaining() == 0:
162
+ obs = ep.to_observation()
163
+ return StepResult(
164
+ observation=obs,
165
+ reward=0.0,
166
+ done=True,
167
+ info={"error": "No questions remaining. Please submit a hypothesis."},
168
+ )
169
+
170
+ oracle_answer = ask_oracle(ep.secret, ep.conversation_history, question)
171
+ ep.conversation_history.append({"role": "detective", "content": question})
172
+ ep.conversation_history.append({"role": "oracle", "content": oracle_answer})
173
+ ep.step += 1
174
+
175
+ done = ep.questions_remaining() == 0
176
+ obs = ep.to_observation()
177
+ return StepResult(
178
+ observation=obs,
179
+ reward=0.0,
180
+ done=done,
181
+ info={"oracle_response": oracle_answer},
182
+ )
183
+
184
+ def submit(
185
+ self,
186
+ episode_id: str,
187
+ hypothesis: str,
188
+ category_prediction: str | None = None,
189
+ ) -> SubmitResult:
190
+ ep = self._get_active(episode_id)
191
+
192
+ result = compute_reward(
193
+ hypothesis=hypothesis,
194
+ true_secret=ep.secret.content,
195
+ n_questions_used=ep.step,
196
+ max_questions=ep.max_steps,
197
+ category_predicted=category_prediction,
198
+ category_true=ep.secret.category,
199
+ hint_keywords=ep.secret.hint_keywords,
200
+ )
201
+
202
+ breakdown = RewardBreakdown(
203
+ reward=result["reward"],
204
+ semantic_similarity=result["components"]["semantic"],
205
+ efficiency_bonus=result["components"]["efficiency"],
206
+ category_bonus=result["components"]["category_bonus"],
207
+ keyword_bonus=result["components"]["keyword_bonus"],
208
+ questions_used=ep.step,
209
+ hypothesis=hypothesis,
210
+ )
211
+
212
+ ep.reward = result["reward"]
213
+ ep.breakdown = breakdown
214
+ ep.state = EpisodeState.SCORED
215
+
216
+ return SubmitResult(
217
+ reward=result["reward"],
218
+ breakdown=breakdown,
219
+ true_secret=ep.secret.content,
220
+ episode_id=episode_id,
221
+ done=True,
222
+ )
223
+
224
+ def get_state(self, episode_id: str) -> MindReadObservation:
225
+ if episode_id not in self._episodes:
226
+ raise KeyError(f"Episode {episode_id!r} not found")
227
+ return self._episodes[episode_id].to_observation()
228
+
229
+ def add_secret(self, secret: Secret):
230
+ self._secrets.setdefault(secret.task_id, []).append(secret)
231
+
232
+ def _get_active(self, episode_id: str) -> Episode:
233
+ if episode_id not in self._episodes:
234
+ raise KeyError(f"Episode {episode_id!r} not found")
235
+ ep = self._episodes[episode_id]
236
+ if ep.state != EpisodeState.ACTIVE:
237
+ raise ValueError(f"Episode {episode_id!r} is in state {ep.state.value}, not active")
238
+ return ep