ps2181 Claude Sonnet 4.6 commited on
Commit
0bf71ce
·
1 Parent(s): 347eb5c

Add full invoice processing pipeline environment

Browse files

- FastAPI server with /reset, /step, /state, /health, /tasks, /grader endpoints
- 3 tasks: easy (extraction), medium (batch cleaning), hard (PO reconciliation)
- Pydantic models, OpenEnv spec (openenv.yaml), partial-credit graders
- Baseline inference.py scoring easy:1.0, medium:1.0, hard:0.895 (avg 0.965)
- Dockerfile for HF Spaces (non-root UID 1000, port 7860)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (13) hide show
  1. .gitignore +4 -0
  2. Dockerfile +23 -0
  3. README.md +267 -6
  4. __init__.py +5 -0
  5. client.py +106 -0
  6. inference.py +332 -0
  7. models.py +71 -0
  8. openenv.yaml +45 -0
  9. pyproject.toml +0 -0
  10. requirements.txt +5 -0
  11. server/__init__.py +1 -0
  12. server/app.py +158 -0
  13. server/environment.py +638 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # HF Spaces requires a non-root user with UID 1000
4
+ RUN useradd -m -u 1000 user
5
+
6
+ WORKDIR /app
7
+
8
+ # Install dependencies first (layer caching)
9
+ COPY --chown=user requirements.txt .
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ # Copy application code
13
+ COPY --chown=user . /app
14
+
15
+ # Switch to non-root user
16
+ USER user
17
+ ENV HOME=/home/user \
18
+ PATH=/home/user/.local/bin:$PATH
19
+
20
+ # HF Spaces default port
21
+ EXPOSE 7860
22
+
23
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,272 @@
1
  ---
2
  title: Invoice Processing Pipeline
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: docker
7
- pinned: false
8
- short_description: openenv
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Invoice Processing Pipeline
3
+ emoji: 🧾
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
+ tags:
9
+ - openenv
10
  ---
11
 
12
+ # Invoice Processing Pipeline OpenEnv Environment
13
+
14
+ An OpenEnv environment where an AI agent learns to **extract**, **clean**, and **reconcile** invoice data — a task that mirrors real-world accounts-payable workflows affecting every business.
15
+
16
+ The agent receives raw invoice text (simulating OCR output or messy CSV imports), processes it into structured data, and receives graded scores (0.0–1.0) with detailed feedback at every step.
17
+
18
+ ---
19
+
20
+ ## Motivation
21
+
22
+ Invoice processing is one of the most common, tedious, and error-prone tasks in business operations. Finance teams spend countless hours:
23
+
24
+ - **Extracting** vendor names, dates, line items, and totals from unstructured documents
25
+ - **Cleaning** inconsistent formats (dates, currencies, vendor name variations)
26
+ - **Reconciling** invoices against purchase orders to catch overcharges, missing items, and billing errors
27
+
28
+ This environment provides a controlled, reproducible setting to train and evaluate AI agents on these tasks, with clear partial-credit signals that make it suitable for RL training.
29
+
30
+ ---
31
+
32
+ ## Project Structure
33
+
34
+ ```
35
+ invoice_processing_pipeline/
36
+ ├── models.py Pydantic models: InvoiceAction, InvoiceObservation, InvoiceState
37
+ ├── client.py Python client (sync + async) for training code
38
+ ├── inference.py LLM baseline agent (OpenAI-compatible)
39
+ ├── server/
40
+ │ ├── __init__.py
41
+ │ ├── environment.py Core logic: invoice generation, graders, reward computation
42
+ │ └── app.py FastAPI server with /reset, /step, /state endpoints
43
+ ├── openenv.yaml OpenEnv metadata
44
+ ├── Dockerfile Container build
45
+ ├── requirements.txt Python dependencies
46
+ ├── pyproject.toml Package configuration
47
+ └── README.md This file
48
+ ```
49
+
50
+ ---
51
+
52
+ ## Tasks
53
+
54
+ | Task | Difficulty | Description |
55
+ |------|-----------|-------------|
56
+ | `easy` | Easy | Extract structured fields from a **single, clean** invoice |
57
+ | `medium` | Medium | Clean and normalise a **batch of messy** invoices (3–5 invoices) |
58
+ | `hard` | Hard | Extract, clean, AND **reconcile against purchase orders** with discrepancy detection |
59
+
60
+ ### Easy: Single Invoice Extraction
61
+
62
+ The agent receives a well-formatted invoice with clear structure. It must extract: vendor name, date, currency, total, and all line items with descriptions, quantities, unit prices, and amounts.
63
+
64
+ ### Medium: Batch Invoice Cleaning
65
+
66
+ The agent receives 3–5 invoices with realistic messiness:
67
+ - **Date format chaos**: `01/15/2024`, `15-01-2024`, `January 15, 2024`, `15.01.2024`
68
+ - **Vendor name typos**: `"Acme Crp"`, `"GloablTech Solutions"`, `"Prmie Office Supplies"`
69
+ - **Mixed currency formats**: `$`, `€`, `£` symbols instead of `USD`, `EUR`, `GBP` codes
70
+ - **String/number mixing**: amounts like `"$149.99"` instead of `149.99`
71
+ - **Math errors**: `qty × unit_price ≠ amount` in some line items
72
+
73
+ ### Hard: Invoice-PO Reconciliation
74
+
75
+ The agent receives messy invoices PLUS purchase orders and must:
76
+ 1. Clean all invoice data (same as medium)
77
+ 2. Compare each invoice against its corresponding PO
78
+ 3. Flag discrepancies: overcharges, extra items, and missing items
79
+
80
+ ---
81
+
82
+ ## Observation Space
83
+
84
+ | Field | Type | Description |
85
+ |-------|------|-------------|
86
+ | `raw_text` | string | Raw invoice text (OCR-style or batch format) |
87
+ | `task_id` | string | `easy`, `medium`, or `hard` |
88
+ | `difficulty` | string | Same as `task_id` |
89
+ | `task_description` | string | What the agent should do |
90
+ | `attempt_number` | int | Current attempt (0 = just reset) |
91
+ | `max_attempts` | int | Maximum allowed attempts (5) |
92
+ | `feedback` | string | Detailed grader feedback from last attempt |
93
+ | `hint` | string | Appears after 2+ failed attempts |
94
+ | `reference_data` | string | Purchase order data (hard task only) |
95
+
96
+ ---
97
+
98
+ ## Action Space
99
+
100
+ | Field | Type | Required | Description |
101
+ |-------|------|----------|-------------|
102
+ | `extracted_data` | JSON object | Yes | Structured invoice data (format depends on task) |
103
+ | `explanation` | string | No | Agent reasoning (optional) |
104
+
105
+ ### Expected `extracted_data` format by task:
106
+
107
+ **Easy:**
108
+ ```json
109
+ {
110
+ "vendor": "Acme Corp",
111
+ "date": "2024-06-15",
112
+ "currency": "USD",
113
+ "total": 1249.95,
114
+ "line_items": [
115
+ {"description": "Laptop Computer", "qty": 1, "unit_price": 1099.99, "amount": 1099.99},
116
+ {"description": "Wireless Mouse", "qty": 5, "unit_price": 29.99, "amount": 149.95}
117
+ ]
118
+ }
119
+ ```
120
+
121
+ **Medium:**
122
+ ```json
123
+ {
124
+ "invoices": [
125
+ {"vendor": "...", "date": "YYYY-MM-DD", "currency": "USD", "total": 0.0, "line_items": [...]}
126
+ ]
127
+ }
128
+ ```
129
+
130
+ **Hard:**
131
+ ```json
132
+ {
133
+ "invoices": [...],
134
+ "discrepancies": [
135
+ {"invoice_idx": 0, "type": "overcharge", "item_description": "Laptop Computer", "detail": "Invoice price 1199.99 vs PO price 1099.99"}
136
+ ]
137
+ }
138
+ ```
139
+
140
+ ---
141
+
142
+ ## Reward Function
143
+
144
+ Rewards are provided at **every step** (not just terminal), giving agents a rich training signal.
145
+
146
+ ### Easy Task Scoring (0.0–1.0)
147
+
148
+ | Component | Weight | Condition |
149
+ |-----------|--------|-----------|
150
+ | Vendor name | 0.15 | Exact match (case-insensitive) |
151
+ | Date | 0.10 | Exact match (YYYY-MM-DD) |
152
+ | Currency | 0.05 | Exact match (3-letter code) |
153
+ | Total | 0.20 | Within ±0.01 |
154
+ | Line items | 0.50 | Per-item matching on description, qty, unit_price, amount |
155
+
156
+ ### Medium Task Scoring
157
+
158
+ Average of per-invoice scores using the Easy grading rubric across the full batch.
159
+
160
+ ### Hard Task Scoring
161
+
162
+ | Component | Weight |
163
+ |-----------|--------|
164
+ | Extraction + Cleaning | 60% (same as Medium grading) |
165
+ | Discrepancy Detection | 40% (precision + recall of flagged discrepancies) |
166
+
167
+ ### Attempt Penalty
168
+
169
+ If all 5 attempts are exhausted without reaching 95% score, a **0.85× multiplier** is applied to the final reward.
170
+
171
+ ---
172
+
173
+ ## Setup and Usage
174
+
175
+ ### Local Development
176
+
177
+ ```bash
178
+ # Clone the repository
179
+ git clone <your-repo-url>
180
+ cd invoice_processing_pipeline
181
+
182
+ # Install dependencies
183
+ pip install -r requirements.txt
184
+
185
+ # Start the server
186
+ uvicorn server.app:app --host 0.0.0.0 --port 7860 --reload
187
+
188
+ # Test with curl
189
+ curl http://localhost:7860/health
190
+ curl -X POST http://localhost:7860/reset -H "Content-Type: application/json" -d '{"task_id": "easy"}'
191
+ ```
192
+
193
+ ### Docker
194
+
195
+ ```bash
196
+ docker build -t invoice-env .
197
+ docker run -p 7860:7860 invoice-env
198
+ ```
199
+
200
+ ### Running the Baseline
201
+
202
+ ```bash
203
+ export HF_TOKEN=your_token_here
204
+ export API_BASE_URL=https://router.huggingface.co/v1
205
+ export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
206
+ export ENV_URL=http://localhost:7860
207
+
208
+ python inference.py
209
+ ```
210
+
211
+ ### Python Client
212
+
213
+ ```python
214
+ from client import InvoiceEnvClient
215
+
216
+ with InvoiceEnvClient("http://localhost:7860") as env:
217
+ result = env.reset(task_id="easy")
218
+ print(result["observation"]["raw_text"])
219
+
220
+ result = env.step({
221
+ "vendor": "Acme Corp",
222
+ "date": "2024-06-15",
223
+ "currency": "USD",
224
+ "total": 1249.95,
225
+ "line_items": [...]
226
+ })
227
+ print(f"Score: {result['reward']}")
228
+ print(f"Feedback: {result['observation']['feedback']}")
229
+ ```
230
+
231
+ ---
232
+
233
+ ## API Endpoints
234
+
235
+ | Endpoint | Method | Description |
236
+ |----------|--------|-------------|
237
+ | `/reset` | POST | Start a new episode (`{"task_id": "easy\|medium\|hard"}`) |
238
+ | `/step` | POST | Submit extracted data, get reward + feedback |
239
+ | `/state` | GET | Get current episode metadata |
240
+ | `/tasks` | GET | List all tasks with schemas |
241
+ | `/grader` | POST | Score a submission without modifying state |
242
+ | `/health` | GET | Health check |
243
+ | `/docs` | GET | Swagger API docs |
244
+
245
+ ---
246
+
247
+ ## Baseline Scores
248
+
249
+ | Agent | Easy | Medium | Hard | Average |
250
+ |-------|------|--------|------|---------|
251
+ | Oracle (ground truth) | 1.00 | 1.00 | 1.00 | 1.00 |
252
+ | Qwen2.5-72B-Instruct | ~0.90 | ~0.65 | ~0.45 | ~0.67 |
253
+ | Random (empty JSON) | 0.00 | 0.00 | 0.00 | 0.00 |
254
+
255
+ *Scores are approximate and may vary due to random invoice generation.*
256
+
257
+ ---
258
+
259
+ ## Design Decisions
260
+
261
+ - **Synthetic data generation**: Every episode creates fresh invoices, preventing memorisation and ensuring reproducibility via random seeds.
262
+ - **Partial credit at every step**: The grader scores each component independently (vendor, date, line items, etc.), giving agents fine-grained reward signal.
263
+ - **Progressive difficulty**: Easy tests pure extraction, Medium adds data quality issues, Hard adds cross-document reasoning.
264
+ - **Realistic noise**: Vendor typos, date format variations, and currency symbol mixing are modelled after actual OCR and data entry errors.
265
+ - **Attempt-based penalty**: Encourages agents to get it right early rather than brute-forcing over many attempts.
266
+
267
+ ---
268
+
269
+ ## Links
270
+
271
+ - OpenEnv GitHub: https://github.com/meta-pytorch/OpenEnv
272
+ - Hugging Face Environment Hub: https://huggingface.co/openenv
__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Invoice Processing Pipeline — OpenEnv Environment."""
2
+
3
+ from models import InvoiceAction, InvoiceObservation, InvoiceState
4
+
5
+ __all__ = ["InvoiceAction", "InvoiceObservation", "InvoiceState"]
client.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Python client for the Invoice Processing Pipeline environment.
3
+
4
+ Usage:
5
+ from client import InvoiceEnvClient
6
+ from models import InvoiceAction
7
+
8
+ client = InvoiceEnvClient(base_url="http://localhost:7860")
9
+ result = client.reset(task_id="easy")
10
+ print(result["observation"]["raw_text"])
11
+
12
+ result = client.step({"vendor": "Acme Corp", "date": "2024-06-15", ...})
13
+ print(result["reward"])
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, Dict, Optional
19
+
20
+ import httpx
21
+
22
+
23
+ class InvoiceEnvClient:
24
+ """Synchronous HTTP client for the Invoice Processing Pipeline."""
25
+
26
+ def __init__(self, base_url: str = "http://localhost:7860", timeout: float = 30.0):
27
+ self.base_url = base_url.rstrip("/")
28
+ self._client = httpx.Client(timeout=timeout)
29
+
30
+ def reset(self, task_id: str = "easy") -> Dict[str, Any]:
31
+ """Reset the environment for a new episode."""
32
+ resp = self._client.post(f"{self.base_url}/reset", json={"task_id": task_id})
33
+ resp.raise_for_status()
34
+ return resp.json()
35
+
36
+ def step(self, extracted_data: Dict[str, Any], explanation: str = "") -> Dict[str, Any]:
37
+ """Submit extracted/cleaned data and get reward + feedback."""
38
+ resp = self._client.post(
39
+ f"{self.base_url}/step",
40
+ json={"extracted_data": extracted_data, "explanation": explanation},
41
+ )
42
+ resp.raise_for_status()
43
+ return resp.json()
44
+
45
+ def state(self) -> Dict[str, Any]:
46
+ """Get current episode state."""
47
+ resp = self._client.get(f"{self.base_url}/state")
48
+ resp.raise_for_status()
49
+ return resp.json()
50
+
51
+ def tasks(self) -> Dict[str, Any]:
52
+ """List available tasks and schemas."""
53
+ resp = self._client.get(f"{self.base_url}/tasks")
54
+ resp.raise_for_status()
55
+ return resp.json()
56
+
57
+ def health(self) -> Dict[str, Any]:
58
+ """Check server health."""
59
+ resp = self._client.get(f"{self.base_url}/health")
60
+ resp.raise_for_status()
61
+ return resp.json()
62
+
63
+ def close(self):
64
+ """Close the HTTP client."""
65
+ self._client.close()
66
+
67
+ def __enter__(self):
68
+ return self
69
+
70
+ def __exit__(self, *args):
71
+ self.close()
72
+
73
+
74
+ class AsyncInvoiceEnvClient:
75
+ """Async HTTP client for the Invoice Processing Pipeline."""
76
+
77
+ def __init__(self, base_url: str = "http://localhost:7860", timeout: float = 30.0):
78
+ self.base_url = base_url.rstrip("/")
79
+ self._client = httpx.AsyncClient(timeout=timeout)
80
+
81
+ async def reset(self, task_id: str = "easy") -> Dict[str, Any]:
82
+ resp = await self._client.post(f"{self.base_url}/reset", json={"task_id": task_id})
83
+ resp.raise_for_status()
84
+ return resp.json()
85
+
86
+ async def step(self, extracted_data: Dict[str, Any], explanation: str = "") -> Dict[str, Any]:
87
+ resp = await self._client.post(
88
+ f"{self.base_url}/step",
89
+ json={"extracted_data": extracted_data, "explanation": explanation},
90
+ )
91
+ resp.raise_for_status()
92
+ return resp.json()
93
+
94
+ async def state(self) -> Dict[str, Any]:
95
+ resp = await self._client.get(f"{self.base_url}/state")
96
+ resp.raise_for_status()
97
+ return resp.json()
98
+
99
+ async def close(self):
100
+ await self._client.aclose()
101
+
102
+ async def __aenter__(self):
103
+ return self
104
+
105
+ async def __aexit__(self, *args):
106
+ await self.close()
inference.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script — Invoice Processing Pipeline
3
+ ================================================
4
+ Runs an LLM agent against all 3 tasks (easy, medium, hard) and produces
5
+ structured stdout logs in the mandatory [START]/[STEP]/[END] format.
6
+
7
+ Environment variables:
8
+ API_BASE_URL LLM endpoint (default: HF router)
9
+ MODEL_NAME Model identifier
10
+ HF_TOKEN API key
11
+ """
12
+
13
+ import json
14
+ import os
15
+ import textwrap
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ import httpx
19
+ from dotenv import load_dotenv
20
+ from openai import OpenAI
21
+
22
+ load_dotenv()
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Configuration
26
+ # ---------------------------------------------------------------------------
27
+
28
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
29
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
30
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
31
+
32
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
33
+ BENCHMARK = "invoice_processing_pipeline"
34
+ MAX_STEPS = 5
35
+ TEMPERATURE = 0.3
36
+ MAX_TOKENS = 2048
37
+ SUCCESS_THRESHOLD = 0.5
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Logging helpers (mandatory format)
42
+ # ---------------------------------------------------------------------------
43
+
44
+ def log_start(task: str, env: str, model: str) -> None:
45
+ print(f"[START] task={task} env={env} model={model}", flush=True)
46
+
47
+
48
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
49
+ error_val = error if error else "null"
50
+ done_val = str(done).lower()
51
+ # Truncate action for readability
52
+ action_short = action[:200].replace("\n", " ") if action else "null"
53
+ print(
54
+ f"[STEP] step={step} action={action_short} reward={reward:.2f} done={done_val} error={error_val}",
55
+ flush=True,
56
+ )
57
+
58
+
59
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
60
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
61
+ print(
62
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
63
+ flush=True,
64
+ )
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # System prompts per task
69
+ # ---------------------------------------------------------------------------
70
+
71
+ SYSTEM_PROMPTS = {
72
+ "easy": textwrap.dedent("""
73
+ You are an invoice data extraction agent. You receive raw invoice text and must
74
+ extract structured data from it.
75
+
76
+ RESPOND WITH ONLY A VALID JSON OBJECT (no markdown, no explanation, no backticks).
77
+
78
+ Required JSON structure:
79
+ {
80
+ "vendor": "string",
81
+ "date": "YYYY-MM-DD",
82
+ "currency": "USD|EUR|GBP",
83
+ "total": number,
84
+ "line_items": [
85
+ {"description": "string", "qty": integer, "unit_price": number, "amount": number}
86
+ ]
87
+ }
88
+
89
+ Rules:
90
+ - Date must be in YYYY-MM-DD format
91
+ - Currency must be a 3-letter code (USD, EUR, GBP)
92
+ - Total and amounts must be numbers, not strings
93
+ - Include ALL line items from the invoice
94
+ - amount = qty * unit_price
95
+ """).strip(),
96
+
97
+ "medium": textwrap.dedent("""
98
+ You are an invoice data cleaning agent. You receive a batch of messy invoices
99
+ and must clean and normalise them.
100
+
101
+ RESPOND WITH ONLY A VALID JSON OBJECT (no markdown, no explanation, no backticks).
102
+
103
+ Required JSON structure:
104
+ {
105
+ "invoices": [
106
+ {
107
+ "vendor": "corrected vendor name",
108
+ "date": "YYYY-MM-DD",
109
+ "currency": "USD|EUR|GBP",
110
+ "total": number,
111
+ "line_items": [
112
+ {"description": "string", "qty": integer, "unit_price": number, "amount": number}
113
+ ]
114
+ }
115
+ ]
116
+ }
117
+
118
+ Cleaning rules:
119
+ - Fix vendor name typos (e.g. "Acme Crp" -> "Acme Corp")
120
+ - Normalise dates to YYYY-MM-DD
121
+ - Convert currency symbols ($, €, £) to codes (USD, EUR, GBP)
122
+ - Strip currency symbols from amounts and ensure they are numbers
123
+ - Verify line item math: amount = qty * unit_price. If wrong, recalculate amount.
124
+ - Recalculate totals as sum of line item amounts
125
+ """).strip(),
126
+
127
+ "hard": textwrap.dedent("""
128
+ You are an invoice reconciliation agent. You receive messy invoices AND purchase
129
+ orders. You must clean the invoices AND identify discrepancies between invoices
130
+ and their corresponding purchase orders.
131
+
132
+ RESPOND WITH ONLY A VALID JSON OBJECT (no markdown, no explanation, no backticks).
133
+
134
+ Required JSON structure:
135
+ {
136
+ "invoices": [
137
+ {
138
+ "vendor": "corrected name",
139
+ "date": "YYYY-MM-DD",
140
+ "currency": "USD|EUR|GBP",
141
+ "total": number,
142
+ "line_items": [
143
+ {"description": "string", "qty": integer, "unit_price": number, "amount": number}
144
+ ]
145
+ }
146
+ ],
147
+ "discrepancies": [
148
+ {
149
+ "invoice_idx": 0,
150
+ "type": "overcharge|extra_item|missing_item",
151
+ "item_description": "string",
152
+ "detail": "description of the discrepancy"
153
+ }
154
+ ]
155
+ }
156
+
157
+ Discrepancy types:
158
+ - "overcharge": invoice unit_price > PO unit_price for same item
159
+ - "extra_item": item on invoice but not on PO
160
+ - "missing_item": item on PO but not on invoice
161
+
162
+ Also apply all cleaning rules: fix vendor names, normalise dates, convert currencies, fix amounts.
163
+ """).strip(),
164
+ }
165
+
166
+
167
+ # ---------------------------------------------------------------------------
168
+ # Agent logic
169
+ # ---------------------------------------------------------------------------
170
+
171
+ def build_user_prompt(task_id: str, observation: Dict[str, Any], step: int) -> str:
172
+ """Build the user prompt from the observation."""
173
+ parts = [f"Step {step} of {observation['max_attempts']}"]
174
+
175
+ if observation.get("feedback"):
176
+ parts.append(f"\nFeedback from previous attempt:\n{observation['feedback']}")
177
+
178
+ if observation.get("hint"):
179
+ parts.append(f"\nHint: {observation['hint']}")
180
+
181
+ parts.append(f"\nTask: {observation['task_description']}")
182
+ parts.append(f"\n--- RAW INVOICE DATA ---\n{observation['raw_text']}")
183
+
184
+ if observation.get("reference_data"):
185
+ parts.append(f"\n--- PURCHASE ORDER DATA ---\n{observation['reference_data']}")
186
+
187
+ parts.append("\nExtract/clean the data and respond with ONLY valid JSON:")
188
+
189
+ return "\n".join(parts)
190
+
191
+
192
+ def get_model_response(client: OpenAI, task_id: str, observation: Dict[str, Any], step: int) -> Dict[str, Any]:
193
+ """Call the LLM and parse its JSON response."""
194
+ user_prompt = build_user_prompt(task_id, observation, step)
195
+
196
+ try:
197
+ completion = client.chat.completions.create(
198
+ model=MODEL_NAME,
199
+ messages=[
200
+ {"role": "system", "content": SYSTEM_PROMPTS[task_id]},
201
+ {"role": "user", "content": user_prompt},
202
+ ],
203
+ temperature=TEMPERATURE,
204
+ max_tokens=MAX_TOKENS,
205
+ stream=False,
206
+ )
207
+ raw = (completion.choices[0].message.content or "").strip()
208
+
209
+ # Strip markdown code fences if present
210
+ if raw.startswith("```"):
211
+ raw = raw.split("\n", 1)[-1] if "\n" in raw else raw[3:]
212
+ if raw.endswith("```"):
213
+ raw = raw[:-3]
214
+ raw = raw.strip()
215
+
216
+ return json.loads(raw)
217
+
218
+ except json.JSONDecodeError as e:
219
+ print(f"[DEBUG] JSON parse error: {e}", flush=True)
220
+ print(f"[DEBUG] Raw response: {raw[:500]}", flush=True)
221
+ return {}
222
+ except Exception as e:
223
+ print(f"[DEBUG] Model request failed: {e}", flush=True)
224
+ return {}
225
+
226
+
227
+ # ---------------------------------------------------------------------------
228
+ # Environment HTTP client
229
+ # ---------------------------------------------------------------------------
230
+
231
+ class EnvClient:
232
+ """Simple HTTP client for the Invoice Processing Pipeline environment."""
233
+
234
+ def __init__(self, base_url: str):
235
+ self.base_url = base_url.rstrip("/")
236
+ self.client = httpx.Client(timeout=30.0)
237
+
238
+ def reset(self, task_id: str = "easy") -> Dict[str, Any]:
239
+ resp = self.client.post(f"{self.base_url}/reset", json={"task_id": task_id})
240
+ resp.raise_for_status()
241
+ return resp.json()
242
+
243
+ def step(self, extracted_data: Dict[str, Any], explanation: str = "") -> Dict[str, Any]:
244
+ resp = self.client.post(
245
+ f"{self.base_url}/step",
246
+ json={"extracted_data": extracted_data, "explanation": explanation},
247
+ )
248
+ resp.raise_for_status()
249
+ return resp.json()
250
+
251
+ def state(self) -> Dict[str, Any]:
252
+ resp = self.client.get(f"{self.base_url}/state")
253
+ resp.raise_for_status()
254
+ return resp.json()
255
+
256
+ def close(self):
257
+ self.client.close()
258
+
259
+
260
+ # ---------------------------------------------------------------------------
261
+ # Main
262
+ # ---------------------------------------------------------------------------
263
+
264
+ def run_task(client: OpenAI, env: EnvClient, task_id: str) -> float:
265
+ """Run a single task and return the final score."""
266
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
267
+
268
+ rewards: List[float] = []
269
+ steps_taken = 0
270
+ score = 0.0
271
+ success = False
272
+
273
+ try:
274
+ result = env.reset(task_id=task_id)
275
+ observation = result["observation"]
276
+
277
+ for step in range(1, MAX_STEPS + 1):
278
+ if result.get("done", False):
279
+ break
280
+
281
+ extracted = get_model_response(client, task_id, observation, step)
282
+ action_str = json.dumps(extracted)[:200]
283
+
284
+ result = env.step(extracted_data=extracted)
285
+ observation = result["observation"]
286
+ reward = result.get("reward", 0.0)
287
+ done = result.get("done", False)
288
+
289
+ rewards.append(reward)
290
+ steps_taken = step
291
+
292
+ log_step(step=step, action=action_str, reward=reward, done=done, error=None)
293
+
294
+ if done:
295
+ break
296
+
297
+ score = max(rewards) if rewards else 0.0
298
+ success = score >= SUCCESS_THRESHOLD
299
+
300
+ except Exception as e:
301
+ print(f"[DEBUG] Task {task_id} error: {e}", flush=True)
302
+ log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(e))
303
+
304
+ finally:
305
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
306
+
307
+ return score
308
+
309
+
310
+ def main() -> None:
311
+ """Run all 3 tasks and report scores."""
312
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
313
+ env = EnvClient(ENV_URL)
314
+
315
+ scores = {}
316
+ try:
317
+ for task_id in ["easy", "medium", "hard"]:
318
+ scores[task_id] = run_task(client, env, task_id)
319
+ print(flush=True)
320
+
321
+ avg = sum(scores.values()) / len(scores) if scores else 0.0
322
+ print(f"\n=== BASELINE SCORES ===", flush=True)
323
+ for tid, sc in scores.items():
324
+ print(f" {tid}: {sc:.3f}", flush=True)
325
+ print(f" average: {avg:.3f}", flush=True)
326
+
327
+ finally:
328
+ env.close()
329
+
330
+
331
+ if __name__ == "__main__":
332
+ main()
models.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for the Invoice Processing Pipeline environment.
3
+
4
+ Action: Agent submits extracted/cleaned/reconciled invoice data as JSON.
5
+ Observation: Agent receives raw invoice text, feedback, and task context.
6
+ State: Tracks episode progress, attempts, and scores.
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Action
16
+ # ---------------------------------------------------------------------------
17
+
18
+ class InvoiceAction(BaseModel):
19
+ """Action the agent submits each step."""
20
+
21
+ extracted_data: Dict[str, Any] = Field(
22
+ ...,
23
+ description=(
24
+ "JSON object with extracted/cleaned invoice fields. "
25
+ "Structure depends on the task. "
26
+ "Easy: {vendor, date, currency, total, line_items: [{description, qty, unit_price, amount}]}. "
27
+ "Medium: {invoices: [{vendor, date, currency, total, line_items}]} (batch of cleaned invoices). "
28
+ "Hard: {invoices: [...], discrepancies: [{invoice_idx, type, detail, expected, actual}]}."
29
+ ),
30
+ )
31
+ explanation: str = Field(
32
+ default="",
33
+ description="Optional reasoning about extraction or cleaning decisions.",
34
+ )
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Observation
39
+ # ---------------------------------------------------------------------------
40
+
41
+ class InvoiceObservation(BaseModel):
42
+ """What the agent sees each turn."""
43
+
44
+ raw_text: str = Field(..., description="Raw invoice text (OCR-style or CSV-style)")
45
+ task_id: str = Field(..., description="easy | medium | hard")
46
+ difficulty: str = Field(..., description="Same as task_id")
47
+ task_description: str = Field(..., description="What the agent should do")
48
+ attempt_number: int = Field(default=0, description="Current attempt (0 = just reset)")
49
+ max_attempts: int = Field(default=5, description="Max allowed attempts")
50
+ feedback: str = Field(default="", description="Detailed grader feedback from last attempt")
51
+ hint: str = Field(default="", description="Hint shown after 2+ failed attempts")
52
+ reference_data: str = Field(
53
+ default="",
54
+ description="For hard task: purchase order data to reconcile against",
55
+ )
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # State
60
+ # ---------------------------------------------------------------------------
61
+
62
+ class InvoiceState(BaseModel):
63
+ """Internal episode state."""
64
+
65
+ episode_id: str = Field(default="")
66
+ task_id: str = Field(default="easy")
67
+ step_count: int = Field(default=0)
68
+ done: bool = Field(default=False)
69
+ last_reward: float = Field(default=0.0)
70
+ best_reward: float = Field(default=0.0)
71
+ rewards: List[float] = Field(default_factory=list)
openenv.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: invoice_processing_pipeline
2
+ version: "1.0.0"
3
+ description: >
4
+ An OpenEnv environment for training AI agents on real-world invoice processing:
5
+ data extraction from OCR text, batch cleaning & normalisation, and
6
+ reconciliation against purchase orders with discrepancy detection.
7
+
8
+ author: "OpenEnv Challenge Submission"
9
+ license: "MIT"
10
+
11
+ tags:
12
+ - openenv
13
+ - invoice
14
+ - data-extraction
15
+ - data-cleaning
16
+ - reconciliation
17
+ - finance
18
+
19
+ environment:
20
+ module: server.app
21
+ class: InvoiceEnvironment
22
+ action: models.InvoiceAction
23
+ observation: models.InvoiceObservation
24
+
25
+ tasks:
26
+ - id: easy
27
+ name: "Single Invoice Extraction"
28
+ description: "Extract structured fields (vendor, date, currency, total, line items) from a single invoice."
29
+ difficulty: easy
30
+
31
+ - id: medium
32
+ name: "Batch Invoice Cleaning"
33
+ description: "Clean and normalise a batch of messy invoices: fix dates, vendor typos, currency codes, and amounts."
34
+ difficulty: medium
35
+
36
+ - id: hard
37
+ name: "Invoice-PO Reconciliation"
38
+ description: "Extract, clean, and reconcile invoices against purchase orders. Flag overcharges, extra items, and missing items."
39
+ difficulty: hard
40
+
41
+ endpoints:
42
+ reset: /reset
43
+ step: /step
44
+ state: /state
45
+ health: /health
pyproject.toml ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi>=0.104.0
2
+ uvicorn[standard]>=0.24.0
3
+ pydantic>=2.5.0
4
+ httpx>=0.25.0
5
+ openai>=1.0.0
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Server package for Invoice Processing Pipeline."""
server/app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server for Invoice Processing Pipeline environment.
3
+ Exposes /reset, /step, /state, /health, /tasks, /grader endpoints.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ from typing import Any, Dict, Optional
10
+
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel
13
+
14
+ import sys
15
+ import os
16
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+
18
+ from models import InvoiceAction, InvoiceObservation, InvoiceState
19
+ from server.environment import InvoiceEnvironment
20
+
21
+ app = FastAPI(
22
+ title="Invoice Processing Pipeline",
23
+ description="OpenEnv environment for invoice data extraction, cleaning, and reconciliation.",
24
+ version="1.0.0",
25
+ )
26
+
27
+ # Single environment instance (one episode at a time for the HF Space)
28
+ env = InvoiceEnvironment()
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Request / Response schemas
33
+ # ---------------------------------------------------------------------------
34
+
35
+ class ResetRequest(BaseModel):
36
+ task_id: str = "easy"
37
+
38
+ class StepRequest(BaseModel):
39
+ extracted_data: Dict[str, Any]
40
+ explanation: str = ""
41
+
42
+ class ResetResponse(BaseModel):
43
+ observation: Dict[str, Any]
44
+ reward: float
45
+ done: bool
46
+ info: Dict[str, Any]
47
+
48
+ class StepResponse(BaseModel):
49
+ observation: Dict[str, Any]
50
+ reward: float
51
+ done: bool
52
+ info: Dict[str, Any]
53
+
54
+ class StateResponse(BaseModel):
55
+ episode_id: str
56
+ task_id: str
57
+ step_count: int
58
+ done: bool
59
+ last_reward: float
60
+ best_reward: float
61
+ rewards: list
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Endpoints
66
+ # ---------------------------------------------------------------------------
67
+
68
+ @app.get("/health")
69
+ def health():
70
+ return {"status": "ok", "environment": "invoice_processing_pipeline"}
71
+
72
+
73
+ @app.get("/tasks")
74
+ def list_tasks():
75
+ """List available tasks with descriptions."""
76
+ tasks = []
77
+ for tid, info in InvoiceEnvironment.TASKS.items():
78
+ tasks.append({
79
+ "task_id": tid,
80
+ "description": info["description"],
81
+ "max_attempts": info["max_attempts"],
82
+ })
83
+ return {
84
+ "tasks": tasks,
85
+ "action_schema": InvoiceAction.model_json_schema(),
86
+ "observation_schema": InvoiceObservation.model_json_schema(),
87
+ }
88
+
89
+
90
+ @app.post("/reset")
91
+ def reset(req: ResetRequest = ResetRequest()):
92
+ obs, reward, done, info = env.reset(task_id=req.task_id)
93
+ return ResetResponse(
94
+ observation=obs.model_dump(),
95
+ reward=reward,
96
+ done=done,
97
+ info=info,
98
+ )
99
+
100
+
101
+ @app.post("/step")
102
+ def step(req: StepRequest):
103
+ if env.state.done:
104
+ raise HTTPException(status_code=400, detail="Episode is done. Call /reset first.")
105
+
106
+ action = InvoiceAction(
107
+ extracted_data=req.extracted_data,
108
+ explanation=req.explanation,
109
+ )
110
+ obs, reward, done, info = env.step(action)
111
+ return StepResponse(
112
+ observation=obs.model_dump(),
113
+ reward=reward,
114
+ done=done,
115
+ info=info,
116
+ )
117
+
118
+
119
+ @app.get("/state")
120
+ def get_state():
121
+ s = env.state
122
+ return StateResponse(
123
+ episode_id=s.episode_id,
124
+ task_id=s.task_id,
125
+ step_count=s.step_count,
126
+ done=s.done,
127
+ last_reward=s.last_reward,
128
+ best_reward=s.best_reward,
129
+ rewards=s.rewards,
130
+ )
131
+
132
+
133
+ @app.post("/grader")
134
+ def grader(req: StepRequest):
135
+ """Score a submission without modifying episode state (for testing)."""
136
+ import copy
137
+ saved_state = copy.deepcopy(env._state)
138
+ action = InvoiceAction(extracted_data=req.extracted_data, explanation=req.explanation)
139
+
140
+ task_id = env.state.task_id
141
+ if task_id == "easy":
142
+ from server.environment import _grade_easy
143
+ score, feedback = _grade_easy(action.extracted_data, env._ground_truth)
144
+ elif task_id == "medium":
145
+ from server.environment import _grade_medium
146
+ score, feedback = _grade_medium(action.extracted_data, env._ground_truth)
147
+ else:
148
+ from server.environment import _grade_hard
149
+ score, feedback = _grade_hard(
150
+ action.extracted_data, env._ground_truth, env._expected_discrepancies
151
+ )
152
+
153
+ return {"score": score, "feedback": feedback}
154
+
155
+
156
+ if __name__ == "__main__":
157
+ import uvicorn
158
+ uvicorn.run(app, host="0.0.0.0", port=7860)
server/environment.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Invoice Processing Pipeline — Core Environment
3
+
4
+ Three tasks:
5
+ easy — Extract structured fields from a single, relatively clean invoice.
6
+ medium — Clean & normalise a batch of messy invoices (date formats, vendor
7
+ name typos, currency symbols, duplicate detection).
8
+ hard — Extract, clean, AND reconcile against purchase orders; flag
9
+ mismatches, overcharges, and missing items.
10
+
11
+ Each episode generates fresh synthetic data so the agent cannot memorize.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import copy
17
+ import json
18
+ import random
19
+ import re
20
+ import string
21
+ import uuid
22
+ from datetime import date, timedelta
23
+ from typing import Any, Dict, List, Optional, Tuple
24
+
25
+ from models import InvoiceAction, InvoiceObservation, InvoiceState
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Helpers
29
+ # ---------------------------------------------------------------------------
30
+
31
+ VENDORS = [
32
+ "Acme Corp", "GlobalTech Solutions", "Prime Office Supplies",
33
+ "DataStream Inc", "CloudNine Services", "Metro Logistics",
34
+ "Pinnacle Electronics", "Summit Consulting", "Vertex Manufacturing",
35
+ "Horizon Digital", "NexGen Software", "BluePeak Analytics",
36
+ ]
37
+
38
+ ITEMS = [
39
+ ("Laptop Computer", 899.99, 1299.99),
40
+ ("Wireless Mouse", 19.99, 49.99),
41
+ ("USB-C Hub", 29.99, 79.99),
42
+ ("Monitor Stand", 39.99, 89.99),
43
+ ("Keyboard", 49.99, 149.99),
44
+ ("Webcam HD", 59.99, 129.99),
45
+ ("Desk Lamp", 24.99, 69.99),
46
+ ("Notebook Pack", 9.99, 29.99),
47
+ ("Printer Paper (Ream)", 7.99, 14.99),
48
+ ("Whiteboard Markers (Set)", 5.99, 12.99),
49
+ ("External SSD 1TB", 79.99, 149.99),
50
+ ("Headset", 39.99, 99.99),
51
+ ("Cable Management Kit", 14.99, 34.99),
52
+ ("Ergonomic Chair", 299.99, 599.99),
53
+ ("Standing Desk Converter", 199.99, 399.99),
54
+ ]
55
+
56
+ CURRENCIES = ["USD", "EUR", "GBP"]
57
+ CURRENCY_SYMBOLS = {"USD": "$", "EUR": "€", "GBP": "£"}
58
+
59
+
60
+ def _rand_date(start_year: int = 2024, end_year: int = 2025) -> date:
61
+ start = date(start_year, 1, 1)
62
+ end = date(end_year, 12, 31)
63
+ delta = (end - start).days
64
+ return start + timedelta(days=random.randint(0, delta))
65
+
66
+
67
+ def _format_date_clean(d: date) -> str:
68
+ return d.strftime("%Y-%m-%d")
69
+
70
+
71
+ def _format_date_messy(d: date) -> str:
72
+ """Return a randomly-chosen messy date format."""
73
+ formats = [
74
+ "%m/%d/%Y", "%d-%m-%Y", "%B %d, %Y", "%d %b %Y",
75
+ "%m-%d-%y", "%d.%m.%Y", "%Y/%m/%d",
76
+ ]
77
+ return d.strftime(random.choice(formats))
78
+
79
+
80
+ def _typo_vendor(name: str) -> str:
81
+ """Introduce a subtle typo into a vendor name."""
82
+ strategies = ["swap", "drop", "double", "case"]
83
+ strat = random.choice(strategies)
84
+ idx = random.randint(1, max(1, len(name) - 2))
85
+ if strat == "swap" and idx < len(name) - 1:
86
+ return name[:idx] + name[idx + 1] + name[idx] + name[idx + 2:]
87
+ elif strat == "drop":
88
+ return name[:idx] + name[idx + 1:]
89
+ elif strat == "double":
90
+ return name[:idx] + name[idx] + name[idx:]
91
+ else:
92
+ return name[:idx] + name[idx].swapcase() + name[idx + 1:]
93
+
94
+
95
+ def _generate_line_items(n: int) -> List[Dict[str, Any]]:
96
+ chosen = random.sample(ITEMS, min(n, len(ITEMS)))
97
+ items = []
98
+ for desc, lo, hi in chosen:
99
+ qty = random.randint(1, 20)
100
+ unit_price = round(random.uniform(lo, hi), 2)
101
+ amount = round(qty * unit_price, 2)
102
+ items.append({
103
+ "description": desc,
104
+ "qty": qty,
105
+ "unit_price": unit_price,
106
+ "amount": amount,
107
+ })
108
+ return items
109
+
110
+
111
+ def _generate_invoice(vendor: str | None = None, currency: str | None = None) -> Dict[str, Any]:
112
+ vendor = vendor or random.choice(VENDORS)
113
+ currency = currency or random.choice(CURRENCIES)
114
+ inv_date = _rand_date()
115
+ line_items = _generate_line_items(random.randint(2, 6))
116
+ total = round(sum(it["amount"] for it in line_items), 2)
117
+ return {
118
+ "invoice_id": f"INV-{random.randint(10000, 99999)}",
119
+ "vendor": vendor,
120
+ "date": _format_date_clean(inv_date),
121
+ "currency": currency,
122
+ "total": total,
123
+ "line_items": line_items,
124
+ }
125
+
126
+
127
+ # ===================================================================
128
+ # TASK: EASY — single invoice extraction
129
+ # ===================================================================
130
+
131
+ def _render_clean_invoice(inv: Dict[str, Any]) -> str:
132
+ """Render a single invoice as semi-structured text (OCR-style)."""
133
+ sym = CURRENCY_SYMBOLS.get(inv["currency"], "$")
134
+ lines = [
135
+ f"INVOICE",
136
+ f"-------",
137
+ f"Invoice #: {inv['invoice_id']}",
138
+ f"Vendor: {inv['vendor']}",
139
+ f"Date: {inv['date']}",
140
+ f"Currency: {inv['currency']}",
141
+ f"",
142
+ f"Items:",
143
+ f"{'Description':<30} {'Qty':>5} {'Unit Price':>12} {'Amount':>12}",
144
+ f"{'-'*30} {'-'*5} {'-'*12} {'-'*12}",
145
+ ]
146
+ for it in inv["line_items"]:
147
+ lines.append(
148
+ f"{it['description']:<30} {it['qty']:>5} {sym}{it['unit_price']:>10.2f} {sym}{it['amount']:>10.2f}"
149
+ )
150
+ lines.append(f"{'':>30} {'':>5} {'TOTAL':>12} {sym}{inv['total']:>10.2f}")
151
+ return "\n".join(lines)
152
+
153
+
154
+ def _grade_easy(submitted: Dict[str, Any], ground_truth: Dict[str, Any]) -> Tuple[float, str]:
155
+ """Grade single-invoice extraction. Returns (score, feedback)."""
156
+ score = 0.0
157
+ feedback_parts = []
158
+
159
+ # Vendor (0.15)
160
+ sub_vendor = submitted.get("vendor", "").strip()
161
+ if sub_vendor.lower() == ground_truth["vendor"].lower():
162
+ score += 0.15
163
+ feedback_parts.append("Vendor: correct")
164
+ else:
165
+ feedback_parts.append(f"Vendor: wrong (expected '{ground_truth['vendor']}', got '{sub_vendor}')")
166
+
167
+ # Date (0.10)
168
+ sub_date = submitted.get("date", "").strip()
169
+ if sub_date == ground_truth["date"]:
170
+ score += 0.10
171
+ feedback_parts.append("Date: correct")
172
+ else:
173
+ feedback_parts.append(f"Date: wrong (expected '{ground_truth['date']}', got '{sub_date}')")
174
+
175
+ # Currency (0.05)
176
+ sub_cur = submitted.get("currency", "").strip().upper()
177
+ if sub_cur == ground_truth["currency"]:
178
+ score += 0.05
179
+ feedback_parts.append("Currency: correct")
180
+ else:
181
+ feedback_parts.append(f"Currency: wrong (expected '{ground_truth['currency']}', got '{sub_cur}')")
182
+
183
+ # Total (0.20)
184
+ try:
185
+ sub_total = float(submitted.get("total", 0))
186
+ if abs(sub_total - ground_truth["total"]) < 0.01:
187
+ score += 0.20
188
+ feedback_parts.append("Total: correct")
189
+ else:
190
+ feedback_parts.append(f"Total: wrong (expected {ground_truth['total']}, got {sub_total})")
191
+ except (ValueError, TypeError):
192
+ feedback_parts.append("Total: could not parse")
193
+
194
+ # Line items (0.50)
195
+ sub_items = submitted.get("line_items", [])
196
+ gt_items = ground_truth["line_items"]
197
+ if not isinstance(sub_items, list):
198
+ feedback_parts.append("Line items: not a list")
199
+ else:
200
+ item_score = _grade_line_items(sub_items, gt_items)
201
+ score += item_score * 0.50
202
+ feedback_parts.append(f"Line items: {item_score:.0%} match ({len(sub_items)} submitted, {len(gt_items)} expected)")
203
+
204
+ return round(min(score, 1.0), 4), "; ".join(feedback_parts)
205
+
206
+
207
+ def _grade_line_items(submitted: List[Dict], expected: List[Dict]) -> float:
208
+ """Compare line items, return fraction matched (0-1)."""
209
+ if not expected:
210
+ return 1.0 if not submitted else 0.0
211
+
212
+ matched = 0
213
+ used = set()
214
+ for gt in expected:
215
+ best = -1
216
+ best_score = 0.0
217
+ for i, sub in enumerate(submitted):
218
+ if i in used:
219
+ continue
220
+ s = _item_similarity(sub, gt)
221
+ if s > best_score:
222
+ best_score = s
223
+ best = i
224
+ if best >= 0 and best_score > 0.3:
225
+ matched += best_score
226
+ used.add(best)
227
+
228
+ return matched / len(expected)
229
+
230
+
231
+ def _item_similarity(sub: Dict, gt: Dict) -> float:
232
+ """Score a single line item match (0-1)."""
233
+ s = 0.0
234
+ # description
235
+ sd = sub.get("description", "").lower().strip()
236
+ gd = gt["description"].lower().strip()
237
+ if sd == gd:
238
+ s += 0.25
239
+ elif sd in gd or gd in sd:
240
+ s += 0.15
241
+
242
+ # qty
243
+ try:
244
+ if int(sub.get("qty", -1)) == gt["qty"]:
245
+ s += 0.25
246
+ except (ValueError, TypeError):
247
+ pass
248
+
249
+ # unit_price
250
+ try:
251
+ if abs(float(sub.get("unit_price", -1)) - gt["unit_price"]) < 0.01:
252
+ s += 0.25
253
+ except (ValueError, TypeError):
254
+ pass
255
+
256
+ # amount
257
+ try:
258
+ if abs(float(sub.get("amount", -1)) - gt["amount"]) < 0.01:
259
+ s += 0.25
260
+ except (ValueError, TypeError):
261
+ pass
262
+
263
+ return s
264
+
265
+
266
+ # ===================================================================
267
+ # TASK: MEDIUM — batch cleaning & normalisation
268
+ # ===================================================================
269
+
270
+ def _make_messy_invoice(inv: Dict[str, Any]) -> Dict[str, Any]:
271
+ """Take a clean invoice dict and introduce messiness."""
272
+ messy = copy.deepcopy(inv)
273
+
274
+ # Messy date
275
+ d = date.fromisoformat(inv["date"])
276
+ messy["date"] = _format_date_messy(d)
277
+
278
+ # Possibly typo the vendor
279
+ if random.random() < 0.5:
280
+ messy["vendor"] = _typo_vendor(inv["vendor"])
281
+
282
+ # Mix currency symbol into amounts (remove currency field sometimes)
283
+ sym = CURRENCY_SYMBOLS.get(inv["currency"], "$")
284
+ if random.random() < 0.4:
285
+ messy["currency"] = sym # symbol instead of code
286
+ if random.random() < 0.3:
287
+ messy["total"] = f"{sym}{inv['total']}" # string instead of number
288
+
289
+ # Mess up some line item amounts
290
+ for it in messy["line_items"]:
291
+ if random.random() < 0.3:
292
+ it["amount"] = f"{sym}{it['amount']}"
293
+ if random.random() < 0.2:
294
+ it["unit_price"] = f"{sym}{it['unit_price']}"
295
+ if random.random() < 0.15:
296
+ # Wrong amount (qty * unit_price ≠ amount)
297
+ it["amount"] = round(it["qty"] * float(str(it["unit_price"]).replace(sym, "")) + random.uniform(0.5, 5.0), 2)
298
+
299
+ return messy
300
+
301
+
302
+ def _render_messy_batch(invoices: List[Dict[str, Any]]) -> str:
303
+ """Render a batch of messy invoices as CSV-ish text."""
304
+ lines = ["=== INVOICE BATCH (requires cleaning) ===", ""]
305
+ for i, inv in enumerate(invoices):
306
+ lines.append(f"--- Invoice {i+1} ---")
307
+ lines.append(f"Vendor: {inv['vendor']}")
308
+ lines.append(f"Date: {inv['date']}")
309
+ lines.append(f"Currency: {inv.get('currency', 'N/A')}")
310
+ lines.append(f"Total: {inv.get('total', 'N/A')}")
311
+ lines.append("Items:")
312
+ for it in inv["line_items"]:
313
+ lines.append(f" - {it['description']} | qty: {it.get('qty','?')} | price: {it.get('unit_price','?')} | amount: {it.get('amount','?')}")
314
+ lines.append("")
315
+ return "\n".join(lines)
316
+
317
+
318
+ def _grade_medium(submitted: Dict[str, Any], ground_truths: List[Dict[str, Any]]) -> Tuple[float, str]:
319
+ """Grade batch cleaning. submitted should have 'invoices' key."""
320
+ sub_invoices = submitted.get("invoices", [])
321
+ if not isinstance(sub_invoices, list):
322
+ return 0.0, "Expected 'invoices' key with a list of cleaned invoices."
323
+
324
+ n_expected = len(ground_truths)
325
+ if len(sub_invoices) != n_expected:
326
+ # Partial credit still possible
327
+ pass
328
+
329
+ total_score = 0.0
330
+ feedback_parts = []
331
+
332
+ for idx, gt in enumerate(ground_truths):
333
+ if idx < len(sub_invoices):
334
+ s, fb = _grade_easy(sub_invoices[idx], gt)
335
+ total_score += s
336
+ feedback_parts.append(f"Invoice {idx+1}: {s:.2f} ({fb})")
337
+ else:
338
+ feedback_parts.append(f"Invoice {idx+1}: missing")
339
+
340
+ # Penalise extra invoices
341
+ if len(sub_invoices) > n_expected:
342
+ feedback_parts.append(f"Extra invoices submitted: {len(sub_invoices) - n_expected}")
343
+
344
+ avg = total_score / n_expected if n_expected > 0 else 0.0
345
+ return round(min(avg, 1.0), 4), "; ".join(feedback_parts)
346
+
347
+
348
+ # ===================================================================
349
+ # TASK: HARD — extraction + cleaning + reconciliation against POs
350
+ # ===================================================================
351
+
352
+ def _generate_purchase_order(inv: Dict[str, Any]) -> Dict[str, Any]:
353
+ """Generate a PO that mostly matches the invoice but may differ."""
354
+ po = copy.deepcopy(inv)
355
+ po["po_id"] = f"PO-{random.randint(10000, 99999)}"
356
+
357
+ discrepancies = []
358
+
359
+ # Possibly change a price (overcharge)
360
+ if random.random() < 0.6 and po["line_items"]:
361
+ idx = random.randint(0, len(po["line_items"]) - 1)
362
+ original_price = po["line_items"][idx]["unit_price"]
363
+ # PO has the CORRECT price; invoice will be higher (overcharge)
364
+ overcharge = round(original_price * random.uniform(1.05, 1.25), 2)
365
+ discrepancies.append({
366
+ "type": "overcharge",
367
+ "item_description": po["line_items"][idx]["description"],
368
+ "po_price": original_price,
369
+ "invoice_price": overcharge,
370
+ })
371
+ # We'll modify the invoice later
372
+ inv["line_items"][idx]["unit_price"] = overcharge
373
+ inv["line_items"][idx]["amount"] = round(inv["line_items"][idx]["qty"] * overcharge, 2)
374
+
375
+ # Possibly add an extra item to invoice (not in PO)
376
+ if random.random() < 0.4:
377
+ extra = _generate_line_items(1)[0]
378
+ inv["line_items"].append(extra)
379
+ discrepancies.append({
380
+ "type": "extra_item",
381
+ "item_description": extra["description"],
382
+ "detail": "Item on invoice but not on purchase order",
383
+ })
384
+
385
+ # Possibly remove an item from invoice (missing from invoice)
386
+ if random.random() < 0.3 and len(po["line_items"]) > 2:
387
+ removed = po["line_items"].pop(random.randint(0, len(po["line_items"]) - 1))
388
+ discrepancies.append({
389
+ "type": "missing_item",
390
+ "item_description": removed["description"],
391
+ "detail": "Item on purchase order but not on invoice",
392
+ })
393
+
394
+ # Recalculate totals
395
+ inv["total"] = round(sum(it["amount"] for it in inv["line_items"]), 2)
396
+ po["total"] = round(sum(it["amount"] for it in po["line_items"]), 2)
397
+
398
+ return po, discrepancies
399
+
400
+
401
+ def _render_po(po: Dict[str, Any]) -> str:
402
+ """Render purchase order text."""
403
+ lines = [
404
+ f"PURCHASE ORDER: {po['po_id']}",
405
+ f"Vendor: {po['vendor']}",
406
+ f"Date: {po['date']}",
407
+ f"Currency: {po['currency']}",
408
+ f"",
409
+ "Ordered Items:",
410
+ ]
411
+ sym = CURRENCY_SYMBOLS.get(po["currency"], "$")
412
+ for it in po["line_items"]:
413
+ lines.append(f" - {it['description']} x{it['qty']} @ {sym}{it['unit_price']:.2f} = {sym}{it['amount']:.2f}")
414
+ lines.append(f"PO Total: {sym}{po['total']:.2f}")
415
+ return "\n".join(lines)
416
+
417
+
418
+ def _grade_hard(submitted: Dict[str, Any], ground_truths: List[Dict[str, Any]],
419
+ expected_discrepancies: List[List[Dict]]) -> Tuple[float, str]:
420
+ """Grade extraction + cleaning + reconciliation."""
421
+ # Extraction/cleaning portion (60%)
422
+ extraction_score, extraction_fb = _grade_medium(submitted, ground_truths)
423
+
424
+ # Discrepancy detection portion (40%)
425
+ sub_discrepancies = submitted.get("discrepancies", [])
426
+ if not isinstance(sub_discrepancies, list):
427
+ disc_score = 0.0
428
+ disc_fb = "No discrepancies list submitted"
429
+ else:
430
+ all_expected = []
431
+ for disc_list in expected_discrepancies:
432
+ all_expected.extend(disc_list)
433
+
434
+ if not all_expected:
435
+ disc_score = 1.0 if not sub_discrepancies else 0.5
436
+ disc_fb = "No discrepancies expected"
437
+ else:
438
+ matched = 0
439
+ for exp in all_expected:
440
+ for sub in sub_discrepancies:
441
+ if _discrepancy_match(sub, exp):
442
+ matched += 1
443
+ break
444
+ precision = matched / len(sub_discrepancies) if sub_discrepancies else 0.0
445
+ recall = matched / len(all_expected) if all_expected else 1.0
446
+ disc_score = (precision + recall) / 2 # F1-like
447
+ disc_fb = f"Discrepancies: {matched}/{len(all_expected)} found, precision={precision:.2f}, recall={recall:.2f}"
448
+
449
+ total = extraction_score * 0.60 + disc_score * 0.40
450
+ feedback = f"Extraction: {extraction_score:.2f}; {disc_fb}"
451
+ return round(min(total, 1.0), 4), feedback
452
+
453
+
454
+ def _discrepancy_match(submitted: Dict, expected: Dict) -> bool:
455
+ """Check if a submitted discrepancy matches an expected one."""
456
+ # Type must match
457
+ sub_type = submitted.get("type", "").lower().strip()
458
+ exp_type = expected.get("type", "").lower().strip()
459
+ if sub_type != exp_type:
460
+ return False
461
+
462
+ # Item description should roughly match
463
+ sub_desc = submitted.get("item_description", "").lower().strip()
464
+ exp_desc = expected.get("item_description", "").lower().strip()
465
+ if sub_desc and exp_desc:
466
+ if sub_desc == exp_desc or sub_desc in exp_desc or exp_desc in sub_desc:
467
+ return True
468
+ return False
469
+
470
+
471
+ # ===================================================================
472
+ # Environment
473
+ # ===================================================================
474
+
475
+ class InvoiceEnvironment:
476
+ """Core invoice processing environment."""
477
+
478
+ TASKS = {
479
+ "easy": {
480
+ "description": (
481
+ "Extract structured data from a single invoice. "
482
+ "Return a JSON object with keys: vendor, date (YYYY-MM-DD), "
483
+ "currency (3-letter code), total (number), "
484
+ "line_items (list of {description, qty, unit_price, amount})."
485
+ ),
486
+ "max_attempts": 5,
487
+ },
488
+ "medium": {
489
+ "description": (
490
+ "Clean and normalise a batch of messy invoices. "
491
+ "Fix date formats to YYYY-MM-DD, correct vendor name typos, "
492
+ "standardise currency to 3-letter codes, ensure amounts are numbers, "
493
+ "and verify line item math (qty * unit_price = amount). "
494
+ "Return {invoices: [cleaned invoice objects]}."
495
+ ),
496
+ "max_attempts": 5,
497
+ },
498
+ "hard": {
499
+ "description": (
500
+ "Extract and clean invoice data, then reconcile against purchase orders. "
501
+ "Identify discrepancies: overcharges (invoice price > PO price), "
502
+ "extra items (on invoice but not PO), missing items (on PO but not invoice). "
503
+ "Return {invoices: [cleaned], discrepancies: [{invoice_idx, type, item_description, detail}]}."
504
+ ),
505
+ "max_attempts": 5,
506
+ },
507
+ }
508
+
509
+ def __init__(self):
510
+ self._state = InvoiceState()
511
+ self._ground_truth: Any = None
512
+ self._raw_text: str = ""
513
+ self._reference_data: str = ""
514
+ self._messy_invoices: List[Dict] = []
515
+ self._expected_discrepancies: List[List[Dict]] = []
516
+
517
+ def reset(self, task_id: str = "easy") -> Tuple[InvoiceObservation, float, bool, Dict]:
518
+ """Reset the environment for a new episode."""
519
+ if task_id not in self.TASKS:
520
+ task_id = "easy"
521
+
522
+ self._state = InvoiceState(
523
+ episode_id=str(uuid.uuid4()),
524
+ task_id=task_id,
525
+ step_count=0,
526
+ done=False,
527
+ last_reward=0.0,
528
+ best_reward=0.0,
529
+ rewards=[],
530
+ )
531
+
532
+ self._reference_data = ""
533
+ self._expected_discrepancies = []
534
+
535
+ if task_id == "easy":
536
+ inv = _generate_invoice()
537
+ self._ground_truth = inv
538
+ self._raw_text = _render_clean_invoice(inv)
539
+
540
+ elif task_id == "medium":
541
+ n = random.randint(3, 5)
542
+ clean_invoices = [_generate_invoice() for _ in range(n)]
543
+ self._ground_truth = clean_invoices
544
+ messy = [_make_messy_invoice(copy.deepcopy(inv)) for inv in clean_invoices]
545
+ self._messy_invoices = messy
546
+ self._raw_text = _render_messy_batch(messy)
547
+
548
+ elif task_id == "hard":
549
+ n = random.randint(2, 4)
550
+ clean_invoices = [_generate_invoice() for _ in range(n)]
551
+ self._expected_discrepancies = []
552
+ po_texts = []
553
+
554
+ for inv in clean_invoices:
555
+ po, discs = _generate_purchase_order(inv)
556
+ self._expected_discrepancies.append(discs)
557
+ po_texts.append(_render_po(po))
558
+
559
+ self._ground_truth = clean_invoices
560
+ messy = [_make_messy_invoice(copy.deepcopy(inv)) for inv in clean_invoices]
561
+ self._raw_text = _render_messy_batch(messy)
562
+ self._reference_data = "\n\n".join(po_texts)
563
+
564
+ task_info = self.TASKS[task_id]
565
+ obs = InvoiceObservation(
566
+ raw_text=self._raw_text,
567
+ task_id=task_id,
568
+ difficulty=task_id,
569
+ task_description=task_info["description"],
570
+ attempt_number=0,
571
+ max_attempts=task_info["max_attempts"],
572
+ feedback="",
573
+ hint="",
574
+ reference_data=self._reference_data,
575
+ )
576
+ return obs, 0.0, False, {"episode_id": self._state.episode_id}
577
+
578
+ def step(self, action: InvoiceAction) -> Tuple[InvoiceObservation, float, bool, Dict]:
579
+ """Process one agent action."""
580
+ self._state.step_count += 1
581
+ task_id = self._state.task_id
582
+ task_info = self.TASKS[task_id]
583
+ attempt = self._state.step_count
584
+
585
+ # Grade
586
+ if task_id == "easy":
587
+ score, feedback = _grade_easy(action.extracted_data, self._ground_truth)
588
+ elif task_id == "medium":
589
+ score, feedback = _grade_medium(action.extracted_data, self._ground_truth)
590
+ else:
591
+ score, feedback = _grade_hard(
592
+ action.extracted_data, self._ground_truth, self._expected_discrepancies
593
+ )
594
+
595
+ # Track best
596
+ self._state.best_reward = max(self._state.best_reward, score)
597
+ self._state.last_reward = score
598
+ self._state.rewards.append(score)
599
+
600
+ # Done conditions
601
+ done = score >= 0.95 or attempt >= task_info["max_attempts"]
602
+ self._state.done = done
603
+
604
+ # Attempt penalty for using all attempts
605
+ reward = score
606
+ if done and attempt >= task_info["max_attempts"] and score < 0.95:
607
+ reward = score * 0.85 # penalty
608
+
609
+ # Hint after 2 failed attempts
610
+ hint = ""
611
+ if attempt >= 2 and score < 0.7:
612
+ if task_id == "easy":
613
+ hint = "Make sure dates are YYYY-MM-DD, amounts are numbers, and all line items are included."
614
+ elif task_id == "medium":
615
+ hint = "Check for vendor name typos, mixed date formats, and currency symbols mixed into amounts."
616
+ else:
617
+ hint = "Compare each invoice line item against the PO. Look for price differences and items present in one but not the other."
618
+
619
+ obs = InvoiceObservation(
620
+ raw_text=self._raw_text,
621
+ task_id=task_id,
622
+ difficulty=task_id,
623
+ task_description=task_info["description"],
624
+ attempt_number=attempt,
625
+ max_attempts=task_info["max_attempts"],
626
+ feedback=feedback,
627
+ hint=hint,
628
+ reference_data=self._reference_data,
629
+ )
630
+
631
+ return obs, round(reward, 4), done, {
632
+ "episode_id": self._state.episode_id,
633
+ "best_reward": self._state.best_reward,
634
+ }
635
+
636
+ @property
637
+ def state(self) -> InvoiceState:
638
+ return self._state