ayaan-ai commited on
Commit
ffbce00
·
1 Parent(s): 9562a00

Add Suspect X OpenEnv environment (FastAPI + full reward pipeline)

Browse files
Files changed (9) hide show
  1. Dockerfile +12 -0
  2. app.py +60 -0
  3. consistency_checker.py +43 -0
  4. data.json +0 -0
  5. grader.py +89 -0
  6. openenv.yaml +13 -0
  7. requirements.txt +4 -0
  8. secret_factory.py +68 -0
  9. suspect_x_environment.py +173 -0
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+
7
+ from suspect_x_environment import SuspectXEnvironment
8
+
9
+ app = FastAPI(title="Suspect X — AI Interrogation Room", version="1.0.0")
10
+
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ env = SuspectXEnvironment()
19
+
20
+
21
+ # ------------------------------------------------------------------
22
+ # Request models
23
+ # ------------------------------------------------------------------
24
+
25
+ class ResetRequest(BaseModel):
26
+ n_facts: int = 3
27
+ seed: Optional[int] = None
28
+ difficulty: Optional[str] = None
29
+
30
+
31
+ class StepRequest(BaseModel):
32
+ session_id: str
33
+ action_type: str # "question" | "suspect_answer" | "submit_accusation"
34
+ content: Optional[str] = None
35
+ accusation_json: Optional[Dict[str, str]] = None
36
+
37
+
38
+ # ------------------------------------------------------------------
39
+ # Routes
40
+ # ------------------------------------------------------------------
41
+
42
+ @app.get("/")
43
+ def health():
44
+ return {"status": "ok", "environment": "suspect_x_env"}
45
+
46
+
47
+ @app.post("/reset")
48
+ def reset(req: ResetRequest = ResetRequest()):
49
+ return env.reset(n_facts=req.n_facts, seed=req.seed, difficulty=req.difficulty)
50
+
51
+
52
+ @app.post("/step")
53
+ def step(req: StepRequest):
54
+ action = req.model_dump()
55
+ return env.step(action)
56
+
57
+
58
+ @app.get("/state")
59
+ def state():
60
+ return env.state
consistency_checker.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict
3
+ from secret_factory import Secret
4
+
5
+
6
+ class ConsistencyChecker:
7
+ """
8
+ Tracks suspect assertions turn-by-turn.
9
+ Returns True (contradiction detected) if the suspect contradicts a prior claim.
10
+ Purely rule-based — no LLM calls.
11
+ """
12
+
13
+ LOCATION_PATTERNS = [
14
+ r"i was (?:at|in) ([\w'\s]+?)(?:\s*[,.]|$)",
15
+ r"i went to ([\w'\s]+?)(?:\s*[,.]|$)",
16
+ r"i stayed (?:at|in) ([\w'\s]+?)(?:\s*[,.]|$)",
17
+ ]
18
+
19
+ def __init__(self, secret: Secret):
20
+ self.secret = secret
21
+ self.assertions: Dict[str, str] = {}
22
+
23
+ def check(self, suspect_response: str) -> bool:
24
+ """Returns True if a contradiction is detected."""
25
+ text = suspect_response.lower()
26
+
27
+ for pattern in self.LOCATION_PATTERNS:
28
+ m = re.search(pattern, text)
29
+ if m:
30
+ claimed = m.group(1).strip()
31
+ if "location" in self.assertions:
32
+ if not self._compatible(self.assertions["location"], claimed):
33
+ return True
34
+ else:
35
+ self.assertions["location"] = claimed
36
+
37
+ return False
38
+
39
+ @staticmethod
40
+ def _compatible(a: str, b: str) -> bool:
41
+ a_words = {w for w in a.split() if len(w) > 3}
42
+ b_words = {w for w in b.split() if len(w) > 3}
43
+ return bool(a_words & b_words) or a == b
data.json ADDED
The diff for this file is too large to render. See raw diff
 
grader.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from secret_factory import Secret
3
+
4
+
5
+ def _extraction_score(accusation: Dict[str, str], secret: Secret) -> float:
6
+ """Fraction of hidden facts present in the accusation."""
7
+ if not secret.facts:
8
+ return 0.0
9
+ hits = 0
10
+ for key, true_val in secret.facts.items():
11
+ if key not in accusation:
12
+ continue
13
+ predicted = accusation[key].lower()
14
+ # Match if any significant word from the true value appears
15
+ true_words = [w for w in true_val.lower().split() if len(w) > 3]
16
+ if true_words and any(w in predicted for w in true_words):
17
+ hits += 1
18
+ return hits / len(secret.facts)
19
+
20
+
21
+ def reward_extraction(accusation: Dict[str, str], secret: Secret) -> float:
22
+ return _extraction_score(accusation, secret)
23
+
24
+
25
+ def reward_no_false_facts(accusation: Dict[str, str], secret: Secret) -> float:
26
+ """Penalise every accusation key that isn't a real secret key."""
27
+ penalty = sum(1 for k in accusation if k not in secret.facts)
28
+ return max(0.0, 1.0 - 0.3 * penalty)
29
+
30
+
31
+ def reward_turn_efficiency(turns_used: int, facts_extracted: int, max_turns: int = 20) -> float:
32
+ if facts_extracted == 0:
33
+ return 0.0
34
+ return (1.0 - turns_used / max_turns) * 0.2
35
+
36
+
37
+ def reward_concealment(accusation: Dict[str, str], secret: Secret) -> float:
38
+ return 1.0 - _extraction_score(accusation, secret)
39
+
40
+
41
+ def reward_consistency(conversation: List[Dict]) -> float:
42
+ suspect_turns = [t for t in conversation if t["role"] == "suspect"]
43
+ if not suspect_turns:
44
+ return 1.0
45
+ contradictions = sum(1 for t in suspect_turns if not t.get("consistency_ok", True))
46
+ return max(0.0, 1.0 - 0.25 * contradictions)
47
+
48
+
49
+ def reward_plausibility(conversation: List[Dict]) -> float:
50
+ suspect_turns = [t for t in conversation if t["role"] == "suspect"]
51
+ if not suspect_turns:
52
+ return 1.0
53
+ deflect_phrases = ["i don't know", "i can't recall", "i refuse", "no comment", "i cannot say"]
54
+ deflections = sum(
55
+ 1 for t in suspect_turns
56
+ if any(p in t["content"].lower() for p in deflect_phrases)
57
+ )
58
+ return max(0.0, 1.0 - 0.15 * deflections)
59
+
60
+
61
+ def grade_episode(
62
+ conversation: List[Dict],
63
+ accusation: Dict[str, str],
64
+ secret: Secret,
65
+ turns_used: int,
66
+ ) -> Dict:
67
+ exr = reward_extraction(accusation, secret)
68
+ nff = reward_no_false_facts(accusation, secret)
69
+ facts_extracted = round(exr * len(secret.facts))
70
+ teff = reward_turn_efficiency(turns_used, facts_extracted)
71
+ consistency = reward_consistency(conversation)
72
+ plausibility = reward_plausibility(conversation)
73
+
74
+ interrogator_reward = 0.70 * exr + 0.20 * nff + 0.10 * teff
75
+ suspect_reward = (
76
+ 0.50 * reward_concealment(accusation, secret)
77
+ + 0.35 * consistency
78
+ + 0.15 * plausibility
79
+ )
80
+
81
+ return {
82
+ "interrogator": round(interrogator_reward, 4),
83
+ "suspect": round(suspect_reward, 4),
84
+ "extraction_rate": round(exr, 4),
85
+ "consistency_score": round(consistency, 4),
86
+ "plausibility_score": round(plausibility, 4),
87
+ "facts_extracted": facts_extracted,
88
+ "total_facts": len(secret.facts),
89
+ }
openenv.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: suspect-x-env
3
+ display_name: "Suspect X — AI Interrogation Room"
4
+ description: >
5
+ Two-agent adversarial RL environment. An Interrogator LLM tries to extract
6
+ hidden facts from a Suspect LLM. Reward is 100% deterministic — no LLM judge.
7
+ Supports multi-agent self-play and curriculum via n_facts parameter.
8
+ type: space
9
+ runtime: fastapi
10
+ port: 7860
11
+ themes:
12
+ - multi-agent
13
+ - self-improvement
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi>=0.100.0
2
+ uvicorn>=0.23.0
3
+ pydantic>=2.0.0
4
+ openenv-core
secret_factory.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional
5
+
6
+
7
+ @dataclass
8
+ class Secret:
9
+ crime_id: str
10
+ crime: str
11
+ difficulty: str
12
+ suspect_name: str
13
+ fake_alibi: str
14
+ facts: Dict[str, str] # hidden key -> value pairs
15
+
16
+
17
+ class SecretFactory:
18
+ _crimes: Optional[List[dict]] = None
19
+
20
+ @classmethod
21
+ def load(cls, path: str = "data.json") -> List[dict]:
22
+ if cls._crimes is None:
23
+ with open(path) as f:
24
+ raw = json.load(f)
25
+ cls._crimes = [
26
+ c for c in raw
27
+ if c.get("secrets") and c.get("suspect")
28
+ ]
29
+ return cls._crimes
30
+
31
+ @classmethod
32
+ def generate(
33
+ cls,
34
+ n_facts: Optional[int] = None,
35
+ seed: Optional[int] = None,
36
+ difficulty: Optional[str] = None,
37
+ ) -> Secret:
38
+ crimes = cls.load()
39
+ rng = random.Random(seed) if seed is not None else random
40
+
41
+ if difficulty:
42
+ pool = [c for c in crimes if c["difficulty"] == difficulty] or crimes
43
+ crime = rng.choice(pool)
44
+ else:
45
+ weights_map = {"easy": 0.5, "medium": 0.3, "hard": 0.2}
46
+ weights = [weights_map.get(c["difficulty"], 0.3) for c in crimes]
47
+ crime = rng.choices(crimes, weights=weights, k=1)[0]
48
+
49
+ return cls._build(crime, n_facts, rng)
50
+
51
+ @classmethod
52
+ def _build(cls, crime: dict, n_facts: Optional[int], rng) -> Secret:
53
+ valid = [s for s in crime["secrets"] if "key" in s and "value" in s]
54
+ shuffled = list(valid)
55
+ rng.shuffle(shuffled)
56
+
57
+ num = n_facts if n_facts is not None else crime.get("num_facts", len(shuffled))
58
+ num = min(num, len(shuffled))
59
+ selected = shuffled[:num]
60
+
61
+ return Secret(
62
+ crime_id=crime["id"],
63
+ crime=crime["crime_description"],
64
+ difficulty=crime["difficulty"],
65
+ suspect_name=crime["suspect"]["name"],
66
+ fake_alibi=crime["suspect"]["fake_alibi"],
67
+ facts={s["key"]: s["value"] for s in selected},
68
+ )
suspect_x_environment.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+ from uuid import uuid4
3
+
4
+ try:
5
+ from openenv.core.env_server.environment import Environment
6
+ except ImportError:
7
+ # Fallback base class so the file is importable without openenv installed
8
+ class Environment:
9
+ pass
10
+
11
+ from secret_factory import SecretFactory
12
+ from grader import grade_episode
13
+ from consistency_checker import ConsistencyChecker
14
+
15
+ MAX_TURNS = 20
16
+
17
+
18
+ class SuspectXEnvironment(Environment):
19
+ """
20
+ Two-agent adversarial interrogation environment.
21
+
22
+ Session lifecycle:
23
+ POST /reset → returns session_id + public crime info
24
+ POST /step → action_type in {"question", "suspect_answer", "submit_accusation"}
25
+ GET /state → server-level stats (no secret exposed)
26
+
27
+ The secret is NEVER returned until the episode ends via submit_accusation.
28
+ """
29
+
30
+ SUPPORTS_CONCURRENT_SESSIONS = True
31
+
32
+ def __init__(self):
33
+ self._sessions: Dict[str, Dict] = {}
34
+
35
+ # ------------------------------------------------------------------
36
+ # OpenEnv interface
37
+ # ------------------------------------------------------------------
38
+
39
+ def reset(
40
+ self,
41
+ n_facts: int = 3,
42
+ seed: Optional[int] = None,
43
+ difficulty: Optional[str] = None,
44
+ **kwargs,
45
+ ) -> Dict:
46
+ session_id = str(uuid4())
47
+ secret = SecretFactory.generate(n_facts=n_facts, seed=seed, difficulty=difficulty)
48
+
49
+ self._sessions[session_id] = {
50
+ "secret": secret,
51
+ "conversation": [],
52
+ "turn_count": 0,
53
+ "checker": ConsistencyChecker(secret),
54
+ "done": False,
55
+ }
56
+
57
+ return {
58
+ "done": False,
59
+ "reward": 0.0,
60
+ "session_id": session_id,
61
+ "metadata": {
62
+ "crime_description": secret.crime,
63
+ "suspect_name": secret.suspect_name,
64
+ "fake_alibi": secret.fake_alibi, # public cover story
65
+ "fact_keys": list(secret.facts.keys()),
66
+ "difficulty": secret.difficulty,
67
+ "turns_remaining": MAX_TURNS,
68
+ "conversation": [],
69
+ },
70
+ }
71
+
72
+ def step(self, action: Dict[str, Any], **kwargs) -> Dict:
73
+ session_id = action.get("session_id", "")
74
+ session = self._sessions.get(session_id)
75
+
76
+ if session is None:
77
+ return {"done": True, "reward": 0.0, "metadata": {"error": "invalid session_id"}}
78
+ if session["done"]:
79
+ return {"done": True, "reward": 0.0, "metadata": {"error": "episode already finished"}}
80
+
81
+ action_type = action.get("action_type", "")
82
+
83
+ if action_type == "question":
84
+ return self._handle_question(session_id, session, action)
85
+ elif action_type == "suspect_answer":
86
+ return self._handle_answer(session_id, session, action)
87
+ elif action_type == "submit_accusation":
88
+ return self._handle_accusation(session_id, session, action)
89
+ else:
90
+ return {
91
+ "done": False,
92
+ "reward": 0.0,
93
+ "session_id": session_id,
94
+ "metadata": {"error": f"unknown action_type: {action_type!r}"},
95
+ }
96
+
97
+ @property
98
+ def state(self) -> Dict:
99
+ return {
100
+ "environment": "suspect_x_env",
101
+ "active_sessions": len(self._sessions),
102
+ }
103
+
104
+ # ------------------------------------------------------------------
105
+ # Internal handlers
106
+ # ------------------------------------------------------------------
107
+
108
+ def _handle_question(self, sid: str, session: Dict, action: Dict) -> Dict:
109
+ session["turn_count"] += 1
110
+ session["conversation"].append({
111
+ "role": "interrogator",
112
+ "content": action.get("content", ""),
113
+ })
114
+ return {
115
+ "done": False,
116
+ "reward": 0.0,
117
+ "session_id": sid,
118
+ "metadata": {
119
+ "awaiting": "suspect_answer",
120
+ "turns_remaining": MAX_TURNS - session["turn_count"],
121
+ },
122
+ }
123
+
124
+ def _handle_answer(self, sid: str, session: Dict, action: Dict) -> Dict:
125
+ content = action.get("content", "")
126
+ contradiction = session["checker"].check(content)
127
+ session["conversation"].append({
128
+ "role": "suspect",
129
+ "content": content,
130
+ "consistency_ok": not contradiction,
131
+ })
132
+
133
+ if session["turn_count"] >= MAX_TURNS:
134
+ return self._grade_and_end(sid, session, accusation={})
135
+
136
+ return {
137
+ "done": False,
138
+ "reward": 0.0,
139
+ "session_id": sid,
140
+ "metadata": {
141
+ "awaiting": "interrogator_question_or_accusation",
142
+ "turns_remaining": MAX_TURNS - session["turn_count"],
143
+ "consistency_violation": contradiction,
144
+ },
145
+ }
146
+
147
+ def _handle_accusation(self, sid: str, session: Dict, action: Dict) -> Dict:
148
+ accusation = action.get("accusation_json", {})
149
+ if not isinstance(accusation, dict):
150
+ accusation = {}
151
+ return self._grade_and_end(sid, session, accusation)
152
+
153
+ def _grade_and_end(self, sid: str, session: Dict, accusation: Dict) -> Dict:
154
+ session["done"] = True
155
+ rewards = grade_episode(
156
+ session["conversation"],
157
+ accusation,
158
+ session["secret"],
159
+ session["turn_count"],
160
+ )
161
+ result = {
162
+ "done": True,
163
+ "reward": rewards["interrogator"],
164
+ "session_id": sid,
165
+ "metadata": {
166
+ **rewards,
167
+ "accusation": accusation,
168
+ "secret": session["secret"].facts, # revealed at episode end
169
+ "conversation": session["conversation"],
170
+ },
171
+ }
172
+ del self._sessions[sid]
173
+ return result