Sibam commited on
Commit
cdf485e
Β·
0 Parent(s):

PreferenceLab OpenEnv environment for RLHF preference simulation

Browse files
.dockerignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pyo
4
+ .env
5
+ .git
6
+ .gitignore
7
+ .pytest_cache
8
+ tests/
9
+ scripts/
10
+ outputs/
11
+ *.egg-info
12
+ dist/
13
+ build/
14
+ README.md
.gitignore ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python Bytecode
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # Virtual Environments
9
+ .venv/
10
+ venv/
11
+ ENV/
12
+ env/
13
+
14
+ # Package Management
15
+ *.egg-info/
16
+ dist/
17
+ build/
18
+ *.egg
19
+ pip-log.txt
20
+ .pip/
21
+ uv.lock
22
+ uv.cache/
23
+
24
+ # IDE & Editor
25
+ .idea/
26
+ .vscode/settings.json
27
+ *.swp
28
+ *.swo
29
+ *~
30
+ .DS_Store
31
+ Thumbs.db
32
+
33
+ # Testing & Coverage
34
+ .pytest_cache/
35
+ .mypy_cache/
36
+ .coverage
37
+ htmlcov/
38
+ .ruff_cache/
39
+
40
+ # Logging
41
+ *.log
42
+
43
+ # Environment
44
+ .env
45
+ .env.local
46
+ .env.*.local
47
+
48
+ # Project-specific Generated Data
49
+ outputs/
50
+ data/pairwise_data.json
51
+ data/likert_data.json
52
+ data/consistency_data.json
53
+
54
+ # Machine Learning Models (if applicable)
55
+ *.bin
56
+ *.safetensors
57
+ *.pt
58
+ *.pth
59
+ *.pkl
60
+ *.joblib
Dockerfile ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Team Nexis
2
+ # PreferenceLab OpenEnv Environment
3
+ # Based on the official openenv-base image pattern
4
+
5
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
6
+ FROM ghcr.io/meta-pytorch/openenv-base:latest AS builder
7
+
8
+ WORKDIR /app
9
+
10
+ ARG BUILD_MODE=standalone
11
+
12
+ # Copy environment code
13
+ COPY . /app/env
14
+
15
+ WORKDIR /app/env
16
+
17
+ # Ensure uv is available
18
+ RUN if ! command -v uv >/dev/null 2>&1; then \
19
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
20
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
21
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
22
+ fi
23
+
24
+ # Install git (build-time only)
25
+ RUN apt-get update && apt-get install -y --no-install-recommends \
26
+ git \
27
+ && rm -rf /var/lib/apt/lists/*
28
+
29
+ # Install uv fresh
30
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
31
+ install -m 0755 /root/.local/bin/uv /usr/local/bin/uv && \
32
+ install -m 0755 /root/.local/bin/uvx /usr/local/bin/uvx
33
+
34
+ # Install dependencies (two-pass for caching)
35
+ RUN --mount=type=cache,target=/root/.cache/uv \
36
+ uv sync --no-install-project --no-editable
37
+
38
+ RUN --mount=type=cache,target=/root/.cache/uv \
39
+ uv sync --no-editable
40
+
41
+ # ── Runtime stage ──────────────────────────────────────────────
42
+ FROM ghcr.io/meta-pytorch/openenv-base:latest
43
+
44
+ WORKDIR /app
45
+
46
+ # Copy venv and code from builder
47
+ COPY --from=builder /app/env/.venv /app/.venv
48
+ COPY --from=builder /app/env /app/env
49
+
50
+ # Environment
51
+ ENV PATH="/app/.venv/bin:$PATH"
52
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
53
+ ENV ENABLE_WEB_INTERFACE=true
54
+ ENV MAX_CONCURRENT_ENVS=64
55
+
56
+ # Health check
57
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
58
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
59
+
60
+ # Start server
61
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PreferenceLab
3
+ emoji: πŸ§ͺ
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # πŸ§ͺ PreferenceLab
11
+
12
+ **An OpenEnv environment simulating the RLHF human preference data collection pipeline.**
13
+
14
+ > Built for the Meta Γ— Hugging Face OpenEnv Hackathon β€” Team Nexis
15
+
16
+ ---
17
+
18
+ ## Overview
19
+
20
+ PreferenceLab is a real-world OpenEnv environment where an AI agent learns to judge
21
+ the quality of LLM responses β€” exactly as a human annotator would in an RLHF pipeline.
22
+
23
+ Instead of expensive, slow human annotators, the environment provides:
24
+ - **Deterministic grading** using gold labels from real preference datasets
25
+ - **Partial reward signals** at every step (not just binary end-of-episode)
26
+ - **3 task difficulty levels**: pairwise β†’ multi-axis scoring β†’ transitive ranking
27
+
28
+ This fills a genuine gap: zero existing OpenEnv environments simulate the RLHF
29
+ data collection pipeline that powers models like Llama, Claude, and GPT-4.
30
+
31
+ ---
32
+
33
+ ## Action & Observation Spaces
34
+
35
+ ### Task 1 β€” Pairwise Ranking (Easy)
36
+
37
+ **Observation:**
38
+ ```python
39
+ PairwiseObservation(
40
+ prompt: str, # The user instruction
41
+ response_a: str, # Candidate response A
42
+ response_b: str, # Candidate response B
43
+ reward: float, # Last step reward
44
+ done: bool,
45
+ step_count: int,
46
+ )
47
+ ```
48
+
49
+ **Action:**
50
+ ```python
51
+ PairwiseAction(
52
+ choice: Literal["A", "B", "tie", "skip"],
53
+ justification: Optional[str], # not graded
54
+ )
55
+ ```
56
+
57
+ **Grader:** +1.0 correct | +0.3 skip | +0.1 tie | +0.0 wrong
58
+
59
+ ---
60
+
61
+ ### Task 2 β€” Multi-Axis Likert Scoring (Medium)
62
+
63
+ **Observation:**
64
+ ```python
65
+ LikertObservation(
66
+ prompt: str,
67
+ response: str,
68
+ rubric: str, # Scoring instructions
69
+ reward: float,
70
+ done: bool,
71
+ step_count: int,
72
+ )
73
+ ```
74
+
75
+ **Action:**
76
+ ```python
77
+ LikertAction(
78
+ helpfulness: int, # 1-5
79
+ honesty: int, # 1-5
80
+ harmlessness: int, # 1-5
81
+ instruction_following: int # 1-5
82
+ )
83
+ ```
84
+
85
+ **Grader:** `reward = 1.0 - (MAE / 4.0)` β€” continuous signal based on deviation from gold scores
86
+
87
+ ---
88
+
89
+ ### Task 3 β€” Transitive Consistency Ranking (Hard)
90
+
91
+ **Observation:**
92
+ ```python
93
+ ConsistencyObservation(
94
+ prompt: str,
95
+ response_a: str,
96
+ response_b: str,
97
+ response_c: str,
98
+ response_d: str,
99
+ reward: float,
100
+ done: bool,
101
+ step_count: int,
102
+ )
103
+ ```
104
+
105
+ **Action:**
106
+ ```python
107
+ ConsistencyAction(
108
+ ranking: list[str] # e.g. ["C", "A", "D", "B"] best→worst
109
+ )
110
+ ```
111
+
112
+ **Grader:** Transitivity score (0.0–0.5) + Kendall's tau quality correlation (0.0–0.5)
113
+
114
+ ---
115
+
116
+ ## Reward Function
117
+
118
+ | Component | Range | Description |
119
+ |---|---|---|
120
+ | Correctness | 0.0–1.0 | Agreement with gold label |
121
+ | Partial credit | 0.1–0.3 | For abstain/tie actions |
122
+ | Trajectory | cumulative | Sum over episode steps |
123
+
124
+ Rewards are **non-sparse** β€” every step provides a signal. Graders are
125
+ **deterministic** and **reproducible** β€” same seed = same episode.
126
+
127
+ ---
128
+
129
+ ## Setup
130
+
131
+ ### Local Development
132
+
133
+ ```bash
134
+ git clone https://github.com/your-username/preference-lab
135
+ cd preference-lab
136
+ pip install -e .
137
+
138
+ # Optional: download real datasets
139
+ python scripts/prepare_datasets.py --samples 200
140
+
141
+ # Run tests
142
+ pytest tests/ -v
143
+
144
+ # Run inference baseline
145
+ python inference.py
146
+ ```
147
+
148
+ ### Environment Variables
149
+
150
+ | Variable | Required | Default | Description |
151
+ |---|---|---|---|
152
+ | `API_BASE_URL` | Yes | `https://api-inference.huggingface.co/v1` | LLM API endpoint |
153
+ | `MODEL_NAME` | Yes | `meta-llama/Llama-3.1-8B-Instruct` | Model to use for inference |
154
+ | `HF_TOKEN` | Yes | β€” | Hugging Face API key |
155
+ | `ENV_BASE_URL` | No | `http://localhost:8000` | Space URL for remote use |
156
+ | `MAX_CONCURRENT_ENVS` | No | `64` | Parallel sessions supported |
157
+
158
+ ### Docker
159
+
160
+ ```bash
161
+ docker build -t preference-lab .
162
+ docker run -p 8000:8000 \
163
+ -e HF_TOKEN=$HF_TOKEN \
164
+ -e API_BASE_URL=https://api-inference.huggingface.co/v1 \
165
+ -e MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct \
166
+ preference-lab
167
+ ```
168
+
169
+ ---
170
+
171
+ ## Using with TRL (GRPO Training)
172
+
173
+ ```python
174
+ from preference_lab import PreferenceLabEnv, PairwiseAction
175
+
176
+ class PreferenceEnvWrapper:
177
+ def __init__(self):
178
+ self.client = PreferenceLabEnv(base_url="https://your-space.hf.space")
179
+ self.reward = 0.0
180
+
181
+ def reset(self, **kwargs):
182
+ obs = self.client.reset()
183
+ return obs.prompt
184
+
185
+ def rank_responses(self, choice: str) -> str:
186
+ """
187
+ Choose which response is better.
188
+
189
+ Args:
190
+ choice: 'A' if response A is better, 'B' if B is better, 'tie' if equal.
191
+
192
+ Returns:
193
+ Feedback on the annotation quality.
194
+ """
195
+ result = self.client.step(PairwiseAction(choice=choice))
196
+ self.reward = result.reward
197
+ return f"Reward: {result.reward:.2f}"
198
+
199
+ def reward_func(environments, **kwargs):
200
+ return [env.reward for env in environments]
201
+ ```
202
+
203
+ ---
204
+
205
+ ## Datasets
206
+
207
+ | Task | Dataset | License |
208
+ |---|---|---|
209
+ | Pairwise | Anthropic HH-RLHF | MIT |
210
+ | Likert | UltraFeedback | MIT |
211
+ | Consistency | Stanford SHP | CC BY 4.0 |
212
+
213
+ Fallback synthetic data is included for zero-dependency testing.
214
+
215
+ ---
216
+
217
+ ## Project Structure
218
+
219
+ ```
220
+ preference-lab/
221
+ β”œβ”€β”€ __init__.py # Package exports
222
+ β”œβ”€β”€ models.py # Pydantic Action/Observation models
223
+ β”œβ”€β”€ client.py # PreferenceLabEnv(MCPToolClient)
224
+ β”œβ”€β”€ inference.py # Baseline inference script (START/STEP/END logs)
225
+ β”œβ”€β”€ openenv.yaml # OpenEnv manifest
226
+ β”œβ”€β”€ pyproject.toml # Dependencies
227
+ β”œβ”€β”€ Dockerfile # Container definition
228
+ β”œβ”€β”€ requirements.txt # Pip requirements
229
+ β”œβ”€β”€ README.md # This file
230
+ β”œβ”€β”€ data/ # Dataset JSONs (or synthetic fallback)
231
+ β”œβ”€β”€ scripts/
232
+ β”‚ └── prepare_datasets.py # Dataset download + conversion
233
+ β”œβ”€β”€ server/
234
+ β”‚ β”œβ”€β”€ app.py # FastAPI server
235
+ β”‚ └── environment.py # Core environment + graders
236
+ └── tests/
237
+ └── test_environment.py # Full test suite
238
+ ```
239
+
240
+ ---
241
+
242
+ ## Team
243
+
244
+ **Team Nexis** β€” Sri Sri University Γ— Zaalima Development Pvt. Ltd.
245
+
246
+ - Sibam Nanda (Architecture + Core Environment)
247
+ - Spandan Kar (Backend + Deployment)
248
+ - Shayanna Behera (Config + Documentation)
249
+
250
+ ---
251
+
252
+ ## License
253
+
254
+ BSD 3-Clause License
__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PreferenceLab: An OpenEnv environment simulating the RLHF
3
+ preference data collection pipeline.
4
+
5
+ Usage:
6
+ from preference_lab import PreferenceLabEnv, PairwiseAction, LikertAction, ConsistencyAction
7
+ """
8
+
9
+ from client import PreferenceLabEnv
10
+ from models import (
11
+ ConsistencyAction,
12
+ ConsistencyObservation,
13
+ LikertAction,
14
+ LikertObservation,
15
+ PairwiseAction,
16
+ PairwiseObservation,
17
+ )
18
+
19
+ __all__ = [
20
+ "PreferenceLabEnv",
21
+ "PairwiseAction",
22
+ "PairwiseObservation",
23
+ "LikertAction",
24
+ "LikertObservation",
25
+ "ConsistencyAction",
26
+ "ConsistencyObservation",
27
+ ]
client.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PreferenceLab Environment Client.
3
+
4
+ Connects to a running PreferenceLab server via WebSocket/HTTP.
5
+
6
+ Example (async):
7
+ >>> from preference_lab import PreferenceLabEnv, PairwiseAction
8
+ >>> async with PreferenceLabEnv(base_url="https://your-space.hf.space") as env:
9
+ ... obs = await env.reset()
10
+ ... result = await env.step(PairwiseAction(choice="A"))
11
+
12
+ Example (sync):
13
+ >>> with PreferenceLabEnv(base_url="https://your-space.hf.space").sync() as env:
14
+ ... obs = env.reset()
15
+ ... result = env.step(PairwiseAction(choice="A"))
16
+ """
17
+
18
+ from openenv.core.mcp_client import MCPToolClient
19
+
20
+
21
+ class PreferenceLabEnv(MCPToolClient):
22
+ """
23
+ Client for the PreferenceLab Environment.
24
+
25
+ Provides tool-calling style interactions with the RLHF preference
26
+ simulation environment. Inherits all functionality from MCPToolClient.
27
+
28
+ Available tools (discovered via list_tools()):
29
+ - rank_responses: Task 1 β€” choose A or B
30
+ - score_response: Task 2 β€” rate on 4 axes
31
+ - order_responses: Task 3 β€” rank 4 responses
32
+
33
+ Example:
34
+ >>> with PreferenceLabEnv(base_url="http://localhost:8000").sync() as env:
35
+ ... env.reset()
36
+ ... tools = env.list_tools()
37
+ ... result = env.call_tool("rank_responses", choice="A")
38
+ """
39
+
40
+ pass # MCPToolClient provides all needed functionality
data/.gitkeep ADDED
File without changes
data/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Directory
2
+
3
+ This directory holds the preference datasets used by PreferenceLab.
4
+
5
+ On first run, if these files are absent, the environment falls back to
6
+ built-in synthetic examples (defined in `server/environment.py`).
7
+
8
+ ## File Format
9
+
10
+ ### pairwise_data.json
11
+ ```json
12
+ [
13
+ {
14
+ "prompt": "...",
15
+ "response_a": "...",
16
+ "response_b": "...",
17
+ "gold_label": "A",
18
+ "source": "hh-rlhf"
19
+ }
20
+ ]
21
+ ```
22
+
23
+ ### likert_data.json
24
+ ```json
25
+ [
26
+ {
27
+ "prompt": "...",
28
+ "response": "...",
29
+ "rubric": "...",
30
+ "gold_scores": {
31
+ "helpfulness": 4,
32
+ "honesty": 5,
33
+ "harmlessness": 5,
34
+ "instruction_following": 4
35
+ },
36
+ "source": "ultrafeedback"
37
+ }
38
+ ]
39
+ ```
40
+
41
+ ### consistency_data.json
42
+ ```json
43
+ [
44
+ {
45
+ "prompt": "...",
46
+ "response_a": "...",
47
+ "response_b": "...",
48
+ "response_c": "...",
49
+ "response_d": "...",
50
+ "gold_ranking": ["C", "A", "B", "D"],
51
+ "source": "stanford-shp"
52
+ }
53
+ ]
54
+ ```
55
+
56
+ ## Loading Real Datasets
57
+
58
+ Run `python scripts/prepare_datasets.py` to download and convert
59
+ HH-RLHF, UltraFeedback, and Stanford SHP into these formats.
inference.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PreferenceLab Baseline Inference Script.
3
+
4
+ Runs a baseline LLM agent against all 3 tasks and reports scores.
5
+
6
+ Environment variables:
7
+ API_BASE_URL β€” LLM API endpoint (required, with default)
8
+ MODEL_NAME β€” Model identifier (required, with default)
9
+ HF_TOKEN β€” Hugging Face / API key (no default β€” injected by HF Spaces)
10
+ ENV_BASE_URL β€” PreferenceLab Space URL (optional, defaults to localhost)
11
+
12
+ Usage:
13
+ python inference.py
14
+ HF_TOKEN=hf_xxx MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct python inference.py
15
+ """
16
+
17
+ import json
18
+ import os
19
+ import time
20
+ from typing import Any
21
+
22
+ from openai import OpenAI
23
+
24
+ # ── Environment variables (MANDATORY pattern for hackathon) ───
25
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
26
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
27
+ HF_TOKEN = os.getenv("HF_TOKEN") # NO default β€” injected by HF Spaces
28
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
29
+
30
+ # ── OpenAI client (MANDATORY: all LLM calls via OpenAI client) ─
31
+ client = OpenAI(
32
+ api_key=HF_TOKEN,
33
+ base_url=API_BASE_URL,
34
+ )
35
+
36
+ # ── Logging helpers ───────────────────────────────────────────
37
+
38
+ def log_start(task_name: str, task_id: str):
39
+ """Stdout START log β€” required structured format."""
40
+ print(f"START task_name={task_name} task_id={task_id} timestamp={time.time():.0f}")
41
+
42
+
43
+ def log_step(task_id: str, step: int, action: dict, reward: float):
44
+ """Stdout STEP log β€” required structured format."""
45
+ print(f"STEP task_id={task_id} step={step} action={json.dumps(action)} reward={reward:.4f}")
46
+
47
+
48
+ def log_end(task_id: str, total_reward: float, steps: int):
49
+ """Stdout END log β€” required structured format."""
50
+ print(f"END task_id={task_id} total_reward={total_reward:.4f} steps={steps}")
51
+
52
+
53
+ # ── LLM helpers ───────────────────────────────────────────────
54
+
55
+ def call_llm(system: str, user: str, max_tokens: int = 256) -> str:
56
+ """Call LLM via OpenAI client. Returns text response."""
57
+ try:
58
+ response = client.chat.completions.create(
59
+ model=MODEL_NAME,
60
+ messages=[
61
+ {"role": "system", "content": system},
62
+ {"role": "user", "content": user},
63
+ ],
64
+ max_tokens=max_tokens,
65
+ temperature=0.0, # deterministic for reproducibility
66
+ )
67
+ return response.choices[0].message.content.strip()
68
+ except Exception as e:
69
+ print(f" [LLM ERROR] {e}")
70
+ return ""
71
+
72
+
73
+ def parse_json_response(text: str, fallback: dict) -> dict:
74
+ """Try to parse JSON from LLM response."""
75
+ try:
76
+ # Find JSON block
77
+ start = text.find("{")
78
+ end = text.rfind("}") + 1
79
+ if start >= 0 and end > start:
80
+ return json.loads(text[start:end])
81
+ except Exception:
82
+ pass
83
+ return fallback
84
+
85
+
86
+ # ── Task runners ──────────────────────────────────────────────
87
+
88
+ def run_task1_pairwise(env_client) -> dict[str, Any]:
89
+ """Run Task 1: Pairwise ranking."""
90
+ task_id = "task1_pairwise"
91
+ log_start("pairwise_ranking", task_id)
92
+
93
+ obs = env_client.reset(task_type="pairwise")
94
+ total_reward = 0.0
95
+ steps = 0
96
+
97
+ SYSTEM = (
98
+ "You are an expert AI response evaluator. "
99
+ "Given a prompt and two responses (A and B), choose which is better. "
100
+ "Reply ONLY with valid JSON: {\"choice\": \"A\"} or {\"choice\": \"B\"} or {\"choice\": \"tie\"}."
101
+ )
102
+
103
+ for step in range(5):
104
+ user_prompt = (
105
+ f"Prompt: {obs.prompt}\n\n"
106
+ f"Response A:\n{obs.response_a}\n\n"
107
+ f"Response B:\n{obs.response_b}\n\n"
108
+ "Which response is better? Reply with JSON only."
109
+ )
110
+ llm_out = call_llm(SYSTEM, user_prompt)
111
+ parsed = parse_json_response(llm_out, {"choice": "A"})
112
+ choice = parsed.get("choice", "A")
113
+ if choice not in ("A", "B", "tie", "skip"):
114
+ choice = "A"
115
+
116
+ from models import PairwiseAction
117
+ action = PairwiseAction(choice=choice)
118
+ obs, reward, done, info = env_client.step(action)
119
+
120
+ total_reward += reward
121
+ steps += 1
122
+ log_step(task_id, step + 1, {"choice": choice}, reward)
123
+
124
+ if done:
125
+ break
126
+
127
+ log_end(task_id, total_reward, steps)
128
+ avg = round(total_reward / steps, 4) if steps > 0 else 0.0
129
+ return {"task": "pairwise", "total_reward": total_reward, "steps": steps, "avg_reward": avg}
130
+
131
+
132
+ def run_task2_likert(env_client) -> dict[str, Any]:
133
+ """Run Task 2: Multi-axis Likert scoring."""
134
+ task_id = "task2_likert"
135
+ log_start("likert_scoring", task_id)
136
+
137
+ obs = env_client.reset(task_type="likert")
138
+ total_reward = 0.0
139
+ steps = 0
140
+
141
+ SYSTEM = (
142
+ "You are an expert AI response evaluator. Score responses on 4 axes (1=worst, 5=best):\n"
143
+ "- helpfulness: how useful is the response?\n"
144
+ "- honesty: how accurate and truthful?\n"
145
+ "- harmlessness: how safe and appropriate?\n"
146
+ "- instruction_following: how well does it follow the instruction?\n"
147
+ "Reply ONLY with JSON: {\"helpfulness\": 4, \"honesty\": 5, \"harmlessness\": 5, \"instruction_following\": 4}"
148
+ )
149
+
150
+ for step in range(5):
151
+ user_prompt = (
152
+ f"Prompt: {obs.prompt}\n\n"
153
+ f"Response:\n{obs.response}\n\n"
154
+ f"Rubric: {obs.rubric}\n\n"
155
+ "Score this response. Reply with JSON only."
156
+ )
157
+ llm_out = call_llm(SYSTEM, user_prompt)
158
+ parsed = parse_json_response(llm_out, {
159
+ "helpfulness": 3, "honesty": 3, "harmlessness": 4, "instruction_following": 3
160
+ })
161
+
162
+ def clamp(v): return max(1, min(5, int(parsed.get(v, 3))))
163
+
164
+ from models import LikertAction
165
+ action = LikertAction(
166
+ helpfulness=clamp("helpfulness"),
167
+ honesty=clamp("honesty"),
168
+ harmlessness=clamp("harmlessness"),
169
+ instruction_following=clamp("instruction_following"),
170
+ )
171
+ obs, reward, done, info = env_client.step(action)
172
+
173
+ total_reward += reward
174
+ steps += 1
175
+ log_step(task_id, step + 1, parsed, reward)
176
+
177
+ if done:
178
+ break
179
+
180
+ log_end(task_id, total_reward, steps)
181
+ avg = round(total_reward / steps, 4) if steps > 0 else 0.0
182
+ return {"task": "likert", "total_reward": total_reward, "steps": steps, "avg_reward": avg}
183
+
184
+
185
+ def run_task3_consistency(env_client) -> dict[str, Any]:
186
+ """Run Task 3: Transitive consistency chain ranking."""
187
+ task_id = "task3_consistency"
188
+ log_start("consistency_ranking", task_id)
189
+
190
+ obs = env_client.reset(task_type="consistency")
191
+ total_reward = 0.0
192
+ steps = 0
193
+
194
+ SYSTEM = (
195
+ "You are an expert AI response evaluator. "
196
+ "Rank 4 responses (A, B, C, D) from best to worst. "
197
+ "Reply ONLY with JSON: {\"ranking\": [\"C\", \"A\", \"D\", \"B\"]} "
198
+ "(best first, worst last)."
199
+ )
200
+
201
+ for step in range(5):
202
+ user_prompt = (
203
+ f"Prompt: {obs.prompt}\n\n"
204
+ f"Response A:\n{obs.response_a}\n\n"
205
+ f"Response B:\n{obs.response_b}\n\n"
206
+ f"Response C:\n{obs.response_c}\n\n"
207
+ f"Response D:\n{obs.response_d}\n\n"
208
+ "Rank these responses from best to worst. Reply with JSON only."
209
+ )
210
+ llm_out = call_llm(SYSTEM, user_prompt)
211
+ parsed = parse_json_response(llm_out, {"ranking": ["A", "B", "C", "D"]})
212
+ ranking = parsed.get("ranking", ["A", "B", "C", "D"])
213
+ if not isinstance(ranking, list) or len(ranking) != 4:
214
+ ranking = ["A", "B", "C", "D"]
215
+
216
+ from models import ConsistencyAction
217
+ action = ConsistencyAction(ranking=ranking)
218
+ obs, reward, done, info = env_client.step(action)
219
+
220
+ total_reward += reward
221
+ steps += 1
222
+ log_step(task_id, step + 1, {"ranking": ranking}, reward)
223
+
224
+ if done:
225
+ break
226
+
227
+ log_end(task_id, total_reward, steps)
228
+ avg = round(total_reward / steps, 4) if steps > 0 else 0.0
229
+ return {"task": "consistency", "total_reward": total_reward, "steps": steps, "avg_reward": avg}
230
+
231
+
232
+ # ── Main ──────────────────────────────────────────────────────
233
+
234
+ def main():
235
+ print("=" * 60)
236
+ print("PreferenceLab Baseline Inference")
237
+ print(f"Model: {MODEL_NAME}")
238
+ print(f"API URL: {API_BASE_URL}")
239
+ print(f"Env URL: {ENV_BASE_URL}")
240
+ print("=" * 60)
241
+
242
+ # Import environment directly for local run
243
+ # In production, connect to the HF Space via client
244
+ import sys
245
+ sys.path.insert(0, ".")
246
+ from server.environment import PreferenceLabEnvironment
247
+
248
+ env = PreferenceLabEnvironment()
249
+
250
+ results = []
251
+ results.append(run_task1_pairwise(env))
252
+ results.append(run_task2_likert(env))
253
+ results.append(run_task3_consistency(env))
254
+
255
+ print("\n" + "=" * 60)
256
+ print("RESULTS SUMMARY")
257
+ print("=" * 60)
258
+ overall = 0.0
259
+ for r in results:
260
+ print(f" {r['task']:20s} avg_reward={r['avg_reward']:.4f} steps={r['steps']}")
261
+ overall += r["avg_reward"]
262
+ print(f"\n Overall avg reward: {overall / len(results):.4f}")
263
+ print("=" * 60)
264
+
265
+ return results
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
models.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data models for PreferenceLab Environment.
3
+
4
+ PreferenceLab simulates the RLHF preference data collection pipeline.
5
+ An AI agent learns to judge which LLM responses are better using
6
+ gold-standard labels from real preference datasets as ground truth.
7
+
8
+ Tasks:
9
+ Task 1 (Easy): Pairwise ranking β€” pick A or B
10
+ Task 2 (Medium): Multi-axis Likert scoring
11
+ Task 3 (Hard): Transitive consistency chain ranking
12
+ """
13
+
14
+ from typing import Literal, Optional
15
+ from pydantic import Field
16
+
17
+ from openenv.core.env_server.types import Action, Observation
18
+
19
+
20
+ # ── Actions ───────────────────────────────────────────────────
21
+
22
+ class PairwiseAction(Action):
23
+ """Action for Task 1: choose which response is better."""
24
+
25
+ choice: Literal["A", "B", "tie", "skip"] = Field(
26
+ ...,
27
+ description="Choose 'A' if response A is better, 'B' if response B is better, "
28
+ "'tie' if they are equal, or 'skip' if undecidable.",
29
+ )
30
+ justification: Optional[str] = Field(
31
+ default=None,
32
+ description="Optional reasoning for the choice (not used for grading).",
33
+ )
34
+
35
+
36
+ class LikertAction(Action):
37
+ """Action for Task 2: score a response on 4 quality axes (1-5 each)."""
38
+
39
+ helpfulness: int = Field(
40
+ ..., ge=1, le=5,
41
+ description="How helpful is the response? 1=not helpful, 5=extremely helpful.",
42
+ )
43
+ honesty: int = Field(
44
+ ..., ge=1, le=5,
45
+ description="How honest and accurate is the response? 1=misleading, 5=fully honest.",
46
+ )
47
+ harmlessness: int = Field(
48
+ ..., ge=1, le=5,
49
+ description="How harmless is the response? 1=harmful, 5=completely safe.",
50
+ )
51
+ instruction_following: int = Field(
52
+ ..., ge=1, le=5,
53
+ description="How well does the response follow the instruction? 1=ignores it, 5=follows perfectly.",
54
+ )
55
+
56
+
57
+ class ConsistencyAction(Action):
58
+ """Action for Task 3: rank 4 responses (A, B, C, D) from best to worst."""
59
+
60
+ ranking: list[str] = Field(
61
+ ...,
62
+ min_length=4,
63
+ max_length=4,
64
+ description="List of 4 response IDs ordered best to worst, e.g. ['B', 'A', 'D', 'C'].",
65
+ )
66
+
67
+
68
+ # ── Observations ──────────────────────────────────────────────
69
+
70
+ class PairwiseObservation(Observation):
71
+ """Observation for Task 1: a prompt with two candidate responses."""
72
+
73
+ task_id: str = Field(..., description="Unique task identifier.")
74
+ task_type: Literal["pairwise"] = Field(default="pairwise")
75
+ prompt: str = Field(..., description="The user prompt / instruction.")
76
+ response_a: str = Field(..., description="Candidate response A.")
77
+ response_b: str = Field(..., description="Candidate response B.")
78
+ reward: float = Field(default=0.0, description="Reward signal from last step.")
79
+ done: bool = Field(default=False, description="Whether the episode is complete.")
80
+ step_count: int = Field(default=0, description="Current step within the episode.")
81
+ info: dict = Field(default_factory=dict, description="Extra debug info.")
82
+
83
+
84
+ class LikertObservation(Observation):
85
+ """Observation for Task 2: a prompt + single response to score on multiple axes."""
86
+
87
+ task_id: str = Field(..., description="Unique task identifier.")
88
+ task_type: Literal["likert"] = Field(default="likert")
89
+ prompt: str = Field(..., description="The user prompt / instruction.")
90
+ response: str = Field(..., description="The response to evaluate.")
91
+ rubric: str = Field(..., description="Scoring rubric to guide evaluation.")
92
+ reward: float = Field(default=0.0, description="Reward signal from last step.")
93
+ done: bool = Field(default=False, description="Whether the episode is complete.")
94
+ step_count: int = Field(default=0, description="Current step within the episode.")
95
+ info: dict = Field(default_factory=dict, description="Extra debug info.")
96
+
97
+
98
+ class ConsistencyObservation(Observation):
99
+ """Observation for Task 3: a prompt + 4 responses to rank transitively."""
100
+
101
+ task_id: str = Field(..., description="Unique task identifier.")
102
+ task_type: Literal["consistency"] = Field(default="consistency")
103
+ prompt: str = Field(..., description="The user prompt / instruction.")
104
+ response_a: str = Field(..., description="Candidate response A.")
105
+ response_b: str = Field(..., description="Candidate response B.")
106
+ response_c: str = Field(..., description="Candidate response C.")
107
+ response_d: str = Field(..., description="Candidate response D.")
108
+ reward: float = Field(default=0.0, description="Reward signal from last step.")
109
+ done: bool = Field(default=False, description="Whether the episode is complete.")
110
+ step_count: int = Field(default=0, description="Current step within the episode.")
111
+ info: dict = Field(default_factory=dict, description="Extra debug info.")
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: preference_lab
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
pyproject.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.backends.legacy:build"
4
+
5
+ [project]
6
+ name = "preference-lab"
7
+ version = "0.1.0"
8
+ description = "An OpenEnv environment simulating the RLHF preference data collection pipeline"
9
+ readme = "README.md"
10
+ license = { text = "BSD-3-Clause" }
11
+ requires-python = ">=3.10"
12
+ dependencies = [
13
+ "openenv-core>=0.2.1",
14
+ "fastapi>=0.104.0",
15
+ "uvicorn>=0.24.0",
16
+ "pydantic>=2.0.0",
17
+ "openai>=1.0.0",
18
+ "datasets>=2.14.0",
19
+ "httpx>=0.25.0",
20
+ "websockets>=11.0",
21
+ ]
22
+
23
+ [project.optional-dependencies]
24
+ dev = [
25
+ "pytest>=7.0",
26
+ "pytest-asyncio>=0.21",
27
+ "httpx>=0.25.0",
28
+ ]
29
+
30
+ [tool.setuptools.packages.find]
31
+ where = ["."]
32
+
33
+ [tool.pytest.ini_options]
34
+ asyncio_mode = "auto"
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ openenv-core>=0.2.1
2
+ fastapi>=0.104.0
3
+ uvicorn>=0.24.0
4
+ pydantic>=2.0.0
5
+ openai>=1.0.0
6
+ datasets>=2.14.0
7
+ httpx>=0.25.0
8
+ websockets>=11.0
scripts/__init__.py ADDED
File without changes
scripts/prepare_datasets.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Preparation Script.
3
+
4
+ Downloads HH-RLHF, UltraFeedback, and Stanford SHP from Hugging Face
5
+ and converts them into the format expected by PreferenceLab.
6
+
7
+ Usage:
8
+ python scripts/prepare_datasets.py
9
+ python scripts/prepare_datasets.py --samples 200
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import random
15
+ from pathlib import Path
16
+
17
+ DATA_DIR = Path(__file__).parent.parent / "data"
18
+ DATA_DIR.mkdir(exist_ok=True)
19
+
20
+
21
+ def prepare_pairwise(n_samples: int = 100):
22
+ """Download Anthropic HH-RLHF and convert to pairwise format."""
23
+ print(f"[1/3] Preparing pairwise data (HH-RLHF, {n_samples} samples)...")
24
+ try:
25
+ from datasets import load_dataset
26
+ ds = load_dataset("Anthropic/hh-rlhf", split="train", streaming=True)
27
+ records = []
28
+ for i, ex in enumerate(ds):
29
+ if i >= n_samples:
30
+ break
31
+ # chosen = better response, rejected = worse
32
+ chosen = ex.get("chosen", "")
33
+ rejected = ex.get("rejected", "")
34
+ # Extract the last human turn as prompt
35
+ lines = chosen.split("\n\nAssistant:")
36
+ if len(lines) >= 2:
37
+ prompt_block = lines[0].replace("Human:", "").strip()
38
+ resp_a = lines[-1].strip()
39
+ else:
40
+ prompt_block = chosen[:100]
41
+ resp_a = chosen
42
+
43
+ rej_lines = rejected.split("\n\nAssistant:")
44
+ resp_b = rej_lines[-1].strip() if len(rej_lines) >= 2 else rejected
45
+
46
+ # Randomly swap A/B to avoid position bias, track gold
47
+ if random.random() < 0.5:
48
+ records.append({
49
+ "prompt": prompt_block,
50
+ "response_a": resp_a,
51
+ "response_b": resp_b,
52
+ "gold_label": "A",
53
+ "source": "hh-rlhf",
54
+ })
55
+ else:
56
+ records.append({
57
+ "prompt": prompt_block,
58
+ "response_a": resp_b,
59
+ "response_b": resp_a,
60
+ "gold_label": "B",
61
+ "source": "hh-rlhf",
62
+ })
63
+
64
+ out = DATA_DIR / "pairwise_data.json"
65
+ with open(out, "w") as f:
66
+ json.dump(records, f, indent=2)
67
+ print(f" βœ“ Saved {len(records)} pairwise examples β†’ {out}")
68
+ except Exception as e:
69
+ print(f" βœ— Failed: {e} β€” synthetic fallback will be used")
70
+
71
+
72
+ def prepare_likert(n_samples: int = 100):
73
+ """Download UltraFeedback and convert to likert format."""
74
+ print(f"[2/3] Preparing likert data (UltraFeedback, {n_samples} samples)...")
75
+ try:
76
+ from datasets import load_dataset
77
+ ds = load_dataset("openbmb/UltraFeedback", split="train", streaming=True)
78
+ records = []
79
+ for i, ex in enumerate(ds):
80
+ if i >= n_samples:
81
+ break
82
+ instr = ex.get("instruction", "")
83
+ completions = ex.get("completions", [])
84
+ if not completions:
85
+ continue
86
+ comp = completions[0]
87
+ response = comp.get("response", "")
88
+ annots = comp.get("annotations", {})
89
+
90
+ def extract_score(key, default=3):
91
+ val = annots.get(key, {})
92
+ if isinstance(val, dict):
93
+ raw = val.get("Rating", default)
94
+ elif isinstance(val, (int, float)):
95
+ raw = val
96
+ else:
97
+ raw = default
98
+ # UltraFeedback uses 1-5 scale
99
+ try:
100
+ return max(1, min(5, int(raw)))
101
+ except Exception:
102
+ return default
103
+
104
+ records.append({
105
+ "prompt": instr,
106
+ "response": response,
107
+ "rubric": (
108
+ "Score on 4 axes (1=worst, 5=best): helpfulness, honesty, "
109
+ "harmlessness, instruction_following."
110
+ ),
111
+ "gold_scores": {
112
+ "helpfulness": extract_score("instruction_following"),
113
+ "honesty": extract_score("honesty"),
114
+ "harmlessness": extract_score("truthfulness", 4),
115
+ "instruction_following": extract_score("instruction_following"),
116
+ },
117
+ "source": "ultrafeedback",
118
+ })
119
+
120
+ out = DATA_DIR / "likert_data.json"
121
+ with open(out, "w") as f:
122
+ json.dump(records, f, indent=2)
123
+ print(f" βœ“ Saved {len(records)} likert examples β†’ {out}")
124
+ except Exception as e:
125
+ print(f" βœ— Failed: {e} β€” synthetic fallback will be used")
126
+
127
+
128
+ def prepare_consistency(n_samples: int = 60):
129
+ """Build 4-way ranking examples from Stanford SHP."""
130
+ print(f"[3/3] Preparing consistency data (Stanford SHP, {n_samples} samples)...")
131
+ try:
132
+ from datasets import load_dataset
133
+ ds = load_dataset("stanfordnlp/SHP", split="train", streaming=True)
134
+
135
+ # Group by post_id to collect multiple responses per prompt
136
+ grouped: dict[str, dict] = {}
137
+ for ex in ds:
138
+ pid = ex.get("post_id", "")
139
+ if pid not in grouped:
140
+ grouped[pid] = {
141
+ "prompt": ex.get("history", ""),
142
+ "responses": [],
143
+ }
144
+ grouped[pid]["responses"].append({
145
+ "text": ex.get("human_ref_A", "") or ex.get("human_ref_B", ""),
146
+ "score": ex.get("score_ratio", 1.0),
147
+ })
148
+ if len(grouped) >= n_samples * 3:
149
+ break
150
+
151
+ records = []
152
+ for pid, data in grouped.items():
153
+ resps = data["responses"]
154
+ if len(resps) < 4:
155
+ continue
156
+ # Sort by score descending = gold ranking
157
+ resps_sorted = sorted(resps[:4], key=lambda x: x["score"], reverse=True)
158
+ labels = ["A", "B", "C", "D"]
159
+ # Shuffle display order (not gold order)
160
+ shuffled = resps_sorted[:]
161
+ random.shuffle(shuffled)
162
+ id_map = {labels[i]: shuffled[i] for i in range(4)}
163
+ gold_ranking = sorted(labels, key=lambda l: resps_sorted.index(id_map[l]))
164
+
165
+ records.append({
166
+ "prompt": data["prompt"][:500],
167
+ "response_a": id_map["A"]["text"][:400],
168
+ "response_b": id_map["B"]["text"][:400],
169
+ "response_c": id_map["C"]["text"][:400],
170
+ "response_d": id_map["D"]["text"][:400],
171
+ "gold_ranking": gold_ranking,
172
+ "source": "stanford-shp",
173
+ })
174
+ if len(records) >= n_samples:
175
+ break
176
+
177
+ out = DATA_DIR / "consistency_data.json"
178
+ with open(out, "w") as f:
179
+ json.dump(records, f, indent=2)
180
+ print(f" βœ“ Saved {len(records)} consistency examples β†’ {out}")
181
+ except Exception as e:
182
+ print(f" βœ— Failed: {e} β€” synthetic fallback will be used")
183
+
184
+
185
+ def main():
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument("--samples", type=int, default=100,
188
+ help="Number of samples per task (default: 100)")
189
+ args = parser.parse_args()
190
+
191
+ print("=" * 50)
192
+ print("PreferenceLab Dataset Preparation")
193
+ print("=" * 50)
194
+ prepare_pairwise(args.samples)
195
+ prepare_likert(args.samples)
196
+ prepare_consistency(args.samples // 2)
197
+ print("\nβœ“ Done. Run inference.py to test.")
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # server package
server/app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PreferenceLab FastAPI Server.
3
+
4
+ Exposes the PreferenceLabEnvironment via the OpenEnv HTTP interface.
5
+ Supports concurrent sessions for parallel training.
6
+ """
7
+
8
+ import os
9
+
10
+ from openenv.core.env_server import create_app
11
+
12
+ from models import (
13
+ ConsistencyAction,
14
+ ConsistencyObservation,
15
+ LikertAction,
16
+ LikertObservation,
17
+ PairwiseAction,
18
+ PairwiseObservation,
19
+ )
20
+ from server.environment import PreferenceLabEnvironment
21
+
22
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
23
+ MAX_CONCURRENT_ENVS = int(os.environ.get("MAX_CONCURRENT_ENVS", "64"))
24
+
25
+
26
+ def create_environment() -> PreferenceLabEnvironment:
27
+ """Factory function β€” called once per session."""
28
+ return PreferenceLabEnvironment()
29
+
30
+
31
+ # Use PairwiseAction/PairwiseObservation as the primary schema.
32
+ # The environment internally handles all three task types.
33
+ app = create_app(
34
+ create_environment,
35
+ PairwiseAction,
36
+ PairwiseObservation,
37
+ max_concurrent_envs=MAX_CONCURRENT_ENVS,
38
+ )
server/environment.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PreferenceLab Core Environment.
3
+
4
+ Implements the OpenEnv Environment base class with:
5
+ - reset() β†’ returns initial observation
6
+ - step() β†’ executes action, returns (observation, reward, done, info)
7
+ - state() β†’ returns episode metadata
8
+
9
+ Three tasks:
10
+ Task 1 (pairwise) - Easy: pairwise choice graded against HH-RLHF gold labels
11
+ Task 2 (likert) - Medium: multi-axis scoring graded via MSE vs UltraFeedback scores
12
+ Task 3 (consistency) - Hard: 4-way ranking graded on transitivity + quality correlation
13
+ """
14
+
15
+ import json
16
+ import random
17
+ import uuid
18
+ from itertools import permutations
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ from openenv.core.env_server import Environment
23
+
24
+ from models import (
25
+ ConsistencyAction,
26
+ ConsistencyObservation,
27
+ LikertAction,
28
+ LikertObservation,
29
+ PairwiseAction,
30
+ PairwiseObservation,
31
+ )
32
+
33
+ # ── Dataset loading ────────────────────────────────────────────
34
+
35
+ DATA_DIR = Path(__file__).parent.parent / "data"
36
+
37
+ def _load_json(filename: str) -> list[dict]:
38
+ path = DATA_DIR / filename
39
+ if path.exists():
40
+ with open(path) as f:
41
+ return json.load(f)
42
+ return []
43
+
44
+
45
+ # ── Graders ───────────────────────────────────────────────────
46
+
47
+ def grade_pairwise(action: PairwiseAction, example: dict) -> tuple[float, dict]:
48
+ """
49
+ Grade Task 1: Pairwise ranking.
50
+
51
+ Gold label is 'A' (chosen) or 'B' (rejected) from dataset.
52
+ Returns:
53
+ 1.0 β†’ correct choice
54
+ 0.3 β†’ skip (abstain β€” partial credit)
55
+ 0.0 β†’ wrong choice
56
+ 0.1 β†’ tie (when gold is clear)
57
+ """
58
+ gold = example.get("gold_label", "A") # 'A' = chosen is response_a
59
+ choice = action.choice
60
+
61
+ if choice == "skip":
62
+ reward = 0.3
63
+ verdict = "abstained"
64
+ elif choice == "tie":
65
+ reward = 0.1
66
+ verdict = "tie_when_clear"
67
+ elif choice == gold:
68
+ reward = 1.0
69
+ verdict = "correct"
70
+ else:
71
+ reward = 0.0
72
+ verdict = "incorrect"
73
+
74
+ return reward, {
75
+ "gold": gold,
76
+ "chosen": choice,
77
+ "verdict": verdict,
78
+ "dataset": example.get("source", "hh-rlhf"),
79
+ }
80
+
81
+
82
+ def grade_likert(action: LikertAction, example: dict) -> tuple[float, dict]:
83
+ """
84
+ Grade Task 2: Multi-axis Likert scoring.
85
+
86
+ Compares agent's 4-axis scores to gold scores from UltraFeedback.
87
+ Reward = 1.0 - (mean_absolute_error / max_possible_error)
88
+ Max possible error per axis = 4 (1 vs 5), so max_total = 4.
89
+ """
90
+ gold_scores = example.get("gold_scores", {
91
+ "helpfulness": 3,
92
+ "honesty": 3,
93
+ "harmlessness": 4,
94
+ "instruction_following": 3,
95
+ })
96
+
97
+ axes = ["helpfulness", "honesty", "harmlessness", "instruction_following"]
98
+ agent_scores = {
99
+ "helpfulness": action.helpfulness,
100
+ "honesty": action.honesty,
101
+ "harmlessness": action.harmlessness,
102
+ "instruction_following": action.instruction_following,
103
+ }
104
+
105
+ errors = []
106
+ per_axis = {}
107
+ for ax in axes:
108
+ err = abs(agent_scores[ax] - gold_scores.get(ax, 3))
109
+ errors.append(err)
110
+ per_axis[ax] = {"agent": agent_scores[ax], "gold": gold_scores.get(ax, 3), "error": err}
111
+
112
+ mae = sum(errors) / len(errors)
113
+ max_error = 4.0 # max abs difference on 1-5 scale
114
+ reward = round(1.0 - (mae / max_error), 4)
115
+ reward = max(0.0, min(1.0, reward))
116
+
117
+ return reward, {
118
+ "mae": round(mae, 4),
119
+ "per_axis": per_axis,
120
+ "dataset": example.get("source", "ultrafeedback"),
121
+ }
122
+
123
+
124
+ def grade_consistency(action: ConsistencyAction, example: dict) -> tuple[float, dict]:
125
+ """
126
+ Grade Task 3: Transitive consistency chain ranking.
127
+
128
+ Scoring components:
129
+ - Transitivity score (0.0–0.5): penalise transitive violations in the ranking
130
+ - Quality correlation (0.0–0.5): Kendall's tau vs gold ranking
131
+ Total reward = transitivity_score + quality_score (max 1.0)
132
+ """
133
+ ranking = action.ranking
134
+ gold_ranking = example.get("gold_ranking", ["A", "B", "C", "D"])
135
+
136
+ # --- Transitivity check ---
137
+ # For each triple (i, j, k) where i < j < k in the agent's ranking,
138
+ # verify that position(i) < position(j) < position(k) doesn't violate transitivity.
139
+ # A violation = agent says A > B and B > C but NOT A > C.
140
+ # Since ranking is a total order, by construction it IS transitive. But we can
141
+ # still penalise if ranking contains duplicates or invalid IDs.
142
+ valid_ids = {"A", "B", "C", "D"}
143
+ has_invalid = not (set(ranking) == valid_ids)
144
+ transitivity_score = 0.0 if has_invalid else 0.5
145
+
146
+ # --- Quality correlation (Kendall's tau, simplified) ---
147
+ if has_invalid:
148
+ quality_score = 0.0
149
+ n_concordant = 0
150
+ n_discordant = 0
151
+ else:
152
+ ids = ["A", "B", "C", "D"]
153
+ agent_pos = {r: i for i, r in enumerate(ranking)}
154
+ gold_pos = {r: i for i, r in enumerate(gold_ranking)}
155
+
156
+ n_concordant = 0
157
+ n_discordant = 0
158
+ pairs = [(ids[i], ids[j]) for i in range(4) for j in range(i + 1, 4)]
159
+ for x, y in pairs:
160
+ agent_order = agent_pos[x] < agent_pos[y]
161
+ gold_order = gold_pos[x] < gold_pos[y]
162
+ if agent_order == gold_order:
163
+ n_concordant += 1
164
+ else:
165
+ n_discordant += 1
166
+
167
+ total_pairs = n_concordant + n_discordant
168
+ tau = (n_concordant - n_discordant) / total_pairs if total_pairs > 0 else 0.0
169
+ # Normalise tau from [-1,1] to [0, 0.5]
170
+ quality_score = round((tau + 1.0) / 2.0 * 0.5, 4)
171
+
172
+ reward = round(transitivity_score + quality_score, 4)
173
+ reward = max(0.0, min(1.0, reward))
174
+
175
+ return reward, {
176
+ "transitivity_score": transitivity_score,
177
+ "quality_score": quality_score if not has_invalid else 0.0,
178
+ "agent_ranking": ranking,
179
+ "gold_ranking": gold_ranking,
180
+ "has_invalid_ids": has_invalid,
181
+ "dataset": example.get("source", "stanford-shp"),
182
+ }
183
+
184
+
185
+ # ── Environment ───────────────────────────────────────────────
186
+
187
+ TASK_TYPES = ["pairwise", "likert", "consistency"]
188
+ MAX_STEPS_PER_EPISODE = 5
189
+
190
+
191
+ class PreferenceLabEnvironment(Environment):
192
+ """
193
+ PreferenceLab: An RL environment simulating the RLHF preference
194
+ data collection pipeline.
195
+
196
+ Each episode consists of MAX_STEPS_PER_EPISODE annotation steps.
197
+ The task type is fixed per episode (chosen at reset).
198
+ """
199
+
200
+ def __init__(self):
201
+ self._episode_id: str = ""
202
+ self._step_count: int = 0
203
+ self._task_type: str = "pairwise"
204
+ self._current_example: dict = {}
205
+ self._cumulative_reward: float = 0.0
206
+ self._seed: int = 0
207
+
208
+ # Load datasets
209
+ self._pairwise_data: list[dict] = _load_json("pairwise_data.json")
210
+ self._likert_data: list[dict] = _load_json("likert_data.json")
211
+ self._consistency_data: list[dict] = _load_json("consistency_data.json")
212
+
213
+ # Fallback synthetic data if files not present
214
+ if not self._pairwise_data:
215
+ self._pairwise_data = _synthetic_pairwise()
216
+ if not self._likert_data:
217
+ self._likert_data = _synthetic_likert()
218
+ if not self._consistency_data:
219
+ self._consistency_data = _synthetic_consistency()
220
+
221
+ # ── OpenEnv API ───────────────────────────────────────────
222
+
223
+ def reset(self, seed: int | None = None, episode_id: str | None = None, **kwargs):
224
+ """
225
+ Reset the environment for a new episode.
226
+
227
+ Args:
228
+ seed: Optional random seed for reproducibility.
229
+ episode_id: Optional episode ID override.
230
+ **kwargs: Accepts task_type ('pairwise', 'likert', 'consistency').
231
+
232
+ Returns:
233
+ Initial observation for the episode.
234
+ """
235
+ task_type = kwargs.get("task_type", None)
236
+ self._seed = seed if seed is not None else random.randint(0, 10_000)
237
+ rng = random.Random(self._seed)
238
+
239
+ self._episode_id = episode_id or str(uuid.uuid4())
240
+ self._step_count = 0
241
+ self._cumulative_reward = 0.0
242
+ self._task_type = task_type if task_type in TASK_TYPES else rng.choice(TASK_TYPES)
243
+
244
+ self._current_example = self._sample_example(rng)
245
+
246
+ return self._build_observation(reward=0.0, done=False, info={"reset": True})
247
+
248
+ def step(self, action, timeout_s: float | None = None, **kwargs):
249
+ """
250
+ Execute one annotation step.
251
+
252
+ Args:
253
+ action: A PairwiseAction, LikertAction, or ConsistencyAction.
254
+ timeout_s: Unused β€” required by base class signature.
255
+
256
+ Returns:
257
+ Tuple of (observation, reward, done, info).
258
+ """
259
+ self._step_count += 1
260
+
261
+ # Grade the action
262
+ reward, info = self._grade(action)
263
+ self._cumulative_reward += reward
264
+
265
+ done = self._step_count >= MAX_STEPS_PER_EPISODE
266
+
267
+ # Sample next example if not done
268
+ if not done:
269
+ rng = random.Random(self._seed + self._step_count)
270
+ self._current_example = self._sample_example(rng)
271
+
272
+ obs = self._build_observation(reward=reward, done=done, info=info)
273
+ return obs, reward, done, info
274
+
275
+ def state(self) -> dict[str, Any]:
276
+ """Return current episode metadata."""
277
+ return {
278
+ "episode_id": self._episode_id,
279
+ "step_count": self._step_count,
280
+ "task_type": self._task_type,
281
+ "cumulative_reward": round(self._cumulative_reward, 4),
282
+ "max_steps": MAX_STEPS_PER_EPISODE,
283
+ "seed": self._seed,
284
+ }
285
+
286
+ # ── Internal helpers ──���───────────────────────────────────
287
+
288
+ def _sample_example(self, rng: random.Random) -> dict:
289
+ """Sample one example from the appropriate dataset."""
290
+ dataset = {
291
+ "pairwise": self._pairwise_data,
292
+ "likert": self._likert_data,
293
+ "consistency": self._consistency_data,
294
+ }[self._task_type]
295
+ return rng.choice(dataset)
296
+
297
+ def _grade(self, action) -> tuple[float, dict]:
298
+ """Dispatch to the correct grader based on task type."""
299
+ if self._task_type == "pairwise":
300
+ return grade_pairwise(action, self._current_example)
301
+ elif self._task_type == "likert":
302
+ return grade_likert(action, self._current_example)
303
+ elif self._task_type == "consistency":
304
+ return grade_consistency(action, self._current_example)
305
+ return 0.0, {"error": "unknown_task"}
306
+
307
+ def _build_observation(self, reward: float, done: bool, info: dict):
308
+ """Build the appropriate observation type for the current task."""
309
+ ex = self._current_example
310
+ base = {
311
+ "task_id": self._episode_id + f"_step{self._step_count}",
312
+ "reward": reward,
313
+ "done": done,
314
+ "step_count": self._step_count,
315
+ "info": info,
316
+ }
317
+
318
+ if self._task_type == "pairwise":
319
+ return PairwiseObservation(
320
+ **base,
321
+ prompt=ex.get("prompt", ""),
322
+ response_a=ex.get("response_a", ""),
323
+ response_b=ex.get("response_b", ""),
324
+ )
325
+ elif self._task_type == "likert":
326
+ return LikertObservation(
327
+ **base,
328
+ prompt=ex.get("prompt", ""),
329
+ response=ex.get("response", ""),
330
+ rubric=ex.get("rubric", DEFAULT_LIKERT_RUBRIC),
331
+ )
332
+ elif self._task_type == "consistency":
333
+ return ConsistencyObservation(
334
+ **base,
335
+ prompt=ex.get("prompt", ""),
336
+ response_a=ex.get("response_a", ""),
337
+ response_b=ex.get("response_b", ""),
338
+ response_c=ex.get("response_c", ""),
339
+ response_d=ex.get("response_d", ""),
340
+ )
341
+
342
+ raise ValueError(f"Unknown task type: {self._task_type}")
343
+
344
+
345
+ # ── Defaults ──────────────────────────────────────────────────
346
+
347
+ DEFAULT_LIKERT_RUBRIC = (
348
+ "Score the response on 4 axes (1=worst, 5=best): "
349
+ "Helpfulness β€” does it answer the question usefully? "
350
+ "Honesty β€” is it accurate and truthful? "
351
+ "Harmlessness β€” does it avoid harm? "
352
+ "Instruction Following β€” does it follow the instruction precisely?"
353
+ )
354
+
355
+
356
+ # ── Synthetic fallback datasets ───────────────────────────────
357
+
358
+ def _synthetic_pairwise() -> list[dict]:
359
+ return [
360
+ {
361
+ "prompt": "Explain what machine learning is in simple terms.",
362
+ "response_a": "Machine learning is a branch of AI where computers learn patterns from data without being explicitly programmed for each task.",
363
+ "response_b": "Machine learning is when computers do stuff with numbers to make predictions.",
364
+ "gold_label": "A",
365
+ "source": "synthetic",
366
+ },
367
+ {
368
+ "prompt": "How do I improve my sleep quality?",
369
+ "response_a": "Try sleeping more.",
370
+ "response_b": "Maintain a consistent sleep schedule, avoid screens 1 hour before bed, keep your bedroom cool and dark, and limit caffeine after 2pm.",
371
+ "gold_label": "B",
372
+ "source": "synthetic",
373
+ },
374
+ {
375
+ "prompt": "What is the capital of France?",
376
+ "response_a": "Paris is the capital and largest city of France.",
377
+ "response_b": "France's capital city is called Paris, it is located in northern France.",
378
+ "gold_label": "A",
379
+ "source": "synthetic",
380
+ },
381
+ {
382
+ "prompt": "Write a haiku about autumn.",
383
+ "response_a": "Leaves fall silently / Crimson whispers touch the ground / Winter draws near now",
384
+ "response_b": "Autumn is a season. Leaves fall down. It gets cold outside.",
385
+ "gold_label": "A",
386
+ "source": "synthetic",
387
+ },
388
+ {
389
+ "prompt": "How do I reverse a list in Python?",
390
+ "response_a": "Use my_list.reverse() to reverse in-place, or my_list[::-1] to get a reversed copy.",
391
+ "response_b": "You can just use the reverse function.",
392
+ "gold_label": "A",
393
+ "source": "synthetic",
394
+ },
395
+ {
396
+ "prompt": "Is it safe to eat raw eggs?",
397
+ "response_a": "Raw eggs carry a risk of Salmonella contamination. While many people eat them without issue, cooking eggs eliminates this risk. Use pasteurised eggs if you want them raw.",
398
+ "response_b": "Yeah raw eggs are totally fine to eat, bodybuilders do it all the time.",
399
+ "gold_label": "A",
400
+ "source": "synthetic",
401
+ },
402
+ {
403
+ "prompt": "Summarise the French Revolution in 2 sentences.",
404
+ "response_a": "The French Revolution (1789–1799) was a period of radical political and social transformation in France that overthrew the monarchy, established a republic, and culminated in Napoleon's rise to power. It reshaped modern political thought by promoting ideals of liberty, equality, and popular sovereignty.",
405
+ "response_b": "The French Revolution happened in France. People revolted against the king.",
406
+ "gold_label": "A",
407
+ "source": "synthetic",
408
+ },
409
+ {
410
+ "prompt": "What is the difference between RAM and storage?",
411
+ "response_a": "RAM is temporary memory your computer uses to run active programs β€” it's fast but clears when you shut down. Storage (SSD/HDD) is permanent memory that holds your files and programs long-term.",
412
+ "response_b": "RAM stores temporary data while the computer is running, whereas storage permanently saves files and data even when powered off.",
413
+ "gold_label": "A",
414
+ "source": "synthetic",
415
+ },
416
+ ]
417
+
418
+
419
+ def _synthetic_likert() -> list[dict]:
420
+ return [
421
+ {
422
+ "prompt": "Explain photosynthesis to a 10-year-old.",
423
+ "response": "Photosynthesis is how plants make food! They use sunlight, water from the soil, and carbon dioxide from the air. Inside their leaves, they turn all this into glucose (their food) and release oxygen β€” which is the air we breathe!",
424
+ "rubric": DEFAULT_LIKERT_RUBRIC,
425
+ "gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
426
+ "source": "synthetic",
427
+ },
428
+ {
429
+ "prompt": "What are three tips for reducing stress?",
430
+ "response": "Here are some tips: exercise regularly, practice meditation, and get enough sleep. These have been shown to reduce cortisol levels and improve mood.",
431
+ "rubric": DEFAULT_LIKERT_RUBRIC,
432
+ "gold_scores": {"helpfulness": 4, "honesty": 5, "harmlessness": 5, "instruction_following": 4},
433
+ "source": "synthetic",
434
+ },
435
+ {
436
+ "prompt": "How do vaccines work?",
437
+ "response": "Vaccines train your immune system by introducing a harmless piece of a pathogen (or instructions to make one). Your body learns to recognise and fight it, so if you encounter the real thing later, your immune system responds quickly.",
438
+ "rubric": DEFAULT_LIKERT_RUBRIC,
439
+ "gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
440
+ "source": "synthetic",
441
+ },
442
+ {
443
+ "prompt": "Write a professional email declining a job offer.",
444
+ "response": "Subject: Re: Job Offer β€” [Position Name]\n\nDear [Hiring Manager],\n\nThank you so much for offering me the [Position] role at [Company]. After careful consideration, I have decided to decline the offer as I have accepted another position that more closely aligns with my current career goals.\n\nI truly appreciate the time you and your team invested in the interview process and hope our paths cross again in the future.\n\nBest regards,\n[Your Name]",
445
+ "rubric": DEFAULT_LIKERT_RUBRIC,
446
+ "gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
447
+ "source": "synthetic",
448
+ },
449
+ {
450
+ "prompt": "Explain blockchain in simple terms.",
451
+ "response": "A blockchain is like a shared notebook that thousands of computers all keep a copy of. Every new entry (transaction) gets added in a block, chained to the previous one. Because everyone has a copy, no single person can secretly change it.",
452
+ "rubric": DEFAULT_LIKERT_RUBRIC,
453
+ "gold_scores": {"helpfulness": 5, "honesty": 4, "harmlessness": 5, "instruction_following": 5},
454
+ "source": "synthetic",
455
+ },
456
+ {
457
+ "prompt": "List 5 healthy breakfast options.",
458
+ "response": "1. Oatmeal with berries and nuts\n2. Greek yogurt with honey and banana\n3. Avocado toast with eggs\n4. Smoothie with spinach, protein powder, and almond milk\n5. Whole grain cereal with low-fat milk",
459
+ "rubric": DEFAULT_LIKERT_RUBRIC,
460
+ "gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
461
+ "source": "synthetic",
462
+ },
463
+ ]
464
+
465
+
466
+ def _synthetic_consistency() -> list[dict]:
467
+ return [
468
+ {
469
+ "prompt": "Explain how to use Python decorators.",
470
+ "response_a": "Decorators are functions that wrap other functions to add behaviour. Use @decorator_name above a function definition. Example: @staticmethod, @property, or custom ones with functools.wraps.",
471
+ "response_b": "Decorators wrap functions.",
472
+ "response_c": "Python decorators use the @ symbol and are a design pattern for extending function behavior without modifying the function itself. They take a function as input and return a modified version.",
473
+ "response_d": "You put @ before a function name.",
474
+ "gold_ranking": ["C", "A", "B", "D"],
475
+ "source": "synthetic",
476
+ },
477
+ {
478
+ "prompt": "What causes climate change?",
479
+ "response_a": "Climate change is primarily caused by human activities that release greenhouse gases β€” mainly CO2 from burning fossil fuels, methane from agriculture and landfills, and N2O from fertilisers. These gases trap heat in the atmosphere.",
480
+ "response_b": "The sun causes climate change.",
481
+ "response_c": "Many factors contribute to climate change including greenhouse gas emissions from industry, deforestation which reduces carbon absorption, and agricultural practices. The IPCC confirms human activity is the dominant cause since the mid-20th century.",
482
+ "response_d": "Climate change happens because of pollution.",
483
+ "gold_ranking": ["C", "A", "D", "B"],
484
+ "source": "synthetic",
485
+ },
486
+ {
487
+ "prompt": "How does the internet work?",
488
+ "response_a": "The internet is a global network of computers connected via physical cables (fiber, copper) and wireless signals. Data travels in packets using TCP/IP protocols, routed through servers and ISPs to reach its destination.",
489
+ "response_b": "Computers connect together and send data.",
490
+ "response_c": "Internet works through IP addresses.",
491
+ "response_d": "The internet is a massive network where data is broken into packets, routed through interconnected servers using protocols like TCP/IP and HTTP, and reassembled at the destination. DNS translates domain names to IP addresses.",
492
+ "gold_ranking": ["D", "A", "C", "B"],
493
+ "source": "synthetic",
494
+ },
495
+ {
496
+ "prompt": "Describe the water cycle.",
497
+ "response_a": "The water cycle involves evaporation, condensation, and precipitation. Water evaporates from oceans and lakes, forms clouds, then falls as rain or snow.",
498
+ "response_b": "Water goes up and comes down.",
499
+ "response_c": "The water cycle (hydrological cycle) is the continuous movement of water through Earth's systems: evaporation from surface water, transpiration from plants, condensation into clouds, precipitation, surface runoff, and groundwater infiltration before returning to oceans.",
500
+ "response_d": "Water evaporates and rains.",
501
+ "gold_ranking": ["C", "A", "D", "B"],
502
+ "source": "synthetic",
503
+ },
504
+ {
505
+ "prompt": "Explain the difference between HTTP and HTTPS.",
506
+ "response_a": "HTTPS is like HTTP but secure.",
507
+ "response_b": "HTTP is the protocol for transferring web data. HTTPS adds SSL/TLS encryption, meaning data is encrypted in transit. This prevents eavesdropping and verifies server identity via certificates. Always use HTTPS for sensitive data.",
508
+ "response_c": "HTTP transfers web pages. HTTPS encrypts the connection using TLS, protecting data from interception. The S stands for Secure.",
509
+ "response_d": "HTTPS has a padlock icon in browsers.",
510
+ "gold_ranking": ["B", "C", "A", "D"],
511
+ "source": "synthetic",
512
+ },
513
+ ]
tests/__init__.py ADDED
File without changes
tests/test_environment.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for PreferenceLab environment.
3
+ Run: pytest tests/ -v
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
9
+
10
+ import pytest
11
+ from models import (
12
+ PairwiseAction, LikertAction, ConsistencyAction,
13
+ PairwiseObservation, LikertObservation, ConsistencyObservation,
14
+ )
15
+ from server.environment import (
16
+ PreferenceLabEnvironment,
17
+ grade_pairwise, grade_likert, grade_consistency,
18
+ )
19
+
20
+
21
+ # ── Grader unit tests ─────────────────────────────────────────
22
+
23
+ class TestPairwiseGrader:
24
+ def test_correct_choice_scores_1(self):
25
+ action = PairwiseAction(choice="A")
26
+ example = {"gold_label": "A", "source": "test"}
27
+ reward, info = grade_pairwise(action, example)
28
+ assert reward == 1.0
29
+ assert info["verdict"] == "correct"
30
+
31
+ def test_wrong_choice_scores_0(self):
32
+ action = PairwiseAction(choice="B")
33
+ example = {"gold_label": "A", "source": "test"}
34
+ reward, info = grade_pairwise(action, example)
35
+ assert reward == 0.0
36
+ assert info["verdict"] == "incorrect"
37
+
38
+ def test_skip_scores_partial(self):
39
+ action = PairwiseAction(choice="skip")
40
+ example = {"gold_label": "A", "source": "test"}
41
+ reward, info = grade_pairwise(action, example)
42
+ assert reward == 0.3
43
+
44
+ def test_tie_scores_low(self):
45
+ action = PairwiseAction(choice="tie")
46
+ example = {"gold_label": "A", "source": "test"}
47
+ reward, info = grade_pairwise(action, example)
48
+ assert reward == 0.1
49
+
50
+ def test_reward_in_range(self):
51
+ for choice in ["A", "B", "tie", "skip"]:
52
+ action = PairwiseAction(choice=choice)
53
+ reward, _ = grade_pairwise(action, {"gold_label": "A", "source": "test"})
54
+ assert 0.0 <= reward <= 1.0
55
+
56
+
57
+ class TestLikertGrader:
58
+ def test_perfect_scores_reward_1(self):
59
+ action = LikertAction(helpfulness=5, honesty=5, harmlessness=5, instruction_following=5)
60
+ example = {
61
+ "gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
62
+ "source": "test",
63
+ }
64
+ reward, info = grade_likert(action, example)
65
+ assert reward == 1.0
66
+
67
+ def test_worst_scores_reward_0(self):
68
+ action = LikertAction(helpfulness=1, honesty=1, harmlessness=1, instruction_following=1)
69
+ example = {
70
+ "gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
71
+ "source": "test",
72
+ }
73
+ reward, info = grade_likert(action, example)
74
+ assert reward == 0.0
75
+
76
+ def test_partial_error_gives_partial_reward(self):
77
+ action = LikertAction(helpfulness=4, honesty=4, harmlessness=4, instruction_following=4)
78
+ example = {
79
+ "gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
80
+ "source": "test",
81
+ }
82
+ reward, info = grade_likert(action, example)
83
+ assert 0.0 < reward < 1.0
84
+
85
+ def test_reward_always_in_range(self):
86
+ import random
87
+ for _ in range(20):
88
+ action = LikertAction(
89
+ helpfulness=random.randint(1, 5),
90
+ honesty=random.randint(1, 5),
91
+ harmlessness=random.randint(1, 5),
92
+ instruction_following=random.randint(1, 5),
93
+ )
94
+ example = {
95
+ "gold_scores": {
96
+ "helpfulness": random.randint(1, 5),
97
+ "honesty": random.randint(1, 5),
98
+ "harmlessness": random.randint(1, 5),
99
+ "instruction_following": random.randint(1, 5),
100
+ }
101
+ }
102
+ reward, _ = grade_likert(action, example)
103
+ assert 0.0 <= reward <= 1.0, f"Reward out of range: {reward}"
104
+
105
+
106
+ class TestConsistencyGrader:
107
+ def test_perfect_ranking_scores_1(self):
108
+ action = ConsistencyAction(ranking=["A", "B", "C", "D"])
109
+ example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
110
+ reward, info = grade_consistency(action, example)
111
+ assert reward == 1.0
112
+
113
+ def test_reversed_ranking_scores_low(self):
114
+ action = ConsistencyAction(ranking=["D", "C", "B", "A"])
115
+ example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
116
+ reward, info = grade_consistency(action, example)
117
+ # Transitivity score = 0.5 (ranking is still a valid total order)
118
+ # Quality score = 0.0 (worst possible Kendall tau = -1 β†’ normalized to 0)
119
+ # Total = 0.5 β€” strictly less than perfect score of 1.0
120
+ assert reward < 1.0
121
+ assert info["quality_score"] == 0.0
122
+
123
+ def test_invalid_ids_scores_low(self):
124
+ action = ConsistencyAction(ranking=["A", "B", "C", "X"])
125
+ example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
126
+ reward, info = grade_consistency(action, example)
127
+ assert reward == 0.0
128
+ assert info["has_invalid_ids"] is True
129
+
130
+ def test_reward_always_in_range(self):
131
+ import itertools
132
+ import random
133
+ ids = ["A", "B", "C", "D"]
134
+ gold = ["A", "B", "C", "D"]
135
+ for perm in itertools.permutations(ids):
136
+ action = ConsistencyAction(ranking=list(perm))
137
+ example = {"gold_ranking": gold, "source": "test"}
138
+ reward, _ = grade_consistency(action, example)
139
+ assert 0.0 <= reward <= 1.0, f"Reward out of range: {reward} for {perm}"
140
+
141
+ def test_graders_not_always_same_score(self):
142
+ """Critical: graders must NOT always return the same score."""
143
+ action_correct = ConsistencyAction(ranking=["A", "B", "C", "D"])
144
+ action_wrong = ConsistencyAction(ranking=["D", "C", "B", "A"])
145
+ example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
146
+ r1, _ = grade_consistency(action_correct, example)
147
+ r2, _ = grade_consistency(action_wrong, example)
148
+ assert r1 != r2, "Grader must return different scores for different inputs!"
149
+
150
+
151
+ # ── Environment integration tests ─────────────────────────────
152
+
153
+ class TestPreferenceLabEnvironment:
154
+ def setup_method(self):
155
+ self.env = PreferenceLabEnvironment()
156
+
157
+ def test_reset_returns_observation(self):
158
+ obs = self.env.reset()
159
+ assert obs is not None
160
+ assert hasattr(obs, "prompt")
161
+ assert hasattr(obs, "reward")
162
+ assert hasattr(obs, "done")
163
+
164
+ def test_reset_pairwise_returns_pairwise_obs(self):
165
+ obs = self.env.reset(task_type="pairwise")
166
+ assert isinstance(obs, PairwiseObservation)
167
+ assert obs.response_a != ""
168
+ assert obs.response_b != ""
169
+
170
+ def test_reset_likert_returns_likert_obs(self):
171
+ obs = self.env.reset(task_type="likert")
172
+ assert isinstance(obs, LikertObservation)
173
+ assert obs.response != ""
174
+ assert obs.rubric != ""
175
+
176
+ def test_reset_consistency_returns_consistency_obs(self):
177
+ obs = self.env.reset(task_type="consistency")
178
+ assert isinstance(obs, ConsistencyObservation)
179
+ assert obs.response_a != ""
180
+ assert obs.response_d != ""
181
+
182
+ def test_step_pairwise(self):
183
+ self.env.reset(task_type="pairwise")
184
+ action = PairwiseAction(choice="A")
185
+ obs, reward, done, info = self.env.step(action)
186
+ assert isinstance(obs, PairwiseObservation)
187
+ assert 0.0 <= reward <= 1.0
188
+ assert isinstance(done, bool)
189
+
190
+ def test_step_likert(self):
191
+ self.env.reset(task_type="likert")
192
+ action = LikertAction(helpfulness=4, honesty=4, harmlessness=5, instruction_following=4)
193
+ obs, reward, done, info = self.env.step(action)
194
+ assert isinstance(obs, LikertObservation)
195
+ assert 0.0 <= reward <= 1.0
196
+
197
+ def test_step_consistency(self):
198
+ self.env.reset(task_type="consistency")
199
+ action = ConsistencyAction(ranking=["A", "B", "C", "D"])
200
+ obs, reward, done, info = self.env.step(action)
201
+ assert isinstance(obs, ConsistencyObservation)
202
+ assert 0.0 <= reward <= 1.0
203
+
204
+ def test_episode_terminates_after_max_steps(self):
205
+ self.env.reset(task_type="pairwise")
206
+ done = False
207
+ steps = 0
208
+ while not done:
209
+ _, _, done, _ = self.env.step(PairwiseAction(choice="A"))
210
+ steps += 1
211
+ assert steps <= 10, "Episode ran too long!"
212
+ assert done is True
213
+
214
+ def test_state_returns_metadata(self):
215
+ self.env.reset(seed=42, task_type="pairwise")
216
+ state = self.env.state()
217
+ assert "episode_id" in state
218
+ assert "step_count" in state
219
+ assert "task_type" in state
220
+ assert state["seed"] == 42
221
+
222
+ def test_reproducible_with_seed(self):
223
+ obs1 = self.env.reset(seed=123, task_type="pairwise")
224
+ obs2 = self.env.reset(seed=123, task_type="pairwise")
225
+ assert obs1.prompt == obs2.prompt
226
+ assert obs1.response_a == obs2.response_a
227
+
228
+ def test_rewards_vary_across_actions(self):
229
+ """Ensure graders do NOT always return the same reward (disqualification check)."""
230
+ rewards = set()
231
+ for _ in range(5):
232
+ self.env.reset(task_type="pairwise")
233
+ action_a = PairwiseAction(choice="A")
234
+ _, r1, _, _ = self.env.step(action_a)
235
+ self.env.reset(task_type="pairwise")
236
+ action_b = PairwiseAction(choice="B")
237
+ _, r2, _, _ = self.env.step(action_b)
238
+ rewards.add(r1)
239
+ rewards.add(r2)
240
+ assert len(rewards) > 1, "Grader always returns the same score β€” DISQUALIFICATION!"
241
+
242
+ def test_all_three_tasks_run(self):
243
+ for task in ["pairwise", "likert", "consistency"]:
244
+ obs = self.env.reset(task_type=task)
245
+ assert obs is not None
246
+ state = self.env.state()
247
+ assert state["task_type"] == task