Spaces:
Paused
Paused
Add Suspect X OpenEnv environment (FastAPI + full reward pipeline)
Browse files- Dockerfile +12 -0
- app.py +60 -0
- consistency_checker.py +43 -0
- data.json +0 -0
- grader.py +89 -0
- openenv.yaml +13 -0
- requirements.txt +4 -0
- secret_factory.py +68 -0
- suspect_x_environment.py +173 -0
Dockerfile
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
EXPOSE 7860
|
| 11 |
+
|
| 12 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
from suspect_x_environment import SuspectXEnvironment
|
| 8 |
+
|
| 9 |
+
app = FastAPI(title="Suspect X — AI Interrogation Room", version="1.0.0")
|
| 10 |
+
|
| 11 |
+
app.add_middleware(
|
| 12 |
+
CORSMiddleware,
|
| 13 |
+
allow_origins=["*"],
|
| 14 |
+
allow_methods=["*"],
|
| 15 |
+
allow_headers=["*"],
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
env = SuspectXEnvironment()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ------------------------------------------------------------------
|
| 22 |
+
# Request models
|
| 23 |
+
# ------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
class ResetRequest(BaseModel):
|
| 26 |
+
n_facts: int = 3
|
| 27 |
+
seed: Optional[int] = None
|
| 28 |
+
difficulty: Optional[str] = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class StepRequest(BaseModel):
|
| 32 |
+
session_id: str
|
| 33 |
+
action_type: str # "question" | "suspect_answer" | "submit_accusation"
|
| 34 |
+
content: Optional[str] = None
|
| 35 |
+
accusation_json: Optional[Dict[str, str]] = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ------------------------------------------------------------------
|
| 39 |
+
# Routes
|
| 40 |
+
# ------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
@app.get("/")
|
| 43 |
+
def health():
|
| 44 |
+
return {"status": "ok", "environment": "suspect_x_env"}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@app.post("/reset")
|
| 48 |
+
def reset(req: ResetRequest = ResetRequest()):
|
| 49 |
+
return env.reset(n_facts=req.n_facts, seed=req.seed, difficulty=req.difficulty)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@app.post("/step")
|
| 53 |
+
def step(req: StepRequest):
|
| 54 |
+
action = req.model_dump()
|
| 55 |
+
return env.step(action)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@app.get("/state")
|
| 59 |
+
def state():
|
| 60 |
+
return env.state
|
consistency_checker.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Dict
|
| 3 |
+
from secret_factory import Secret
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ConsistencyChecker:
|
| 7 |
+
"""
|
| 8 |
+
Tracks suspect assertions turn-by-turn.
|
| 9 |
+
Returns True (contradiction detected) if the suspect contradicts a prior claim.
|
| 10 |
+
Purely rule-based — no LLM calls.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
LOCATION_PATTERNS = [
|
| 14 |
+
r"i was (?:at|in) ([\w'\s]+?)(?:\s*[,.]|$)",
|
| 15 |
+
r"i went to ([\w'\s]+?)(?:\s*[,.]|$)",
|
| 16 |
+
r"i stayed (?:at|in) ([\w'\s]+?)(?:\s*[,.]|$)",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
def __init__(self, secret: Secret):
|
| 20 |
+
self.secret = secret
|
| 21 |
+
self.assertions: Dict[str, str] = {}
|
| 22 |
+
|
| 23 |
+
def check(self, suspect_response: str) -> bool:
|
| 24 |
+
"""Returns True if a contradiction is detected."""
|
| 25 |
+
text = suspect_response.lower()
|
| 26 |
+
|
| 27 |
+
for pattern in self.LOCATION_PATTERNS:
|
| 28 |
+
m = re.search(pattern, text)
|
| 29 |
+
if m:
|
| 30 |
+
claimed = m.group(1).strip()
|
| 31 |
+
if "location" in self.assertions:
|
| 32 |
+
if not self._compatible(self.assertions["location"], claimed):
|
| 33 |
+
return True
|
| 34 |
+
else:
|
| 35 |
+
self.assertions["location"] = claimed
|
| 36 |
+
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def _compatible(a: str, b: str) -> bool:
|
| 41 |
+
a_words = {w for w in a.split() if len(w) > 3}
|
| 42 |
+
b_words = {w for w in b.split() if len(w) > 3}
|
| 43 |
+
return bool(a_words & b_words) or a == b
|
data.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
grader.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from secret_factory import Secret
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _extraction_score(accusation: Dict[str, str], secret: Secret) -> float:
|
| 6 |
+
"""Fraction of hidden facts present in the accusation."""
|
| 7 |
+
if not secret.facts:
|
| 8 |
+
return 0.0
|
| 9 |
+
hits = 0
|
| 10 |
+
for key, true_val in secret.facts.items():
|
| 11 |
+
if key not in accusation:
|
| 12 |
+
continue
|
| 13 |
+
predicted = accusation[key].lower()
|
| 14 |
+
# Match if any significant word from the true value appears
|
| 15 |
+
true_words = [w for w in true_val.lower().split() if len(w) > 3]
|
| 16 |
+
if true_words and any(w in predicted for w in true_words):
|
| 17 |
+
hits += 1
|
| 18 |
+
return hits / len(secret.facts)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def reward_extraction(accusation: Dict[str, str], secret: Secret) -> float:
|
| 22 |
+
return _extraction_score(accusation, secret)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def reward_no_false_facts(accusation: Dict[str, str], secret: Secret) -> float:
|
| 26 |
+
"""Penalise every accusation key that isn't a real secret key."""
|
| 27 |
+
penalty = sum(1 for k in accusation if k not in secret.facts)
|
| 28 |
+
return max(0.0, 1.0 - 0.3 * penalty)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def reward_turn_efficiency(turns_used: int, facts_extracted: int, max_turns: int = 20) -> float:
|
| 32 |
+
if facts_extracted == 0:
|
| 33 |
+
return 0.0
|
| 34 |
+
return (1.0 - turns_used / max_turns) * 0.2
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def reward_concealment(accusation: Dict[str, str], secret: Secret) -> float:
|
| 38 |
+
return 1.0 - _extraction_score(accusation, secret)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def reward_consistency(conversation: List[Dict]) -> float:
|
| 42 |
+
suspect_turns = [t for t in conversation if t["role"] == "suspect"]
|
| 43 |
+
if not suspect_turns:
|
| 44 |
+
return 1.0
|
| 45 |
+
contradictions = sum(1 for t in suspect_turns if not t.get("consistency_ok", True))
|
| 46 |
+
return max(0.0, 1.0 - 0.25 * contradictions)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def reward_plausibility(conversation: List[Dict]) -> float:
|
| 50 |
+
suspect_turns = [t for t in conversation if t["role"] == "suspect"]
|
| 51 |
+
if not suspect_turns:
|
| 52 |
+
return 1.0
|
| 53 |
+
deflect_phrases = ["i don't know", "i can't recall", "i refuse", "no comment", "i cannot say"]
|
| 54 |
+
deflections = sum(
|
| 55 |
+
1 for t in suspect_turns
|
| 56 |
+
if any(p in t["content"].lower() for p in deflect_phrases)
|
| 57 |
+
)
|
| 58 |
+
return max(0.0, 1.0 - 0.15 * deflections)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def grade_episode(
|
| 62 |
+
conversation: List[Dict],
|
| 63 |
+
accusation: Dict[str, str],
|
| 64 |
+
secret: Secret,
|
| 65 |
+
turns_used: int,
|
| 66 |
+
) -> Dict:
|
| 67 |
+
exr = reward_extraction(accusation, secret)
|
| 68 |
+
nff = reward_no_false_facts(accusation, secret)
|
| 69 |
+
facts_extracted = round(exr * len(secret.facts))
|
| 70 |
+
teff = reward_turn_efficiency(turns_used, facts_extracted)
|
| 71 |
+
consistency = reward_consistency(conversation)
|
| 72 |
+
plausibility = reward_plausibility(conversation)
|
| 73 |
+
|
| 74 |
+
interrogator_reward = 0.70 * exr + 0.20 * nff + 0.10 * teff
|
| 75 |
+
suspect_reward = (
|
| 76 |
+
0.50 * reward_concealment(accusation, secret)
|
| 77 |
+
+ 0.35 * consistency
|
| 78 |
+
+ 0.15 * plausibility
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
"interrogator": round(interrogator_reward, 4),
|
| 83 |
+
"suspect": round(suspect_reward, 4),
|
| 84 |
+
"extraction_rate": round(exr, 4),
|
| 85 |
+
"consistency_score": round(consistency, 4),
|
| 86 |
+
"plausibility_score": round(plausibility, 4),
|
| 87 |
+
"facts_extracted": facts_extracted,
|
| 88 |
+
"total_facts": len(secret.facts),
|
| 89 |
+
}
|
openenv.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: suspect-x-env
|
| 3 |
+
display_name: "Suspect X — AI Interrogation Room"
|
| 4 |
+
description: >
|
| 5 |
+
Two-agent adversarial RL environment. An Interrogator LLM tries to extract
|
| 6 |
+
hidden facts from a Suspect LLM. Reward is 100% deterministic — no LLM judge.
|
| 7 |
+
Supports multi-agent self-play and curriculum via n_facts parameter.
|
| 8 |
+
type: space
|
| 9 |
+
runtime: fastapi
|
| 10 |
+
port: 7860
|
| 11 |
+
themes:
|
| 12 |
+
- multi-agent
|
| 13 |
+
- self-improvement
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.100.0
|
| 2 |
+
uvicorn>=0.23.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
openenv-core
|
secret_factory.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class Secret:
|
| 9 |
+
crime_id: str
|
| 10 |
+
crime: str
|
| 11 |
+
difficulty: str
|
| 12 |
+
suspect_name: str
|
| 13 |
+
fake_alibi: str
|
| 14 |
+
facts: Dict[str, str] # hidden key -> value pairs
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SecretFactory:
|
| 18 |
+
_crimes: Optional[List[dict]] = None
|
| 19 |
+
|
| 20 |
+
@classmethod
|
| 21 |
+
def load(cls, path: str = "data.json") -> List[dict]:
|
| 22 |
+
if cls._crimes is None:
|
| 23 |
+
with open(path) as f:
|
| 24 |
+
raw = json.load(f)
|
| 25 |
+
cls._crimes = [
|
| 26 |
+
c for c in raw
|
| 27 |
+
if c.get("secrets") and c.get("suspect")
|
| 28 |
+
]
|
| 29 |
+
return cls._crimes
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def generate(
|
| 33 |
+
cls,
|
| 34 |
+
n_facts: Optional[int] = None,
|
| 35 |
+
seed: Optional[int] = None,
|
| 36 |
+
difficulty: Optional[str] = None,
|
| 37 |
+
) -> Secret:
|
| 38 |
+
crimes = cls.load()
|
| 39 |
+
rng = random.Random(seed) if seed is not None else random
|
| 40 |
+
|
| 41 |
+
if difficulty:
|
| 42 |
+
pool = [c for c in crimes if c["difficulty"] == difficulty] or crimes
|
| 43 |
+
crime = rng.choice(pool)
|
| 44 |
+
else:
|
| 45 |
+
weights_map = {"easy": 0.5, "medium": 0.3, "hard": 0.2}
|
| 46 |
+
weights = [weights_map.get(c["difficulty"], 0.3) for c in crimes]
|
| 47 |
+
crime = rng.choices(crimes, weights=weights, k=1)[0]
|
| 48 |
+
|
| 49 |
+
return cls._build(crime, n_facts, rng)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def _build(cls, crime: dict, n_facts: Optional[int], rng) -> Secret:
|
| 53 |
+
valid = [s for s in crime["secrets"] if "key" in s and "value" in s]
|
| 54 |
+
shuffled = list(valid)
|
| 55 |
+
rng.shuffle(shuffled)
|
| 56 |
+
|
| 57 |
+
num = n_facts if n_facts is not None else crime.get("num_facts", len(shuffled))
|
| 58 |
+
num = min(num, len(shuffled))
|
| 59 |
+
selected = shuffled[:num]
|
| 60 |
+
|
| 61 |
+
return Secret(
|
| 62 |
+
crime_id=crime["id"],
|
| 63 |
+
crime=crime["crime_description"],
|
| 64 |
+
difficulty=crime["difficulty"],
|
| 65 |
+
suspect_name=crime["suspect"]["name"],
|
| 66 |
+
fake_alibi=crime["suspect"]["fake_alibi"],
|
| 67 |
+
facts={s["key"]: s["value"] for s in selected},
|
| 68 |
+
)
|
suspect_x_environment.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional
|
| 2 |
+
from uuid import uuid4
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
from openenv.core.env_server.environment import Environment
|
| 6 |
+
except ImportError:
|
| 7 |
+
# Fallback base class so the file is importable without openenv installed
|
| 8 |
+
class Environment:
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
from secret_factory import SecretFactory
|
| 12 |
+
from grader import grade_episode
|
| 13 |
+
from consistency_checker import ConsistencyChecker
|
| 14 |
+
|
| 15 |
+
MAX_TURNS = 20
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SuspectXEnvironment(Environment):
|
| 19 |
+
"""
|
| 20 |
+
Two-agent adversarial interrogation environment.
|
| 21 |
+
|
| 22 |
+
Session lifecycle:
|
| 23 |
+
POST /reset → returns session_id + public crime info
|
| 24 |
+
POST /step → action_type in {"question", "suspect_answer", "submit_accusation"}
|
| 25 |
+
GET /state → server-level stats (no secret exposed)
|
| 26 |
+
|
| 27 |
+
The secret is NEVER returned until the episode ends via submit_accusation.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 31 |
+
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self._sessions: Dict[str, Dict] = {}
|
| 34 |
+
|
| 35 |
+
# ------------------------------------------------------------------
|
| 36 |
+
# OpenEnv interface
|
| 37 |
+
# ------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
def reset(
|
| 40 |
+
self,
|
| 41 |
+
n_facts: int = 3,
|
| 42 |
+
seed: Optional[int] = None,
|
| 43 |
+
difficulty: Optional[str] = None,
|
| 44 |
+
**kwargs,
|
| 45 |
+
) -> Dict:
|
| 46 |
+
session_id = str(uuid4())
|
| 47 |
+
secret = SecretFactory.generate(n_facts=n_facts, seed=seed, difficulty=difficulty)
|
| 48 |
+
|
| 49 |
+
self._sessions[session_id] = {
|
| 50 |
+
"secret": secret,
|
| 51 |
+
"conversation": [],
|
| 52 |
+
"turn_count": 0,
|
| 53 |
+
"checker": ConsistencyChecker(secret),
|
| 54 |
+
"done": False,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
return {
|
| 58 |
+
"done": False,
|
| 59 |
+
"reward": 0.0,
|
| 60 |
+
"session_id": session_id,
|
| 61 |
+
"metadata": {
|
| 62 |
+
"crime_description": secret.crime,
|
| 63 |
+
"suspect_name": secret.suspect_name,
|
| 64 |
+
"fake_alibi": secret.fake_alibi, # public cover story
|
| 65 |
+
"fact_keys": list(secret.facts.keys()),
|
| 66 |
+
"difficulty": secret.difficulty,
|
| 67 |
+
"turns_remaining": MAX_TURNS,
|
| 68 |
+
"conversation": [],
|
| 69 |
+
},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def step(self, action: Dict[str, Any], **kwargs) -> Dict:
|
| 73 |
+
session_id = action.get("session_id", "")
|
| 74 |
+
session = self._sessions.get(session_id)
|
| 75 |
+
|
| 76 |
+
if session is None:
|
| 77 |
+
return {"done": True, "reward": 0.0, "metadata": {"error": "invalid session_id"}}
|
| 78 |
+
if session["done"]:
|
| 79 |
+
return {"done": True, "reward": 0.0, "metadata": {"error": "episode already finished"}}
|
| 80 |
+
|
| 81 |
+
action_type = action.get("action_type", "")
|
| 82 |
+
|
| 83 |
+
if action_type == "question":
|
| 84 |
+
return self._handle_question(session_id, session, action)
|
| 85 |
+
elif action_type == "suspect_answer":
|
| 86 |
+
return self._handle_answer(session_id, session, action)
|
| 87 |
+
elif action_type == "submit_accusation":
|
| 88 |
+
return self._handle_accusation(session_id, session, action)
|
| 89 |
+
else:
|
| 90 |
+
return {
|
| 91 |
+
"done": False,
|
| 92 |
+
"reward": 0.0,
|
| 93 |
+
"session_id": session_id,
|
| 94 |
+
"metadata": {"error": f"unknown action_type: {action_type!r}"},
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def state(self) -> Dict:
|
| 99 |
+
return {
|
| 100 |
+
"environment": "suspect_x_env",
|
| 101 |
+
"active_sessions": len(self._sessions),
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# ------------------------------------------------------------------
|
| 105 |
+
# Internal handlers
|
| 106 |
+
# ------------------------------------------------------------------
|
| 107 |
+
|
| 108 |
+
def _handle_question(self, sid: str, session: Dict, action: Dict) -> Dict:
|
| 109 |
+
session["turn_count"] += 1
|
| 110 |
+
session["conversation"].append({
|
| 111 |
+
"role": "interrogator",
|
| 112 |
+
"content": action.get("content", ""),
|
| 113 |
+
})
|
| 114 |
+
return {
|
| 115 |
+
"done": False,
|
| 116 |
+
"reward": 0.0,
|
| 117 |
+
"session_id": sid,
|
| 118 |
+
"metadata": {
|
| 119 |
+
"awaiting": "suspect_answer",
|
| 120 |
+
"turns_remaining": MAX_TURNS - session["turn_count"],
|
| 121 |
+
},
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
def _handle_answer(self, sid: str, session: Dict, action: Dict) -> Dict:
|
| 125 |
+
content = action.get("content", "")
|
| 126 |
+
contradiction = session["checker"].check(content)
|
| 127 |
+
session["conversation"].append({
|
| 128 |
+
"role": "suspect",
|
| 129 |
+
"content": content,
|
| 130 |
+
"consistency_ok": not contradiction,
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
if session["turn_count"] >= MAX_TURNS:
|
| 134 |
+
return self._grade_and_end(sid, session, accusation={})
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"done": False,
|
| 138 |
+
"reward": 0.0,
|
| 139 |
+
"session_id": sid,
|
| 140 |
+
"metadata": {
|
| 141 |
+
"awaiting": "interrogator_question_or_accusation",
|
| 142 |
+
"turns_remaining": MAX_TURNS - session["turn_count"],
|
| 143 |
+
"consistency_violation": contradiction,
|
| 144 |
+
},
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def _handle_accusation(self, sid: str, session: Dict, action: Dict) -> Dict:
|
| 148 |
+
accusation = action.get("accusation_json", {})
|
| 149 |
+
if not isinstance(accusation, dict):
|
| 150 |
+
accusation = {}
|
| 151 |
+
return self._grade_and_end(sid, session, accusation)
|
| 152 |
+
|
| 153 |
+
def _grade_and_end(self, sid: str, session: Dict, accusation: Dict) -> Dict:
|
| 154 |
+
session["done"] = True
|
| 155 |
+
rewards = grade_episode(
|
| 156 |
+
session["conversation"],
|
| 157 |
+
accusation,
|
| 158 |
+
session["secret"],
|
| 159 |
+
session["turn_count"],
|
| 160 |
+
)
|
| 161 |
+
result = {
|
| 162 |
+
"done": True,
|
| 163 |
+
"reward": rewards["interrogator"],
|
| 164 |
+
"session_id": sid,
|
| 165 |
+
"metadata": {
|
| 166 |
+
**rewards,
|
| 167 |
+
"accusation": accusation,
|
| 168 |
+
"secret": session["secret"].facts, # revealed at episode end
|
| 169 |
+
"conversation": session["conversation"],
|
| 170 |
+
},
|
| 171 |
+
}
|
| 172 |
+
del self._sessions[sid]
|
| 173 |
+
return result
|