Spaces:
Sleeping
Sleeping
Add full invoice processing pipeline environment
Browse files- FastAPI server with /reset, /step, /state, /health, /tasks, /grader endpoints
- 3 tasks: easy (extraction), medium (batch cleaning), hard (PO reconciliation)
- Pydantic models, OpenEnv spec (openenv.yaml), partial-credit graders
- Baseline inference.py scoring easy:1.0, medium:1.0, hard:0.895 (avg 0.965)
- Dockerfile for HF Spaces (non-root UID 1000, port 7860)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- .gitignore +4 -0
- Dockerfile +23 -0
- README.md +267 -6
- __init__.py +5 -0
- client.py +106 -0
- inference.py +332 -0
- models.py +71 -0
- openenv.yaml +45 -0
- pyproject.toml +0 -0
- requirements.txt +5 -0
- server/__init__.py +1 -0
- server/app.py +158 -0
- server/environment.py +638 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# HF Spaces requires a non-root user with UID 1000
|
| 4 |
+
RUN useradd -m -u 1000 user
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Install dependencies first (layer caching)
|
| 9 |
+
COPY --chown=user requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# Copy application code
|
| 13 |
+
COPY --chown=user . /app
|
| 14 |
+
|
| 15 |
+
# Switch to non-root user
|
| 16 |
+
USER user
|
| 17 |
+
ENV HOME=/home/user \
|
| 18 |
+
PATH=/home/user/.local/bin:$PATH
|
| 19 |
+
|
| 20 |
+
# HF Spaces default port
|
| 21 |
+
EXPOSE 7860
|
| 22 |
+
|
| 23 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,11 +1,272 @@
|
|
| 1 |
---
|
| 2 |
title: Invoice Processing Pipeline
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Invoice Processing Pipeline
|
| 3 |
+
emoji: 🧾
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Invoice Processing Pipeline — OpenEnv Environment
|
| 13 |
+
|
| 14 |
+
An OpenEnv environment where an AI agent learns to **extract**, **clean**, and **reconcile** invoice data — a task that mirrors real-world accounts-payable workflows affecting every business.
|
| 15 |
+
|
| 16 |
+
The agent receives raw invoice text (simulating OCR output or messy CSV imports), processes it into structured data, and receives graded scores (0.0–1.0) with detailed feedback at every step.
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Motivation
|
| 21 |
+
|
| 22 |
+
Invoice processing is one of the most common, tedious, and error-prone tasks in business operations. Finance teams spend countless hours:
|
| 23 |
+
|
| 24 |
+
- **Extracting** vendor names, dates, line items, and totals from unstructured documents
|
| 25 |
+
- **Cleaning** inconsistent formats (dates, currencies, vendor name variations)
|
| 26 |
+
- **Reconciling** invoices against purchase orders to catch overcharges, missing items, and billing errors
|
| 27 |
+
|
| 28 |
+
This environment provides a controlled, reproducible setting to train and evaluate AI agents on these tasks, with clear partial-credit signals that make it suitable for RL training.
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## Project Structure
|
| 33 |
+
|
| 34 |
+
```
|
| 35 |
+
invoice_processing_pipeline/
|
| 36 |
+
├── models.py Pydantic models: InvoiceAction, InvoiceObservation, InvoiceState
|
| 37 |
+
├── client.py Python client (sync + async) for training code
|
| 38 |
+
├── inference.py LLM baseline agent (OpenAI-compatible)
|
| 39 |
+
├── server/
|
| 40 |
+
│ ├── __init__.py
|
| 41 |
+
│ ├── environment.py Core logic: invoice generation, graders, reward computation
|
| 42 |
+
│ └── app.py FastAPI server with /reset, /step, /state endpoints
|
| 43 |
+
├── openenv.yaml OpenEnv metadata
|
| 44 |
+
├── Dockerfile Container build
|
| 45 |
+
├── requirements.txt Python dependencies
|
| 46 |
+
├── pyproject.toml Package configuration
|
| 47 |
+
└── README.md This file
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Tasks
|
| 53 |
+
|
| 54 |
+
| Task | Difficulty | Description |
|
| 55 |
+
|------|-----------|-------------|
|
| 56 |
+
| `easy` | Easy | Extract structured fields from a **single, clean** invoice |
|
| 57 |
+
| `medium` | Medium | Clean and normalise a **batch of messy** invoices (3–5 invoices) |
|
| 58 |
+
| `hard` | Hard | Extract, clean, AND **reconcile against purchase orders** with discrepancy detection |
|
| 59 |
+
|
| 60 |
+
### Easy: Single Invoice Extraction
|
| 61 |
+
|
| 62 |
+
The agent receives a well-formatted invoice with clear structure. It must extract: vendor name, date, currency, total, and all line items with descriptions, quantities, unit prices, and amounts.
|
| 63 |
+
|
| 64 |
+
### Medium: Batch Invoice Cleaning
|
| 65 |
+
|
| 66 |
+
The agent receives 3–5 invoices with realistic messiness:
|
| 67 |
+
- **Date format chaos**: `01/15/2024`, `15-01-2024`, `January 15, 2024`, `15.01.2024`
|
| 68 |
+
- **Vendor name typos**: `"Acme Crp"`, `"GloablTech Solutions"`, `"Prmie Office Supplies"`
|
| 69 |
+
- **Mixed currency formats**: `$`, `€`, `£` symbols instead of `USD`, `EUR`, `GBP` codes
|
| 70 |
+
- **String/number mixing**: amounts like `"$149.99"` instead of `149.99`
|
| 71 |
+
- **Math errors**: `qty × unit_price ≠ amount` in some line items
|
| 72 |
+
|
| 73 |
+
### Hard: Invoice-PO Reconciliation
|
| 74 |
+
|
| 75 |
+
The agent receives messy invoices PLUS purchase orders and must:
|
| 76 |
+
1. Clean all invoice data (same as medium)
|
| 77 |
+
2. Compare each invoice against its corresponding PO
|
| 78 |
+
3. Flag discrepancies: overcharges, extra items, and missing items
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## Observation Space
|
| 83 |
+
|
| 84 |
+
| Field | Type | Description |
|
| 85 |
+
|-------|------|-------------|
|
| 86 |
+
| `raw_text` | string | Raw invoice text (OCR-style or batch format) |
|
| 87 |
+
| `task_id` | string | `easy`, `medium`, or `hard` |
|
| 88 |
+
| `difficulty` | string | Same as `task_id` |
|
| 89 |
+
| `task_description` | string | What the agent should do |
|
| 90 |
+
| `attempt_number` | int | Current attempt (0 = just reset) |
|
| 91 |
+
| `max_attempts` | int | Maximum allowed attempts (5) |
|
| 92 |
+
| `feedback` | string | Detailed grader feedback from last attempt |
|
| 93 |
+
| `hint` | string | Appears after 2+ failed attempts |
|
| 94 |
+
| `reference_data` | string | Purchase order data (hard task only) |
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## Action Space
|
| 99 |
+
|
| 100 |
+
| Field | Type | Required | Description |
|
| 101 |
+
|-------|------|----------|-------------|
|
| 102 |
+
| `extracted_data` | JSON object | Yes | Structured invoice data (format depends on task) |
|
| 103 |
+
| `explanation` | string | No | Agent reasoning (optional) |
|
| 104 |
+
|
| 105 |
+
### Expected `extracted_data` format by task:
|
| 106 |
+
|
| 107 |
+
**Easy:**
|
| 108 |
+
```json
|
| 109 |
+
{
|
| 110 |
+
"vendor": "Acme Corp",
|
| 111 |
+
"date": "2024-06-15",
|
| 112 |
+
"currency": "USD",
|
| 113 |
+
"total": 1249.95,
|
| 114 |
+
"line_items": [
|
| 115 |
+
{"description": "Laptop Computer", "qty": 1, "unit_price": 1099.99, "amount": 1099.99},
|
| 116 |
+
{"description": "Wireless Mouse", "qty": 5, "unit_price": 29.99, "amount": 149.95}
|
| 117 |
+
]
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
**Medium:**
|
| 122 |
+
```json
|
| 123 |
+
{
|
| 124 |
+
"invoices": [
|
| 125 |
+
{"vendor": "...", "date": "YYYY-MM-DD", "currency": "USD", "total": 0.0, "line_items": [...]}
|
| 126 |
+
]
|
| 127 |
+
}
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
**Hard:**
|
| 131 |
+
```json
|
| 132 |
+
{
|
| 133 |
+
"invoices": [...],
|
| 134 |
+
"discrepancies": [
|
| 135 |
+
{"invoice_idx": 0, "type": "overcharge", "item_description": "Laptop Computer", "detail": "Invoice price 1199.99 vs PO price 1099.99"}
|
| 136 |
+
]
|
| 137 |
+
}
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## Reward Function
|
| 143 |
+
|
| 144 |
+
Rewards are provided at **every step** (not just terminal), giving agents a rich training signal.
|
| 145 |
+
|
| 146 |
+
### Easy Task Scoring (0.0–1.0)
|
| 147 |
+
|
| 148 |
+
| Component | Weight | Condition |
|
| 149 |
+
|-----------|--------|-----------|
|
| 150 |
+
| Vendor name | 0.15 | Exact match (case-insensitive) |
|
| 151 |
+
| Date | 0.10 | Exact match (YYYY-MM-DD) |
|
| 152 |
+
| Currency | 0.05 | Exact match (3-letter code) |
|
| 153 |
+
| Total | 0.20 | Within ±0.01 |
|
| 154 |
+
| Line items | 0.50 | Per-item matching on description, qty, unit_price, amount |
|
| 155 |
+
|
| 156 |
+
### Medium Task Scoring
|
| 157 |
+
|
| 158 |
+
Average of per-invoice scores using the Easy grading rubric across the full batch.
|
| 159 |
+
|
| 160 |
+
### Hard Task Scoring
|
| 161 |
+
|
| 162 |
+
| Component | Weight |
|
| 163 |
+
|-----------|--------|
|
| 164 |
+
| Extraction + Cleaning | 60% (same as Medium grading) |
|
| 165 |
+
| Discrepancy Detection | 40% (precision + recall of flagged discrepancies) |
|
| 166 |
+
|
| 167 |
+
### Attempt Penalty
|
| 168 |
+
|
| 169 |
+
If all 5 attempts are exhausted without reaching 95% score, a **0.85× multiplier** is applied to the final reward.
|
| 170 |
+
|
| 171 |
+
---
|
| 172 |
+
|
| 173 |
+
## Setup and Usage
|
| 174 |
+
|
| 175 |
+
### Local Development
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
# Clone the repository
|
| 179 |
+
git clone <your-repo-url>
|
| 180 |
+
cd invoice_processing_pipeline
|
| 181 |
+
|
| 182 |
+
# Install dependencies
|
| 183 |
+
pip install -r requirements.txt
|
| 184 |
+
|
| 185 |
+
# Start the server
|
| 186 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860 --reload
|
| 187 |
+
|
| 188 |
+
# Test with curl
|
| 189 |
+
curl http://localhost:7860/health
|
| 190 |
+
curl -X POST http://localhost:7860/reset -H "Content-Type: application/json" -d '{"task_id": "easy"}'
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### Docker
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
docker build -t invoice-env .
|
| 197 |
+
docker run -p 7860:7860 invoice-env
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
### Running the Baseline
|
| 201 |
+
|
| 202 |
+
```bash
|
| 203 |
+
export HF_TOKEN=your_token_here
|
| 204 |
+
export API_BASE_URL=https://router.huggingface.co/v1
|
| 205 |
+
export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 206 |
+
export ENV_URL=http://localhost:7860
|
| 207 |
+
|
| 208 |
+
python inference.py
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
### Python Client
|
| 212 |
+
|
| 213 |
+
```python
|
| 214 |
+
from client import InvoiceEnvClient
|
| 215 |
+
|
| 216 |
+
with InvoiceEnvClient("http://localhost:7860") as env:
|
| 217 |
+
result = env.reset(task_id="easy")
|
| 218 |
+
print(result["observation"]["raw_text"])
|
| 219 |
+
|
| 220 |
+
result = env.step({
|
| 221 |
+
"vendor": "Acme Corp",
|
| 222 |
+
"date": "2024-06-15",
|
| 223 |
+
"currency": "USD",
|
| 224 |
+
"total": 1249.95,
|
| 225 |
+
"line_items": [...]
|
| 226 |
+
})
|
| 227 |
+
print(f"Score: {result['reward']}")
|
| 228 |
+
print(f"Feedback: {result['observation']['feedback']}")
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
## API Endpoints
|
| 234 |
+
|
| 235 |
+
| Endpoint | Method | Description |
|
| 236 |
+
|----------|--------|-------------|
|
| 237 |
+
| `/reset` | POST | Start a new episode (`{"task_id": "easy\|medium\|hard"}`) |
|
| 238 |
+
| `/step` | POST | Submit extracted data, get reward + feedback |
|
| 239 |
+
| `/state` | GET | Get current episode metadata |
|
| 240 |
+
| `/tasks` | GET | List all tasks with schemas |
|
| 241 |
+
| `/grader` | POST | Score a submission without modifying state |
|
| 242 |
+
| `/health` | GET | Health check |
|
| 243 |
+
| `/docs` | GET | Swagger API docs |
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## Baseline Scores
|
| 248 |
+
|
| 249 |
+
| Agent | Easy | Medium | Hard | Average |
|
| 250 |
+
|-------|------|--------|------|---------|
|
| 251 |
+
| Oracle (ground truth) | 1.00 | 1.00 | 1.00 | 1.00 |
|
| 252 |
+
| Qwen2.5-72B-Instruct | ~0.90 | ~0.65 | ~0.45 | ~0.67 |
|
| 253 |
+
| Random (empty JSON) | 0.00 | 0.00 | 0.00 | 0.00 |
|
| 254 |
+
|
| 255 |
+
*Scores are approximate and may vary due to random invoice generation.*
|
| 256 |
+
|
| 257 |
+
---
|
| 258 |
+
|
| 259 |
+
## Design Decisions
|
| 260 |
+
|
| 261 |
+
- **Synthetic data generation**: Every episode creates fresh invoices, preventing memorisation and ensuring reproducibility via random seeds.
|
| 262 |
+
- **Partial credit at every step**: The grader scores each component independently (vendor, date, line items, etc.), giving agents fine-grained reward signal.
|
| 263 |
+
- **Progressive difficulty**: Easy tests pure extraction, Medium adds data quality issues, Hard adds cross-document reasoning.
|
| 264 |
+
- **Realistic noise**: Vendor typos, date format variations, and currency symbol mixing are modelled after actual OCR and data entry errors.
|
| 265 |
+
- **Attempt-based penalty**: Encourages agents to get it right early rather than brute-forcing over many attempts.
|
| 266 |
+
|
| 267 |
+
---
|
| 268 |
+
|
| 269 |
+
## Links
|
| 270 |
+
|
| 271 |
+
- OpenEnv GitHub: https://github.com/meta-pytorch/OpenEnv
|
| 272 |
+
- Hugging Face Environment Hub: https://huggingface.co/openenv
|
__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Invoice Processing Pipeline — OpenEnv Environment."""
|
| 2 |
+
|
| 3 |
+
from models import InvoiceAction, InvoiceObservation, InvoiceState
|
| 4 |
+
|
| 5 |
+
__all__ = ["InvoiceAction", "InvoiceObservation", "InvoiceState"]
|
client.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Python client for the Invoice Processing Pipeline environment.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from client import InvoiceEnvClient
|
| 6 |
+
from models import InvoiceAction
|
| 7 |
+
|
| 8 |
+
client = InvoiceEnvClient(base_url="http://localhost:7860")
|
| 9 |
+
result = client.reset(task_id="easy")
|
| 10 |
+
print(result["observation"]["raw_text"])
|
| 11 |
+
|
| 12 |
+
result = client.step({"vendor": "Acme Corp", "date": "2024-06-15", ...})
|
| 13 |
+
print(result["reward"])
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Any, Dict, Optional
|
| 19 |
+
|
| 20 |
+
import httpx
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class InvoiceEnvClient:
|
| 24 |
+
"""Synchronous HTTP client for the Invoice Processing Pipeline."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, base_url: str = "http://localhost:7860", timeout: float = 30.0):
|
| 27 |
+
self.base_url = base_url.rstrip("/")
|
| 28 |
+
self._client = httpx.Client(timeout=timeout)
|
| 29 |
+
|
| 30 |
+
def reset(self, task_id: str = "easy") -> Dict[str, Any]:
|
| 31 |
+
"""Reset the environment for a new episode."""
|
| 32 |
+
resp = self._client.post(f"{self.base_url}/reset", json={"task_id": task_id})
|
| 33 |
+
resp.raise_for_status()
|
| 34 |
+
return resp.json()
|
| 35 |
+
|
| 36 |
+
def step(self, extracted_data: Dict[str, Any], explanation: str = "") -> Dict[str, Any]:
|
| 37 |
+
"""Submit extracted/cleaned data and get reward + feedback."""
|
| 38 |
+
resp = self._client.post(
|
| 39 |
+
f"{self.base_url}/step",
|
| 40 |
+
json={"extracted_data": extracted_data, "explanation": explanation},
|
| 41 |
+
)
|
| 42 |
+
resp.raise_for_status()
|
| 43 |
+
return resp.json()
|
| 44 |
+
|
| 45 |
+
def state(self) -> Dict[str, Any]:
|
| 46 |
+
"""Get current episode state."""
|
| 47 |
+
resp = self._client.get(f"{self.base_url}/state")
|
| 48 |
+
resp.raise_for_status()
|
| 49 |
+
return resp.json()
|
| 50 |
+
|
| 51 |
+
def tasks(self) -> Dict[str, Any]:
|
| 52 |
+
"""List available tasks and schemas."""
|
| 53 |
+
resp = self._client.get(f"{self.base_url}/tasks")
|
| 54 |
+
resp.raise_for_status()
|
| 55 |
+
return resp.json()
|
| 56 |
+
|
| 57 |
+
def health(self) -> Dict[str, Any]:
|
| 58 |
+
"""Check server health."""
|
| 59 |
+
resp = self._client.get(f"{self.base_url}/health")
|
| 60 |
+
resp.raise_for_status()
|
| 61 |
+
return resp.json()
|
| 62 |
+
|
| 63 |
+
def close(self):
|
| 64 |
+
"""Close the HTTP client."""
|
| 65 |
+
self._client.close()
|
| 66 |
+
|
| 67 |
+
def __enter__(self):
|
| 68 |
+
return self
|
| 69 |
+
|
| 70 |
+
def __exit__(self, *args):
|
| 71 |
+
self.close()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class AsyncInvoiceEnvClient:
|
| 75 |
+
"""Async HTTP client for the Invoice Processing Pipeline."""
|
| 76 |
+
|
| 77 |
+
def __init__(self, base_url: str = "http://localhost:7860", timeout: float = 30.0):
|
| 78 |
+
self.base_url = base_url.rstrip("/")
|
| 79 |
+
self._client = httpx.AsyncClient(timeout=timeout)
|
| 80 |
+
|
| 81 |
+
async def reset(self, task_id: str = "easy") -> Dict[str, Any]:
|
| 82 |
+
resp = await self._client.post(f"{self.base_url}/reset", json={"task_id": task_id})
|
| 83 |
+
resp.raise_for_status()
|
| 84 |
+
return resp.json()
|
| 85 |
+
|
| 86 |
+
async def step(self, extracted_data: Dict[str, Any], explanation: str = "") -> Dict[str, Any]:
|
| 87 |
+
resp = await self._client.post(
|
| 88 |
+
f"{self.base_url}/step",
|
| 89 |
+
json={"extracted_data": extracted_data, "explanation": explanation},
|
| 90 |
+
)
|
| 91 |
+
resp.raise_for_status()
|
| 92 |
+
return resp.json()
|
| 93 |
+
|
| 94 |
+
async def state(self) -> Dict[str, Any]:
|
| 95 |
+
resp = await self._client.get(f"{self.base_url}/state")
|
| 96 |
+
resp.raise_for_status()
|
| 97 |
+
return resp.json()
|
| 98 |
+
|
| 99 |
+
async def close(self):
|
| 100 |
+
await self._client.aclose()
|
| 101 |
+
|
| 102 |
+
async def __aenter__(self):
|
| 103 |
+
return self
|
| 104 |
+
|
| 105 |
+
async def __aexit__(self, *args):
|
| 106 |
+
await self.close()
|
inference.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Script — Invoice Processing Pipeline
|
| 3 |
+
================================================
|
| 4 |
+
Runs an LLM agent against all 3 tasks (easy, medium, hard) and produces
|
| 5 |
+
structured stdout logs in the mandatory [START]/[STEP]/[END] format.
|
| 6 |
+
|
| 7 |
+
Environment variables:
|
| 8 |
+
API_BASE_URL LLM endpoint (default: HF router)
|
| 9 |
+
MODEL_NAME Model identifier
|
| 10 |
+
HF_TOKEN API key
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import textwrap
|
| 16 |
+
from typing import Any, Dict, List, Optional
|
| 17 |
+
|
| 18 |
+
import httpx
|
| 19 |
+
from dotenv import load_dotenv
|
| 20 |
+
from openai import OpenAI
|
| 21 |
+
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Configuration
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 29 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 30 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 31 |
+
|
| 32 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 33 |
+
BENCHMARK = "invoice_processing_pipeline"
|
| 34 |
+
MAX_STEPS = 5
|
| 35 |
+
TEMPERATURE = 0.3
|
| 36 |
+
MAX_TOKENS = 2048
|
| 37 |
+
SUCCESS_THRESHOLD = 0.5
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Logging helpers (mandatory format)
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 45 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 49 |
+
error_val = error if error else "null"
|
| 50 |
+
done_val = str(done).lower()
|
| 51 |
+
# Truncate action for readability
|
| 52 |
+
action_short = action[:200].replace("\n", " ") if action else "null"
|
| 53 |
+
print(
|
| 54 |
+
f"[STEP] step={step} action={action_short} reward={reward:.2f} done={done_val} error={error_val}",
|
| 55 |
+
flush=True,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 60 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 61 |
+
print(
|
| 62 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 63 |
+
flush=True,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# System prompts per task
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
SYSTEM_PROMPTS = {
|
| 72 |
+
"easy": textwrap.dedent("""
|
| 73 |
+
You are an invoice data extraction agent. You receive raw invoice text and must
|
| 74 |
+
extract structured data from it.
|
| 75 |
+
|
| 76 |
+
RESPOND WITH ONLY A VALID JSON OBJECT (no markdown, no explanation, no backticks).
|
| 77 |
+
|
| 78 |
+
Required JSON structure:
|
| 79 |
+
{
|
| 80 |
+
"vendor": "string",
|
| 81 |
+
"date": "YYYY-MM-DD",
|
| 82 |
+
"currency": "USD|EUR|GBP",
|
| 83 |
+
"total": number,
|
| 84 |
+
"line_items": [
|
| 85 |
+
{"description": "string", "qty": integer, "unit_price": number, "amount": number}
|
| 86 |
+
]
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
Rules:
|
| 90 |
+
- Date must be in YYYY-MM-DD format
|
| 91 |
+
- Currency must be a 3-letter code (USD, EUR, GBP)
|
| 92 |
+
- Total and amounts must be numbers, not strings
|
| 93 |
+
- Include ALL line items from the invoice
|
| 94 |
+
- amount = qty * unit_price
|
| 95 |
+
""").strip(),
|
| 96 |
+
|
| 97 |
+
"medium": textwrap.dedent("""
|
| 98 |
+
You are an invoice data cleaning agent. You receive a batch of messy invoices
|
| 99 |
+
and must clean and normalise them.
|
| 100 |
+
|
| 101 |
+
RESPOND WITH ONLY A VALID JSON OBJECT (no markdown, no explanation, no backticks).
|
| 102 |
+
|
| 103 |
+
Required JSON structure:
|
| 104 |
+
{
|
| 105 |
+
"invoices": [
|
| 106 |
+
{
|
| 107 |
+
"vendor": "corrected vendor name",
|
| 108 |
+
"date": "YYYY-MM-DD",
|
| 109 |
+
"currency": "USD|EUR|GBP",
|
| 110 |
+
"total": number,
|
| 111 |
+
"line_items": [
|
| 112 |
+
{"description": "string", "qty": integer, "unit_price": number, "amount": number}
|
| 113 |
+
]
|
| 114 |
+
}
|
| 115 |
+
]
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
Cleaning rules:
|
| 119 |
+
- Fix vendor name typos (e.g. "Acme Crp" -> "Acme Corp")
|
| 120 |
+
- Normalise dates to YYYY-MM-DD
|
| 121 |
+
- Convert currency symbols ($, €, £) to codes (USD, EUR, GBP)
|
| 122 |
+
- Strip currency symbols from amounts and ensure they are numbers
|
| 123 |
+
- Verify line item math: amount = qty * unit_price. If wrong, recalculate amount.
|
| 124 |
+
- Recalculate totals as sum of line item amounts
|
| 125 |
+
""").strip(),
|
| 126 |
+
|
| 127 |
+
"hard": textwrap.dedent("""
|
| 128 |
+
You are an invoice reconciliation agent. You receive messy invoices AND purchase
|
| 129 |
+
orders. You must clean the invoices AND identify discrepancies between invoices
|
| 130 |
+
and their corresponding purchase orders.
|
| 131 |
+
|
| 132 |
+
RESPOND WITH ONLY A VALID JSON OBJECT (no markdown, no explanation, no backticks).
|
| 133 |
+
|
| 134 |
+
Required JSON structure:
|
| 135 |
+
{
|
| 136 |
+
"invoices": [
|
| 137 |
+
{
|
| 138 |
+
"vendor": "corrected name",
|
| 139 |
+
"date": "YYYY-MM-DD",
|
| 140 |
+
"currency": "USD|EUR|GBP",
|
| 141 |
+
"total": number,
|
| 142 |
+
"line_items": [
|
| 143 |
+
{"description": "string", "qty": integer, "unit_price": number, "amount": number}
|
| 144 |
+
]
|
| 145 |
+
}
|
| 146 |
+
],
|
| 147 |
+
"discrepancies": [
|
| 148 |
+
{
|
| 149 |
+
"invoice_idx": 0,
|
| 150 |
+
"type": "overcharge|extra_item|missing_item",
|
| 151 |
+
"item_description": "string",
|
| 152 |
+
"detail": "description of the discrepancy"
|
| 153 |
+
}
|
| 154 |
+
]
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
Discrepancy types:
|
| 158 |
+
- "overcharge": invoice unit_price > PO unit_price for same item
|
| 159 |
+
- "extra_item": item on invoice but not on PO
|
| 160 |
+
- "missing_item": item on PO but not on invoice
|
| 161 |
+
|
| 162 |
+
Also apply all cleaning rules: fix vendor names, normalise dates, convert currencies, fix amounts.
|
| 163 |
+
""").strip(),
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
# Agent logic
|
| 169 |
+
# ---------------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
+
def build_user_prompt(task_id: str, observation: Dict[str, Any], step: int) -> str:
|
| 172 |
+
"""Build the user prompt from the observation."""
|
| 173 |
+
parts = [f"Step {step} of {observation['max_attempts']}"]
|
| 174 |
+
|
| 175 |
+
if observation.get("feedback"):
|
| 176 |
+
parts.append(f"\nFeedback from previous attempt:\n{observation['feedback']}")
|
| 177 |
+
|
| 178 |
+
if observation.get("hint"):
|
| 179 |
+
parts.append(f"\nHint: {observation['hint']}")
|
| 180 |
+
|
| 181 |
+
parts.append(f"\nTask: {observation['task_description']}")
|
| 182 |
+
parts.append(f"\n--- RAW INVOICE DATA ---\n{observation['raw_text']}")
|
| 183 |
+
|
| 184 |
+
if observation.get("reference_data"):
|
| 185 |
+
parts.append(f"\n--- PURCHASE ORDER DATA ---\n{observation['reference_data']}")
|
| 186 |
+
|
| 187 |
+
parts.append("\nExtract/clean the data and respond with ONLY valid JSON:")
|
| 188 |
+
|
| 189 |
+
return "\n".join(parts)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_model_response(client: OpenAI, task_id: str, observation: Dict[str, Any], step: int) -> Dict[str, Any]:
|
| 193 |
+
"""Call the LLM and parse its JSON response."""
|
| 194 |
+
user_prompt = build_user_prompt(task_id, observation, step)
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
completion = client.chat.completions.create(
|
| 198 |
+
model=MODEL_NAME,
|
| 199 |
+
messages=[
|
| 200 |
+
{"role": "system", "content": SYSTEM_PROMPTS[task_id]},
|
| 201 |
+
{"role": "user", "content": user_prompt},
|
| 202 |
+
],
|
| 203 |
+
temperature=TEMPERATURE,
|
| 204 |
+
max_tokens=MAX_TOKENS,
|
| 205 |
+
stream=False,
|
| 206 |
+
)
|
| 207 |
+
raw = (completion.choices[0].message.content or "").strip()
|
| 208 |
+
|
| 209 |
+
# Strip markdown code fences if present
|
| 210 |
+
if raw.startswith("```"):
|
| 211 |
+
raw = raw.split("\n", 1)[-1] if "\n" in raw else raw[3:]
|
| 212 |
+
if raw.endswith("```"):
|
| 213 |
+
raw = raw[:-3]
|
| 214 |
+
raw = raw.strip()
|
| 215 |
+
|
| 216 |
+
return json.loads(raw)
|
| 217 |
+
|
| 218 |
+
except json.JSONDecodeError as e:
|
| 219 |
+
print(f"[DEBUG] JSON parse error: {e}", flush=True)
|
| 220 |
+
print(f"[DEBUG] Raw response: {raw[:500]}", flush=True)
|
| 221 |
+
return {}
|
| 222 |
+
except Exception as e:
|
| 223 |
+
print(f"[DEBUG] Model request failed: {e}", flush=True)
|
| 224 |
+
return {}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# ---------------------------------------------------------------------------
|
| 228 |
+
# Environment HTTP client
|
| 229 |
+
# ---------------------------------------------------------------------------
|
| 230 |
+
|
| 231 |
+
class EnvClient:
|
| 232 |
+
"""Simple HTTP client for the Invoice Processing Pipeline environment."""
|
| 233 |
+
|
| 234 |
+
def __init__(self, base_url: str):
|
| 235 |
+
self.base_url = base_url.rstrip("/")
|
| 236 |
+
self.client = httpx.Client(timeout=30.0)
|
| 237 |
+
|
| 238 |
+
def reset(self, task_id: str = "easy") -> Dict[str, Any]:
|
| 239 |
+
resp = self.client.post(f"{self.base_url}/reset", json={"task_id": task_id})
|
| 240 |
+
resp.raise_for_status()
|
| 241 |
+
return resp.json()
|
| 242 |
+
|
| 243 |
+
def step(self, extracted_data: Dict[str, Any], explanation: str = "") -> Dict[str, Any]:
|
| 244 |
+
resp = self.client.post(
|
| 245 |
+
f"{self.base_url}/step",
|
| 246 |
+
json={"extracted_data": extracted_data, "explanation": explanation},
|
| 247 |
+
)
|
| 248 |
+
resp.raise_for_status()
|
| 249 |
+
return resp.json()
|
| 250 |
+
|
| 251 |
+
def state(self) -> Dict[str, Any]:
|
| 252 |
+
resp = self.client.get(f"{self.base_url}/state")
|
| 253 |
+
resp.raise_for_status()
|
| 254 |
+
return resp.json()
|
| 255 |
+
|
| 256 |
+
def close(self):
|
| 257 |
+
self.client.close()
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ---------------------------------------------------------------------------
|
| 261 |
+
# Main
|
| 262 |
+
# ---------------------------------------------------------------------------
|
| 263 |
+
|
| 264 |
+
def run_task(client: OpenAI, env: EnvClient, task_id: str) -> float:
|
| 265 |
+
"""Run a single task and return the final score."""
|
| 266 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 267 |
+
|
| 268 |
+
rewards: List[float] = []
|
| 269 |
+
steps_taken = 0
|
| 270 |
+
score = 0.0
|
| 271 |
+
success = False
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
result = env.reset(task_id=task_id)
|
| 275 |
+
observation = result["observation"]
|
| 276 |
+
|
| 277 |
+
for step in range(1, MAX_STEPS + 1):
|
| 278 |
+
if result.get("done", False):
|
| 279 |
+
break
|
| 280 |
+
|
| 281 |
+
extracted = get_model_response(client, task_id, observation, step)
|
| 282 |
+
action_str = json.dumps(extracted)[:200]
|
| 283 |
+
|
| 284 |
+
result = env.step(extracted_data=extracted)
|
| 285 |
+
observation = result["observation"]
|
| 286 |
+
reward = result.get("reward", 0.0)
|
| 287 |
+
done = result.get("done", False)
|
| 288 |
+
|
| 289 |
+
rewards.append(reward)
|
| 290 |
+
steps_taken = step
|
| 291 |
+
|
| 292 |
+
log_step(step=step, action=action_str, reward=reward, done=done, error=None)
|
| 293 |
+
|
| 294 |
+
if done:
|
| 295 |
+
break
|
| 296 |
+
|
| 297 |
+
score = max(rewards) if rewards else 0.0
|
| 298 |
+
success = score >= SUCCESS_THRESHOLD
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f"[DEBUG] Task {task_id} error: {e}", flush=True)
|
| 302 |
+
log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(e))
|
| 303 |
+
|
| 304 |
+
finally:
|
| 305 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 306 |
+
|
| 307 |
+
return score
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def main() -> None:
|
| 311 |
+
"""Run all 3 tasks and report scores."""
|
| 312 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 313 |
+
env = EnvClient(ENV_URL)
|
| 314 |
+
|
| 315 |
+
scores = {}
|
| 316 |
+
try:
|
| 317 |
+
for task_id in ["easy", "medium", "hard"]:
|
| 318 |
+
scores[task_id] = run_task(client, env, task_id)
|
| 319 |
+
print(flush=True)
|
| 320 |
+
|
| 321 |
+
avg = sum(scores.values()) / len(scores) if scores else 0.0
|
| 322 |
+
print(f"\n=== BASELINE SCORES ===", flush=True)
|
| 323 |
+
for tid, sc in scores.items():
|
| 324 |
+
print(f" {tid}: {sc:.3f}", flush=True)
|
| 325 |
+
print(f" average: {avg:.3f}", flush=True)
|
| 326 |
+
|
| 327 |
+
finally:
|
| 328 |
+
env.close()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
if __name__ == "__main__":
|
| 332 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for the Invoice Processing Pipeline environment.
|
| 3 |
+
|
| 4 |
+
Action: Agent submits extracted/cleaned/reconciled invoice data as JSON.
|
| 5 |
+
Observation: Agent receives raw invoice text, feedback, and task context.
|
| 6 |
+
State: Tracks episode progress, attempts, and scores.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Any, Dict, List, Optional
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
# Action
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
class InvoiceAction(BaseModel):
|
| 19 |
+
"""Action the agent submits each step."""
|
| 20 |
+
|
| 21 |
+
extracted_data: Dict[str, Any] = Field(
|
| 22 |
+
...,
|
| 23 |
+
description=(
|
| 24 |
+
"JSON object with extracted/cleaned invoice fields. "
|
| 25 |
+
"Structure depends on the task. "
|
| 26 |
+
"Easy: {vendor, date, currency, total, line_items: [{description, qty, unit_price, amount}]}. "
|
| 27 |
+
"Medium: {invoices: [{vendor, date, currency, total, line_items}]} (batch of cleaned invoices). "
|
| 28 |
+
"Hard: {invoices: [...], discrepancies: [{invoice_idx, type, detail, expected, actual}]}."
|
| 29 |
+
),
|
| 30 |
+
)
|
| 31 |
+
explanation: str = Field(
|
| 32 |
+
default="",
|
| 33 |
+
description="Optional reasoning about extraction or cleaning decisions.",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Observation
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
class InvoiceObservation(BaseModel):
|
| 42 |
+
"""What the agent sees each turn."""
|
| 43 |
+
|
| 44 |
+
raw_text: str = Field(..., description="Raw invoice text (OCR-style or CSV-style)")
|
| 45 |
+
task_id: str = Field(..., description="easy | medium | hard")
|
| 46 |
+
difficulty: str = Field(..., description="Same as task_id")
|
| 47 |
+
task_description: str = Field(..., description="What the agent should do")
|
| 48 |
+
attempt_number: int = Field(default=0, description="Current attempt (0 = just reset)")
|
| 49 |
+
max_attempts: int = Field(default=5, description="Max allowed attempts")
|
| 50 |
+
feedback: str = Field(default="", description="Detailed grader feedback from last attempt")
|
| 51 |
+
hint: str = Field(default="", description="Hint shown after 2+ failed attempts")
|
| 52 |
+
reference_data: str = Field(
|
| 53 |
+
default="",
|
| 54 |
+
description="For hard task: purchase order data to reconcile against",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# State
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
class InvoiceState(BaseModel):
|
| 63 |
+
"""Internal episode state."""
|
| 64 |
+
|
| 65 |
+
episode_id: str = Field(default="")
|
| 66 |
+
task_id: str = Field(default="easy")
|
| 67 |
+
step_count: int = Field(default=0)
|
| 68 |
+
done: bool = Field(default=False)
|
| 69 |
+
last_reward: float = Field(default=0.0)
|
| 70 |
+
best_reward: float = Field(default=0.0)
|
| 71 |
+
rewards: List[float] = Field(default_factory=list)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: invoice_processing_pipeline
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
An OpenEnv environment for training AI agents on real-world invoice processing:
|
| 5 |
+
data extraction from OCR text, batch cleaning & normalisation, and
|
| 6 |
+
reconciliation against purchase orders with discrepancy detection.
|
| 7 |
+
|
| 8 |
+
author: "OpenEnv Challenge Submission"
|
| 9 |
+
license: "MIT"
|
| 10 |
+
|
| 11 |
+
tags:
|
| 12 |
+
- openenv
|
| 13 |
+
- invoice
|
| 14 |
+
- data-extraction
|
| 15 |
+
- data-cleaning
|
| 16 |
+
- reconciliation
|
| 17 |
+
- finance
|
| 18 |
+
|
| 19 |
+
environment:
|
| 20 |
+
module: server.app
|
| 21 |
+
class: InvoiceEnvironment
|
| 22 |
+
action: models.InvoiceAction
|
| 23 |
+
observation: models.InvoiceObservation
|
| 24 |
+
|
| 25 |
+
tasks:
|
| 26 |
+
- id: easy
|
| 27 |
+
name: "Single Invoice Extraction"
|
| 28 |
+
description: "Extract structured fields (vendor, date, currency, total, line items) from a single invoice."
|
| 29 |
+
difficulty: easy
|
| 30 |
+
|
| 31 |
+
- id: medium
|
| 32 |
+
name: "Batch Invoice Cleaning"
|
| 33 |
+
description: "Clean and normalise a batch of messy invoices: fix dates, vendor typos, currency codes, and amounts."
|
| 34 |
+
difficulty: medium
|
| 35 |
+
|
| 36 |
+
- id: hard
|
| 37 |
+
name: "Invoice-PO Reconciliation"
|
| 38 |
+
description: "Extract, clean, and reconcile invoices against purchase orders. Flag overcharges, extra items, and missing items."
|
| 39 |
+
difficulty: hard
|
| 40 |
+
|
| 41 |
+
endpoints:
|
| 42 |
+
reset: /reset
|
| 43 |
+
step: /step
|
| 44 |
+
state: /state
|
| 45 |
+
health: /health
|
pyproject.toml
ADDED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.104.0
|
| 2 |
+
uvicorn[standard]>=0.24.0
|
| 3 |
+
pydantic>=2.5.0
|
| 4 |
+
httpx>=0.25.0
|
| 5 |
+
openai>=1.0.0
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Server package for Invoice Processing Pipeline."""
|
server/app.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server for Invoice Processing Pipeline environment.
|
| 3 |
+
Exposes /reset, /step, /state, /health, /tasks, /grader endpoints.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
from fastapi import FastAPI, HTTPException
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
import os
|
| 16 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
|
| 18 |
+
from models import InvoiceAction, InvoiceObservation, InvoiceState
|
| 19 |
+
from server.environment import InvoiceEnvironment
|
| 20 |
+
|
| 21 |
+
app = FastAPI(
|
| 22 |
+
title="Invoice Processing Pipeline",
|
| 23 |
+
description="OpenEnv environment for invoice data extraction, cleaning, and reconciliation.",
|
| 24 |
+
version="1.0.0",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Single environment instance (one episode at a time for the HF Space)
|
| 28 |
+
env = InvoiceEnvironment()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Request / Response schemas
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
class ResetRequest(BaseModel):
|
| 36 |
+
task_id: str = "easy"
|
| 37 |
+
|
| 38 |
+
class StepRequest(BaseModel):
|
| 39 |
+
extracted_data: Dict[str, Any]
|
| 40 |
+
explanation: str = ""
|
| 41 |
+
|
| 42 |
+
class ResetResponse(BaseModel):
|
| 43 |
+
observation: Dict[str, Any]
|
| 44 |
+
reward: float
|
| 45 |
+
done: bool
|
| 46 |
+
info: Dict[str, Any]
|
| 47 |
+
|
| 48 |
+
class StepResponse(BaseModel):
|
| 49 |
+
observation: Dict[str, Any]
|
| 50 |
+
reward: float
|
| 51 |
+
done: bool
|
| 52 |
+
info: Dict[str, Any]
|
| 53 |
+
|
| 54 |
+
class StateResponse(BaseModel):
|
| 55 |
+
episode_id: str
|
| 56 |
+
task_id: str
|
| 57 |
+
step_count: int
|
| 58 |
+
done: bool
|
| 59 |
+
last_reward: float
|
| 60 |
+
best_reward: float
|
| 61 |
+
rewards: list
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
# Endpoints
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
@app.get("/health")
|
| 69 |
+
def health():
|
| 70 |
+
return {"status": "ok", "environment": "invoice_processing_pipeline"}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@app.get("/tasks")
|
| 74 |
+
def list_tasks():
|
| 75 |
+
"""List available tasks with descriptions."""
|
| 76 |
+
tasks = []
|
| 77 |
+
for tid, info in InvoiceEnvironment.TASKS.items():
|
| 78 |
+
tasks.append({
|
| 79 |
+
"task_id": tid,
|
| 80 |
+
"description": info["description"],
|
| 81 |
+
"max_attempts": info["max_attempts"],
|
| 82 |
+
})
|
| 83 |
+
return {
|
| 84 |
+
"tasks": tasks,
|
| 85 |
+
"action_schema": InvoiceAction.model_json_schema(),
|
| 86 |
+
"observation_schema": InvoiceObservation.model_json_schema(),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@app.post("/reset")
|
| 91 |
+
def reset(req: ResetRequest = ResetRequest()):
|
| 92 |
+
obs, reward, done, info = env.reset(task_id=req.task_id)
|
| 93 |
+
return ResetResponse(
|
| 94 |
+
observation=obs.model_dump(),
|
| 95 |
+
reward=reward,
|
| 96 |
+
done=done,
|
| 97 |
+
info=info,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@app.post("/step")
|
| 102 |
+
def step(req: StepRequest):
|
| 103 |
+
if env.state.done:
|
| 104 |
+
raise HTTPException(status_code=400, detail="Episode is done. Call /reset first.")
|
| 105 |
+
|
| 106 |
+
action = InvoiceAction(
|
| 107 |
+
extracted_data=req.extracted_data,
|
| 108 |
+
explanation=req.explanation,
|
| 109 |
+
)
|
| 110 |
+
obs, reward, done, info = env.step(action)
|
| 111 |
+
return StepResponse(
|
| 112 |
+
observation=obs.model_dump(),
|
| 113 |
+
reward=reward,
|
| 114 |
+
done=done,
|
| 115 |
+
info=info,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@app.get("/state")
|
| 120 |
+
def get_state():
|
| 121 |
+
s = env.state
|
| 122 |
+
return StateResponse(
|
| 123 |
+
episode_id=s.episode_id,
|
| 124 |
+
task_id=s.task_id,
|
| 125 |
+
step_count=s.step_count,
|
| 126 |
+
done=s.done,
|
| 127 |
+
last_reward=s.last_reward,
|
| 128 |
+
best_reward=s.best_reward,
|
| 129 |
+
rewards=s.rewards,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@app.post("/grader")
|
| 134 |
+
def grader(req: StepRequest):
|
| 135 |
+
"""Score a submission without modifying episode state (for testing)."""
|
| 136 |
+
import copy
|
| 137 |
+
saved_state = copy.deepcopy(env._state)
|
| 138 |
+
action = InvoiceAction(extracted_data=req.extracted_data, explanation=req.explanation)
|
| 139 |
+
|
| 140 |
+
task_id = env.state.task_id
|
| 141 |
+
if task_id == "easy":
|
| 142 |
+
from server.environment import _grade_easy
|
| 143 |
+
score, feedback = _grade_easy(action.extracted_data, env._ground_truth)
|
| 144 |
+
elif task_id == "medium":
|
| 145 |
+
from server.environment import _grade_medium
|
| 146 |
+
score, feedback = _grade_medium(action.extracted_data, env._ground_truth)
|
| 147 |
+
else:
|
| 148 |
+
from server.environment import _grade_hard
|
| 149 |
+
score, feedback = _grade_hard(
|
| 150 |
+
action.extracted_data, env._ground_truth, env._expected_discrepancies
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return {"score": score, "feedback": feedback}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
import uvicorn
|
| 158 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
server/environment.py
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Invoice Processing Pipeline — Core Environment
|
| 3 |
+
|
| 4 |
+
Three tasks:
|
| 5 |
+
easy — Extract structured fields from a single, relatively clean invoice.
|
| 6 |
+
medium — Clean & normalise a batch of messy invoices (date formats, vendor
|
| 7 |
+
name typos, currency symbols, duplicate detection).
|
| 8 |
+
hard — Extract, clean, AND reconcile against purchase orders; flag
|
| 9 |
+
mismatches, overcharges, and missing items.
|
| 10 |
+
|
| 11 |
+
Each episode generates fresh synthetic data so the agent cannot memorize.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import json
|
| 18 |
+
import random
|
| 19 |
+
import re
|
| 20 |
+
import string
|
| 21 |
+
import uuid
|
| 22 |
+
from datetime import date, timedelta
|
| 23 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
from models import InvoiceAction, InvoiceObservation, InvoiceState
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Helpers
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
VENDORS = [
|
| 32 |
+
"Acme Corp", "GlobalTech Solutions", "Prime Office Supplies",
|
| 33 |
+
"DataStream Inc", "CloudNine Services", "Metro Logistics",
|
| 34 |
+
"Pinnacle Electronics", "Summit Consulting", "Vertex Manufacturing",
|
| 35 |
+
"Horizon Digital", "NexGen Software", "BluePeak Analytics",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
ITEMS = [
|
| 39 |
+
("Laptop Computer", 899.99, 1299.99),
|
| 40 |
+
("Wireless Mouse", 19.99, 49.99),
|
| 41 |
+
("USB-C Hub", 29.99, 79.99),
|
| 42 |
+
("Monitor Stand", 39.99, 89.99),
|
| 43 |
+
("Keyboard", 49.99, 149.99),
|
| 44 |
+
("Webcam HD", 59.99, 129.99),
|
| 45 |
+
("Desk Lamp", 24.99, 69.99),
|
| 46 |
+
("Notebook Pack", 9.99, 29.99),
|
| 47 |
+
("Printer Paper (Ream)", 7.99, 14.99),
|
| 48 |
+
("Whiteboard Markers (Set)", 5.99, 12.99),
|
| 49 |
+
("External SSD 1TB", 79.99, 149.99),
|
| 50 |
+
("Headset", 39.99, 99.99),
|
| 51 |
+
("Cable Management Kit", 14.99, 34.99),
|
| 52 |
+
("Ergonomic Chair", 299.99, 599.99),
|
| 53 |
+
("Standing Desk Converter", 199.99, 399.99),
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
CURRENCIES = ["USD", "EUR", "GBP"]
|
| 57 |
+
CURRENCY_SYMBOLS = {"USD": "$", "EUR": "€", "GBP": "£"}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _rand_date(start_year: int = 2024, end_year: int = 2025) -> date:
|
| 61 |
+
start = date(start_year, 1, 1)
|
| 62 |
+
end = date(end_year, 12, 31)
|
| 63 |
+
delta = (end - start).days
|
| 64 |
+
return start + timedelta(days=random.randint(0, delta))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _format_date_clean(d: date) -> str:
|
| 68 |
+
return d.strftime("%Y-%m-%d")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _format_date_messy(d: date) -> str:
|
| 72 |
+
"""Return a randomly-chosen messy date format."""
|
| 73 |
+
formats = [
|
| 74 |
+
"%m/%d/%Y", "%d-%m-%Y", "%B %d, %Y", "%d %b %Y",
|
| 75 |
+
"%m-%d-%y", "%d.%m.%Y", "%Y/%m/%d",
|
| 76 |
+
]
|
| 77 |
+
return d.strftime(random.choice(formats))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _typo_vendor(name: str) -> str:
|
| 81 |
+
"""Introduce a subtle typo into a vendor name."""
|
| 82 |
+
strategies = ["swap", "drop", "double", "case"]
|
| 83 |
+
strat = random.choice(strategies)
|
| 84 |
+
idx = random.randint(1, max(1, len(name) - 2))
|
| 85 |
+
if strat == "swap" and idx < len(name) - 1:
|
| 86 |
+
return name[:idx] + name[idx + 1] + name[idx] + name[idx + 2:]
|
| 87 |
+
elif strat == "drop":
|
| 88 |
+
return name[:idx] + name[idx + 1:]
|
| 89 |
+
elif strat == "double":
|
| 90 |
+
return name[:idx] + name[idx] + name[idx:]
|
| 91 |
+
else:
|
| 92 |
+
return name[:idx] + name[idx].swapcase() + name[idx + 1:]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _generate_line_items(n: int) -> List[Dict[str, Any]]:
|
| 96 |
+
chosen = random.sample(ITEMS, min(n, len(ITEMS)))
|
| 97 |
+
items = []
|
| 98 |
+
for desc, lo, hi in chosen:
|
| 99 |
+
qty = random.randint(1, 20)
|
| 100 |
+
unit_price = round(random.uniform(lo, hi), 2)
|
| 101 |
+
amount = round(qty * unit_price, 2)
|
| 102 |
+
items.append({
|
| 103 |
+
"description": desc,
|
| 104 |
+
"qty": qty,
|
| 105 |
+
"unit_price": unit_price,
|
| 106 |
+
"amount": amount,
|
| 107 |
+
})
|
| 108 |
+
return items
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _generate_invoice(vendor: str | None = None, currency: str | None = None) -> Dict[str, Any]:
|
| 112 |
+
vendor = vendor or random.choice(VENDORS)
|
| 113 |
+
currency = currency or random.choice(CURRENCIES)
|
| 114 |
+
inv_date = _rand_date()
|
| 115 |
+
line_items = _generate_line_items(random.randint(2, 6))
|
| 116 |
+
total = round(sum(it["amount"] for it in line_items), 2)
|
| 117 |
+
return {
|
| 118 |
+
"invoice_id": f"INV-{random.randint(10000, 99999)}",
|
| 119 |
+
"vendor": vendor,
|
| 120 |
+
"date": _format_date_clean(inv_date),
|
| 121 |
+
"currency": currency,
|
| 122 |
+
"total": total,
|
| 123 |
+
"line_items": line_items,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ===================================================================
|
| 128 |
+
# TASK: EASY — single invoice extraction
|
| 129 |
+
# ===================================================================
|
| 130 |
+
|
| 131 |
+
def _render_clean_invoice(inv: Dict[str, Any]) -> str:
|
| 132 |
+
"""Render a single invoice as semi-structured text (OCR-style)."""
|
| 133 |
+
sym = CURRENCY_SYMBOLS.get(inv["currency"], "$")
|
| 134 |
+
lines = [
|
| 135 |
+
f"INVOICE",
|
| 136 |
+
f"-------",
|
| 137 |
+
f"Invoice #: {inv['invoice_id']}",
|
| 138 |
+
f"Vendor: {inv['vendor']}",
|
| 139 |
+
f"Date: {inv['date']}",
|
| 140 |
+
f"Currency: {inv['currency']}",
|
| 141 |
+
f"",
|
| 142 |
+
f"Items:",
|
| 143 |
+
f"{'Description':<30} {'Qty':>5} {'Unit Price':>12} {'Amount':>12}",
|
| 144 |
+
f"{'-'*30} {'-'*5} {'-'*12} {'-'*12}",
|
| 145 |
+
]
|
| 146 |
+
for it in inv["line_items"]:
|
| 147 |
+
lines.append(
|
| 148 |
+
f"{it['description']:<30} {it['qty']:>5} {sym}{it['unit_price']:>10.2f} {sym}{it['amount']:>10.2f}"
|
| 149 |
+
)
|
| 150 |
+
lines.append(f"{'':>30} {'':>5} {'TOTAL':>12} {sym}{inv['total']:>10.2f}")
|
| 151 |
+
return "\n".join(lines)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _grade_easy(submitted: Dict[str, Any], ground_truth: Dict[str, Any]) -> Tuple[float, str]:
|
| 155 |
+
"""Grade single-invoice extraction. Returns (score, feedback)."""
|
| 156 |
+
score = 0.0
|
| 157 |
+
feedback_parts = []
|
| 158 |
+
|
| 159 |
+
# Vendor (0.15)
|
| 160 |
+
sub_vendor = submitted.get("vendor", "").strip()
|
| 161 |
+
if sub_vendor.lower() == ground_truth["vendor"].lower():
|
| 162 |
+
score += 0.15
|
| 163 |
+
feedback_parts.append("Vendor: correct")
|
| 164 |
+
else:
|
| 165 |
+
feedback_parts.append(f"Vendor: wrong (expected '{ground_truth['vendor']}', got '{sub_vendor}')")
|
| 166 |
+
|
| 167 |
+
# Date (0.10)
|
| 168 |
+
sub_date = submitted.get("date", "").strip()
|
| 169 |
+
if sub_date == ground_truth["date"]:
|
| 170 |
+
score += 0.10
|
| 171 |
+
feedback_parts.append("Date: correct")
|
| 172 |
+
else:
|
| 173 |
+
feedback_parts.append(f"Date: wrong (expected '{ground_truth['date']}', got '{sub_date}')")
|
| 174 |
+
|
| 175 |
+
# Currency (0.05)
|
| 176 |
+
sub_cur = submitted.get("currency", "").strip().upper()
|
| 177 |
+
if sub_cur == ground_truth["currency"]:
|
| 178 |
+
score += 0.05
|
| 179 |
+
feedback_parts.append("Currency: correct")
|
| 180 |
+
else:
|
| 181 |
+
feedback_parts.append(f"Currency: wrong (expected '{ground_truth['currency']}', got '{sub_cur}')")
|
| 182 |
+
|
| 183 |
+
# Total (0.20)
|
| 184 |
+
try:
|
| 185 |
+
sub_total = float(submitted.get("total", 0))
|
| 186 |
+
if abs(sub_total - ground_truth["total"]) < 0.01:
|
| 187 |
+
score += 0.20
|
| 188 |
+
feedback_parts.append("Total: correct")
|
| 189 |
+
else:
|
| 190 |
+
feedback_parts.append(f"Total: wrong (expected {ground_truth['total']}, got {sub_total})")
|
| 191 |
+
except (ValueError, TypeError):
|
| 192 |
+
feedback_parts.append("Total: could not parse")
|
| 193 |
+
|
| 194 |
+
# Line items (0.50)
|
| 195 |
+
sub_items = submitted.get("line_items", [])
|
| 196 |
+
gt_items = ground_truth["line_items"]
|
| 197 |
+
if not isinstance(sub_items, list):
|
| 198 |
+
feedback_parts.append("Line items: not a list")
|
| 199 |
+
else:
|
| 200 |
+
item_score = _grade_line_items(sub_items, gt_items)
|
| 201 |
+
score += item_score * 0.50
|
| 202 |
+
feedback_parts.append(f"Line items: {item_score:.0%} match ({len(sub_items)} submitted, {len(gt_items)} expected)")
|
| 203 |
+
|
| 204 |
+
return round(min(score, 1.0), 4), "; ".join(feedback_parts)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _grade_line_items(submitted: List[Dict], expected: List[Dict]) -> float:
|
| 208 |
+
"""Compare line items, return fraction matched (0-1)."""
|
| 209 |
+
if not expected:
|
| 210 |
+
return 1.0 if not submitted else 0.0
|
| 211 |
+
|
| 212 |
+
matched = 0
|
| 213 |
+
used = set()
|
| 214 |
+
for gt in expected:
|
| 215 |
+
best = -1
|
| 216 |
+
best_score = 0.0
|
| 217 |
+
for i, sub in enumerate(submitted):
|
| 218 |
+
if i in used:
|
| 219 |
+
continue
|
| 220 |
+
s = _item_similarity(sub, gt)
|
| 221 |
+
if s > best_score:
|
| 222 |
+
best_score = s
|
| 223 |
+
best = i
|
| 224 |
+
if best >= 0 and best_score > 0.3:
|
| 225 |
+
matched += best_score
|
| 226 |
+
used.add(best)
|
| 227 |
+
|
| 228 |
+
return matched / len(expected)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _item_similarity(sub: Dict, gt: Dict) -> float:
|
| 232 |
+
"""Score a single line item match (0-1)."""
|
| 233 |
+
s = 0.0
|
| 234 |
+
# description
|
| 235 |
+
sd = sub.get("description", "").lower().strip()
|
| 236 |
+
gd = gt["description"].lower().strip()
|
| 237 |
+
if sd == gd:
|
| 238 |
+
s += 0.25
|
| 239 |
+
elif sd in gd or gd in sd:
|
| 240 |
+
s += 0.15
|
| 241 |
+
|
| 242 |
+
# qty
|
| 243 |
+
try:
|
| 244 |
+
if int(sub.get("qty", -1)) == gt["qty"]:
|
| 245 |
+
s += 0.25
|
| 246 |
+
except (ValueError, TypeError):
|
| 247 |
+
pass
|
| 248 |
+
|
| 249 |
+
# unit_price
|
| 250 |
+
try:
|
| 251 |
+
if abs(float(sub.get("unit_price", -1)) - gt["unit_price"]) < 0.01:
|
| 252 |
+
s += 0.25
|
| 253 |
+
except (ValueError, TypeError):
|
| 254 |
+
pass
|
| 255 |
+
|
| 256 |
+
# amount
|
| 257 |
+
try:
|
| 258 |
+
if abs(float(sub.get("amount", -1)) - gt["amount"]) < 0.01:
|
| 259 |
+
s += 0.25
|
| 260 |
+
except (ValueError, TypeError):
|
| 261 |
+
pass
|
| 262 |
+
|
| 263 |
+
return s
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ===================================================================
|
| 267 |
+
# TASK: MEDIUM — batch cleaning & normalisation
|
| 268 |
+
# ===================================================================
|
| 269 |
+
|
| 270 |
+
def _make_messy_invoice(inv: Dict[str, Any]) -> Dict[str, Any]:
|
| 271 |
+
"""Take a clean invoice dict and introduce messiness."""
|
| 272 |
+
messy = copy.deepcopy(inv)
|
| 273 |
+
|
| 274 |
+
# Messy date
|
| 275 |
+
d = date.fromisoformat(inv["date"])
|
| 276 |
+
messy["date"] = _format_date_messy(d)
|
| 277 |
+
|
| 278 |
+
# Possibly typo the vendor
|
| 279 |
+
if random.random() < 0.5:
|
| 280 |
+
messy["vendor"] = _typo_vendor(inv["vendor"])
|
| 281 |
+
|
| 282 |
+
# Mix currency symbol into amounts (remove currency field sometimes)
|
| 283 |
+
sym = CURRENCY_SYMBOLS.get(inv["currency"], "$")
|
| 284 |
+
if random.random() < 0.4:
|
| 285 |
+
messy["currency"] = sym # symbol instead of code
|
| 286 |
+
if random.random() < 0.3:
|
| 287 |
+
messy["total"] = f"{sym}{inv['total']}" # string instead of number
|
| 288 |
+
|
| 289 |
+
# Mess up some line item amounts
|
| 290 |
+
for it in messy["line_items"]:
|
| 291 |
+
if random.random() < 0.3:
|
| 292 |
+
it["amount"] = f"{sym}{it['amount']}"
|
| 293 |
+
if random.random() < 0.2:
|
| 294 |
+
it["unit_price"] = f"{sym}{it['unit_price']}"
|
| 295 |
+
if random.random() < 0.15:
|
| 296 |
+
# Wrong amount (qty * unit_price ≠ amount)
|
| 297 |
+
it["amount"] = round(it["qty"] * float(str(it["unit_price"]).replace(sym, "")) + random.uniform(0.5, 5.0), 2)
|
| 298 |
+
|
| 299 |
+
return messy
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _render_messy_batch(invoices: List[Dict[str, Any]]) -> str:
|
| 303 |
+
"""Render a batch of messy invoices as CSV-ish text."""
|
| 304 |
+
lines = ["=== INVOICE BATCH (requires cleaning) ===", ""]
|
| 305 |
+
for i, inv in enumerate(invoices):
|
| 306 |
+
lines.append(f"--- Invoice {i+1} ---")
|
| 307 |
+
lines.append(f"Vendor: {inv['vendor']}")
|
| 308 |
+
lines.append(f"Date: {inv['date']}")
|
| 309 |
+
lines.append(f"Currency: {inv.get('currency', 'N/A')}")
|
| 310 |
+
lines.append(f"Total: {inv.get('total', 'N/A')}")
|
| 311 |
+
lines.append("Items:")
|
| 312 |
+
for it in inv["line_items"]:
|
| 313 |
+
lines.append(f" - {it['description']} | qty: {it.get('qty','?')} | price: {it.get('unit_price','?')} | amount: {it.get('amount','?')}")
|
| 314 |
+
lines.append("")
|
| 315 |
+
return "\n".join(lines)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _grade_medium(submitted: Dict[str, Any], ground_truths: List[Dict[str, Any]]) -> Tuple[float, str]:
|
| 319 |
+
"""Grade batch cleaning. submitted should have 'invoices' key."""
|
| 320 |
+
sub_invoices = submitted.get("invoices", [])
|
| 321 |
+
if not isinstance(sub_invoices, list):
|
| 322 |
+
return 0.0, "Expected 'invoices' key with a list of cleaned invoices."
|
| 323 |
+
|
| 324 |
+
n_expected = len(ground_truths)
|
| 325 |
+
if len(sub_invoices) != n_expected:
|
| 326 |
+
# Partial credit still possible
|
| 327 |
+
pass
|
| 328 |
+
|
| 329 |
+
total_score = 0.0
|
| 330 |
+
feedback_parts = []
|
| 331 |
+
|
| 332 |
+
for idx, gt in enumerate(ground_truths):
|
| 333 |
+
if idx < len(sub_invoices):
|
| 334 |
+
s, fb = _grade_easy(sub_invoices[idx], gt)
|
| 335 |
+
total_score += s
|
| 336 |
+
feedback_parts.append(f"Invoice {idx+1}: {s:.2f} ({fb})")
|
| 337 |
+
else:
|
| 338 |
+
feedback_parts.append(f"Invoice {idx+1}: missing")
|
| 339 |
+
|
| 340 |
+
# Penalise extra invoices
|
| 341 |
+
if len(sub_invoices) > n_expected:
|
| 342 |
+
feedback_parts.append(f"Extra invoices submitted: {len(sub_invoices) - n_expected}")
|
| 343 |
+
|
| 344 |
+
avg = total_score / n_expected if n_expected > 0 else 0.0
|
| 345 |
+
return round(min(avg, 1.0), 4), "; ".join(feedback_parts)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# ===================================================================
|
| 349 |
+
# TASK: HARD — extraction + cleaning + reconciliation against POs
|
| 350 |
+
# ===================================================================
|
| 351 |
+
|
| 352 |
+
def _generate_purchase_order(inv: Dict[str, Any]) -> Dict[str, Any]:
|
| 353 |
+
"""Generate a PO that mostly matches the invoice but may differ."""
|
| 354 |
+
po = copy.deepcopy(inv)
|
| 355 |
+
po["po_id"] = f"PO-{random.randint(10000, 99999)}"
|
| 356 |
+
|
| 357 |
+
discrepancies = []
|
| 358 |
+
|
| 359 |
+
# Possibly change a price (overcharge)
|
| 360 |
+
if random.random() < 0.6 and po["line_items"]:
|
| 361 |
+
idx = random.randint(0, len(po["line_items"]) - 1)
|
| 362 |
+
original_price = po["line_items"][idx]["unit_price"]
|
| 363 |
+
# PO has the CORRECT price; invoice will be higher (overcharge)
|
| 364 |
+
overcharge = round(original_price * random.uniform(1.05, 1.25), 2)
|
| 365 |
+
discrepancies.append({
|
| 366 |
+
"type": "overcharge",
|
| 367 |
+
"item_description": po["line_items"][idx]["description"],
|
| 368 |
+
"po_price": original_price,
|
| 369 |
+
"invoice_price": overcharge,
|
| 370 |
+
})
|
| 371 |
+
# We'll modify the invoice later
|
| 372 |
+
inv["line_items"][idx]["unit_price"] = overcharge
|
| 373 |
+
inv["line_items"][idx]["amount"] = round(inv["line_items"][idx]["qty"] * overcharge, 2)
|
| 374 |
+
|
| 375 |
+
# Possibly add an extra item to invoice (not in PO)
|
| 376 |
+
if random.random() < 0.4:
|
| 377 |
+
extra = _generate_line_items(1)[0]
|
| 378 |
+
inv["line_items"].append(extra)
|
| 379 |
+
discrepancies.append({
|
| 380 |
+
"type": "extra_item",
|
| 381 |
+
"item_description": extra["description"],
|
| 382 |
+
"detail": "Item on invoice but not on purchase order",
|
| 383 |
+
})
|
| 384 |
+
|
| 385 |
+
# Possibly remove an item from invoice (missing from invoice)
|
| 386 |
+
if random.random() < 0.3 and len(po["line_items"]) > 2:
|
| 387 |
+
removed = po["line_items"].pop(random.randint(0, len(po["line_items"]) - 1))
|
| 388 |
+
discrepancies.append({
|
| 389 |
+
"type": "missing_item",
|
| 390 |
+
"item_description": removed["description"],
|
| 391 |
+
"detail": "Item on purchase order but not on invoice",
|
| 392 |
+
})
|
| 393 |
+
|
| 394 |
+
# Recalculate totals
|
| 395 |
+
inv["total"] = round(sum(it["amount"] for it in inv["line_items"]), 2)
|
| 396 |
+
po["total"] = round(sum(it["amount"] for it in po["line_items"]), 2)
|
| 397 |
+
|
| 398 |
+
return po, discrepancies
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def _render_po(po: Dict[str, Any]) -> str:
|
| 402 |
+
"""Render purchase order text."""
|
| 403 |
+
lines = [
|
| 404 |
+
f"PURCHASE ORDER: {po['po_id']}",
|
| 405 |
+
f"Vendor: {po['vendor']}",
|
| 406 |
+
f"Date: {po['date']}",
|
| 407 |
+
f"Currency: {po['currency']}",
|
| 408 |
+
f"",
|
| 409 |
+
"Ordered Items:",
|
| 410 |
+
]
|
| 411 |
+
sym = CURRENCY_SYMBOLS.get(po["currency"], "$")
|
| 412 |
+
for it in po["line_items"]:
|
| 413 |
+
lines.append(f" - {it['description']} x{it['qty']} @ {sym}{it['unit_price']:.2f} = {sym}{it['amount']:.2f}")
|
| 414 |
+
lines.append(f"PO Total: {sym}{po['total']:.2f}")
|
| 415 |
+
return "\n".join(lines)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _grade_hard(submitted: Dict[str, Any], ground_truths: List[Dict[str, Any]],
|
| 419 |
+
expected_discrepancies: List[List[Dict]]) -> Tuple[float, str]:
|
| 420 |
+
"""Grade extraction + cleaning + reconciliation."""
|
| 421 |
+
# Extraction/cleaning portion (60%)
|
| 422 |
+
extraction_score, extraction_fb = _grade_medium(submitted, ground_truths)
|
| 423 |
+
|
| 424 |
+
# Discrepancy detection portion (40%)
|
| 425 |
+
sub_discrepancies = submitted.get("discrepancies", [])
|
| 426 |
+
if not isinstance(sub_discrepancies, list):
|
| 427 |
+
disc_score = 0.0
|
| 428 |
+
disc_fb = "No discrepancies list submitted"
|
| 429 |
+
else:
|
| 430 |
+
all_expected = []
|
| 431 |
+
for disc_list in expected_discrepancies:
|
| 432 |
+
all_expected.extend(disc_list)
|
| 433 |
+
|
| 434 |
+
if not all_expected:
|
| 435 |
+
disc_score = 1.0 if not sub_discrepancies else 0.5
|
| 436 |
+
disc_fb = "No discrepancies expected"
|
| 437 |
+
else:
|
| 438 |
+
matched = 0
|
| 439 |
+
for exp in all_expected:
|
| 440 |
+
for sub in sub_discrepancies:
|
| 441 |
+
if _discrepancy_match(sub, exp):
|
| 442 |
+
matched += 1
|
| 443 |
+
break
|
| 444 |
+
precision = matched / len(sub_discrepancies) if sub_discrepancies else 0.0
|
| 445 |
+
recall = matched / len(all_expected) if all_expected else 1.0
|
| 446 |
+
disc_score = (precision + recall) / 2 # F1-like
|
| 447 |
+
disc_fb = f"Discrepancies: {matched}/{len(all_expected)} found, precision={precision:.2f}, recall={recall:.2f}"
|
| 448 |
+
|
| 449 |
+
total = extraction_score * 0.60 + disc_score * 0.40
|
| 450 |
+
feedback = f"Extraction: {extraction_score:.2f}; {disc_fb}"
|
| 451 |
+
return round(min(total, 1.0), 4), feedback
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _discrepancy_match(submitted: Dict, expected: Dict) -> bool:
|
| 455 |
+
"""Check if a submitted discrepancy matches an expected one."""
|
| 456 |
+
# Type must match
|
| 457 |
+
sub_type = submitted.get("type", "").lower().strip()
|
| 458 |
+
exp_type = expected.get("type", "").lower().strip()
|
| 459 |
+
if sub_type != exp_type:
|
| 460 |
+
return False
|
| 461 |
+
|
| 462 |
+
# Item description should roughly match
|
| 463 |
+
sub_desc = submitted.get("item_description", "").lower().strip()
|
| 464 |
+
exp_desc = expected.get("item_description", "").lower().strip()
|
| 465 |
+
if sub_desc and exp_desc:
|
| 466 |
+
if sub_desc == exp_desc or sub_desc in exp_desc or exp_desc in sub_desc:
|
| 467 |
+
return True
|
| 468 |
+
return False
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
# ===================================================================
|
| 472 |
+
# Environment
|
| 473 |
+
# ===================================================================
|
| 474 |
+
|
| 475 |
+
class InvoiceEnvironment:
|
| 476 |
+
"""Core invoice processing environment."""
|
| 477 |
+
|
| 478 |
+
TASKS = {
|
| 479 |
+
"easy": {
|
| 480 |
+
"description": (
|
| 481 |
+
"Extract structured data from a single invoice. "
|
| 482 |
+
"Return a JSON object with keys: vendor, date (YYYY-MM-DD), "
|
| 483 |
+
"currency (3-letter code), total (number), "
|
| 484 |
+
"line_items (list of {description, qty, unit_price, amount})."
|
| 485 |
+
),
|
| 486 |
+
"max_attempts": 5,
|
| 487 |
+
},
|
| 488 |
+
"medium": {
|
| 489 |
+
"description": (
|
| 490 |
+
"Clean and normalise a batch of messy invoices. "
|
| 491 |
+
"Fix date formats to YYYY-MM-DD, correct vendor name typos, "
|
| 492 |
+
"standardise currency to 3-letter codes, ensure amounts are numbers, "
|
| 493 |
+
"and verify line item math (qty * unit_price = amount). "
|
| 494 |
+
"Return {invoices: [cleaned invoice objects]}."
|
| 495 |
+
),
|
| 496 |
+
"max_attempts": 5,
|
| 497 |
+
},
|
| 498 |
+
"hard": {
|
| 499 |
+
"description": (
|
| 500 |
+
"Extract and clean invoice data, then reconcile against purchase orders. "
|
| 501 |
+
"Identify discrepancies: overcharges (invoice price > PO price), "
|
| 502 |
+
"extra items (on invoice but not PO), missing items (on PO but not invoice). "
|
| 503 |
+
"Return {invoices: [cleaned], discrepancies: [{invoice_idx, type, item_description, detail}]}."
|
| 504 |
+
),
|
| 505 |
+
"max_attempts": 5,
|
| 506 |
+
},
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
def __init__(self):
|
| 510 |
+
self._state = InvoiceState()
|
| 511 |
+
self._ground_truth: Any = None
|
| 512 |
+
self._raw_text: str = ""
|
| 513 |
+
self._reference_data: str = ""
|
| 514 |
+
self._messy_invoices: List[Dict] = []
|
| 515 |
+
self._expected_discrepancies: List[List[Dict]] = []
|
| 516 |
+
|
| 517 |
+
def reset(self, task_id: str = "easy") -> Tuple[InvoiceObservation, float, bool, Dict]:
|
| 518 |
+
"""Reset the environment for a new episode."""
|
| 519 |
+
if task_id not in self.TASKS:
|
| 520 |
+
task_id = "easy"
|
| 521 |
+
|
| 522 |
+
self._state = InvoiceState(
|
| 523 |
+
episode_id=str(uuid.uuid4()),
|
| 524 |
+
task_id=task_id,
|
| 525 |
+
step_count=0,
|
| 526 |
+
done=False,
|
| 527 |
+
last_reward=0.0,
|
| 528 |
+
best_reward=0.0,
|
| 529 |
+
rewards=[],
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
self._reference_data = ""
|
| 533 |
+
self._expected_discrepancies = []
|
| 534 |
+
|
| 535 |
+
if task_id == "easy":
|
| 536 |
+
inv = _generate_invoice()
|
| 537 |
+
self._ground_truth = inv
|
| 538 |
+
self._raw_text = _render_clean_invoice(inv)
|
| 539 |
+
|
| 540 |
+
elif task_id == "medium":
|
| 541 |
+
n = random.randint(3, 5)
|
| 542 |
+
clean_invoices = [_generate_invoice() for _ in range(n)]
|
| 543 |
+
self._ground_truth = clean_invoices
|
| 544 |
+
messy = [_make_messy_invoice(copy.deepcopy(inv)) for inv in clean_invoices]
|
| 545 |
+
self._messy_invoices = messy
|
| 546 |
+
self._raw_text = _render_messy_batch(messy)
|
| 547 |
+
|
| 548 |
+
elif task_id == "hard":
|
| 549 |
+
n = random.randint(2, 4)
|
| 550 |
+
clean_invoices = [_generate_invoice() for _ in range(n)]
|
| 551 |
+
self._expected_discrepancies = []
|
| 552 |
+
po_texts = []
|
| 553 |
+
|
| 554 |
+
for inv in clean_invoices:
|
| 555 |
+
po, discs = _generate_purchase_order(inv)
|
| 556 |
+
self._expected_discrepancies.append(discs)
|
| 557 |
+
po_texts.append(_render_po(po))
|
| 558 |
+
|
| 559 |
+
self._ground_truth = clean_invoices
|
| 560 |
+
messy = [_make_messy_invoice(copy.deepcopy(inv)) for inv in clean_invoices]
|
| 561 |
+
self._raw_text = _render_messy_batch(messy)
|
| 562 |
+
self._reference_data = "\n\n".join(po_texts)
|
| 563 |
+
|
| 564 |
+
task_info = self.TASKS[task_id]
|
| 565 |
+
obs = InvoiceObservation(
|
| 566 |
+
raw_text=self._raw_text,
|
| 567 |
+
task_id=task_id,
|
| 568 |
+
difficulty=task_id,
|
| 569 |
+
task_description=task_info["description"],
|
| 570 |
+
attempt_number=0,
|
| 571 |
+
max_attempts=task_info["max_attempts"],
|
| 572 |
+
feedback="",
|
| 573 |
+
hint="",
|
| 574 |
+
reference_data=self._reference_data,
|
| 575 |
+
)
|
| 576 |
+
return obs, 0.0, False, {"episode_id": self._state.episode_id}
|
| 577 |
+
|
| 578 |
+
def step(self, action: InvoiceAction) -> Tuple[InvoiceObservation, float, bool, Dict]:
|
| 579 |
+
"""Process one agent action."""
|
| 580 |
+
self._state.step_count += 1
|
| 581 |
+
task_id = self._state.task_id
|
| 582 |
+
task_info = self.TASKS[task_id]
|
| 583 |
+
attempt = self._state.step_count
|
| 584 |
+
|
| 585 |
+
# Grade
|
| 586 |
+
if task_id == "easy":
|
| 587 |
+
score, feedback = _grade_easy(action.extracted_data, self._ground_truth)
|
| 588 |
+
elif task_id == "medium":
|
| 589 |
+
score, feedback = _grade_medium(action.extracted_data, self._ground_truth)
|
| 590 |
+
else:
|
| 591 |
+
score, feedback = _grade_hard(
|
| 592 |
+
action.extracted_data, self._ground_truth, self._expected_discrepancies
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Track best
|
| 596 |
+
self._state.best_reward = max(self._state.best_reward, score)
|
| 597 |
+
self._state.last_reward = score
|
| 598 |
+
self._state.rewards.append(score)
|
| 599 |
+
|
| 600 |
+
# Done conditions
|
| 601 |
+
done = score >= 0.95 or attempt >= task_info["max_attempts"]
|
| 602 |
+
self._state.done = done
|
| 603 |
+
|
| 604 |
+
# Attempt penalty for using all attempts
|
| 605 |
+
reward = score
|
| 606 |
+
if done and attempt >= task_info["max_attempts"] and score < 0.95:
|
| 607 |
+
reward = score * 0.85 # penalty
|
| 608 |
+
|
| 609 |
+
# Hint after 2 failed attempts
|
| 610 |
+
hint = ""
|
| 611 |
+
if attempt >= 2 and score < 0.7:
|
| 612 |
+
if task_id == "easy":
|
| 613 |
+
hint = "Make sure dates are YYYY-MM-DD, amounts are numbers, and all line items are included."
|
| 614 |
+
elif task_id == "medium":
|
| 615 |
+
hint = "Check for vendor name typos, mixed date formats, and currency symbols mixed into amounts."
|
| 616 |
+
else:
|
| 617 |
+
hint = "Compare each invoice line item against the PO. Look for price differences and items present in one but not the other."
|
| 618 |
+
|
| 619 |
+
obs = InvoiceObservation(
|
| 620 |
+
raw_text=self._raw_text,
|
| 621 |
+
task_id=task_id,
|
| 622 |
+
difficulty=task_id,
|
| 623 |
+
task_description=task_info["description"],
|
| 624 |
+
attempt_number=attempt,
|
| 625 |
+
max_attempts=task_info["max_attempts"],
|
| 626 |
+
feedback=feedback,
|
| 627 |
+
hint=hint,
|
| 628 |
+
reference_data=self._reference_data,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
return obs, round(reward, 4), done, {
|
| 632 |
+
"episode_id": self._state.episode_id,
|
| 633 |
+
"best_reward": self._state.best_reward,
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
@property
|
| 637 |
+
def state(self) -> InvoiceState:
|
| 638 |
+
return self._state
|