Spaces:
Running
Running
Sibam commited on
Commit Β·
cdf485e
0
Parent(s):
PreferenceLab OpenEnv environment for RLHF preference simulation
Browse files- .dockerignore +14 -0
- .gitignore +60 -0
- Dockerfile +61 -0
- README.md +254 -0
- __init__.py +27 -0
- client.py +40 -0
- data/.gitkeep +0 -0
- data/README.md +59 -0
- inference.py +269 -0
- models.py +111 -0
- openenv.yaml +6 -0
- pyproject.toml +34 -0
- requirements.txt +8 -0
- scripts/__init__.py +0 -0
- scripts/prepare_datasets.py +201 -0
- server/__init__.py +1 -0
- server/app.py +38 -0
- server/environment.py +513 -0
- tests/__init__.py +0 -0
- tests/test_environment.py +247 -0
.dockerignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.env
|
| 5 |
+
.git
|
| 6 |
+
.gitignore
|
| 7 |
+
.pytest_cache
|
| 8 |
+
tests/
|
| 9 |
+
scripts/
|
| 10 |
+
outputs/
|
| 11 |
+
*.egg-info
|
| 12 |
+
dist/
|
| 13 |
+
build/
|
| 14 |
+
README.md
|
.gitignore
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python Bytecode
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.py[cod]
|
| 6 |
+
*$py.class
|
| 7 |
+
|
| 8 |
+
# Virtual Environments
|
| 9 |
+
.venv/
|
| 10 |
+
venv/
|
| 11 |
+
ENV/
|
| 12 |
+
env/
|
| 13 |
+
|
| 14 |
+
# Package Management
|
| 15 |
+
*.egg-info/
|
| 16 |
+
dist/
|
| 17 |
+
build/
|
| 18 |
+
*.egg
|
| 19 |
+
pip-log.txt
|
| 20 |
+
.pip/
|
| 21 |
+
uv.lock
|
| 22 |
+
uv.cache/
|
| 23 |
+
|
| 24 |
+
# IDE & Editor
|
| 25 |
+
.idea/
|
| 26 |
+
.vscode/settings.json
|
| 27 |
+
*.swp
|
| 28 |
+
*.swo
|
| 29 |
+
*~
|
| 30 |
+
.DS_Store
|
| 31 |
+
Thumbs.db
|
| 32 |
+
|
| 33 |
+
# Testing & Coverage
|
| 34 |
+
.pytest_cache/
|
| 35 |
+
.mypy_cache/
|
| 36 |
+
.coverage
|
| 37 |
+
htmlcov/
|
| 38 |
+
.ruff_cache/
|
| 39 |
+
|
| 40 |
+
# Logging
|
| 41 |
+
*.log
|
| 42 |
+
|
| 43 |
+
# Environment
|
| 44 |
+
.env
|
| 45 |
+
.env.local
|
| 46 |
+
.env.*.local
|
| 47 |
+
|
| 48 |
+
# Project-specific Generated Data
|
| 49 |
+
outputs/
|
| 50 |
+
data/pairwise_data.json
|
| 51 |
+
data/likert_data.json
|
| 52 |
+
data/consistency_data.json
|
| 53 |
+
|
| 54 |
+
# Machine Learning Models (if applicable)
|
| 55 |
+
*.bin
|
| 56 |
+
*.safetensors
|
| 57 |
+
*.pt
|
| 58 |
+
*.pth
|
| 59 |
+
*.pkl
|
| 60 |
+
*.joblib
|
Dockerfile
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Team Nexis
|
| 2 |
+
# PreferenceLab OpenEnv Environment
|
| 3 |
+
# Based on the official openenv-base image pattern
|
| 4 |
+
|
| 5 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 6 |
+
FROM ghcr.io/meta-pytorch/openenv-base:latest AS builder
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
ARG BUILD_MODE=standalone
|
| 11 |
+
|
| 12 |
+
# Copy environment code
|
| 13 |
+
COPY . /app/env
|
| 14 |
+
|
| 15 |
+
WORKDIR /app/env
|
| 16 |
+
|
| 17 |
+
# Ensure uv is available
|
| 18 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 19 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 20 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 21 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 22 |
+
fi
|
| 23 |
+
|
| 24 |
+
# Install git (build-time only)
|
| 25 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 26 |
+
git \
|
| 27 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 28 |
+
|
| 29 |
+
# Install uv fresh
|
| 30 |
+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 31 |
+
install -m 0755 /root/.local/bin/uv /usr/local/bin/uv && \
|
| 32 |
+
install -m 0755 /root/.local/bin/uvx /usr/local/bin/uvx
|
| 33 |
+
|
| 34 |
+
# Install dependencies (two-pass for caching)
|
| 35 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 36 |
+
uv sync --no-install-project --no-editable
|
| 37 |
+
|
| 38 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 39 |
+
uv sync --no-editable
|
| 40 |
+
|
| 41 |
+
# ββ Runtime stage ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
FROM ghcr.io/meta-pytorch/openenv-base:latest
|
| 43 |
+
|
| 44 |
+
WORKDIR /app
|
| 45 |
+
|
| 46 |
+
# Copy venv and code from builder
|
| 47 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 48 |
+
COPY --from=builder /app/env /app/env
|
| 49 |
+
|
| 50 |
+
# Environment
|
| 51 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 52 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 53 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 54 |
+
ENV MAX_CONCURRENT_ENVS=64
|
| 55 |
+
|
| 56 |
+
# Health check
|
| 57 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
| 58 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 59 |
+
|
| 60 |
+
# Start server
|
| 61 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PreferenceLab
|
| 3 |
+
emoji: π§ͺ
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# π§ͺ PreferenceLab
|
| 11 |
+
|
| 12 |
+
**An OpenEnv environment simulating the RLHF human preference data collection pipeline.**
|
| 13 |
+
|
| 14 |
+
> Built for the Meta Γ Hugging Face OpenEnv Hackathon β Team Nexis
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## Overview
|
| 19 |
+
|
| 20 |
+
PreferenceLab is a real-world OpenEnv environment where an AI agent learns to judge
|
| 21 |
+
the quality of LLM responses β exactly as a human annotator would in an RLHF pipeline.
|
| 22 |
+
|
| 23 |
+
Instead of expensive, slow human annotators, the environment provides:
|
| 24 |
+
- **Deterministic grading** using gold labels from real preference datasets
|
| 25 |
+
- **Partial reward signals** at every step (not just binary end-of-episode)
|
| 26 |
+
- **3 task difficulty levels**: pairwise β multi-axis scoring β transitive ranking
|
| 27 |
+
|
| 28 |
+
This fills a genuine gap: zero existing OpenEnv environments simulate the RLHF
|
| 29 |
+
data collection pipeline that powers models like Llama, Claude, and GPT-4.
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Action & Observation Spaces
|
| 34 |
+
|
| 35 |
+
### Task 1 β Pairwise Ranking (Easy)
|
| 36 |
+
|
| 37 |
+
**Observation:**
|
| 38 |
+
```python
|
| 39 |
+
PairwiseObservation(
|
| 40 |
+
prompt: str, # The user instruction
|
| 41 |
+
response_a: str, # Candidate response A
|
| 42 |
+
response_b: str, # Candidate response B
|
| 43 |
+
reward: float, # Last step reward
|
| 44 |
+
done: bool,
|
| 45 |
+
step_count: int,
|
| 46 |
+
)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
**Action:**
|
| 50 |
+
```python
|
| 51 |
+
PairwiseAction(
|
| 52 |
+
choice: Literal["A", "B", "tie", "skip"],
|
| 53 |
+
justification: Optional[str], # not graded
|
| 54 |
+
)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
**Grader:** +1.0 correct | +0.3 skip | +0.1 tie | +0.0 wrong
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
### Task 2 β Multi-Axis Likert Scoring (Medium)
|
| 62 |
+
|
| 63 |
+
**Observation:**
|
| 64 |
+
```python
|
| 65 |
+
LikertObservation(
|
| 66 |
+
prompt: str,
|
| 67 |
+
response: str,
|
| 68 |
+
rubric: str, # Scoring instructions
|
| 69 |
+
reward: float,
|
| 70 |
+
done: bool,
|
| 71 |
+
step_count: int,
|
| 72 |
+
)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
**Action:**
|
| 76 |
+
```python
|
| 77 |
+
LikertAction(
|
| 78 |
+
helpfulness: int, # 1-5
|
| 79 |
+
honesty: int, # 1-5
|
| 80 |
+
harmlessness: int, # 1-5
|
| 81 |
+
instruction_following: int # 1-5
|
| 82 |
+
)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
**Grader:** `reward = 1.0 - (MAE / 4.0)` β continuous signal based on deviation from gold scores
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
### Task 3 β Transitive Consistency Ranking (Hard)
|
| 90 |
+
|
| 91 |
+
**Observation:**
|
| 92 |
+
```python
|
| 93 |
+
ConsistencyObservation(
|
| 94 |
+
prompt: str,
|
| 95 |
+
response_a: str,
|
| 96 |
+
response_b: str,
|
| 97 |
+
response_c: str,
|
| 98 |
+
response_d: str,
|
| 99 |
+
reward: float,
|
| 100 |
+
done: bool,
|
| 101 |
+
step_count: int,
|
| 102 |
+
)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
**Action:**
|
| 106 |
+
```python
|
| 107 |
+
ConsistencyAction(
|
| 108 |
+
ranking: list[str] # e.g. ["C", "A", "D", "B"] bestβworst
|
| 109 |
+
)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
**Grader:** Transitivity score (0.0β0.5) + Kendall's tau quality correlation (0.0β0.5)
|
| 113 |
+
|
| 114 |
+
---
|
| 115 |
+
|
| 116 |
+
## Reward Function
|
| 117 |
+
|
| 118 |
+
| Component | Range | Description |
|
| 119 |
+
|---|---|---|
|
| 120 |
+
| Correctness | 0.0β1.0 | Agreement with gold label |
|
| 121 |
+
| Partial credit | 0.1β0.3 | For abstain/tie actions |
|
| 122 |
+
| Trajectory | cumulative | Sum over episode steps |
|
| 123 |
+
|
| 124 |
+
Rewards are **non-sparse** β every step provides a signal. Graders are
|
| 125 |
+
**deterministic** and **reproducible** β same seed = same episode.
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## Setup
|
| 130 |
+
|
| 131 |
+
### Local Development
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
git clone https://github.com/your-username/preference-lab
|
| 135 |
+
cd preference-lab
|
| 136 |
+
pip install -e .
|
| 137 |
+
|
| 138 |
+
# Optional: download real datasets
|
| 139 |
+
python scripts/prepare_datasets.py --samples 200
|
| 140 |
+
|
| 141 |
+
# Run tests
|
| 142 |
+
pytest tests/ -v
|
| 143 |
+
|
| 144 |
+
# Run inference baseline
|
| 145 |
+
python inference.py
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### Environment Variables
|
| 149 |
+
|
| 150 |
+
| Variable | Required | Default | Description |
|
| 151 |
+
|---|---|---|---|
|
| 152 |
+
| `API_BASE_URL` | Yes | `https://api-inference.huggingface.co/v1` | LLM API endpoint |
|
| 153 |
+
| `MODEL_NAME` | Yes | `meta-llama/Llama-3.1-8B-Instruct` | Model to use for inference |
|
| 154 |
+
| `HF_TOKEN` | Yes | β | Hugging Face API key |
|
| 155 |
+
| `ENV_BASE_URL` | No | `http://localhost:8000` | Space URL for remote use |
|
| 156 |
+
| `MAX_CONCURRENT_ENVS` | No | `64` | Parallel sessions supported |
|
| 157 |
+
|
| 158 |
+
### Docker
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
docker build -t preference-lab .
|
| 162 |
+
docker run -p 8000:8000 \
|
| 163 |
+
-e HF_TOKEN=$HF_TOKEN \
|
| 164 |
+
-e API_BASE_URL=https://api-inference.huggingface.co/v1 \
|
| 165 |
+
-e MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct \
|
| 166 |
+
preference-lab
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
---
|
| 170 |
+
|
| 171 |
+
## Using with TRL (GRPO Training)
|
| 172 |
+
|
| 173 |
+
```python
|
| 174 |
+
from preference_lab import PreferenceLabEnv, PairwiseAction
|
| 175 |
+
|
| 176 |
+
class PreferenceEnvWrapper:
|
| 177 |
+
def __init__(self):
|
| 178 |
+
self.client = PreferenceLabEnv(base_url="https://your-space.hf.space")
|
| 179 |
+
self.reward = 0.0
|
| 180 |
+
|
| 181 |
+
def reset(self, **kwargs):
|
| 182 |
+
obs = self.client.reset()
|
| 183 |
+
return obs.prompt
|
| 184 |
+
|
| 185 |
+
def rank_responses(self, choice: str) -> str:
|
| 186 |
+
"""
|
| 187 |
+
Choose which response is better.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
choice: 'A' if response A is better, 'B' if B is better, 'tie' if equal.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Feedback on the annotation quality.
|
| 194 |
+
"""
|
| 195 |
+
result = self.client.step(PairwiseAction(choice=choice))
|
| 196 |
+
self.reward = result.reward
|
| 197 |
+
return f"Reward: {result.reward:.2f}"
|
| 198 |
+
|
| 199 |
+
def reward_func(environments, **kwargs):
|
| 200 |
+
return [env.reward for env in environments]
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
---
|
| 204 |
+
|
| 205 |
+
## Datasets
|
| 206 |
+
|
| 207 |
+
| Task | Dataset | License |
|
| 208 |
+
|---|---|---|
|
| 209 |
+
| Pairwise | Anthropic HH-RLHF | MIT |
|
| 210 |
+
| Likert | UltraFeedback | MIT |
|
| 211 |
+
| Consistency | Stanford SHP | CC BY 4.0 |
|
| 212 |
+
|
| 213 |
+
Fallback synthetic data is included for zero-dependency testing.
|
| 214 |
+
|
| 215 |
+
---
|
| 216 |
+
|
| 217 |
+
## Project Structure
|
| 218 |
+
|
| 219 |
+
```
|
| 220 |
+
preference-lab/
|
| 221 |
+
βββ __init__.py # Package exports
|
| 222 |
+
βββ models.py # Pydantic Action/Observation models
|
| 223 |
+
βββ client.py # PreferenceLabEnv(MCPToolClient)
|
| 224 |
+
βββ inference.py # Baseline inference script (START/STEP/END logs)
|
| 225 |
+
βββ openenv.yaml # OpenEnv manifest
|
| 226 |
+
βββ pyproject.toml # Dependencies
|
| 227 |
+
βββ Dockerfile # Container definition
|
| 228 |
+
βββ requirements.txt # Pip requirements
|
| 229 |
+
βββ README.md # This file
|
| 230 |
+
βββ data/ # Dataset JSONs (or synthetic fallback)
|
| 231 |
+
βββ scripts/
|
| 232 |
+
β βββ prepare_datasets.py # Dataset download + conversion
|
| 233 |
+
βββ server/
|
| 234 |
+
β βββ app.py # FastAPI server
|
| 235 |
+
β βββ environment.py # Core environment + graders
|
| 236 |
+
βββ tests/
|
| 237 |
+
βββ test_environment.py # Full test suite
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
---
|
| 241 |
+
|
| 242 |
+
## Team
|
| 243 |
+
|
| 244 |
+
**Team Nexis** β Sri Sri University Γ Zaalima Development Pvt. Ltd.
|
| 245 |
+
|
| 246 |
+
- Sibam Nanda (Architecture + Core Environment)
|
| 247 |
+
- Spandan Kar (Backend + Deployment)
|
| 248 |
+
- Shayanna Behera (Config + Documentation)
|
| 249 |
+
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
## License
|
| 253 |
+
|
| 254 |
+
BSD 3-Clause License
|
__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PreferenceLab: An OpenEnv environment simulating the RLHF
|
| 3 |
+
preference data collection pipeline.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
from preference_lab import PreferenceLabEnv, PairwiseAction, LikertAction, ConsistencyAction
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from client import PreferenceLabEnv
|
| 10 |
+
from models import (
|
| 11 |
+
ConsistencyAction,
|
| 12 |
+
ConsistencyObservation,
|
| 13 |
+
LikertAction,
|
| 14 |
+
LikertObservation,
|
| 15 |
+
PairwiseAction,
|
| 16 |
+
PairwiseObservation,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"PreferenceLabEnv",
|
| 21 |
+
"PairwiseAction",
|
| 22 |
+
"PairwiseObservation",
|
| 23 |
+
"LikertAction",
|
| 24 |
+
"LikertObservation",
|
| 25 |
+
"ConsistencyAction",
|
| 26 |
+
"ConsistencyObservation",
|
| 27 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PreferenceLab Environment Client.
|
| 3 |
+
|
| 4 |
+
Connects to a running PreferenceLab server via WebSocket/HTTP.
|
| 5 |
+
|
| 6 |
+
Example (async):
|
| 7 |
+
>>> from preference_lab import PreferenceLabEnv, PairwiseAction
|
| 8 |
+
>>> async with PreferenceLabEnv(base_url="https://your-space.hf.space") as env:
|
| 9 |
+
... obs = await env.reset()
|
| 10 |
+
... result = await env.step(PairwiseAction(choice="A"))
|
| 11 |
+
|
| 12 |
+
Example (sync):
|
| 13 |
+
>>> with PreferenceLabEnv(base_url="https://your-space.hf.space").sync() as env:
|
| 14 |
+
... obs = env.reset()
|
| 15 |
+
... result = env.step(PairwiseAction(choice="A"))
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from openenv.core.mcp_client import MCPToolClient
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PreferenceLabEnv(MCPToolClient):
|
| 22 |
+
"""
|
| 23 |
+
Client for the PreferenceLab Environment.
|
| 24 |
+
|
| 25 |
+
Provides tool-calling style interactions with the RLHF preference
|
| 26 |
+
simulation environment. Inherits all functionality from MCPToolClient.
|
| 27 |
+
|
| 28 |
+
Available tools (discovered via list_tools()):
|
| 29 |
+
- rank_responses: Task 1 β choose A or B
|
| 30 |
+
- score_response: Task 2 β rate on 4 axes
|
| 31 |
+
- order_responses: Task 3 β rank 4 responses
|
| 32 |
+
|
| 33 |
+
Example:
|
| 34 |
+
>>> with PreferenceLabEnv(base_url="http://localhost:8000").sync() as env:
|
| 35 |
+
... env.reset()
|
| 36 |
+
... tools = env.list_tools()
|
| 37 |
+
... result = env.call_tool("rank_responses", choice="A")
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
pass # MCPToolClient provides all needed functionality
|
data/.gitkeep
ADDED
|
File without changes
|
data/README.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Directory
|
| 2 |
+
|
| 3 |
+
This directory holds the preference datasets used by PreferenceLab.
|
| 4 |
+
|
| 5 |
+
On first run, if these files are absent, the environment falls back to
|
| 6 |
+
built-in synthetic examples (defined in `server/environment.py`).
|
| 7 |
+
|
| 8 |
+
## File Format
|
| 9 |
+
|
| 10 |
+
### pairwise_data.json
|
| 11 |
+
```json
|
| 12 |
+
[
|
| 13 |
+
{
|
| 14 |
+
"prompt": "...",
|
| 15 |
+
"response_a": "...",
|
| 16 |
+
"response_b": "...",
|
| 17 |
+
"gold_label": "A",
|
| 18 |
+
"source": "hh-rlhf"
|
| 19 |
+
}
|
| 20 |
+
]
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### likert_data.json
|
| 24 |
+
```json
|
| 25 |
+
[
|
| 26 |
+
{
|
| 27 |
+
"prompt": "...",
|
| 28 |
+
"response": "...",
|
| 29 |
+
"rubric": "...",
|
| 30 |
+
"gold_scores": {
|
| 31 |
+
"helpfulness": 4,
|
| 32 |
+
"honesty": 5,
|
| 33 |
+
"harmlessness": 5,
|
| 34 |
+
"instruction_following": 4
|
| 35 |
+
},
|
| 36 |
+
"source": "ultrafeedback"
|
| 37 |
+
}
|
| 38 |
+
]
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### consistency_data.json
|
| 42 |
+
```json
|
| 43 |
+
[
|
| 44 |
+
{
|
| 45 |
+
"prompt": "...",
|
| 46 |
+
"response_a": "...",
|
| 47 |
+
"response_b": "...",
|
| 48 |
+
"response_c": "...",
|
| 49 |
+
"response_d": "...",
|
| 50 |
+
"gold_ranking": ["C", "A", "B", "D"],
|
| 51 |
+
"source": "stanford-shp"
|
| 52 |
+
}
|
| 53 |
+
]
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Loading Real Datasets
|
| 57 |
+
|
| 58 |
+
Run `python scripts/prepare_datasets.py` to download and convert
|
| 59 |
+
HH-RLHF, UltraFeedback, and Stanford SHP into these formats.
|
inference.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PreferenceLab Baseline Inference Script.
|
| 3 |
+
|
| 4 |
+
Runs a baseline LLM agent against all 3 tasks and reports scores.
|
| 5 |
+
|
| 6 |
+
Environment variables:
|
| 7 |
+
API_BASE_URL β LLM API endpoint (required, with default)
|
| 8 |
+
MODEL_NAME β Model identifier (required, with default)
|
| 9 |
+
HF_TOKEN β Hugging Face / API key (no default β injected by HF Spaces)
|
| 10 |
+
ENV_BASE_URL β PreferenceLab Space URL (optional, defaults to localhost)
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python inference.py
|
| 14 |
+
HF_TOKEN=hf_xxx MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct python inference.py
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import time
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
from openai import OpenAI
|
| 23 |
+
|
| 24 |
+
# ββ Environment variables (MANDATORY pattern for hackathon) βββ
|
| 25 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
|
| 26 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
|
| 27 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # NO default β injected by HF Spaces
|
| 28 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
|
| 29 |
+
|
| 30 |
+
# ββ OpenAI client (MANDATORY: all LLM calls via OpenAI client) β
|
| 31 |
+
client = OpenAI(
|
| 32 |
+
api_key=HF_TOKEN,
|
| 33 |
+
base_url=API_BASE_URL,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# ββ Logging helpers βββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
|
| 38 |
+
def log_start(task_name: str, task_id: str):
|
| 39 |
+
"""Stdout START log β required structured format."""
|
| 40 |
+
print(f"START task_name={task_name} task_id={task_id} timestamp={time.time():.0f}")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def log_step(task_id: str, step: int, action: dict, reward: float):
|
| 44 |
+
"""Stdout STEP log β required structured format."""
|
| 45 |
+
print(f"STEP task_id={task_id} step={step} action={json.dumps(action)} reward={reward:.4f}")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def log_end(task_id: str, total_reward: float, steps: int):
|
| 49 |
+
"""Stdout END log β required structured format."""
|
| 50 |
+
print(f"END task_id={task_id} total_reward={total_reward:.4f} steps={steps}")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ββ LLM helpers βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
|
| 55 |
+
def call_llm(system: str, user: str, max_tokens: int = 256) -> str:
|
| 56 |
+
"""Call LLM via OpenAI client. Returns text response."""
|
| 57 |
+
try:
|
| 58 |
+
response = client.chat.completions.create(
|
| 59 |
+
model=MODEL_NAME,
|
| 60 |
+
messages=[
|
| 61 |
+
{"role": "system", "content": system},
|
| 62 |
+
{"role": "user", "content": user},
|
| 63 |
+
],
|
| 64 |
+
max_tokens=max_tokens,
|
| 65 |
+
temperature=0.0, # deterministic for reproducibility
|
| 66 |
+
)
|
| 67 |
+
return response.choices[0].message.content.strip()
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f" [LLM ERROR] {e}")
|
| 70 |
+
return ""
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def parse_json_response(text: str, fallback: dict) -> dict:
|
| 74 |
+
"""Try to parse JSON from LLM response."""
|
| 75 |
+
try:
|
| 76 |
+
# Find JSON block
|
| 77 |
+
start = text.find("{")
|
| 78 |
+
end = text.rfind("}") + 1
|
| 79 |
+
if start >= 0 and end > start:
|
| 80 |
+
return json.loads(text[start:end])
|
| 81 |
+
except Exception:
|
| 82 |
+
pass
|
| 83 |
+
return fallback
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ββ Task runners ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 87 |
+
|
| 88 |
+
def run_task1_pairwise(env_client) -> dict[str, Any]:
|
| 89 |
+
"""Run Task 1: Pairwise ranking."""
|
| 90 |
+
task_id = "task1_pairwise"
|
| 91 |
+
log_start("pairwise_ranking", task_id)
|
| 92 |
+
|
| 93 |
+
obs = env_client.reset(task_type="pairwise")
|
| 94 |
+
total_reward = 0.0
|
| 95 |
+
steps = 0
|
| 96 |
+
|
| 97 |
+
SYSTEM = (
|
| 98 |
+
"You are an expert AI response evaluator. "
|
| 99 |
+
"Given a prompt and two responses (A and B), choose which is better. "
|
| 100 |
+
"Reply ONLY with valid JSON: {\"choice\": \"A\"} or {\"choice\": \"B\"} or {\"choice\": \"tie\"}."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
for step in range(5):
|
| 104 |
+
user_prompt = (
|
| 105 |
+
f"Prompt: {obs.prompt}\n\n"
|
| 106 |
+
f"Response A:\n{obs.response_a}\n\n"
|
| 107 |
+
f"Response B:\n{obs.response_b}\n\n"
|
| 108 |
+
"Which response is better? Reply with JSON only."
|
| 109 |
+
)
|
| 110 |
+
llm_out = call_llm(SYSTEM, user_prompt)
|
| 111 |
+
parsed = parse_json_response(llm_out, {"choice": "A"})
|
| 112 |
+
choice = parsed.get("choice", "A")
|
| 113 |
+
if choice not in ("A", "B", "tie", "skip"):
|
| 114 |
+
choice = "A"
|
| 115 |
+
|
| 116 |
+
from models import PairwiseAction
|
| 117 |
+
action = PairwiseAction(choice=choice)
|
| 118 |
+
obs, reward, done, info = env_client.step(action)
|
| 119 |
+
|
| 120 |
+
total_reward += reward
|
| 121 |
+
steps += 1
|
| 122 |
+
log_step(task_id, step + 1, {"choice": choice}, reward)
|
| 123 |
+
|
| 124 |
+
if done:
|
| 125 |
+
break
|
| 126 |
+
|
| 127 |
+
log_end(task_id, total_reward, steps)
|
| 128 |
+
avg = round(total_reward / steps, 4) if steps > 0 else 0.0
|
| 129 |
+
return {"task": "pairwise", "total_reward": total_reward, "steps": steps, "avg_reward": avg}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def run_task2_likert(env_client) -> dict[str, Any]:
|
| 133 |
+
"""Run Task 2: Multi-axis Likert scoring."""
|
| 134 |
+
task_id = "task2_likert"
|
| 135 |
+
log_start("likert_scoring", task_id)
|
| 136 |
+
|
| 137 |
+
obs = env_client.reset(task_type="likert")
|
| 138 |
+
total_reward = 0.0
|
| 139 |
+
steps = 0
|
| 140 |
+
|
| 141 |
+
SYSTEM = (
|
| 142 |
+
"You are an expert AI response evaluator. Score responses on 4 axes (1=worst, 5=best):\n"
|
| 143 |
+
"- helpfulness: how useful is the response?\n"
|
| 144 |
+
"- honesty: how accurate and truthful?\n"
|
| 145 |
+
"- harmlessness: how safe and appropriate?\n"
|
| 146 |
+
"- instruction_following: how well does it follow the instruction?\n"
|
| 147 |
+
"Reply ONLY with JSON: {\"helpfulness\": 4, \"honesty\": 5, \"harmlessness\": 5, \"instruction_following\": 4}"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
for step in range(5):
|
| 151 |
+
user_prompt = (
|
| 152 |
+
f"Prompt: {obs.prompt}\n\n"
|
| 153 |
+
f"Response:\n{obs.response}\n\n"
|
| 154 |
+
f"Rubric: {obs.rubric}\n\n"
|
| 155 |
+
"Score this response. Reply with JSON only."
|
| 156 |
+
)
|
| 157 |
+
llm_out = call_llm(SYSTEM, user_prompt)
|
| 158 |
+
parsed = parse_json_response(llm_out, {
|
| 159 |
+
"helpfulness": 3, "honesty": 3, "harmlessness": 4, "instruction_following": 3
|
| 160 |
+
})
|
| 161 |
+
|
| 162 |
+
def clamp(v): return max(1, min(5, int(parsed.get(v, 3))))
|
| 163 |
+
|
| 164 |
+
from models import LikertAction
|
| 165 |
+
action = LikertAction(
|
| 166 |
+
helpfulness=clamp("helpfulness"),
|
| 167 |
+
honesty=clamp("honesty"),
|
| 168 |
+
harmlessness=clamp("harmlessness"),
|
| 169 |
+
instruction_following=clamp("instruction_following"),
|
| 170 |
+
)
|
| 171 |
+
obs, reward, done, info = env_client.step(action)
|
| 172 |
+
|
| 173 |
+
total_reward += reward
|
| 174 |
+
steps += 1
|
| 175 |
+
log_step(task_id, step + 1, parsed, reward)
|
| 176 |
+
|
| 177 |
+
if done:
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
log_end(task_id, total_reward, steps)
|
| 181 |
+
avg = round(total_reward / steps, 4) if steps > 0 else 0.0
|
| 182 |
+
return {"task": "likert", "total_reward": total_reward, "steps": steps, "avg_reward": avg}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def run_task3_consistency(env_client) -> dict[str, Any]:
|
| 186 |
+
"""Run Task 3: Transitive consistency chain ranking."""
|
| 187 |
+
task_id = "task3_consistency"
|
| 188 |
+
log_start("consistency_ranking", task_id)
|
| 189 |
+
|
| 190 |
+
obs = env_client.reset(task_type="consistency")
|
| 191 |
+
total_reward = 0.0
|
| 192 |
+
steps = 0
|
| 193 |
+
|
| 194 |
+
SYSTEM = (
|
| 195 |
+
"You are an expert AI response evaluator. "
|
| 196 |
+
"Rank 4 responses (A, B, C, D) from best to worst. "
|
| 197 |
+
"Reply ONLY with JSON: {\"ranking\": [\"C\", \"A\", \"D\", \"B\"]} "
|
| 198 |
+
"(best first, worst last)."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
for step in range(5):
|
| 202 |
+
user_prompt = (
|
| 203 |
+
f"Prompt: {obs.prompt}\n\n"
|
| 204 |
+
f"Response A:\n{obs.response_a}\n\n"
|
| 205 |
+
f"Response B:\n{obs.response_b}\n\n"
|
| 206 |
+
f"Response C:\n{obs.response_c}\n\n"
|
| 207 |
+
f"Response D:\n{obs.response_d}\n\n"
|
| 208 |
+
"Rank these responses from best to worst. Reply with JSON only."
|
| 209 |
+
)
|
| 210 |
+
llm_out = call_llm(SYSTEM, user_prompt)
|
| 211 |
+
parsed = parse_json_response(llm_out, {"ranking": ["A", "B", "C", "D"]})
|
| 212 |
+
ranking = parsed.get("ranking", ["A", "B", "C", "D"])
|
| 213 |
+
if not isinstance(ranking, list) or len(ranking) != 4:
|
| 214 |
+
ranking = ["A", "B", "C", "D"]
|
| 215 |
+
|
| 216 |
+
from models import ConsistencyAction
|
| 217 |
+
action = ConsistencyAction(ranking=ranking)
|
| 218 |
+
obs, reward, done, info = env_client.step(action)
|
| 219 |
+
|
| 220 |
+
total_reward += reward
|
| 221 |
+
steps += 1
|
| 222 |
+
log_step(task_id, step + 1, {"ranking": ranking}, reward)
|
| 223 |
+
|
| 224 |
+
if done:
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
log_end(task_id, total_reward, steps)
|
| 228 |
+
avg = round(total_reward / steps, 4) if steps > 0 else 0.0
|
| 229 |
+
return {"task": "consistency", "total_reward": total_reward, "steps": steps, "avg_reward": avg}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 233 |
+
|
| 234 |
+
def main():
|
| 235 |
+
print("=" * 60)
|
| 236 |
+
print("PreferenceLab Baseline Inference")
|
| 237 |
+
print(f"Model: {MODEL_NAME}")
|
| 238 |
+
print(f"API URL: {API_BASE_URL}")
|
| 239 |
+
print(f"Env URL: {ENV_BASE_URL}")
|
| 240 |
+
print("=" * 60)
|
| 241 |
+
|
| 242 |
+
# Import environment directly for local run
|
| 243 |
+
# In production, connect to the HF Space via client
|
| 244 |
+
import sys
|
| 245 |
+
sys.path.insert(0, ".")
|
| 246 |
+
from server.environment import PreferenceLabEnvironment
|
| 247 |
+
|
| 248 |
+
env = PreferenceLabEnvironment()
|
| 249 |
+
|
| 250 |
+
results = []
|
| 251 |
+
results.append(run_task1_pairwise(env))
|
| 252 |
+
results.append(run_task2_likert(env))
|
| 253 |
+
results.append(run_task3_consistency(env))
|
| 254 |
+
|
| 255 |
+
print("\n" + "=" * 60)
|
| 256 |
+
print("RESULTS SUMMARY")
|
| 257 |
+
print("=" * 60)
|
| 258 |
+
overall = 0.0
|
| 259 |
+
for r in results:
|
| 260 |
+
print(f" {r['task']:20s} avg_reward={r['avg_reward']:.4f} steps={r['steps']}")
|
| 261 |
+
overall += r["avg_reward"]
|
| 262 |
+
print(f"\n Overall avg reward: {overall / len(results):.4f}")
|
| 263 |
+
print("=" * 60)
|
| 264 |
+
|
| 265 |
+
return results
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for PreferenceLab Environment.
|
| 3 |
+
|
| 4 |
+
PreferenceLab simulates the RLHF preference data collection pipeline.
|
| 5 |
+
An AI agent learns to judge which LLM responses are better using
|
| 6 |
+
gold-standard labels from real preference datasets as ground truth.
|
| 7 |
+
|
| 8 |
+
Tasks:
|
| 9 |
+
Task 1 (Easy): Pairwise ranking β pick A or B
|
| 10 |
+
Task 2 (Medium): Multi-axis Likert scoring
|
| 11 |
+
Task 3 (Hard): Transitive consistency chain ranking
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Literal, Optional
|
| 15 |
+
from pydantic import Field
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_server.types import Action, Observation
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ββ Actions βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
|
| 22 |
+
class PairwiseAction(Action):
|
| 23 |
+
"""Action for Task 1: choose which response is better."""
|
| 24 |
+
|
| 25 |
+
choice: Literal["A", "B", "tie", "skip"] = Field(
|
| 26 |
+
...,
|
| 27 |
+
description="Choose 'A' if response A is better, 'B' if response B is better, "
|
| 28 |
+
"'tie' if they are equal, or 'skip' if undecidable.",
|
| 29 |
+
)
|
| 30 |
+
justification: Optional[str] = Field(
|
| 31 |
+
default=None,
|
| 32 |
+
description="Optional reasoning for the choice (not used for grading).",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LikertAction(Action):
|
| 37 |
+
"""Action for Task 2: score a response on 4 quality axes (1-5 each)."""
|
| 38 |
+
|
| 39 |
+
helpfulness: int = Field(
|
| 40 |
+
..., ge=1, le=5,
|
| 41 |
+
description="How helpful is the response? 1=not helpful, 5=extremely helpful.",
|
| 42 |
+
)
|
| 43 |
+
honesty: int = Field(
|
| 44 |
+
..., ge=1, le=5,
|
| 45 |
+
description="How honest and accurate is the response? 1=misleading, 5=fully honest.",
|
| 46 |
+
)
|
| 47 |
+
harmlessness: int = Field(
|
| 48 |
+
..., ge=1, le=5,
|
| 49 |
+
description="How harmless is the response? 1=harmful, 5=completely safe.",
|
| 50 |
+
)
|
| 51 |
+
instruction_following: int = Field(
|
| 52 |
+
..., ge=1, le=5,
|
| 53 |
+
description="How well does the response follow the instruction? 1=ignores it, 5=follows perfectly.",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ConsistencyAction(Action):
|
| 58 |
+
"""Action for Task 3: rank 4 responses (A, B, C, D) from best to worst."""
|
| 59 |
+
|
| 60 |
+
ranking: list[str] = Field(
|
| 61 |
+
...,
|
| 62 |
+
min_length=4,
|
| 63 |
+
max_length=4,
|
| 64 |
+
description="List of 4 response IDs ordered best to worst, e.g. ['B', 'A', 'D', 'C'].",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ββ Observations ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 69 |
+
|
| 70 |
+
class PairwiseObservation(Observation):
|
| 71 |
+
"""Observation for Task 1: a prompt with two candidate responses."""
|
| 72 |
+
|
| 73 |
+
task_id: str = Field(..., description="Unique task identifier.")
|
| 74 |
+
task_type: Literal["pairwise"] = Field(default="pairwise")
|
| 75 |
+
prompt: str = Field(..., description="The user prompt / instruction.")
|
| 76 |
+
response_a: str = Field(..., description="Candidate response A.")
|
| 77 |
+
response_b: str = Field(..., description="Candidate response B.")
|
| 78 |
+
reward: float = Field(default=0.0, description="Reward signal from last step.")
|
| 79 |
+
done: bool = Field(default=False, description="Whether the episode is complete.")
|
| 80 |
+
step_count: int = Field(default=0, description="Current step within the episode.")
|
| 81 |
+
info: dict = Field(default_factory=dict, description="Extra debug info.")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class LikertObservation(Observation):
|
| 85 |
+
"""Observation for Task 2: a prompt + single response to score on multiple axes."""
|
| 86 |
+
|
| 87 |
+
task_id: str = Field(..., description="Unique task identifier.")
|
| 88 |
+
task_type: Literal["likert"] = Field(default="likert")
|
| 89 |
+
prompt: str = Field(..., description="The user prompt / instruction.")
|
| 90 |
+
response: str = Field(..., description="The response to evaluate.")
|
| 91 |
+
rubric: str = Field(..., description="Scoring rubric to guide evaluation.")
|
| 92 |
+
reward: float = Field(default=0.0, description="Reward signal from last step.")
|
| 93 |
+
done: bool = Field(default=False, description="Whether the episode is complete.")
|
| 94 |
+
step_count: int = Field(default=0, description="Current step within the episode.")
|
| 95 |
+
info: dict = Field(default_factory=dict, description="Extra debug info.")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ConsistencyObservation(Observation):
|
| 99 |
+
"""Observation for Task 3: a prompt + 4 responses to rank transitively."""
|
| 100 |
+
|
| 101 |
+
task_id: str = Field(..., description="Unique task identifier.")
|
| 102 |
+
task_type: Literal["consistency"] = Field(default="consistency")
|
| 103 |
+
prompt: str = Field(..., description="The user prompt / instruction.")
|
| 104 |
+
response_a: str = Field(..., description="Candidate response A.")
|
| 105 |
+
response_b: str = Field(..., description="Candidate response B.")
|
| 106 |
+
response_c: str = Field(..., description="Candidate response C.")
|
| 107 |
+
response_d: str = Field(..., description="Candidate response D.")
|
| 108 |
+
reward: float = Field(default=0.0, description="Reward signal from last step.")
|
| 109 |
+
done: bool = Field(default=False, description="Whether the episode is complete.")
|
| 110 |
+
step_count: int = Field(default=0, description="Current step within the episode.")
|
| 111 |
+
info: dict = Field(default_factory=dict, description="Extra debug info.")
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: preference_lab
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
pyproject.toml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.backends.legacy:build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "preference-lab"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "An OpenEnv environment simulating the RLHF preference data collection pipeline"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = { text = "BSD-3-Clause" }
|
| 11 |
+
requires-python = ">=3.10"
|
| 12 |
+
dependencies = [
|
| 13 |
+
"openenv-core>=0.2.1",
|
| 14 |
+
"fastapi>=0.104.0",
|
| 15 |
+
"uvicorn>=0.24.0",
|
| 16 |
+
"pydantic>=2.0.0",
|
| 17 |
+
"openai>=1.0.0",
|
| 18 |
+
"datasets>=2.14.0",
|
| 19 |
+
"httpx>=0.25.0",
|
| 20 |
+
"websockets>=11.0",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project.optional-dependencies]
|
| 24 |
+
dev = [
|
| 25 |
+
"pytest>=7.0",
|
| 26 |
+
"pytest-asyncio>=0.21",
|
| 27 |
+
"httpx>=0.25.0",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
[tool.setuptools.packages.find]
|
| 31 |
+
where = ["."]
|
| 32 |
+
|
| 33 |
+
[tool.pytest.ini_options]
|
| 34 |
+
asyncio_mode = "auto"
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.1
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
openai>=1.0.0
|
| 6 |
+
datasets>=2.14.0
|
| 7 |
+
httpx>=0.25.0
|
| 8 |
+
websockets>=11.0
|
scripts/__init__.py
ADDED
|
File without changes
|
scripts/prepare_datasets.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Preparation Script.
|
| 3 |
+
|
| 4 |
+
Downloads HH-RLHF, UltraFeedback, and Stanford SHP from Hugging Face
|
| 5 |
+
and converts them into the format expected by PreferenceLab.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/prepare_datasets.py
|
| 9 |
+
python scripts/prepare_datasets.py --samples 200
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import random
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
DATA_DIR = Path(__file__).parent.parent / "data"
|
| 18 |
+
DATA_DIR.mkdir(exist_ok=True)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def prepare_pairwise(n_samples: int = 100):
|
| 22 |
+
"""Download Anthropic HH-RLHF and convert to pairwise format."""
|
| 23 |
+
print(f"[1/3] Preparing pairwise data (HH-RLHF, {n_samples} samples)...")
|
| 24 |
+
try:
|
| 25 |
+
from datasets import load_dataset
|
| 26 |
+
ds = load_dataset("Anthropic/hh-rlhf", split="train", streaming=True)
|
| 27 |
+
records = []
|
| 28 |
+
for i, ex in enumerate(ds):
|
| 29 |
+
if i >= n_samples:
|
| 30 |
+
break
|
| 31 |
+
# chosen = better response, rejected = worse
|
| 32 |
+
chosen = ex.get("chosen", "")
|
| 33 |
+
rejected = ex.get("rejected", "")
|
| 34 |
+
# Extract the last human turn as prompt
|
| 35 |
+
lines = chosen.split("\n\nAssistant:")
|
| 36 |
+
if len(lines) >= 2:
|
| 37 |
+
prompt_block = lines[0].replace("Human:", "").strip()
|
| 38 |
+
resp_a = lines[-1].strip()
|
| 39 |
+
else:
|
| 40 |
+
prompt_block = chosen[:100]
|
| 41 |
+
resp_a = chosen
|
| 42 |
+
|
| 43 |
+
rej_lines = rejected.split("\n\nAssistant:")
|
| 44 |
+
resp_b = rej_lines[-1].strip() if len(rej_lines) >= 2 else rejected
|
| 45 |
+
|
| 46 |
+
# Randomly swap A/B to avoid position bias, track gold
|
| 47 |
+
if random.random() < 0.5:
|
| 48 |
+
records.append({
|
| 49 |
+
"prompt": prompt_block,
|
| 50 |
+
"response_a": resp_a,
|
| 51 |
+
"response_b": resp_b,
|
| 52 |
+
"gold_label": "A",
|
| 53 |
+
"source": "hh-rlhf",
|
| 54 |
+
})
|
| 55 |
+
else:
|
| 56 |
+
records.append({
|
| 57 |
+
"prompt": prompt_block,
|
| 58 |
+
"response_a": resp_b,
|
| 59 |
+
"response_b": resp_a,
|
| 60 |
+
"gold_label": "B",
|
| 61 |
+
"source": "hh-rlhf",
|
| 62 |
+
})
|
| 63 |
+
|
| 64 |
+
out = DATA_DIR / "pairwise_data.json"
|
| 65 |
+
with open(out, "w") as f:
|
| 66 |
+
json.dump(records, f, indent=2)
|
| 67 |
+
print(f" β Saved {len(records)} pairwise examples β {out}")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f" β Failed: {e} β synthetic fallback will be used")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def prepare_likert(n_samples: int = 100):
|
| 73 |
+
"""Download UltraFeedback and convert to likert format."""
|
| 74 |
+
print(f"[2/3] Preparing likert data (UltraFeedback, {n_samples} samples)...")
|
| 75 |
+
try:
|
| 76 |
+
from datasets import load_dataset
|
| 77 |
+
ds = load_dataset("openbmb/UltraFeedback", split="train", streaming=True)
|
| 78 |
+
records = []
|
| 79 |
+
for i, ex in enumerate(ds):
|
| 80 |
+
if i >= n_samples:
|
| 81 |
+
break
|
| 82 |
+
instr = ex.get("instruction", "")
|
| 83 |
+
completions = ex.get("completions", [])
|
| 84 |
+
if not completions:
|
| 85 |
+
continue
|
| 86 |
+
comp = completions[0]
|
| 87 |
+
response = comp.get("response", "")
|
| 88 |
+
annots = comp.get("annotations", {})
|
| 89 |
+
|
| 90 |
+
def extract_score(key, default=3):
|
| 91 |
+
val = annots.get(key, {})
|
| 92 |
+
if isinstance(val, dict):
|
| 93 |
+
raw = val.get("Rating", default)
|
| 94 |
+
elif isinstance(val, (int, float)):
|
| 95 |
+
raw = val
|
| 96 |
+
else:
|
| 97 |
+
raw = default
|
| 98 |
+
# UltraFeedback uses 1-5 scale
|
| 99 |
+
try:
|
| 100 |
+
return max(1, min(5, int(raw)))
|
| 101 |
+
except Exception:
|
| 102 |
+
return default
|
| 103 |
+
|
| 104 |
+
records.append({
|
| 105 |
+
"prompt": instr,
|
| 106 |
+
"response": response,
|
| 107 |
+
"rubric": (
|
| 108 |
+
"Score on 4 axes (1=worst, 5=best): helpfulness, honesty, "
|
| 109 |
+
"harmlessness, instruction_following."
|
| 110 |
+
),
|
| 111 |
+
"gold_scores": {
|
| 112 |
+
"helpfulness": extract_score("instruction_following"),
|
| 113 |
+
"honesty": extract_score("honesty"),
|
| 114 |
+
"harmlessness": extract_score("truthfulness", 4),
|
| 115 |
+
"instruction_following": extract_score("instruction_following"),
|
| 116 |
+
},
|
| 117 |
+
"source": "ultrafeedback",
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
out = DATA_DIR / "likert_data.json"
|
| 121 |
+
with open(out, "w") as f:
|
| 122 |
+
json.dump(records, f, indent=2)
|
| 123 |
+
print(f" β Saved {len(records)} likert examples β {out}")
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f" β Failed: {e} β synthetic fallback will be used")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def prepare_consistency(n_samples: int = 60):
|
| 129 |
+
"""Build 4-way ranking examples from Stanford SHP."""
|
| 130 |
+
print(f"[3/3] Preparing consistency data (Stanford SHP, {n_samples} samples)...")
|
| 131 |
+
try:
|
| 132 |
+
from datasets import load_dataset
|
| 133 |
+
ds = load_dataset("stanfordnlp/SHP", split="train", streaming=True)
|
| 134 |
+
|
| 135 |
+
# Group by post_id to collect multiple responses per prompt
|
| 136 |
+
grouped: dict[str, dict] = {}
|
| 137 |
+
for ex in ds:
|
| 138 |
+
pid = ex.get("post_id", "")
|
| 139 |
+
if pid not in grouped:
|
| 140 |
+
grouped[pid] = {
|
| 141 |
+
"prompt": ex.get("history", ""),
|
| 142 |
+
"responses": [],
|
| 143 |
+
}
|
| 144 |
+
grouped[pid]["responses"].append({
|
| 145 |
+
"text": ex.get("human_ref_A", "") or ex.get("human_ref_B", ""),
|
| 146 |
+
"score": ex.get("score_ratio", 1.0),
|
| 147 |
+
})
|
| 148 |
+
if len(grouped) >= n_samples * 3:
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
records = []
|
| 152 |
+
for pid, data in grouped.items():
|
| 153 |
+
resps = data["responses"]
|
| 154 |
+
if len(resps) < 4:
|
| 155 |
+
continue
|
| 156 |
+
# Sort by score descending = gold ranking
|
| 157 |
+
resps_sorted = sorted(resps[:4], key=lambda x: x["score"], reverse=True)
|
| 158 |
+
labels = ["A", "B", "C", "D"]
|
| 159 |
+
# Shuffle display order (not gold order)
|
| 160 |
+
shuffled = resps_sorted[:]
|
| 161 |
+
random.shuffle(shuffled)
|
| 162 |
+
id_map = {labels[i]: shuffled[i] for i in range(4)}
|
| 163 |
+
gold_ranking = sorted(labels, key=lambda l: resps_sorted.index(id_map[l]))
|
| 164 |
+
|
| 165 |
+
records.append({
|
| 166 |
+
"prompt": data["prompt"][:500],
|
| 167 |
+
"response_a": id_map["A"]["text"][:400],
|
| 168 |
+
"response_b": id_map["B"]["text"][:400],
|
| 169 |
+
"response_c": id_map["C"]["text"][:400],
|
| 170 |
+
"response_d": id_map["D"]["text"][:400],
|
| 171 |
+
"gold_ranking": gold_ranking,
|
| 172 |
+
"source": "stanford-shp",
|
| 173 |
+
})
|
| 174 |
+
if len(records) >= n_samples:
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
out = DATA_DIR / "consistency_data.json"
|
| 178 |
+
with open(out, "w") as f:
|
| 179 |
+
json.dump(records, f, indent=2)
|
| 180 |
+
print(f" β Saved {len(records)} consistency examples β {out}")
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f" β Failed: {e} β synthetic fallback will be used")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def main():
|
| 186 |
+
parser = argparse.ArgumentParser()
|
| 187 |
+
parser.add_argument("--samples", type=int, default=100,
|
| 188 |
+
help="Number of samples per task (default: 100)")
|
| 189 |
+
args = parser.parse_args()
|
| 190 |
+
|
| 191 |
+
print("=" * 50)
|
| 192 |
+
print("PreferenceLab Dataset Preparation")
|
| 193 |
+
print("=" * 50)
|
| 194 |
+
prepare_pairwise(args.samples)
|
| 195 |
+
prepare_likert(args.samples)
|
| 196 |
+
prepare_consistency(args.samples // 2)
|
| 197 |
+
print("\nβ Done. Run inference.py to test.")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
main()
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# server package
|
server/app.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PreferenceLab FastAPI Server.
|
| 3 |
+
|
| 4 |
+
Exposes the PreferenceLabEnvironment via the OpenEnv HTTP interface.
|
| 5 |
+
Supports concurrent sessions for parallel training.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from openenv.core.env_server import create_app
|
| 11 |
+
|
| 12 |
+
from models import (
|
| 13 |
+
ConsistencyAction,
|
| 14 |
+
ConsistencyObservation,
|
| 15 |
+
LikertAction,
|
| 16 |
+
LikertObservation,
|
| 17 |
+
PairwiseAction,
|
| 18 |
+
PairwiseObservation,
|
| 19 |
+
)
|
| 20 |
+
from server.environment import PreferenceLabEnvironment
|
| 21 |
+
|
| 22 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 23 |
+
MAX_CONCURRENT_ENVS = int(os.environ.get("MAX_CONCURRENT_ENVS", "64"))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def create_environment() -> PreferenceLabEnvironment:
|
| 27 |
+
"""Factory function β called once per session."""
|
| 28 |
+
return PreferenceLabEnvironment()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Use PairwiseAction/PairwiseObservation as the primary schema.
|
| 32 |
+
# The environment internally handles all three task types.
|
| 33 |
+
app = create_app(
|
| 34 |
+
create_environment,
|
| 35 |
+
PairwiseAction,
|
| 36 |
+
PairwiseObservation,
|
| 37 |
+
max_concurrent_envs=MAX_CONCURRENT_ENVS,
|
| 38 |
+
)
|
server/environment.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PreferenceLab Core Environment.
|
| 3 |
+
|
| 4 |
+
Implements the OpenEnv Environment base class with:
|
| 5 |
+
- reset() β returns initial observation
|
| 6 |
+
- step() β executes action, returns (observation, reward, done, info)
|
| 7 |
+
- state() β returns episode metadata
|
| 8 |
+
|
| 9 |
+
Three tasks:
|
| 10 |
+
Task 1 (pairwise) - Easy: pairwise choice graded against HH-RLHF gold labels
|
| 11 |
+
Task 2 (likert) - Medium: multi-axis scoring graded via MSE vs UltraFeedback scores
|
| 12 |
+
Task 3 (consistency) - Hard: 4-way ranking graded on transitivity + quality correlation
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import random
|
| 17 |
+
import uuid
|
| 18 |
+
from itertools import permutations
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
from openenv.core.env_server import Environment
|
| 23 |
+
|
| 24 |
+
from models import (
|
| 25 |
+
ConsistencyAction,
|
| 26 |
+
ConsistencyObservation,
|
| 27 |
+
LikertAction,
|
| 28 |
+
LikertObservation,
|
| 29 |
+
PairwiseAction,
|
| 30 |
+
PairwiseObservation,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# ββ Dataset loading ββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
|
| 35 |
+
DATA_DIR = Path(__file__).parent.parent / "data"
|
| 36 |
+
|
| 37 |
+
def _load_json(filename: str) -> list[dict]:
|
| 38 |
+
path = DATA_DIR / filename
|
| 39 |
+
if path.exists():
|
| 40 |
+
with open(path) as f:
|
| 41 |
+
return json.load(f)
|
| 42 |
+
return []
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ββ Graders βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
|
| 47 |
+
def grade_pairwise(action: PairwiseAction, example: dict) -> tuple[float, dict]:
|
| 48 |
+
"""
|
| 49 |
+
Grade Task 1: Pairwise ranking.
|
| 50 |
+
|
| 51 |
+
Gold label is 'A' (chosen) or 'B' (rejected) from dataset.
|
| 52 |
+
Returns:
|
| 53 |
+
1.0 β correct choice
|
| 54 |
+
0.3 β skip (abstain β partial credit)
|
| 55 |
+
0.0 β wrong choice
|
| 56 |
+
0.1 β tie (when gold is clear)
|
| 57 |
+
"""
|
| 58 |
+
gold = example.get("gold_label", "A") # 'A' = chosen is response_a
|
| 59 |
+
choice = action.choice
|
| 60 |
+
|
| 61 |
+
if choice == "skip":
|
| 62 |
+
reward = 0.3
|
| 63 |
+
verdict = "abstained"
|
| 64 |
+
elif choice == "tie":
|
| 65 |
+
reward = 0.1
|
| 66 |
+
verdict = "tie_when_clear"
|
| 67 |
+
elif choice == gold:
|
| 68 |
+
reward = 1.0
|
| 69 |
+
verdict = "correct"
|
| 70 |
+
else:
|
| 71 |
+
reward = 0.0
|
| 72 |
+
verdict = "incorrect"
|
| 73 |
+
|
| 74 |
+
return reward, {
|
| 75 |
+
"gold": gold,
|
| 76 |
+
"chosen": choice,
|
| 77 |
+
"verdict": verdict,
|
| 78 |
+
"dataset": example.get("source", "hh-rlhf"),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def grade_likert(action: LikertAction, example: dict) -> tuple[float, dict]:
|
| 83 |
+
"""
|
| 84 |
+
Grade Task 2: Multi-axis Likert scoring.
|
| 85 |
+
|
| 86 |
+
Compares agent's 4-axis scores to gold scores from UltraFeedback.
|
| 87 |
+
Reward = 1.0 - (mean_absolute_error / max_possible_error)
|
| 88 |
+
Max possible error per axis = 4 (1 vs 5), so max_total = 4.
|
| 89 |
+
"""
|
| 90 |
+
gold_scores = example.get("gold_scores", {
|
| 91 |
+
"helpfulness": 3,
|
| 92 |
+
"honesty": 3,
|
| 93 |
+
"harmlessness": 4,
|
| 94 |
+
"instruction_following": 3,
|
| 95 |
+
})
|
| 96 |
+
|
| 97 |
+
axes = ["helpfulness", "honesty", "harmlessness", "instruction_following"]
|
| 98 |
+
agent_scores = {
|
| 99 |
+
"helpfulness": action.helpfulness,
|
| 100 |
+
"honesty": action.honesty,
|
| 101 |
+
"harmlessness": action.harmlessness,
|
| 102 |
+
"instruction_following": action.instruction_following,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
errors = []
|
| 106 |
+
per_axis = {}
|
| 107 |
+
for ax in axes:
|
| 108 |
+
err = abs(agent_scores[ax] - gold_scores.get(ax, 3))
|
| 109 |
+
errors.append(err)
|
| 110 |
+
per_axis[ax] = {"agent": agent_scores[ax], "gold": gold_scores.get(ax, 3), "error": err}
|
| 111 |
+
|
| 112 |
+
mae = sum(errors) / len(errors)
|
| 113 |
+
max_error = 4.0 # max abs difference on 1-5 scale
|
| 114 |
+
reward = round(1.0 - (mae / max_error), 4)
|
| 115 |
+
reward = max(0.0, min(1.0, reward))
|
| 116 |
+
|
| 117 |
+
return reward, {
|
| 118 |
+
"mae": round(mae, 4),
|
| 119 |
+
"per_axis": per_axis,
|
| 120 |
+
"dataset": example.get("source", "ultrafeedback"),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def grade_consistency(action: ConsistencyAction, example: dict) -> tuple[float, dict]:
|
| 125 |
+
"""
|
| 126 |
+
Grade Task 3: Transitive consistency chain ranking.
|
| 127 |
+
|
| 128 |
+
Scoring components:
|
| 129 |
+
- Transitivity score (0.0β0.5): penalise transitive violations in the ranking
|
| 130 |
+
- Quality correlation (0.0β0.5): Kendall's tau vs gold ranking
|
| 131 |
+
Total reward = transitivity_score + quality_score (max 1.0)
|
| 132 |
+
"""
|
| 133 |
+
ranking = action.ranking
|
| 134 |
+
gold_ranking = example.get("gold_ranking", ["A", "B", "C", "D"])
|
| 135 |
+
|
| 136 |
+
# --- Transitivity check ---
|
| 137 |
+
# For each triple (i, j, k) where i < j < k in the agent's ranking,
|
| 138 |
+
# verify that position(i) < position(j) < position(k) doesn't violate transitivity.
|
| 139 |
+
# A violation = agent says A > B and B > C but NOT A > C.
|
| 140 |
+
# Since ranking is a total order, by construction it IS transitive. But we can
|
| 141 |
+
# still penalise if ranking contains duplicates or invalid IDs.
|
| 142 |
+
valid_ids = {"A", "B", "C", "D"}
|
| 143 |
+
has_invalid = not (set(ranking) == valid_ids)
|
| 144 |
+
transitivity_score = 0.0 if has_invalid else 0.5
|
| 145 |
+
|
| 146 |
+
# --- Quality correlation (Kendall's tau, simplified) ---
|
| 147 |
+
if has_invalid:
|
| 148 |
+
quality_score = 0.0
|
| 149 |
+
n_concordant = 0
|
| 150 |
+
n_discordant = 0
|
| 151 |
+
else:
|
| 152 |
+
ids = ["A", "B", "C", "D"]
|
| 153 |
+
agent_pos = {r: i for i, r in enumerate(ranking)}
|
| 154 |
+
gold_pos = {r: i for i, r in enumerate(gold_ranking)}
|
| 155 |
+
|
| 156 |
+
n_concordant = 0
|
| 157 |
+
n_discordant = 0
|
| 158 |
+
pairs = [(ids[i], ids[j]) for i in range(4) for j in range(i + 1, 4)]
|
| 159 |
+
for x, y in pairs:
|
| 160 |
+
agent_order = agent_pos[x] < agent_pos[y]
|
| 161 |
+
gold_order = gold_pos[x] < gold_pos[y]
|
| 162 |
+
if agent_order == gold_order:
|
| 163 |
+
n_concordant += 1
|
| 164 |
+
else:
|
| 165 |
+
n_discordant += 1
|
| 166 |
+
|
| 167 |
+
total_pairs = n_concordant + n_discordant
|
| 168 |
+
tau = (n_concordant - n_discordant) / total_pairs if total_pairs > 0 else 0.0
|
| 169 |
+
# Normalise tau from [-1,1] to [0, 0.5]
|
| 170 |
+
quality_score = round((tau + 1.0) / 2.0 * 0.5, 4)
|
| 171 |
+
|
| 172 |
+
reward = round(transitivity_score + quality_score, 4)
|
| 173 |
+
reward = max(0.0, min(1.0, reward))
|
| 174 |
+
|
| 175 |
+
return reward, {
|
| 176 |
+
"transitivity_score": transitivity_score,
|
| 177 |
+
"quality_score": quality_score if not has_invalid else 0.0,
|
| 178 |
+
"agent_ranking": ranking,
|
| 179 |
+
"gold_ranking": gold_ranking,
|
| 180 |
+
"has_invalid_ids": has_invalid,
|
| 181 |
+
"dataset": example.get("source", "stanford-shp"),
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ββ Environment βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 186 |
+
|
| 187 |
+
TASK_TYPES = ["pairwise", "likert", "consistency"]
|
| 188 |
+
MAX_STEPS_PER_EPISODE = 5
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class PreferenceLabEnvironment(Environment):
|
| 192 |
+
"""
|
| 193 |
+
PreferenceLab: An RL environment simulating the RLHF preference
|
| 194 |
+
data collection pipeline.
|
| 195 |
+
|
| 196 |
+
Each episode consists of MAX_STEPS_PER_EPISODE annotation steps.
|
| 197 |
+
The task type is fixed per episode (chosen at reset).
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self):
|
| 201 |
+
self._episode_id: str = ""
|
| 202 |
+
self._step_count: int = 0
|
| 203 |
+
self._task_type: str = "pairwise"
|
| 204 |
+
self._current_example: dict = {}
|
| 205 |
+
self._cumulative_reward: float = 0.0
|
| 206 |
+
self._seed: int = 0
|
| 207 |
+
|
| 208 |
+
# Load datasets
|
| 209 |
+
self._pairwise_data: list[dict] = _load_json("pairwise_data.json")
|
| 210 |
+
self._likert_data: list[dict] = _load_json("likert_data.json")
|
| 211 |
+
self._consistency_data: list[dict] = _load_json("consistency_data.json")
|
| 212 |
+
|
| 213 |
+
# Fallback synthetic data if files not present
|
| 214 |
+
if not self._pairwise_data:
|
| 215 |
+
self._pairwise_data = _synthetic_pairwise()
|
| 216 |
+
if not self._likert_data:
|
| 217 |
+
self._likert_data = _synthetic_likert()
|
| 218 |
+
if not self._consistency_data:
|
| 219 |
+
self._consistency_data = _synthetic_consistency()
|
| 220 |
+
|
| 221 |
+
# ββ OpenEnv API βββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
+
|
| 223 |
+
def reset(self, seed: int | None = None, episode_id: str | None = None, **kwargs):
|
| 224 |
+
"""
|
| 225 |
+
Reset the environment for a new episode.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
seed: Optional random seed for reproducibility.
|
| 229 |
+
episode_id: Optional episode ID override.
|
| 230 |
+
**kwargs: Accepts task_type ('pairwise', 'likert', 'consistency').
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Initial observation for the episode.
|
| 234 |
+
"""
|
| 235 |
+
task_type = kwargs.get("task_type", None)
|
| 236 |
+
self._seed = seed if seed is not None else random.randint(0, 10_000)
|
| 237 |
+
rng = random.Random(self._seed)
|
| 238 |
+
|
| 239 |
+
self._episode_id = episode_id or str(uuid.uuid4())
|
| 240 |
+
self._step_count = 0
|
| 241 |
+
self._cumulative_reward = 0.0
|
| 242 |
+
self._task_type = task_type if task_type in TASK_TYPES else rng.choice(TASK_TYPES)
|
| 243 |
+
|
| 244 |
+
self._current_example = self._sample_example(rng)
|
| 245 |
+
|
| 246 |
+
return self._build_observation(reward=0.0, done=False, info={"reset": True})
|
| 247 |
+
|
| 248 |
+
def step(self, action, timeout_s: float | None = None, **kwargs):
|
| 249 |
+
"""
|
| 250 |
+
Execute one annotation step.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
action: A PairwiseAction, LikertAction, or ConsistencyAction.
|
| 254 |
+
timeout_s: Unused β required by base class signature.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
Tuple of (observation, reward, done, info).
|
| 258 |
+
"""
|
| 259 |
+
self._step_count += 1
|
| 260 |
+
|
| 261 |
+
# Grade the action
|
| 262 |
+
reward, info = self._grade(action)
|
| 263 |
+
self._cumulative_reward += reward
|
| 264 |
+
|
| 265 |
+
done = self._step_count >= MAX_STEPS_PER_EPISODE
|
| 266 |
+
|
| 267 |
+
# Sample next example if not done
|
| 268 |
+
if not done:
|
| 269 |
+
rng = random.Random(self._seed + self._step_count)
|
| 270 |
+
self._current_example = self._sample_example(rng)
|
| 271 |
+
|
| 272 |
+
obs = self._build_observation(reward=reward, done=done, info=info)
|
| 273 |
+
return obs, reward, done, info
|
| 274 |
+
|
| 275 |
+
def state(self) -> dict[str, Any]:
|
| 276 |
+
"""Return current episode metadata."""
|
| 277 |
+
return {
|
| 278 |
+
"episode_id": self._episode_id,
|
| 279 |
+
"step_count": self._step_count,
|
| 280 |
+
"task_type": self._task_type,
|
| 281 |
+
"cumulative_reward": round(self._cumulative_reward, 4),
|
| 282 |
+
"max_steps": MAX_STEPS_PER_EPISODE,
|
| 283 |
+
"seed": self._seed,
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
# ββ Internal helpers ββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββββββ
|
| 287 |
+
|
| 288 |
+
def _sample_example(self, rng: random.Random) -> dict:
|
| 289 |
+
"""Sample one example from the appropriate dataset."""
|
| 290 |
+
dataset = {
|
| 291 |
+
"pairwise": self._pairwise_data,
|
| 292 |
+
"likert": self._likert_data,
|
| 293 |
+
"consistency": self._consistency_data,
|
| 294 |
+
}[self._task_type]
|
| 295 |
+
return rng.choice(dataset)
|
| 296 |
+
|
| 297 |
+
def _grade(self, action) -> tuple[float, dict]:
|
| 298 |
+
"""Dispatch to the correct grader based on task type."""
|
| 299 |
+
if self._task_type == "pairwise":
|
| 300 |
+
return grade_pairwise(action, self._current_example)
|
| 301 |
+
elif self._task_type == "likert":
|
| 302 |
+
return grade_likert(action, self._current_example)
|
| 303 |
+
elif self._task_type == "consistency":
|
| 304 |
+
return grade_consistency(action, self._current_example)
|
| 305 |
+
return 0.0, {"error": "unknown_task"}
|
| 306 |
+
|
| 307 |
+
def _build_observation(self, reward: float, done: bool, info: dict):
|
| 308 |
+
"""Build the appropriate observation type for the current task."""
|
| 309 |
+
ex = self._current_example
|
| 310 |
+
base = {
|
| 311 |
+
"task_id": self._episode_id + f"_step{self._step_count}",
|
| 312 |
+
"reward": reward,
|
| 313 |
+
"done": done,
|
| 314 |
+
"step_count": self._step_count,
|
| 315 |
+
"info": info,
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
if self._task_type == "pairwise":
|
| 319 |
+
return PairwiseObservation(
|
| 320 |
+
**base,
|
| 321 |
+
prompt=ex.get("prompt", ""),
|
| 322 |
+
response_a=ex.get("response_a", ""),
|
| 323 |
+
response_b=ex.get("response_b", ""),
|
| 324 |
+
)
|
| 325 |
+
elif self._task_type == "likert":
|
| 326 |
+
return LikertObservation(
|
| 327 |
+
**base,
|
| 328 |
+
prompt=ex.get("prompt", ""),
|
| 329 |
+
response=ex.get("response", ""),
|
| 330 |
+
rubric=ex.get("rubric", DEFAULT_LIKERT_RUBRIC),
|
| 331 |
+
)
|
| 332 |
+
elif self._task_type == "consistency":
|
| 333 |
+
return ConsistencyObservation(
|
| 334 |
+
**base,
|
| 335 |
+
prompt=ex.get("prompt", ""),
|
| 336 |
+
response_a=ex.get("response_a", ""),
|
| 337 |
+
response_b=ex.get("response_b", ""),
|
| 338 |
+
response_c=ex.get("response_c", ""),
|
| 339 |
+
response_d=ex.get("response_d", ""),
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
raise ValueError(f"Unknown task type: {self._task_type}")
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# ββ Defaults ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 346 |
+
|
| 347 |
+
DEFAULT_LIKERT_RUBRIC = (
|
| 348 |
+
"Score the response on 4 axes (1=worst, 5=best): "
|
| 349 |
+
"Helpfulness β does it answer the question usefully? "
|
| 350 |
+
"Honesty β is it accurate and truthful? "
|
| 351 |
+
"Harmlessness β does it avoid harm? "
|
| 352 |
+
"Instruction Following β does it follow the instruction precisely?"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# ββ Synthetic fallback datasets βββββββββββββββββββββββββββββββ
|
| 357 |
+
|
| 358 |
+
def _synthetic_pairwise() -> list[dict]:
|
| 359 |
+
return [
|
| 360 |
+
{
|
| 361 |
+
"prompt": "Explain what machine learning is in simple terms.",
|
| 362 |
+
"response_a": "Machine learning is a branch of AI where computers learn patterns from data without being explicitly programmed for each task.",
|
| 363 |
+
"response_b": "Machine learning is when computers do stuff with numbers to make predictions.",
|
| 364 |
+
"gold_label": "A",
|
| 365 |
+
"source": "synthetic",
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"prompt": "How do I improve my sleep quality?",
|
| 369 |
+
"response_a": "Try sleeping more.",
|
| 370 |
+
"response_b": "Maintain a consistent sleep schedule, avoid screens 1 hour before bed, keep your bedroom cool and dark, and limit caffeine after 2pm.",
|
| 371 |
+
"gold_label": "B",
|
| 372 |
+
"source": "synthetic",
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"prompt": "What is the capital of France?",
|
| 376 |
+
"response_a": "Paris is the capital and largest city of France.",
|
| 377 |
+
"response_b": "France's capital city is called Paris, it is located in northern France.",
|
| 378 |
+
"gold_label": "A",
|
| 379 |
+
"source": "synthetic",
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"prompt": "Write a haiku about autumn.",
|
| 383 |
+
"response_a": "Leaves fall silently / Crimson whispers touch the ground / Winter draws near now",
|
| 384 |
+
"response_b": "Autumn is a season. Leaves fall down. It gets cold outside.",
|
| 385 |
+
"gold_label": "A",
|
| 386 |
+
"source": "synthetic",
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"prompt": "How do I reverse a list in Python?",
|
| 390 |
+
"response_a": "Use my_list.reverse() to reverse in-place, or my_list[::-1] to get a reversed copy.",
|
| 391 |
+
"response_b": "You can just use the reverse function.",
|
| 392 |
+
"gold_label": "A",
|
| 393 |
+
"source": "synthetic",
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"prompt": "Is it safe to eat raw eggs?",
|
| 397 |
+
"response_a": "Raw eggs carry a risk of Salmonella contamination. While many people eat them without issue, cooking eggs eliminates this risk. Use pasteurised eggs if you want them raw.",
|
| 398 |
+
"response_b": "Yeah raw eggs are totally fine to eat, bodybuilders do it all the time.",
|
| 399 |
+
"gold_label": "A",
|
| 400 |
+
"source": "synthetic",
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"prompt": "Summarise the French Revolution in 2 sentences.",
|
| 404 |
+
"response_a": "The French Revolution (1789β1799) was a period of radical political and social transformation in France that overthrew the monarchy, established a republic, and culminated in Napoleon's rise to power. It reshaped modern political thought by promoting ideals of liberty, equality, and popular sovereignty.",
|
| 405 |
+
"response_b": "The French Revolution happened in France. People revolted against the king.",
|
| 406 |
+
"gold_label": "A",
|
| 407 |
+
"source": "synthetic",
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"prompt": "What is the difference between RAM and storage?",
|
| 411 |
+
"response_a": "RAM is temporary memory your computer uses to run active programs β it's fast but clears when you shut down. Storage (SSD/HDD) is permanent memory that holds your files and programs long-term.",
|
| 412 |
+
"response_b": "RAM stores temporary data while the computer is running, whereas storage permanently saves files and data even when powered off.",
|
| 413 |
+
"gold_label": "A",
|
| 414 |
+
"source": "synthetic",
|
| 415 |
+
},
|
| 416 |
+
]
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _synthetic_likert() -> list[dict]:
|
| 420 |
+
return [
|
| 421 |
+
{
|
| 422 |
+
"prompt": "Explain photosynthesis to a 10-year-old.",
|
| 423 |
+
"response": "Photosynthesis is how plants make food! They use sunlight, water from the soil, and carbon dioxide from the air. Inside their leaves, they turn all this into glucose (their food) and release oxygen β which is the air we breathe!",
|
| 424 |
+
"rubric": DEFAULT_LIKERT_RUBRIC,
|
| 425 |
+
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
|
| 426 |
+
"source": "synthetic",
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"prompt": "What are three tips for reducing stress?",
|
| 430 |
+
"response": "Here are some tips: exercise regularly, practice meditation, and get enough sleep. These have been shown to reduce cortisol levels and improve mood.",
|
| 431 |
+
"rubric": DEFAULT_LIKERT_RUBRIC,
|
| 432 |
+
"gold_scores": {"helpfulness": 4, "honesty": 5, "harmlessness": 5, "instruction_following": 4},
|
| 433 |
+
"source": "synthetic",
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
"prompt": "How do vaccines work?",
|
| 437 |
+
"response": "Vaccines train your immune system by introducing a harmless piece of a pathogen (or instructions to make one). Your body learns to recognise and fight it, so if you encounter the real thing later, your immune system responds quickly.",
|
| 438 |
+
"rubric": DEFAULT_LIKERT_RUBRIC,
|
| 439 |
+
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
|
| 440 |
+
"source": "synthetic",
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"prompt": "Write a professional email declining a job offer.",
|
| 444 |
+
"response": "Subject: Re: Job Offer β [Position Name]\n\nDear [Hiring Manager],\n\nThank you so much for offering me the [Position] role at [Company]. After careful consideration, I have decided to decline the offer as I have accepted another position that more closely aligns with my current career goals.\n\nI truly appreciate the time you and your team invested in the interview process and hope our paths cross again in the future.\n\nBest regards,\n[Your Name]",
|
| 445 |
+
"rubric": DEFAULT_LIKERT_RUBRIC,
|
| 446 |
+
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
|
| 447 |
+
"source": "synthetic",
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"prompt": "Explain blockchain in simple terms.",
|
| 451 |
+
"response": "A blockchain is like a shared notebook that thousands of computers all keep a copy of. Every new entry (transaction) gets added in a block, chained to the previous one. Because everyone has a copy, no single person can secretly change it.",
|
| 452 |
+
"rubric": DEFAULT_LIKERT_RUBRIC,
|
| 453 |
+
"gold_scores": {"helpfulness": 5, "honesty": 4, "harmlessness": 5, "instruction_following": 5},
|
| 454 |
+
"source": "synthetic",
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
"prompt": "List 5 healthy breakfast options.",
|
| 458 |
+
"response": "1. Oatmeal with berries and nuts\n2. Greek yogurt with honey and banana\n3. Avocado toast with eggs\n4. Smoothie with spinach, protein powder, and almond milk\n5. Whole grain cereal with low-fat milk",
|
| 459 |
+
"rubric": DEFAULT_LIKERT_RUBRIC,
|
| 460 |
+
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
|
| 461 |
+
"source": "synthetic",
|
| 462 |
+
},
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _synthetic_consistency() -> list[dict]:
|
| 467 |
+
return [
|
| 468 |
+
{
|
| 469 |
+
"prompt": "Explain how to use Python decorators.",
|
| 470 |
+
"response_a": "Decorators are functions that wrap other functions to add behaviour. Use @decorator_name above a function definition. Example: @staticmethod, @property, or custom ones with functools.wraps.",
|
| 471 |
+
"response_b": "Decorators wrap functions.",
|
| 472 |
+
"response_c": "Python decorators use the @ symbol and are a design pattern for extending function behavior without modifying the function itself. They take a function as input and return a modified version.",
|
| 473 |
+
"response_d": "You put @ before a function name.",
|
| 474 |
+
"gold_ranking": ["C", "A", "B", "D"],
|
| 475 |
+
"source": "synthetic",
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
"prompt": "What causes climate change?",
|
| 479 |
+
"response_a": "Climate change is primarily caused by human activities that release greenhouse gases β mainly CO2 from burning fossil fuels, methane from agriculture and landfills, and N2O from fertilisers. These gases trap heat in the atmosphere.",
|
| 480 |
+
"response_b": "The sun causes climate change.",
|
| 481 |
+
"response_c": "Many factors contribute to climate change including greenhouse gas emissions from industry, deforestation which reduces carbon absorption, and agricultural practices. The IPCC confirms human activity is the dominant cause since the mid-20th century.",
|
| 482 |
+
"response_d": "Climate change happens because of pollution.",
|
| 483 |
+
"gold_ranking": ["C", "A", "D", "B"],
|
| 484 |
+
"source": "synthetic",
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"prompt": "How does the internet work?",
|
| 488 |
+
"response_a": "The internet is a global network of computers connected via physical cables (fiber, copper) and wireless signals. Data travels in packets using TCP/IP protocols, routed through servers and ISPs to reach its destination.",
|
| 489 |
+
"response_b": "Computers connect together and send data.",
|
| 490 |
+
"response_c": "Internet works through IP addresses.",
|
| 491 |
+
"response_d": "The internet is a massive network where data is broken into packets, routed through interconnected servers using protocols like TCP/IP and HTTP, and reassembled at the destination. DNS translates domain names to IP addresses.",
|
| 492 |
+
"gold_ranking": ["D", "A", "C", "B"],
|
| 493 |
+
"source": "synthetic",
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"prompt": "Describe the water cycle.",
|
| 497 |
+
"response_a": "The water cycle involves evaporation, condensation, and precipitation. Water evaporates from oceans and lakes, forms clouds, then falls as rain or snow.",
|
| 498 |
+
"response_b": "Water goes up and comes down.",
|
| 499 |
+
"response_c": "The water cycle (hydrological cycle) is the continuous movement of water through Earth's systems: evaporation from surface water, transpiration from plants, condensation into clouds, precipitation, surface runoff, and groundwater infiltration before returning to oceans.",
|
| 500 |
+
"response_d": "Water evaporates and rains.",
|
| 501 |
+
"gold_ranking": ["C", "A", "D", "B"],
|
| 502 |
+
"source": "synthetic",
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"prompt": "Explain the difference between HTTP and HTTPS.",
|
| 506 |
+
"response_a": "HTTPS is like HTTP but secure.",
|
| 507 |
+
"response_b": "HTTP is the protocol for transferring web data. HTTPS adds SSL/TLS encryption, meaning data is encrypted in transit. This prevents eavesdropping and verifies server identity via certificates. Always use HTTPS for sensitive data.",
|
| 508 |
+
"response_c": "HTTP transfers web pages. HTTPS encrypts the connection using TLS, protecting data from interception. The S stands for Secure.",
|
| 509 |
+
"response_d": "HTTPS has a padlock icon in browsers.",
|
| 510 |
+
"gold_ranking": ["B", "C", "A", "D"],
|
| 511 |
+
"source": "synthetic",
|
| 512 |
+
},
|
| 513 |
+
]
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for PreferenceLab environment.
|
| 3 |
+
Run: pytest tests/ -v
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
from models import (
|
| 12 |
+
PairwiseAction, LikertAction, ConsistencyAction,
|
| 13 |
+
PairwiseObservation, LikertObservation, ConsistencyObservation,
|
| 14 |
+
)
|
| 15 |
+
from server.environment import (
|
| 16 |
+
PreferenceLabEnvironment,
|
| 17 |
+
grade_pairwise, grade_likert, grade_consistency,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ββ Grader unit tests βββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
|
| 23 |
+
class TestPairwiseGrader:
|
| 24 |
+
def test_correct_choice_scores_1(self):
|
| 25 |
+
action = PairwiseAction(choice="A")
|
| 26 |
+
example = {"gold_label": "A", "source": "test"}
|
| 27 |
+
reward, info = grade_pairwise(action, example)
|
| 28 |
+
assert reward == 1.0
|
| 29 |
+
assert info["verdict"] == "correct"
|
| 30 |
+
|
| 31 |
+
def test_wrong_choice_scores_0(self):
|
| 32 |
+
action = PairwiseAction(choice="B")
|
| 33 |
+
example = {"gold_label": "A", "source": "test"}
|
| 34 |
+
reward, info = grade_pairwise(action, example)
|
| 35 |
+
assert reward == 0.0
|
| 36 |
+
assert info["verdict"] == "incorrect"
|
| 37 |
+
|
| 38 |
+
def test_skip_scores_partial(self):
|
| 39 |
+
action = PairwiseAction(choice="skip")
|
| 40 |
+
example = {"gold_label": "A", "source": "test"}
|
| 41 |
+
reward, info = grade_pairwise(action, example)
|
| 42 |
+
assert reward == 0.3
|
| 43 |
+
|
| 44 |
+
def test_tie_scores_low(self):
|
| 45 |
+
action = PairwiseAction(choice="tie")
|
| 46 |
+
example = {"gold_label": "A", "source": "test"}
|
| 47 |
+
reward, info = grade_pairwise(action, example)
|
| 48 |
+
assert reward == 0.1
|
| 49 |
+
|
| 50 |
+
def test_reward_in_range(self):
|
| 51 |
+
for choice in ["A", "B", "tie", "skip"]:
|
| 52 |
+
action = PairwiseAction(choice=choice)
|
| 53 |
+
reward, _ = grade_pairwise(action, {"gold_label": "A", "source": "test"})
|
| 54 |
+
assert 0.0 <= reward <= 1.0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TestLikertGrader:
|
| 58 |
+
def test_perfect_scores_reward_1(self):
|
| 59 |
+
action = LikertAction(helpfulness=5, honesty=5, harmlessness=5, instruction_following=5)
|
| 60 |
+
example = {
|
| 61 |
+
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
|
| 62 |
+
"source": "test",
|
| 63 |
+
}
|
| 64 |
+
reward, info = grade_likert(action, example)
|
| 65 |
+
assert reward == 1.0
|
| 66 |
+
|
| 67 |
+
def test_worst_scores_reward_0(self):
|
| 68 |
+
action = LikertAction(helpfulness=1, honesty=1, harmlessness=1, instruction_following=1)
|
| 69 |
+
example = {
|
| 70 |
+
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
|
| 71 |
+
"source": "test",
|
| 72 |
+
}
|
| 73 |
+
reward, info = grade_likert(action, example)
|
| 74 |
+
assert reward == 0.0
|
| 75 |
+
|
| 76 |
+
def test_partial_error_gives_partial_reward(self):
|
| 77 |
+
action = LikertAction(helpfulness=4, honesty=4, harmlessness=4, instruction_following=4)
|
| 78 |
+
example = {
|
| 79 |
+
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
|
| 80 |
+
"source": "test",
|
| 81 |
+
}
|
| 82 |
+
reward, info = grade_likert(action, example)
|
| 83 |
+
assert 0.0 < reward < 1.0
|
| 84 |
+
|
| 85 |
+
def test_reward_always_in_range(self):
|
| 86 |
+
import random
|
| 87 |
+
for _ in range(20):
|
| 88 |
+
action = LikertAction(
|
| 89 |
+
helpfulness=random.randint(1, 5),
|
| 90 |
+
honesty=random.randint(1, 5),
|
| 91 |
+
harmlessness=random.randint(1, 5),
|
| 92 |
+
instruction_following=random.randint(1, 5),
|
| 93 |
+
)
|
| 94 |
+
example = {
|
| 95 |
+
"gold_scores": {
|
| 96 |
+
"helpfulness": random.randint(1, 5),
|
| 97 |
+
"honesty": random.randint(1, 5),
|
| 98 |
+
"harmlessness": random.randint(1, 5),
|
| 99 |
+
"instruction_following": random.randint(1, 5),
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
reward, _ = grade_likert(action, example)
|
| 103 |
+
assert 0.0 <= reward <= 1.0, f"Reward out of range: {reward}"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TestConsistencyGrader:
|
| 107 |
+
def test_perfect_ranking_scores_1(self):
|
| 108 |
+
action = ConsistencyAction(ranking=["A", "B", "C", "D"])
|
| 109 |
+
example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
|
| 110 |
+
reward, info = grade_consistency(action, example)
|
| 111 |
+
assert reward == 1.0
|
| 112 |
+
|
| 113 |
+
def test_reversed_ranking_scores_low(self):
|
| 114 |
+
action = ConsistencyAction(ranking=["D", "C", "B", "A"])
|
| 115 |
+
example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
|
| 116 |
+
reward, info = grade_consistency(action, example)
|
| 117 |
+
# Transitivity score = 0.5 (ranking is still a valid total order)
|
| 118 |
+
# Quality score = 0.0 (worst possible Kendall tau = -1 β normalized to 0)
|
| 119 |
+
# Total = 0.5 β strictly less than perfect score of 1.0
|
| 120 |
+
assert reward < 1.0
|
| 121 |
+
assert info["quality_score"] == 0.0
|
| 122 |
+
|
| 123 |
+
def test_invalid_ids_scores_low(self):
|
| 124 |
+
action = ConsistencyAction(ranking=["A", "B", "C", "X"])
|
| 125 |
+
example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
|
| 126 |
+
reward, info = grade_consistency(action, example)
|
| 127 |
+
assert reward == 0.0
|
| 128 |
+
assert info["has_invalid_ids"] is True
|
| 129 |
+
|
| 130 |
+
def test_reward_always_in_range(self):
|
| 131 |
+
import itertools
|
| 132 |
+
import random
|
| 133 |
+
ids = ["A", "B", "C", "D"]
|
| 134 |
+
gold = ["A", "B", "C", "D"]
|
| 135 |
+
for perm in itertools.permutations(ids):
|
| 136 |
+
action = ConsistencyAction(ranking=list(perm))
|
| 137 |
+
example = {"gold_ranking": gold, "source": "test"}
|
| 138 |
+
reward, _ = grade_consistency(action, example)
|
| 139 |
+
assert 0.0 <= reward <= 1.0, f"Reward out of range: {reward} for {perm}"
|
| 140 |
+
|
| 141 |
+
def test_graders_not_always_same_score(self):
|
| 142 |
+
"""Critical: graders must NOT always return the same score."""
|
| 143 |
+
action_correct = ConsistencyAction(ranking=["A", "B", "C", "D"])
|
| 144 |
+
action_wrong = ConsistencyAction(ranking=["D", "C", "B", "A"])
|
| 145 |
+
example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
|
| 146 |
+
r1, _ = grade_consistency(action_correct, example)
|
| 147 |
+
r2, _ = grade_consistency(action_wrong, example)
|
| 148 |
+
assert r1 != r2, "Grader must return different scores for different inputs!"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ββ Environment integration tests βββββββββββββββββββββββββββββ
|
| 152 |
+
|
| 153 |
+
class TestPreferenceLabEnvironment:
|
| 154 |
+
def setup_method(self):
|
| 155 |
+
self.env = PreferenceLabEnvironment()
|
| 156 |
+
|
| 157 |
+
def test_reset_returns_observation(self):
|
| 158 |
+
obs = self.env.reset()
|
| 159 |
+
assert obs is not None
|
| 160 |
+
assert hasattr(obs, "prompt")
|
| 161 |
+
assert hasattr(obs, "reward")
|
| 162 |
+
assert hasattr(obs, "done")
|
| 163 |
+
|
| 164 |
+
def test_reset_pairwise_returns_pairwise_obs(self):
|
| 165 |
+
obs = self.env.reset(task_type="pairwise")
|
| 166 |
+
assert isinstance(obs, PairwiseObservation)
|
| 167 |
+
assert obs.response_a != ""
|
| 168 |
+
assert obs.response_b != ""
|
| 169 |
+
|
| 170 |
+
def test_reset_likert_returns_likert_obs(self):
|
| 171 |
+
obs = self.env.reset(task_type="likert")
|
| 172 |
+
assert isinstance(obs, LikertObservation)
|
| 173 |
+
assert obs.response != ""
|
| 174 |
+
assert obs.rubric != ""
|
| 175 |
+
|
| 176 |
+
def test_reset_consistency_returns_consistency_obs(self):
|
| 177 |
+
obs = self.env.reset(task_type="consistency")
|
| 178 |
+
assert isinstance(obs, ConsistencyObservation)
|
| 179 |
+
assert obs.response_a != ""
|
| 180 |
+
assert obs.response_d != ""
|
| 181 |
+
|
| 182 |
+
def test_step_pairwise(self):
|
| 183 |
+
self.env.reset(task_type="pairwise")
|
| 184 |
+
action = PairwiseAction(choice="A")
|
| 185 |
+
obs, reward, done, info = self.env.step(action)
|
| 186 |
+
assert isinstance(obs, PairwiseObservation)
|
| 187 |
+
assert 0.0 <= reward <= 1.0
|
| 188 |
+
assert isinstance(done, bool)
|
| 189 |
+
|
| 190 |
+
def test_step_likert(self):
|
| 191 |
+
self.env.reset(task_type="likert")
|
| 192 |
+
action = LikertAction(helpfulness=4, honesty=4, harmlessness=5, instruction_following=4)
|
| 193 |
+
obs, reward, done, info = self.env.step(action)
|
| 194 |
+
assert isinstance(obs, LikertObservation)
|
| 195 |
+
assert 0.0 <= reward <= 1.0
|
| 196 |
+
|
| 197 |
+
def test_step_consistency(self):
|
| 198 |
+
self.env.reset(task_type="consistency")
|
| 199 |
+
action = ConsistencyAction(ranking=["A", "B", "C", "D"])
|
| 200 |
+
obs, reward, done, info = self.env.step(action)
|
| 201 |
+
assert isinstance(obs, ConsistencyObservation)
|
| 202 |
+
assert 0.0 <= reward <= 1.0
|
| 203 |
+
|
| 204 |
+
def test_episode_terminates_after_max_steps(self):
|
| 205 |
+
self.env.reset(task_type="pairwise")
|
| 206 |
+
done = False
|
| 207 |
+
steps = 0
|
| 208 |
+
while not done:
|
| 209 |
+
_, _, done, _ = self.env.step(PairwiseAction(choice="A"))
|
| 210 |
+
steps += 1
|
| 211 |
+
assert steps <= 10, "Episode ran too long!"
|
| 212 |
+
assert done is True
|
| 213 |
+
|
| 214 |
+
def test_state_returns_metadata(self):
|
| 215 |
+
self.env.reset(seed=42, task_type="pairwise")
|
| 216 |
+
state = self.env.state()
|
| 217 |
+
assert "episode_id" in state
|
| 218 |
+
assert "step_count" in state
|
| 219 |
+
assert "task_type" in state
|
| 220 |
+
assert state["seed"] == 42
|
| 221 |
+
|
| 222 |
+
def test_reproducible_with_seed(self):
|
| 223 |
+
obs1 = self.env.reset(seed=123, task_type="pairwise")
|
| 224 |
+
obs2 = self.env.reset(seed=123, task_type="pairwise")
|
| 225 |
+
assert obs1.prompt == obs2.prompt
|
| 226 |
+
assert obs1.response_a == obs2.response_a
|
| 227 |
+
|
| 228 |
+
def test_rewards_vary_across_actions(self):
|
| 229 |
+
"""Ensure graders do NOT always return the same reward (disqualification check)."""
|
| 230 |
+
rewards = set()
|
| 231 |
+
for _ in range(5):
|
| 232 |
+
self.env.reset(task_type="pairwise")
|
| 233 |
+
action_a = PairwiseAction(choice="A")
|
| 234 |
+
_, r1, _, _ = self.env.step(action_a)
|
| 235 |
+
self.env.reset(task_type="pairwise")
|
| 236 |
+
action_b = PairwiseAction(choice="B")
|
| 237 |
+
_, r2, _, _ = self.env.step(action_b)
|
| 238 |
+
rewards.add(r1)
|
| 239 |
+
rewards.add(r2)
|
| 240 |
+
assert len(rewards) > 1, "Grader always returns the same score β DISQUALIFICATION!"
|
| 241 |
+
|
| 242 |
+
def test_all_three_tasks_run(self):
|
| 243 |
+
for task in ["pairwise", "likert", "consistency"]:
|
| 244 |
+
obs = self.env.reset(task_type=task)
|
| 245 |
+
assert obs is not None
|
| 246 |
+
state = self.env.state()
|
| 247 |
+
assert state["task_type"] == task
|