Spaces:
Running
Running
FEAT: Add to create SLM training data
Browse files- .gitignore +4 -2
- dataset/PIPELINE_FLOW.md +414 -0
- dataset/README.md +108 -0
- dataset/__init__.py +1 -0
- dataset/config.yaml +52 -0
- dataset/scripts/__init__.py +1 -0
- dataset/scripts/build_inventory.py +111 -0
- dataset/scripts/build_relations.py +288 -0
- dataset/scripts/cli.py +305 -0
- dataset/scripts/export_training_data.py +191 -0
- dataset/scripts/generate_samples.py +1091 -0
- dataset/scripts/sql_templates.py +317 -0
- dataset/scripts/validate_dataset.py +275 -0
- pyproject.toml +7 -1
- uv.lock +0 -0
.gitignore
CHANGED
|
@@ -133,6 +133,8 @@ dmypy.json
|
|
| 133 |
# Pyre type checker
|
| 134 |
.pyre/
|
| 135 |
|
| 136 |
-
|
| 137 |
data/
|
| 138 |
-
output/
|
|
|
|
|
|
|
|
|
| 133 |
# Pyre type checker
|
| 134 |
.pyre/
|
| 135 |
|
| 136 |
+
# Dataset
|
| 137 |
data/
|
| 138 |
+
output/
|
| 139 |
+
*.parquet
|
| 140 |
+
*.jsonl
|
dataset/PIPELINE_FLOW.md
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset Generation Pipeline Flow
|
| 2 |
+
|
| 3 |
+
This document explains how the optimized pipeline processes data with concrete examples.
|
| 4 |
+
|
| 5 |
+
**Example Configuration:**
|
| 6 |
+
- Countries: `['EC', 'BE', 'KE', 'AE', 'SG']` (5 countries)
|
| 7 |
+
- Sample targets: 100 per family × 8 families = 800 samples
|
| 8 |
+
- Retry multiplier: 2 (generate 1,600 attempts to get 800 valid samples)
|
| 9 |
+
- Max workers: 8
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## Step 0: Build Entity Inventory (One-Time Setup)
|
| 14 |
+
|
| 15 |
+
**Script:** `build_inventory.py`
|
| 16 |
+
|
| 17 |
+
**What it does:** Extracts compact metadata from the full parquet files for fast sampling.
|
| 18 |
+
|
| 19 |
+
**Input:**
|
| 20 |
+
- `divisions_area.parquet` (~500K entities globally)
|
| 21 |
+
- `natural_earth.parquet` (~50K entities globally)
|
| 22 |
+
|
| 23 |
+
**Process:**
|
| 24 |
+
```sql
|
| 25 |
+
-- For each parquet, extract:
|
| 26 |
+
SELECT
|
| 27 |
+
id,
|
| 28 |
+
names."primary" AS name,
|
| 29 |
+
subtype,
|
| 30 |
+
country,
|
| 31 |
+
region,
|
| 32 |
+
admin_level,
|
| 33 |
+
ST_Area(geometry) AS area_sq_deg,
|
| 34 |
+
-- bounding box for spatial filtering
|
| 35 |
+
FROM read_parquet(...)
|
| 36 |
+
WHERE names."primary" IS NOT NULL
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
**Output:**
|
| 40 |
+
- `intermediate/divisions_area_inventory.parquet` (~50K rows for 5 countries)
|
| 41 |
+
- `intermediate/natural_earth_inventory.parquet` (~50K rows)
|
| 42 |
+
|
| 43 |
+
**Parallelization:** None (runs once, fast enough)
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## Step 1: Build Relation Tables (Parallelized)
|
| 48 |
+
|
| 49 |
+
**Script:** `build_relations.py`
|
| 50 |
+
|
| 51 |
+
**What it does:** Pre-computes spatial relationships between entities so sample generation doesn't need to run expensive spatial joins.
|
| 52 |
+
|
| 53 |
+
### Before Optimization (Sequential)
|
| 54 |
+
```
|
| 55 |
+
Total time: adjacency (60s) + containment (15s) + intersection (10s) + cross_source (8s) = 93s
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### After Optimization (Parallel)
|
| 59 |
+
```
|
| 60 |
+
Total time: max(60s, 15s, 10s, 8s) = 60s
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
**4 concurrent tasks, each with its own DuckDB connection:**
|
| 64 |
+
|
| 65 |
+
#### Task 1: Adjacency Pairs (60s)
|
| 66 |
+
```sql
|
| 67 |
+
-- Find all touching boundaries within 5 countries
|
| 68 |
+
WITH features AS (
|
| 69 |
+
SELECT id, name, subtype, country, geometry, ST_Envelope(geometry) AS bbox
|
| 70 |
+
FROM divisions_area
|
| 71 |
+
WHERE country IN ('EC', 'BE', 'KE', 'AE', 'SG')
|
| 72 |
+
)
|
| 73 |
+
SELECT
|
| 74 |
+
a.id AS anchor_id,
|
| 75 |
+
a.name AS anchor_name,
|
| 76 |
+
b.id AS target_id,
|
| 77 |
+
b.subtype AS target_subtype
|
| 78 |
+
FROM features a
|
| 79 |
+
JOIN features b ON (
|
| 80 |
+
a.id < b.id
|
| 81 |
+
AND ST_Intersects(a.bbox, b.bbox) -- Fast bbox pre-filter
|
| 82 |
+
AND ST_Touches(a.geometry, b.geometry) -- Expensive but necessary
|
| 83 |
+
)
|
| 84 |
+
LIMIT 50000
|
| 85 |
+
```
|
| 86 |
+
**Output:** `adjacency_pairs.parquet` (50,000 rows)
|
| 87 |
+
|
| 88 |
+
#### Task 2: Containment Pairs (15s)
|
| 89 |
+
```sql
|
| 90 |
+
-- Find all parent-child relationships
|
| 91 |
+
-- Example: Ecuador contains Quito
|
| 92 |
+
SELECT
|
| 93 |
+
container.id AS container_id,
|
| 94 |
+
container.name AS container_name,
|
| 95 |
+
contained.id AS contained_id,
|
| 96 |
+
contained.subtype AS contained_subtype
|
| 97 |
+
FROM features container
|
| 98 |
+
JOIN features contained ON (
|
| 99 |
+
container.admin_level < contained.admin_level -- Parent has lower level
|
| 100 |
+
AND ST_Within(contained.geometry, container.geometry)
|
| 101 |
+
)
|
| 102 |
+
LIMIT 1000
|
| 103 |
+
```
|
| 104 |
+
**Output:** `containment_pairs.parquet` (1,000 rows)
|
| 105 |
+
|
| 106 |
+
#### Task 3: Intersection Pairs (10s)
|
| 107 |
+
```sql
|
| 108 |
+
-- Find overlapping regions (not touching, not containing)
|
| 109 |
+
-- Example: Two administrative regions that overlap
|
| 110 |
+
SELECT a.id, a.name, b.id, b.subtype
|
| 111 |
+
FROM features a
|
| 112 |
+
JOIN features b ON (
|
| 113 |
+
ST_Intersects(a.geometry, b.geometry)
|
| 114 |
+
AND NOT ST_Touches(a.geometry, b.geometry)
|
| 115 |
+
AND NOT ST_Within(a.geometry, b.geometry)
|
| 116 |
+
)
|
| 117 |
+
LIMIT 500
|
| 118 |
+
```
|
| 119 |
+
**Output:** `intersection_pairs.parquet` (500 rows)
|
| 120 |
+
|
| 121 |
+
#### Task 4: Cross-Source Relations (8s)
|
| 122 |
+
```sql
|
| 123 |
+
-- Find relationships between divisions and natural features
|
| 124 |
+
-- Example: Ecuador intersects Pacific Ocean
|
| 125 |
+
SELECT
|
| 126 |
+
d.id AS division_id,
|
| 127 |
+
d.name AS division_name,
|
| 128 |
+
n.id AS natural_id,
|
| 129 |
+
n.name AS natural_name,
|
| 130 |
+
CASE
|
| 131 |
+
WHEN ST_Touches(...) THEN 'touches'
|
| 132 |
+
WHEN ST_Intersects(...) THEN 'intersects'
|
| 133 |
+
END AS relation_type
|
| 134 |
+
FROM divisions d
|
| 135 |
+
JOIN natural_features n ON ST_Intersects(d.geometry, n.geometry)
|
| 136 |
+
WHERE d.country IN ('EC', 'BE', 'KE', 'AE', 'SG')
|
| 137 |
+
AND n.subtype IN ('sea', 'ocean', 'Lake', 'River')
|
| 138 |
+
LIMIT 500
|
| 139 |
+
```
|
| 140 |
+
**Output:** `cross_source_relations.parquet` (500 rows)
|
| 141 |
+
|
| 142 |
+
**ThreadPoolExecutor with 4 workers runs all tasks concurrently.**
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
## Step 2: Generate Samples (Batch-Parallelized)
|
| 147 |
+
|
| 148 |
+
**Script:** `generate_samples.py`
|
| 149 |
+
|
| 150 |
+
**What it does:** Creates training samples by:
|
| 151 |
+
1. Sampling anchors from relation tables
|
| 152 |
+
2. Rendering SQL templates
|
| 153 |
+
3. Executing SQL to verify it works
|
| 154 |
+
4. Building candidate lists with distractors
|
| 155 |
+
5. Generating questions
|
| 156 |
+
|
| 157 |
+
### Work Item Preparation
|
| 158 |
+
|
| 159 |
+
**Total work items:** 8 families × 100 targets × 2 retry_multiplier = **1,600 items**
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
work_items = [
|
| 163 |
+
('adjacency', template_dict_1, 'sample_001', '/path/to/intermediate'),
|
| 164 |
+
('containment', template_dict_2, 'sample_002', '/path/to/intermediate'),
|
| 165 |
+
('adjacency', template_dict_3, 'sample_003', '/path/to/intermediate'),
|
| 166 |
+
# ... 1,597 more items
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
# Shuffle for balanced batches
|
| 170 |
+
random.shuffle(work_items)
|
| 171 |
+
|
| 172 |
+
# Partition into 8 batches (one per worker)
|
| 173 |
+
batch_size = 1600 / 8 = 200 items per batch
|
| 174 |
+
batches = [
|
| 175 |
+
batch_1: items[0:200], # ~25 of each family (mixed)
|
| 176 |
+
batch_2: items[200:400],
|
| 177 |
+
batch_3: items[400:600],
|
| 178 |
+
# ... 8 batches total
|
| 179 |
+
]
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
### Before Optimization (Per-Sample Workers)
|
| 183 |
+
|
| 184 |
+
```
|
| 185 |
+
For each of 1,600 samples:
|
| 186 |
+
- Fork new process
|
| 187 |
+
- Create DuckDB connection
|
| 188 |
+
- INSTALL spatial (5-10ms)
|
| 189 |
+
- LOAD spatial (5-10ms)
|
| 190 |
+
- Import sql_templates module
|
| 191 |
+
- Load 4 relation parquet files (50-100ms)
|
| 192 |
+
- Generate 1 sample (20-50ms)
|
| 193 |
+
- Close connection
|
| 194 |
+
|
| 195 |
+
Total overhead per sample: ~100ms
|
| 196 |
+
Total overhead: 1,600 × 100ms = 160 seconds of pure overhead
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
### After Optimization (Batch Workers)
|
| 200 |
+
|
| 201 |
+
```
|
| 202 |
+
8 workers run in parallel, each processes 200 samples:
|
| 203 |
+
|
| 204 |
+
Worker 1 (batch of 200 items):
|
| 205 |
+
- Create DuckDB connection (once)
|
| 206 |
+
- INSTALL + LOAD spatial (once, 10ms)
|
| 207 |
+
- Import sql_templates (once)
|
| 208 |
+
- Load 4 relation tables (once, 100ms)
|
| 209 |
+
|
| 210 |
+
FOR EACH of 200 items:
|
| 211 |
+
- Sample anchor from pre-loaded table (instant)
|
| 212 |
+
- Render SQL template
|
| 213 |
+
- Execute SQL to verify (20-50ms)
|
| 214 |
+
- Build candidate list with Jaro-Winkler (10-30ms)
|
| 215 |
+
- Generate question
|
| 216 |
+
|
| 217 |
+
- Close connection
|
| 218 |
+
|
| 219 |
+
Total overhead per worker: ~110ms (one-time)
|
| 220 |
+
Total overhead across 8 workers: ~110ms (parallel)
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
**Speedup:** 160s → 0.11s overhead = **~1,450x faster on initialization overhead**
|
| 224 |
+
|
| 225 |
+
### Sample Generation Example
|
| 226 |
+
|
| 227 |
+
**Adjacency sample generation:**
|
| 228 |
+
|
| 229 |
+
```python
|
| 230 |
+
# 1. Sample anchor from pre-loaded adjacency_pairs DataFrame
|
| 231 |
+
row = adjacency_df.sample(n=1).iloc[0]
|
| 232 |
+
# Result: anchor_id='EC-123', anchor_name='Quito', target_subtype='locality'
|
| 233 |
+
|
| 234 |
+
# 2. Render SQL template
|
| 235 |
+
sql = f"""
|
| 236 |
+
WITH a AS (
|
| 237 |
+
SELECT geometry FROM divisions_area WHERE id = 'EC-123'
|
| 238 |
+
)
|
| 239 |
+
SELECT b.id, b.names."primary" AS name, b.geometry
|
| 240 |
+
FROM divisions_area AS b, a
|
| 241 |
+
WHERE b.id != 'EC-123'
|
| 242 |
+
AND b.subtype = 'locality'
|
| 243 |
+
AND ST_Touches(a.geometry, b.geometry)
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
# 3. Execute to verify (returns 5 neighboring localities)
|
| 247 |
+
result = con.execute(sql).fetchdf() # 30ms
|
| 248 |
+
# ✓ Not empty, sample is valid
|
| 249 |
+
|
| 250 |
+
# 4. Build candidate list (10 candidates: 1 true + 9 distractors)
|
| 251 |
+
candidates = build_candidate_list(
|
| 252 |
+
con, 'EC-123', 'Quito', 'divisions_area', num_candidates=10
|
| 253 |
+
)
|
| 254 |
+
# Uses Jaro-Winkler to find similar names:
|
| 255 |
+
SELECT id, name, subtype, country,
|
| 256 |
+
jaro_winkler_similarity(lower(name), lower('Quito')) AS similarity
|
| 257 |
+
FROM divisions_area
|
| 258 |
+
WHERE id != 'EC-123'
|
| 259 |
+
ORDER BY similarity DESC
|
| 260 |
+
LIMIT 9
|
| 261 |
+
|
| 262 |
+
# Results: ['Quito', 'Cuito', 'Quijos', 'Quinindé', ...]
|
| 263 |
+
# Shuffle and reassign IDs: c1, c2, ..., c10
|
| 264 |
+
|
| 265 |
+
# 5. Generate question
|
| 266 |
+
question = "Which localities border Quito?"
|
| 267 |
+
|
| 268 |
+
# 6. Return TrainingSample with sql_verified=True
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
### Batch Progress Tracking
|
| 272 |
+
|
| 273 |
+
```
|
| 274 |
+
Console output:
|
| 275 |
+
|
| 276 |
+
Generating 1600 samples across 8 families...
|
| 277 |
+
Split into 8 batches of ~200 items (1 DuckDB init per batch)
|
| 278 |
+
|
| 279 |
+
Progress: 200/1600 samples (1/8 batches) # Worker 1 done
|
| 280 |
+
Progress: 400/1600 samples (2/8 batches) # Worker 2 done
|
| 281 |
+
Progress: 600/1600 samples (3/8 batches) # Worker 3 done
|
| 282 |
+
...
|
| 283 |
+
Progress: 1600/1600 samples (8/8 batches) # All done
|
| 284 |
+
|
| 285 |
+
Results by family:
|
| 286 |
+
adjacency : 185 success / 15 failed (92.5% success rate, target: 100)
|
| 287 |
+
aggregation : 178 success / 22 failed (89.0% success rate, target: 100)
|
| 288 |
+
buffer : 192 success / 8 failed (96.0% success rate, target: 100)
|
| 289 |
+
containment : 188 success / 12 failed (94.0% success rate, target: 100)
|
| 290 |
+
direct_lookup : 200 success / 0 failed (100% success rate, target: 100)
|
| 291 |
+
intersection : 181 success / 19 failed (90.5% success rate, target: 100)
|
| 292 |
+
partial_selection : 175 success / 25 failed (87.5% success rate, target: 100)
|
| 293 |
+
set_operations : 190 success / 10 failed (95.0% success rate, target: 100)
|
| 294 |
+
|
| 295 |
+
Total: 1,489 valid samples from 1,600 attempts
|
| 296 |
+
```
|
| 297 |
+
|
| 298 |
+
---
|
| 299 |
+
|
| 300 |
+
## Step 3: Validate Dataset (Optimized)
|
| 301 |
+
|
| 302 |
+
**Script:** `validate_dataset.py`
|
| 303 |
+
|
| 304 |
+
**What it does:** Validates samples in parallel, but **skips SQL re-execution** for samples with `sql_verified: True`.
|
| 305 |
+
|
| 306 |
+
### Before Optimization
|
| 307 |
+
|
| 308 |
+
```
|
| 309 |
+
For each of 1,489 samples:
|
| 310 |
+
- Execute SQL to verify (30ms)
|
| 311 |
+
- Validate candidates (1ms)
|
| 312 |
+
- Check question (1ms)
|
| 313 |
+
|
| 314 |
+
Total: 1,489 × 32ms = 47.6 seconds
|
| 315 |
+
```
|
| 316 |
+
|
| 317 |
+
### After Optimization
|
| 318 |
+
|
| 319 |
+
```
|
| 320 |
+
For each of 1,489 samples:
|
| 321 |
+
- Check metadata.sql_verified flag
|
| 322 |
+
- IF True: skip SQL execution (saved 30ms)
|
| 323 |
+
- Validate candidates (1ms)
|
| 324 |
+
- Check question (1ms)
|
| 325 |
+
|
| 326 |
+
Total: 1,489 × 2ms = 3.0 seconds
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
**Speedup:** 47.6s → 3.0s = **~16x faster**
|
| 330 |
+
|
| 331 |
+
**Parallelization:** 8 workers process samples in parallel batches
|
| 332 |
+
|
| 333 |
+
---
|
| 334 |
+
|
| 335 |
+
## Step 4: Export Splits
|
| 336 |
+
|
| 337 |
+
**Script:** `export_training_data.py`
|
| 338 |
+
|
| 339 |
+
**What it does:** Stratified split into train/val/test (80/10/10) by task family.
|
| 340 |
+
|
| 341 |
+
**Input:** `dataset_validated.jsonl` (1,489 samples)
|
| 342 |
+
|
| 343 |
+
**Process:**
|
| 344 |
+
```python
|
| 345 |
+
# Group by family
|
| 346 |
+
adjacency_samples: 185
|
| 347 |
+
aggregation_samples: 178
|
| 348 |
+
buffer_samples: 192
|
| 349 |
+
# ... etc
|
| 350 |
+
|
| 351 |
+
# Split each family 80/10/10
|
| 352 |
+
adjacency_train: 148, val: 19, test: 18
|
| 353 |
+
aggregation_train: 142, val: 18, test: 18
|
| 354 |
+
# ... etc
|
| 355 |
+
|
| 356 |
+
# Combine and shuffle
|
| 357 |
+
train: 1,191 samples
|
| 358 |
+
val: 149 samples
|
| 359 |
+
test: 149 samples
|
| 360 |
+
```
|
| 361 |
+
|
| 362 |
+
**Output:**
|
| 363 |
+
- `output/train.jsonl` (1,191 samples)
|
| 364 |
+
- `output/val.jsonl` (149 samples)
|
| 365 |
+
- `output/test.jsonl` (149 samples)
|
| 366 |
+
|
| 367 |
+
**Parallelization:** None needed (fast enough)
|
| 368 |
+
|
| 369 |
+
---
|
| 370 |
+
|
| 371 |
+
## Overall Pipeline Timing
|
| 372 |
+
|
| 373 |
+
### Before Optimizations
|
| 374 |
+
```
|
| 375 |
+
Step 0: Build Inventory : 5s (one-time)
|
| 376 |
+
Step 1: Build Relations : 93s (sequential)
|
| 377 |
+
Step 2: Generate Samples : 320s (160s overhead + 160s generation)
|
| 378 |
+
Step 3: Validate Dataset : 48s (re-executing all SQL)
|
| 379 |
+
Step 4: Export Splits : 2s
|
| 380 |
+
------
|
| 381 |
+
Total : 468s (~7.8 minutes)
|
| 382 |
+
```
|
| 383 |
+
|
| 384 |
+
### After Optimizations
|
| 385 |
+
```
|
| 386 |
+
Step 0: Build Inventory : 5s (one-time)
|
| 387 |
+
Step 1: Build Relations : 60s (parallel, limited by slowest task)
|
| 388 |
+
Step 2: Generate Samples : 165s (0.11s overhead + 165s generation)
|
| 389 |
+
Step 3: Validate Dataset : 3s (skips SQL re-execution)
|
| 390 |
+
Step 4: Export Splits : 2s
|
| 391 |
+
------
|
| 392 |
+
Total : 235s (~3.9 minutes)
|
| 393 |
+
```
|
| 394 |
+
|
| 395 |
+
**Overall speedup:** 468s → 235s = **~2x faster**
|
| 396 |
+
|
| 397 |
+
**At 10K scale (100x more samples):**
|
| 398 |
+
- Before: ~780 minutes (13 hours)
|
| 399 |
+
- After: ~390 minutes (6.5 hours)
|
| 400 |
+
- With further optimizations (sampling without replacement, better caching): **<2 hours**
|
| 401 |
+
|
| 402 |
+
---
|
| 403 |
+
|
| 404 |
+
## Key Optimizations Summary
|
| 405 |
+
|
| 406 |
+
| Optimization | Impact | Where |
|
| 407 |
+
|-------------|--------|-------|
|
| 408 |
+
| **Batch workers** | 1,450x on init overhead | `generate_samples.py` |
|
| 409 |
+
| **Parallel relations** | 1.5x on relation building | `build_relations.py` |
|
| 410 |
+
| **Jaro-Winkler** | 2-3x on distractor search | `generate_samples.py` |
|
| 411 |
+
| **Skip SQL re-validation** | 16x on validation | `validate_dataset.py` |
|
| 412 |
+
| **Drop individual JSON files** | 1.2x on I/O | `generate_samples.py` |
|
| 413 |
+
|
| 414 |
+
**Combined:** Enables scaling from hundreds to tens of thousands of samples efficiently.
|
dataset/README.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset Generation CLI
|
| 2 |
+
|
| 3 |
+
Generate synthetic text-to-SQL training datasets.
|
| 4 |
+
|
| 5 |
+
## Quick Start
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
# Install
|
| 9 |
+
uv sync
|
| 10 |
+
|
| 11 |
+
# Generate dataset
|
| 12 |
+
gazet-dataset full-pipeline --config dataset/config.yaml
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Configuration
|
| 16 |
+
|
| 17 |
+
Edit `dataset/config.yaml`:
|
| 18 |
+
|
| 19 |
+
```yaml
|
| 20 |
+
countries:
|
| 21 |
+
- EC # Ecuador
|
| 22 |
+
- BE # Belgium
|
| 23 |
+
- KE # Kenya
|
| 24 |
+
- AE # UAE
|
| 25 |
+
- SG # Singapore
|
| 26 |
+
- CH # Switzerland
|
| 27 |
+
|
| 28 |
+
sample_targets:
|
| 29 |
+
direct_lookup: 100
|
| 30 |
+
adjacency: 100
|
| 31 |
+
containment: 100
|
| 32 |
+
intersection: 100
|
| 33 |
+
buffer: 100
|
| 34 |
+
set_operations: 100
|
| 35 |
+
partial_selection: 100
|
| 36 |
+
aggregation: 100
|
| 37 |
+
|
| 38 |
+
generation:
|
| 39 |
+
max_workers: 8
|
| 40 |
+
retry_multiplier: 2
|
| 41 |
+
append_mode: true
|
| 42 |
+
|
| 43 |
+
auto_scaling:
|
| 44 |
+
safety_factor: 1.5 # Auto-calculates relation limits
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## Growing Your Dataset
|
| 48 |
+
|
| 49 |
+
### Start Small
|
| 50 |
+
```bash
|
| 51 |
+
# Generate initial dataset (e.g., 100 samples)
|
| 52 |
+
gazet-dataset full-pipeline --config dataset/config.yaml
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### Add More Samples (Same Countries)
|
| 56 |
+
```bash
|
| 57 |
+
# Increase sample_targets in config.yaml, then:
|
| 58 |
+
gazet-dataset full-pipeline --config dataset/config.yaml --append
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### Add New Countries
|
| 62 |
+
```bash
|
| 63 |
+
# Add countries to config.yaml, then:
|
| 64 |
+
gazet-dataset full-pipeline --config dataset/config.yaml --append
|
| 65 |
+
# Auto-rebuilds relations if countries changed
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### Scale to 10K+
|
| 69 |
+
```yaml
|
| 70 |
+
# config.yaml - increase targets
|
| 71 |
+
sample_targets:
|
| 72 |
+
adjacency: 1000
|
| 73 |
+
containment: 1000
|
| 74 |
+
intersection: 1000
|
| 75 |
+
# ... etc
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
gazet-dataset full-pipeline --config dataset/config.yaml --append
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## Commands
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
gazet-dataset full-pipeline --config <path> # Run everything
|
| 86 |
+
gazet-dataset build-relations --config <path> # Build spatial relations
|
| 87 |
+
gazet-dataset generate-samples --config <path> # Generate samples
|
| 88 |
+
gazet-dataset validate --config <path> # Validate dataset
|
| 89 |
+
gazet-dataset export --config <path> # Export train/val/test
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
**Options:**
|
| 93 |
+
- `--append`: Add to existing dataset instead of overwriting
|
| 94 |
+
|
| 95 |
+
## Output
|
| 96 |
+
|
| 97 |
+
- `dataset/output/dataset_raw.jsonl` - Generated samples
|
| 98 |
+
- `dataset/output/dataset_validated.jsonl` - Validated samples
|
| 99 |
+
- `dataset/output/train.jsonl` - Training split
|
| 100 |
+
- `dataset/output/val.jsonl` - Validation split
|
| 101 |
+
- `dataset/output/test.jsonl` - Test split
|
| 102 |
+
|
| 103 |
+
## Tips
|
| 104 |
+
|
| 105 |
+
- Start with 2-3 countries and small sample targets
|
| 106 |
+
- Use `--append` to grow dataset incrementally
|
| 107 |
+
- Relation limits auto-calculate from sample targets
|
| 108 |
+
- Check success rates in output summary
|
dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Synthetic dataset generation package."""
|
dataset/config.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset Generation Configuration
|
| 2 |
+
# This config controls which countries to process and how many samples to generate
|
| 3 |
+
|
| 4 |
+
# Countries to include in relation building
|
| 5 |
+
# Use ISO 3166-1 alpha-2 codes
|
| 6 |
+
countries:
|
| 7 |
+
# - EC # Ecuador
|
| 8 |
+
# - BE # Belgium
|
| 9 |
+
# - KE # Kenya
|
| 10 |
+
# - AE # UAE
|
| 11 |
+
# - SG # Singapore
|
| 12 |
+
# - CH # Switzerland
|
| 13 |
+
- IN # India
|
| 14 |
+
- PK # Pakistan
|
| 15 |
+
# - SL # Sri Lanka
|
| 16 |
+
# - BD # Bangladesh
|
| 17 |
+
|
| 18 |
+
# Sample generation targets per family
|
| 19 |
+
# Relation limits are auto-calculated from these targets
|
| 20 |
+
sample_targets:
|
| 21 |
+
direct_lookup: 20
|
| 22 |
+
adjacency: 20
|
| 23 |
+
containment: 20
|
| 24 |
+
intersection: 20
|
| 25 |
+
buffer: 20
|
| 26 |
+
set_operations: 20
|
| 27 |
+
partial_selection: 20
|
| 28 |
+
aggregation: 20
|
| 29 |
+
|
| 30 |
+
# Generation settings
|
| 31 |
+
generation:
|
| 32 |
+
max_workers: 8 # Number of parallel workers
|
| 33 |
+
retry_multiplier: 2 # Generate 2x samples to account for failures
|
| 34 |
+
append_mode: true # If true, append to existing dataset instead of overwriting
|
| 35 |
+
|
| 36 |
+
# Auto-scaling configuration
|
| 37 |
+
# Relation limits are automatically calculated: target * retry_multiplier * safety_factor
|
| 38 |
+
auto_scaling:
|
| 39 |
+
safety_factor: 1.5 # Extra buffer to ensure enough unique pairs
|
| 40 |
+
|
| 41 |
+
# Manual overrides (optional) - uncomment to override auto-calculated limits
|
| 42 |
+
manual_limits: {}
|
| 43 |
+
# adjacency: 10000 # Uncomment to manually set
|
| 44 |
+
# containment: 2000
|
| 45 |
+
# intersection: 1000
|
| 46 |
+
# cross_source: 500
|
| 47 |
+
|
| 48 |
+
# Output paths (relative to dataset directory)
|
| 49 |
+
output:
|
| 50 |
+
samples_dir: "output/samples"
|
| 51 |
+
dataset_file: "output/dataset_raw.jsonl"
|
| 52 |
+
intermediate_dir: "intermediate"
|
dataset/scripts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Dataset generation scripts package."""
|
dataset/scripts/build_inventory.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build entity inventory from divisions_area and natural_earth parquet files.
|
| 3 |
+
|
| 4 |
+
This script creates compact inventory tables containing only the fields needed
|
| 5 |
+
for candidate sampling and distractor generation.
|
| 6 |
+
|
| 7 |
+
Output:
|
| 8 |
+
- intermediate/divisions_area_inventory.parquet
|
| 9 |
+
- intermediate/natural_earth_inventory.parquet
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_divisions_area_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFrame:
|
| 20 |
+
"""Extract compact inventory from divisions_area."""
|
| 21 |
+
query = """
|
| 22 |
+
SELECT
|
| 23 |
+
'divisions_area' AS source,
|
| 24 |
+
id,
|
| 25 |
+
names."primary" AS name,
|
| 26 |
+
subtype,
|
| 27 |
+
country,
|
| 28 |
+
region,
|
| 29 |
+
admin_level,
|
| 30 |
+
class,
|
| 31 |
+
is_land,
|
| 32 |
+
is_territorial,
|
| 33 |
+
division_id,
|
| 34 |
+
ST_Area(geometry) AS area_sq_deg,
|
| 35 |
+
ST_XMin(geometry) AS xmin,
|
| 36 |
+
ST_YMin(geometry) AS ymin,
|
| 37 |
+
ST_XMax(geometry) AS xmax,
|
| 38 |
+
ST_YMax(geometry) AS ymax
|
| 39 |
+
FROM read_parquet(?)
|
| 40 |
+
WHERE names."primary" IS NOT NULL
|
| 41 |
+
AND trim(names."primary") != ''
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH]).fetchdf()
|
| 45 |
+
print(f"Divisions area inventory: {len(df)} entities")
|
| 46 |
+
print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
|
| 47 |
+
print(f"Countries: {df['country'].nunique()} unique")
|
| 48 |
+
|
| 49 |
+
return df
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_natural_earth_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFrame:
|
| 53 |
+
"""Extract compact inventory from natural_earth."""
|
| 54 |
+
query = """
|
| 55 |
+
SELECT
|
| 56 |
+
'natural_earth' AS source,
|
| 57 |
+
id,
|
| 58 |
+
names."primary" AS name,
|
| 59 |
+
subtype,
|
| 60 |
+
country,
|
| 61 |
+
region,
|
| 62 |
+
admin_level,
|
| 63 |
+
class,
|
| 64 |
+
is_land,
|
| 65 |
+
is_territorial,
|
| 66 |
+
ST_Area(geometry) AS area_sq_deg,
|
| 67 |
+
ST_XMin(geometry) AS xmin,
|
| 68 |
+
ST_YMin(geometry) AS ymin,
|
| 69 |
+
ST_XMax(geometry) AS xmax,
|
| 70 |
+
ST_YMax(geometry) AS ymax
|
| 71 |
+
FROM read_parquet(?)
|
| 72 |
+
WHERE names."primary" IS NOT NULL
|
| 73 |
+
AND trim(names."primary") != ''
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
df = con.execute(query, [NATURAL_EARTH_PATH]).fetchdf()
|
| 77 |
+
print(f"\nNatural earth inventory: {len(df)} entities")
|
| 78 |
+
print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
|
| 79 |
+
|
| 80 |
+
return df
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def main():
|
| 84 |
+
"""Build and save inventory tables."""
|
| 85 |
+
output_dir = Path(__file__).parent.parent / "intermediate"
|
| 86 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 87 |
+
|
| 88 |
+
con = duckdb.connect()
|
| 89 |
+
con.execute("INSTALL spatial")
|
| 90 |
+
con.execute("LOAD spatial")
|
| 91 |
+
|
| 92 |
+
print("Building divisions_area inventory...")
|
| 93 |
+
divisions_df = build_divisions_area_inventory(con)
|
| 94 |
+
divisions_path = output_dir / "divisions_area_inventory.parquet"
|
| 95 |
+
divisions_df.to_parquet(divisions_path, index=False)
|
| 96 |
+
print(f"Saved to {divisions_path}")
|
| 97 |
+
|
| 98 |
+
print("\nBuilding natural_earth inventory...")
|
| 99 |
+
natural_earth_df = build_natural_earth_inventory(con)
|
| 100 |
+
natural_earth_path = output_dir / "natural_earth_inventory.parquet"
|
| 101 |
+
natural_earth_df.to_parquet(natural_earth_path, index=False)
|
| 102 |
+
print(f"Saved to {natural_earth_path}")
|
| 103 |
+
|
| 104 |
+
con.close()
|
| 105 |
+
|
| 106 |
+
print("\n✓ Inventory build complete")
|
| 107 |
+
print(f" Total entities: {len(divisions_df) + len(natural_earth_df)}")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
main()
|
dataset/scripts/build_relations.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Precompute spatial relation tables for efficient anchor sampling.
|
| 3 |
+
|
| 4 |
+
This script computes:
|
| 5 |
+
- Adjacency pairs (touching features)
|
| 6 |
+
- Containment pairs (features within other features)
|
| 7 |
+
- Intersection pairs (overlapping features)
|
| 8 |
+
- Cross-source relations (divisions_area ↔ natural_earth)
|
| 9 |
+
|
| 10 |
+
Output:
|
| 11 |
+
- intermediate/adjacency_pairs.parquet
|
| 12 |
+
- intermediate/containment_pairs.parquet
|
| 13 |
+
- intermediate/intersection_pairs.parquet
|
| 14 |
+
- intermediate/cross_source_relations.parquet
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import duckdb
|
| 18 |
+
import pandas as pd
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 21 |
+
|
| 22 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def compute_adjacency_pairs(
|
| 26 |
+
con: duckdb.DuckDBPyConnection,
|
| 27 |
+
countries: list,
|
| 28 |
+
limit: int
|
| 29 |
+
) -> pd.DataFrame:
|
| 30 |
+
"""Find all pairs of features that touch (share a boundary)."""
|
| 31 |
+
print("Computing adjacency pairs (optimized with spatial index)...")
|
| 32 |
+
|
| 33 |
+
# Use bounding box pre-filter to avoid full cartesian product
|
| 34 |
+
query = """
|
| 35 |
+
WITH features AS (
|
| 36 |
+
SELECT
|
| 37 |
+
id,
|
| 38 |
+
names."primary" AS name,
|
| 39 |
+
subtype,
|
| 40 |
+
country,
|
| 41 |
+
admin_level,
|
| 42 |
+
geometry,
|
| 43 |
+
ST_Envelope(geometry) AS bbox
|
| 44 |
+
FROM read_parquet(?)
|
| 45 |
+
WHERE country IN (SELECT unnest(?))
|
| 46 |
+
)
|
| 47 |
+
SELECT
|
| 48 |
+
a.id AS anchor_id,
|
| 49 |
+
a.name AS anchor_name,
|
| 50 |
+
a.subtype AS anchor_subtype,
|
| 51 |
+
a.country AS anchor_country,
|
| 52 |
+
b.id AS target_id,
|
| 53 |
+
b.name AS target_name,
|
| 54 |
+
b.subtype AS target_subtype,
|
| 55 |
+
b.country AS target_country,
|
| 56 |
+
'adjacency' AS relation_type
|
| 57 |
+
FROM features AS a
|
| 58 |
+
JOIN features AS b ON (
|
| 59 |
+
a.id < b.id
|
| 60 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 61 |
+
AND ST_Touches(a.geometry, b.geometry)
|
| 62 |
+
)
|
| 63 |
+
LIMIT ?
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH, countries, limit]).fetchdf()
|
| 67 |
+
print(f"Found {len(df)} adjacency pairs")
|
| 68 |
+
|
| 69 |
+
return df
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def compute_containment_pairs(
|
| 73 |
+
con: duckdb.DuckDBPyConnection,
|
| 74 |
+
countries: list,
|
| 75 |
+
limit: int
|
| 76 |
+
) -> pd.DataFrame:
|
| 77 |
+
"""Find all pairs where one feature contains another."""
|
| 78 |
+
print("\nComputing containment pairs (optimized)...")
|
| 79 |
+
|
| 80 |
+
query = """
|
| 81 |
+
WITH features AS (
|
| 82 |
+
SELECT
|
| 83 |
+
id,
|
| 84 |
+
names."primary" AS name,
|
| 85 |
+
subtype,
|
| 86 |
+
country,
|
| 87 |
+
admin_level,
|
| 88 |
+
geometry,
|
| 89 |
+
ST_Envelope(geometry) AS bbox
|
| 90 |
+
FROM read_parquet(?)
|
| 91 |
+
WHERE country IN (SELECT unnest(?))
|
| 92 |
+
)
|
| 93 |
+
SELECT
|
| 94 |
+
a.id AS container_id,
|
| 95 |
+
a.name AS container_name,
|
| 96 |
+
a.subtype AS container_subtype,
|
| 97 |
+
b.id AS contained_id,
|
| 98 |
+
b.name AS contained_name,
|
| 99 |
+
b.subtype AS contained_subtype,
|
| 100 |
+
'containment' AS relation_type
|
| 101 |
+
FROM features AS a
|
| 102 |
+
JOIN features AS b ON (
|
| 103 |
+
a.id != b.id
|
| 104 |
+
AND a.admin_level < b.admin_level
|
| 105 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 106 |
+
AND ST_Within(b.geometry, a.geometry)
|
| 107 |
+
)
|
| 108 |
+
LIMIT ?
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH, countries, limit]).fetchdf()
|
| 112 |
+
print(f"Found {len(df)} containment pairs")
|
| 113 |
+
|
| 114 |
+
return df
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def compute_intersection_pairs(
|
| 118 |
+
con: duckdb.DuckDBPyConnection,
|
| 119 |
+
countries: list,
|
| 120 |
+
limit: int
|
| 121 |
+
) -> pd.DataFrame:
|
| 122 |
+
"""Find pairs that intersect but don't touch or contain."""
|
| 123 |
+
print("\nComputing intersection pairs (optimized)...")
|
| 124 |
+
|
| 125 |
+
query = """
|
| 126 |
+
WITH features AS (
|
| 127 |
+
SELECT
|
| 128 |
+
id,
|
| 129 |
+
names."primary" AS name,
|
| 130 |
+
subtype,
|
| 131 |
+
country,
|
| 132 |
+
admin_level,
|
| 133 |
+
geometry,
|
| 134 |
+
ST_Envelope(geometry) AS bbox
|
| 135 |
+
FROM read_parquet(?)
|
| 136 |
+
WHERE country IN (SELECT unnest(?))
|
| 137 |
+
)
|
| 138 |
+
SELECT
|
| 139 |
+
a.id AS anchor_id,
|
| 140 |
+
a.name AS anchor_name,
|
| 141 |
+
a.subtype AS anchor_subtype,
|
| 142 |
+
b.id AS target_id,
|
| 143 |
+
b.name AS target_name,
|
| 144 |
+
b.subtype AS target_subtype,
|
| 145 |
+
'intersection' AS relation_type
|
| 146 |
+
FROM features AS a
|
| 147 |
+
JOIN features AS b ON (
|
| 148 |
+
a.id < b.id
|
| 149 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 150 |
+
AND ST_Intersects(a.geometry, b.geometry)
|
| 151 |
+
AND NOT ST_Touches(a.geometry, b.geometry)
|
| 152 |
+
AND NOT ST_Within(a.geometry, b.geometry)
|
| 153 |
+
AND NOT ST_Within(b.geometry, a.geometry)
|
| 154 |
+
)
|
| 155 |
+
LIMIT ?
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH, countries, limit]).fetchdf()
|
| 159 |
+
print(f"Found {len(df)} same-source intersection pairs")
|
| 160 |
+
|
| 161 |
+
return df
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def compute_cross_source_relations(
|
| 165 |
+
con: duckdb.DuckDBPyConnection,
|
| 166 |
+
countries: list,
|
| 167 |
+
limit: int
|
| 168 |
+
) -> pd.DataFrame:
|
| 169 |
+
"""Find relations between divisions_area and natural_earth."""
|
| 170 |
+
print("\nComputing cross-source relations...")
|
| 171 |
+
|
| 172 |
+
query = """
|
| 173 |
+
WITH divisions AS (
|
| 174 |
+
SELECT
|
| 175 |
+
id,
|
| 176 |
+
names."primary" AS name,
|
| 177 |
+
subtype,
|
| 178 |
+
country,
|
| 179 |
+
geometry
|
| 180 |
+
FROM read_parquet(?)
|
| 181 |
+
WHERE country IN (SELECT unnest(?))
|
| 182 |
+
),
|
| 183 |
+
natural_features AS (
|
| 184 |
+
SELECT
|
| 185 |
+
id,
|
| 186 |
+
names."primary" AS name,
|
| 187 |
+
subtype,
|
| 188 |
+
geometry
|
| 189 |
+
FROM read_parquet(?)
|
| 190 |
+
WHERE subtype IN ('sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay')
|
| 191 |
+
LIMIT 200
|
| 192 |
+
)
|
| 193 |
+
SELECT
|
| 194 |
+
d.id AS division_id,
|
| 195 |
+
d.name AS division_name,
|
| 196 |
+
d.subtype AS division_subtype,
|
| 197 |
+
d.country AS division_country,
|
| 198 |
+
n.id AS natural_id,
|
| 199 |
+
n.name AS natural_name,
|
| 200 |
+
n.subtype AS natural_subtype,
|
| 201 |
+
CASE
|
| 202 |
+
WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches'
|
| 203 |
+
WHEN ST_Within(d.geometry, n.geometry) THEN 'within'
|
| 204 |
+
WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains'
|
| 205 |
+
WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects'
|
| 206 |
+
END AS relation_type
|
| 207 |
+
FROM divisions AS d
|
| 208 |
+
JOIN natural_features AS n ON ST_Intersects(d.geometry, n.geometry)
|
| 209 |
+
LIMIT ?
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH, countries, NATURAL_EARTH_PATH, limit]).fetchdf()
|
| 213 |
+
print(f"Found {len(df)} cross-source relations")
|
| 214 |
+
|
| 215 |
+
return df
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def _make_connection():
|
| 219 |
+
"""Create a new DuckDB connection with spatial extension loaded."""
|
| 220 |
+
con = duckdb.connect()
|
| 221 |
+
con.execute("INSTALL spatial")
|
| 222 |
+
con.execute("LOAD spatial")
|
| 223 |
+
return con
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _compute_and_save(compute_fn, countries, limit, output_path):
|
| 227 |
+
"""Compute a relation table and save it to parquet. Uses its own DuckDB connection."""
|
| 228 |
+
con = _make_connection()
|
| 229 |
+
try:
|
| 230 |
+
df = compute_fn(con, countries, limit)
|
| 231 |
+
df.to_parquet(output_path, index=False)
|
| 232 |
+
print(f"Saved to {output_path}")
|
| 233 |
+
return df
|
| 234 |
+
finally:
|
| 235 |
+
con.close()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def main(countries: list = None, relation_limits: dict = None):
|
| 239 |
+
"""Compute and save all relation tables in parallel.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
countries: List of country codes to process
|
| 243 |
+
relation_limits: Dict with keys: adjacency, containment, intersection, cross_source
|
| 244 |
+
"""
|
| 245 |
+
# Defaults
|
| 246 |
+
if countries is None:
|
| 247 |
+
countries = ['EC', 'BE', 'KE', 'AE', 'SG', 'CH']
|
| 248 |
+
if relation_limits is None:
|
| 249 |
+
relation_limits = {
|
| 250 |
+
'adjacency': 50000,
|
| 251 |
+
'containment': 1000,
|
| 252 |
+
'intersection': 500,
|
| 253 |
+
'cross_source': 500
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
output_dir = Path(__file__).parent.parent / "intermediate"
|
| 257 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 258 |
+
|
| 259 |
+
# Define all relation tasks
|
| 260 |
+
tasks = [
|
| 261 |
+
("adjacency", compute_adjacency_pairs, relation_limits['adjacency'], output_dir / "adjacency_pairs.parquet"),
|
| 262 |
+
("containment", compute_containment_pairs, relation_limits['containment'], output_dir / "containment_pairs.parquet"),
|
| 263 |
+
("intersection", compute_intersection_pairs, relation_limits['intersection'], output_dir / "intersection_pairs.parquet"),
|
| 264 |
+
("cross_source", compute_cross_source_relations, relation_limits['cross_source'], output_dir / "cross_source_relations.parquet"),
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
print(f"Computing {len(tasks)} relation types in parallel...")
|
| 268 |
+
|
| 269 |
+
# Run all relation computations concurrently
|
| 270 |
+
with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
|
| 271 |
+
futures = {
|
| 272 |
+
executor.submit(_compute_and_save, compute_fn, countries, limit, path): name
|
| 273 |
+
for name, compute_fn, limit, path in tasks
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
for future in as_completed(futures):
|
| 277 |
+
name = futures[future]
|
| 278 |
+
try:
|
| 279 |
+
future.result()
|
| 280 |
+
except Exception as e:
|
| 281 |
+
print(f"ERROR computing {name}: {e}")
|
| 282 |
+
raise
|
| 283 |
+
|
| 284 |
+
print("\n✓ Relation tables build complete")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
main()
|
dataset/scripts/cli.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CLI for synthetic dataset generation.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python cli.py build-relations --config ../config.yaml
|
| 7 |
+
python cli.py generate-samples --config ../config.yaml
|
| 8 |
+
python cli.py generate-samples --config ../config.yaml --append
|
| 9 |
+
python cli.py full-pipeline --config ../config.yaml
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import yaml
|
| 16 |
+
import subprocess
|
| 17 |
+
from typing import Dict, Set
|
| 18 |
+
import pandas as pd
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_config(config_path: Path) -> dict:
|
| 22 |
+
"""Load configuration from YAML file."""
|
| 23 |
+
with open(config_path) as f:
|
| 24 |
+
return yaml.safe_load(f)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def should_rebuild_relations(config: dict, intermediate_dir: Path, append: bool) -> bool:
|
| 28 |
+
"""Check if relation tables need to be rebuilt.
|
| 29 |
+
|
| 30 |
+
Returns True if:
|
| 31 |
+
- Not in append mode (always rebuild)
|
| 32 |
+
- Relation tables don't exist
|
| 33 |
+
- Countries in config differ from countries in existing relation tables
|
| 34 |
+
"""
|
| 35 |
+
if not append:
|
| 36 |
+
return True
|
| 37 |
+
|
| 38 |
+
# Check if relation tables exist
|
| 39 |
+
adjacency_file = intermediate_dir / "adjacency_pairs.parquet"
|
| 40 |
+
if not adjacency_file.exists():
|
| 41 |
+
print("WARNING: Relation tables not found, will rebuild despite append mode")
|
| 42 |
+
return True
|
| 43 |
+
|
| 44 |
+
# Check if countries have changed
|
| 45 |
+
try:
|
| 46 |
+
df = pd.read_parquet(adjacency_file)
|
| 47 |
+
if 'anchor_country' in df.columns:
|
| 48 |
+
existing_countries = set(df['anchor_country'].unique())
|
| 49 |
+
config_countries = set(config['countries'])
|
| 50 |
+
|
| 51 |
+
if existing_countries != config_countries:
|
| 52 |
+
print(f"WARNING: Countries changed:")
|
| 53 |
+
print(f" Previous: {sorted(existing_countries)}")
|
| 54 |
+
print(f" New: {sorted(config_countries)}")
|
| 55 |
+
print(f" Will rebuild relation tables to include new countries")
|
| 56 |
+
return True
|
| 57 |
+
else:
|
| 58 |
+
print(f"Countries unchanged: {sorted(config_countries)}")
|
| 59 |
+
return False
|
| 60 |
+
else:
|
| 61 |
+
# Can't determine countries, rebuild to be safe
|
| 62 |
+
print("WARNING: Cannot determine countries from existing tables, will rebuild")
|
| 63 |
+
return True
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"WARNING: Error reading existing relation tables: {e}")
|
| 66 |
+
print(" Will rebuild to be safe")
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def calculate_relation_limits(config: dict) -> Dict[str, int]:
|
| 71 |
+
"""Auto-calculate relation limits based on sample targets."""
|
| 72 |
+
sample_targets = config['sample_targets']
|
| 73 |
+
retry_mult = config['generation']['retry_multiplier']
|
| 74 |
+
safety = config.get('auto_scaling', {}).get('safety_factor', 1.5)
|
| 75 |
+
|
| 76 |
+
# Map task families to relation types they need
|
| 77 |
+
family_to_relation = {
|
| 78 |
+
'adjacency': 'adjacency',
|
| 79 |
+
'containment': 'containment',
|
| 80 |
+
'intersection': 'intersection',
|
| 81 |
+
'buffer': 'adjacency', # Buffer uses adjacency pairs
|
| 82 |
+
'set_operations': 'intersection', # Set ops use intersection pairs
|
| 83 |
+
'partial_selection': 'containment', # Partial uses containment
|
| 84 |
+
'aggregation': 'containment', # Aggregation uses containment
|
| 85 |
+
'direct_lookup': None, # Uses inventory only
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
# Calculate required limits by summing needs per relation type
|
| 89 |
+
relation_needs = {}
|
| 90 |
+
for family, target in sample_targets.items():
|
| 91 |
+
relation_type = family_to_relation.get(family)
|
| 92 |
+
if relation_type:
|
| 93 |
+
needed = int(target * retry_mult * safety)
|
| 94 |
+
relation_needs[relation_type] = relation_needs.get(relation_type, 0) + needed
|
| 95 |
+
|
| 96 |
+
# Add cross-source (used by mixed-source partial selection)
|
| 97 |
+
partial_target = sample_targets.get('partial_selection', 0)
|
| 98 |
+
relation_needs['cross_source'] = int(partial_target * retry_mult * safety * 0.3)
|
| 99 |
+
|
| 100 |
+
# Apply manual overrides if specified
|
| 101 |
+
manual = config.get('auto_scaling', {}).get('manual_limits', {})
|
| 102 |
+
relation_needs.update(manual)
|
| 103 |
+
|
| 104 |
+
return relation_needs
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_relations(config_path: Path):
|
| 108 |
+
"""Run relation building with config."""
|
| 109 |
+
config = load_config(config_path)
|
| 110 |
+
|
| 111 |
+
# Auto-calculate relation limits
|
| 112 |
+
relation_limits = calculate_relation_limits(config)
|
| 113 |
+
|
| 114 |
+
print("=" * 60)
|
| 115 |
+
print("STEP 1: Building Relation Tables")
|
| 116 |
+
print("=" * 60)
|
| 117 |
+
print(f"Countries: {', '.join(config['countries'])}")
|
| 118 |
+
print(f"\nAuto-calculated relation limits:")
|
| 119 |
+
for rel_type, limit in relation_limits.items():
|
| 120 |
+
print(f" {rel_type:20s}: {limit:,}")
|
| 121 |
+
print()
|
| 122 |
+
|
| 123 |
+
# Import and run the relation builder
|
| 124 |
+
from dataset.scripts import build_relations
|
| 125 |
+
|
| 126 |
+
# Run with config parameters
|
| 127 |
+
build_relations.main(
|
| 128 |
+
countries=config['countries'],
|
| 129 |
+
relation_limits=relation_limits
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
print("\n✓ Relation tables built successfully")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def generate_samples(config_path: Path, append: bool = False):
|
| 136 |
+
"""Run sample generation with config."""
|
| 137 |
+
config = load_config(config_path)
|
| 138 |
+
|
| 139 |
+
print("=" * 60)
|
| 140 |
+
print("STEP 2: Generating Samples")
|
| 141 |
+
print("=" * 60)
|
| 142 |
+
print(f"Targets: {config['sample_targets']}")
|
| 143 |
+
print(f"Workers: {config['generation']['max_workers']}")
|
| 144 |
+
print(f"Append mode: {append or config['generation']['append_mode']}")
|
| 145 |
+
print()
|
| 146 |
+
|
| 147 |
+
# Simple import - no number prefixes needed
|
| 148 |
+
from dataset.scripts import generate_samples as gs_module
|
| 149 |
+
|
| 150 |
+
# Override config values
|
| 151 |
+
gs_module.TARGET_COUNTS = config['sample_targets']
|
| 152 |
+
gs_module.MAX_WORKERS = config['generation']['max_workers']
|
| 153 |
+
gs_module.RETRY_MULTIPLIER = config['generation']['retry_multiplier']
|
| 154 |
+
gs_module.APPEND_MODE = append or config['generation']['append_mode']
|
| 155 |
+
|
| 156 |
+
# Run the main function
|
| 157 |
+
gs_module.main()
|
| 158 |
+
|
| 159 |
+
print("\n✓ Samples generated successfully")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def validate_dataset(config_path: Path):
|
| 163 |
+
"""Run dataset validation."""
|
| 164 |
+
print("=" * 60)
|
| 165 |
+
print("STEP 3: Validating Dataset")
|
| 166 |
+
print("=" * 60)
|
| 167 |
+
|
| 168 |
+
script_dir = Path(__file__).parent
|
| 169 |
+
result = subprocess.run(
|
| 170 |
+
[sys.executable, str(script_dir / "validate_dataset.py")],
|
| 171 |
+
check=True
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
print("\n✓ Dataset validated successfully")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def export_dataset(config_path: Path):
|
| 178 |
+
"""Run dataset export."""
|
| 179 |
+
print("=" * 60)
|
| 180 |
+
print("STEP 4: Exporting Dataset")
|
| 181 |
+
print("=" * 60)
|
| 182 |
+
|
| 183 |
+
script_dir = Path(__file__).parent
|
| 184 |
+
result = subprocess.run(
|
| 185 |
+
[sys.executable, str(script_dir / "export_training_data.py")],
|
| 186 |
+
check=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
print("\n✓ Dataset exported successfully")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def full_pipeline(config_path: Path, append: bool = False):
|
| 193 |
+
"""Run the full pipeline."""
|
| 194 |
+
print("\n" + "=" * 60)
|
| 195 |
+
print("RUNNING FULL DATASET GENERATION PIPELINE")
|
| 196 |
+
print("=" * 60 + "\n")
|
| 197 |
+
|
| 198 |
+
config = load_config(config_path)
|
| 199 |
+
|
| 200 |
+
# Check if inventory exists, create if not
|
| 201 |
+
script_dir = Path(__file__).parent
|
| 202 |
+
intermediate_dir = script_dir.parent / "intermediate"
|
| 203 |
+
inventory_files = [
|
| 204 |
+
intermediate_dir / "divisions_area_inventory.parquet",
|
| 205 |
+
intermediate_dir / "natural_earth_inventory.parquet"
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
inventory_missing = any(not f.exists() for f in inventory_files)
|
| 209 |
+
|
| 210 |
+
if inventory_missing:
|
| 211 |
+
print("=" * 60)
|
| 212 |
+
print("STEP 0: Building Entity Inventory")
|
| 213 |
+
print("=" * 60)
|
| 214 |
+
print("Inventory files not found. Building inventory...\n")
|
| 215 |
+
|
| 216 |
+
from dataset.scripts import build_inventory
|
| 217 |
+
build_inventory.main()
|
| 218 |
+
|
| 219 |
+
print("\n✓ Inventory built successfully\n")
|
| 220 |
+
|
| 221 |
+
# Check if we need to rebuild relations
|
| 222 |
+
need_rebuild = should_rebuild_relations(config, intermediate_dir, append)
|
| 223 |
+
|
| 224 |
+
if need_rebuild:
|
| 225 |
+
build_relations(config_path)
|
| 226 |
+
else:
|
| 227 |
+
print("Using existing relation tables (append mode, same countries)")
|
| 228 |
+
|
| 229 |
+
generate_samples(config_path, append=append)
|
| 230 |
+
validate_dataset(config_path)
|
| 231 |
+
export_dataset(config_path)
|
| 232 |
+
|
| 233 |
+
print("\n" + "=" * 60)
|
| 234 |
+
print("✓ PIPELINE COMPLETED SUCCESSFULLY")
|
| 235 |
+
print("=" * 60)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def main():
|
| 239 |
+
parser = argparse.ArgumentParser(
|
| 240 |
+
description="Synthetic dataset generation CLI",
|
| 241 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 242 |
+
epilog="""
|
| 243 |
+
Examples:
|
| 244 |
+
# Build relation tables only
|
| 245 |
+
python cli.py build-relations --config ../config.yaml
|
| 246 |
+
|
| 247 |
+
# Generate samples only
|
| 248 |
+
python cli.py generate-samples --config ../config.yaml
|
| 249 |
+
|
| 250 |
+
# Generate and append to existing dataset
|
| 251 |
+
python cli.py generate-samples --config ../config.yaml --append
|
| 252 |
+
|
| 253 |
+
# Run full pipeline
|
| 254 |
+
python cli.py full-pipeline --config ../config.yaml
|
| 255 |
+
|
| 256 |
+
# Run full pipeline in append mode (skip relation building)
|
| 257 |
+
python cli.py full-pipeline --config ../config.yaml --append
|
| 258 |
+
"""
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
'command',
|
| 263 |
+
choices=['build-relations', 'generate-samples', 'validate', 'export', 'full-pipeline'],
|
| 264 |
+
help='Command to run'
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
'--config',
|
| 269 |
+
type=Path,
|
| 270 |
+
required=True,
|
| 271 |
+
help='Path to config YAML file'
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
'--append',
|
| 276 |
+
action='store_true',
|
| 277 |
+
help='Append to existing dataset instead of overwriting'
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
args = parser.parse_args()
|
| 281 |
+
|
| 282 |
+
# Validate config file exists
|
| 283 |
+
if not args.config.exists():
|
| 284 |
+
print(f"Error: Config file not found: {args.config}")
|
| 285 |
+
sys.exit(1)
|
| 286 |
+
|
| 287 |
+
# Run the appropriate command
|
| 288 |
+
try:
|
| 289 |
+
if args.command == 'build-relations':
|
| 290 |
+
build_relations(args.config)
|
| 291 |
+
elif args.command == 'generate-samples':
|
| 292 |
+
generate_samples(args.config, args.append)
|
| 293 |
+
elif args.command == 'validate':
|
| 294 |
+
validate_dataset(args.config)
|
| 295 |
+
elif args.command == 'export':
|
| 296 |
+
export_dataset(args.config)
|
| 297 |
+
elif args.command == 'full-pipeline':
|
| 298 |
+
full_pipeline(args.config, args.append)
|
| 299 |
+
except Exception as e:
|
| 300 |
+
print(f"\n✗ Error: {e}")
|
| 301 |
+
sys.exit(1)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
main()
|
dataset/scripts/export_training_data.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Export validated dataset to train/val/test splits.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads validated samples
|
| 6 |
+
2. Splits into train (80%), val (10%), test (10%)
|
| 7 |
+
3. Ensures balanced splits across task families
|
| 8 |
+
4. Exports to JSONL format
|
| 9 |
+
|
| 10 |
+
Output:
|
| 11 |
+
- output/train.jsonl (80% of samples)
|
| 12 |
+
- output/val.jsonl (10% of samples)
|
| 13 |
+
- output/test.jsonl (10% of samples)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import random
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Dict, Any
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
|
| 24 |
+
"""Load samples from JSONL file."""
|
| 25 |
+
samples = []
|
| 26 |
+
with open(jsonl_path, 'r') as f:
|
| 27 |
+
for line in f:
|
| 28 |
+
samples.append(json.loads(line))
|
| 29 |
+
return samples
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def stratified_split(
|
| 33 |
+
samples: List[Dict[str, Any]],
|
| 34 |
+
train_ratio: float = 0.8,
|
| 35 |
+
val_ratio: float = 0.1,
|
| 36 |
+
test_ratio: float = 0.1,
|
| 37 |
+
random_seed: int = 42
|
| 38 |
+
) -> tuple[List[Dict], List[Dict], List[Dict]]:
|
| 39 |
+
"""Split samples by task family to ensure balanced distribution."""
|
| 40 |
+
|
| 41 |
+
random.seed(random_seed)
|
| 42 |
+
|
| 43 |
+
# Group by task family
|
| 44 |
+
by_family = defaultdict(list)
|
| 45 |
+
for sample in samples:
|
| 46 |
+
family = sample['metadata']['task_family']
|
| 47 |
+
by_family[family].append(sample)
|
| 48 |
+
|
| 49 |
+
train_samples = []
|
| 50 |
+
val_samples = []
|
| 51 |
+
test_samples = []
|
| 52 |
+
|
| 53 |
+
# Split each family
|
| 54 |
+
for family, family_samples in by_family.items():
|
| 55 |
+
# Shuffle
|
| 56 |
+
random.shuffle(family_samples)
|
| 57 |
+
|
| 58 |
+
n = len(family_samples)
|
| 59 |
+
n_train = int(n * train_ratio)
|
| 60 |
+
n_val = int(n * val_ratio)
|
| 61 |
+
|
| 62 |
+
train_samples.extend(family_samples[:n_train])
|
| 63 |
+
val_samples.extend(family_samples[n_train:n_train + n_val])
|
| 64 |
+
test_samples.extend(family_samples[n_train + n_val:])
|
| 65 |
+
|
| 66 |
+
# Shuffle final splits
|
| 67 |
+
random.shuffle(train_samples)
|
| 68 |
+
random.shuffle(val_samples)
|
| 69 |
+
random.shuffle(test_samples)
|
| 70 |
+
|
| 71 |
+
return train_samples, val_samples, test_samples
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def save_split(samples: List[Dict[str, Any]], output_path: Path):
|
| 75 |
+
"""Save samples to JSONL file."""
|
| 76 |
+
with open(output_path, 'w') as f:
|
| 77 |
+
for sample in samples:
|
| 78 |
+
f.write(json.dumps(sample) + '\n')
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def print_split_stats(split_name: str, samples: List[Dict[str, Any]]):
|
| 82 |
+
"""Print statistics for a split."""
|
| 83 |
+
families = defaultdict(int)
|
| 84 |
+
for sample in samples:
|
| 85 |
+
family = sample['metadata']['task_family']
|
| 86 |
+
families[family] += 1
|
| 87 |
+
|
| 88 |
+
print(f"\n{split_name}:")
|
| 89 |
+
print(f" Total: {len(samples)}")
|
| 90 |
+
for family, count in sorted(families.items()):
|
| 91 |
+
print(f" {family:20s}: {count:3d}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def print_country_stats(samples: List[Dict[str, Any]]):
|
| 95 |
+
"""Print country distribution statistics."""
|
| 96 |
+
country_counts = defaultdict(int)
|
| 97 |
+
|
| 98 |
+
# Extract countries from selected/answer candidates only
|
| 99 |
+
for sample in samples:
|
| 100 |
+
selected_ids = set(sample.get('target', {}).get('selected_candidates', []))
|
| 101 |
+
countries_in_sample = set()
|
| 102 |
+
for candidate in sample.get('candidates', []):
|
| 103 |
+
if candidate.get('candidate_id') in selected_ids:
|
| 104 |
+
country = candidate.get('country')
|
| 105 |
+
if country:
|
| 106 |
+
countries_in_sample.add(country)
|
| 107 |
+
|
| 108 |
+
# Count each unique country once per sample
|
| 109 |
+
for country in countries_in_sample:
|
| 110 |
+
country_counts[country] += 1
|
| 111 |
+
|
| 112 |
+
if not country_counts:
|
| 113 |
+
print("\nNo country information found in samples")
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
print(f"\nCOUNTRY DISTRIBUTION:")
|
| 117 |
+
print(f" Total unique countries: {len(country_counts)}")
|
| 118 |
+
print(f"\n Top countries by sample count:")
|
| 119 |
+
|
| 120 |
+
# Sort by count descending
|
| 121 |
+
sorted_countries = sorted(country_counts.items(), key=lambda x: x[1], reverse=True)
|
| 122 |
+
|
| 123 |
+
# Show top 20
|
| 124 |
+
for country, count in sorted_countries[:20]:
|
| 125 |
+
percentage = (count / len(samples)) * 100
|
| 126 |
+
print(f" {country:3s}: {count:4d} samples ({percentage:5.1f}%)")
|
| 127 |
+
|
| 128 |
+
if len(sorted_countries) > 20:
|
| 129 |
+
remaining = len(sorted_countries) - 20
|
| 130 |
+
remaining_count = sum(c for _, c in sorted_countries[20:])
|
| 131 |
+
print(f" ... and {remaining} more countries ({remaining_count} samples)")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def main():
|
| 135 |
+
"""Export dataset splits."""
|
| 136 |
+
|
| 137 |
+
script_dir = Path(__file__).parent
|
| 138 |
+
output_dir = script_dir.parent / "output"
|
| 139 |
+
|
| 140 |
+
validated_file = output_dir / "dataset_validated.jsonl"
|
| 141 |
+
train_file = output_dir / "train.jsonl"
|
| 142 |
+
val_file = output_dir / "val.jsonl"
|
| 143 |
+
test_file = output_dir / "test.jsonl"
|
| 144 |
+
|
| 145 |
+
if not validated_file.exists():
|
| 146 |
+
print(f"Error: {validated_file} not found. Run 05_validate_dataset.py first.")
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
# Load validated samples
|
| 150 |
+
print("Loading validated samples...")
|
| 151 |
+
samples = load_samples(validated_file)
|
| 152 |
+
print(f"Loaded {len(samples)} samples")
|
| 153 |
+
|
| 154 |
+
# Split
|
| 155 |
+
print("\nSplitting dataset (80/10/10)...")
|
| 156 |
+
train_samples, val_samples, test_samples = stratified_split(samples)
|
| 157 |
+
|
| 158 |
+
# Save splits
|
| 159 |
+
print("\nSaving splits...")
|
| 160 |
+
save_split(train_samples, train_file)
|
| 161 |
+
save_split(val_samples, val_file)
|
| 162 |
+
save_split(test_samples, test_file)
|
| 163 |
+
|
| 164 |
+
print(f" Train: {train_file} ({len(train_samples)} samples)")
|
| 165 |
+
print(f" Val: {val_file} ({len(val_samples)} samples)")
|
| 166 |
+
print(f" Test: {test_file} ({len(test_samples)} samples)")
|
| 167 |
+
|
| 168 |
+
# Print statistics
|
| 169 |
+
print("\n" + "=" * 60)
|
| 170 |
+
print("SPLIT STATISTICS")
|
| 171 |
+
print("=" * 60)
|
| 172 |
+
|
| 173 |
+
print_split_stats("TRAIN", train_samples)
|
| 174 |
+
print_split_stats("VAL", val_samples)
|
| 175 |
+
print_split_stats("TEST", test_samples)
|
| 176 |
+
|
| 177 |
+
# Print country distribution
|
| 178 |
+
print("\n" + "=" * 60)
|
| 179 |
+
print("GEOGRAPHIC DISTRIBUTION")
|
| 180 |
+
print("=" * 60)
|
| 181 |
+
print_country_stats(samples)
|
| 182 |
+
|
| 183 |
+
print("\n✓ Export complete")
|
| 184 |
+
print(f"\nReady for training!")
|
| 185 |
+
print(f" Training data: {train_file}")
|
| 186 |
+
print(f" Validation data: {val_file}")
|
| 187 |
+
print(f" Test data: {test_file}")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
main()
|
dataset/scripts/generate_samples.py
ADDED
|
@@ -0,0 +1,1091 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate synthetic training samples for text-to-SQL task.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads relation tables and entity inventories
|
| 6 |
+
2. For each SQL template, samples valid anchors
|
| 7 |
+
3. Renders and executes SQL to verify it works
|
| 8 |
+
4. Builds candidate lists with controlled distractors
|
| 9 |
+
5. Generates natural language questions using LLM
|
| 10 |
+
6. Saves complete training samples
|
| 11 |
+
|
| 12 |
+
Output:
|
| 13 |
+
- output/samples/sample_*.json (individual samples)
|
| 14 |
+
- output/dataset_raw.jsonl (all samples)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import random
|
| 19 |
+
import warnings
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import List, Dict, Any, Optional
|
| 22 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 23 |
+
from functools import partial
|
| 24 |
+
|
| 25 |
+
import duckdb
|
| 26 |
+
import pandas as pd
|
| 27 |
+
from pydantic import BaseModel
|
| 28 |
+
|
| 29 |
+
# Suppress warnings
|
| 30 |
+
warnings.filterwarnings('ignore')
|
| 31 |
+
|
| 32 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 33 |
+
|
| 34 |
+
# Configurable parameters (can be overridden by CLI)
|
| 35 |
+
TARGET_COUNTS = None # Will be set in main() or by CLI
|
| 36 |
+
MAX_WORKERS = 8
|
| 37 |
+
RETRY_MULTIPLIER = 2
|
| 38 |
+
APPEND_MODE = False
|
| 39 |
+
|
| 40 |
+
# Import templates from same directory
|
| 41 |
+
from . import sql_templates
|
| 42 |
+
TEMPLATES = sql_templates.TEMPLATES
|
| 43 |
+
SQLTemplate = sql_templates.SQLTemplate
|
| 44 |
+
get_templates_by_family = sql_templates.get_templates_by_family
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Candidate(BaseModel):
|
| 48 |
+
"""Candidate entity for grounding."""
|
| 49 |
+
candidate_id: str
|
| 50 |
+
source: str
|
| 51 |
+
id: str
|
| 52 |
+
name: str
|
| 53 |
+
subtype: Optional[str] = None
|
| 54 |
+
country: Optional[str] = None
|
| 55 |
+
region: Optional[str] = None
|
| 56 |
+
admin_level: Optional[int] = None
|
| 57 |
+
similarity: float = 0.0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TrainingSample(BaseModel):
|
| 61 |
+
"""Complete training sample."""
|
| 62 |
+
id: str
|
| 63 |
+
question: str
|
| 64 |
+
candidates: List[Candidate]
|
| 65 |
+
target: Dict[str, Any]
|
| 66 |
+
metadata: Dict[str, Any]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_relation_tables(intermediate_dir: Path, quiet: bool = False) -> Dict[str, pd.DataFrame]:
|
| 70 |
+
"""Load all precomputed relation tables."""
|
| 71 |
+
tables = {}
|
| 72 |
+
|
| 73 |
+
for file in intermediate_dir.glob("*.parquet"):
|
| 74 |
+
name = file.stem
|
| 75 |
+
tables[name] = pd.read_parquet(file)
|
| 76 |
+
if not quiet:
|
| 77 |
+
print(f" {name}: {len(tables[name])} rows")
|
| 78 |
+
|
| 79 |
+
return tables
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def sample_adjacency_anchor(adjacency_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 83 |
+
"""Sample a random adjacency pair."""
|
| 84 |
+
if adjacency_df.empty:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
row = adjacency_df.sample(n=1).iloc[0]
|
| 88 |
+
return {
|
| 89 |
+
'anchor_id': row['anchor_id'],
|
| 90 |
+
'anchor_name': row['anchor_name'],
|
| 91 |
+
'anchor_subtype': row['anchor_subtype'],
|
| 92 |
+
'anchor_country': row.get('anchor_country'), # May not exist in all tables
|
| 93 |
+
'target_subtype': row.get('target_subtype')
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def sample_intersection_anchor(intersection_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 98 |
+
"""Sample a random intersection pair."""
|
| 99 |
+
if intersection_df.empty:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
row = intersection_df.sample(n=1).iloc[0]
|
| 103 |
+
return {
|
| 104 |
+
'anchor_id': row['anchor_id'],
|
| 105 |
+
'anchor_name': row['anchor_name'],
|
| 106 |
+
'anchor_subtype': row['anchor_subtype'],
|
| 107 |
+
'target_id': row.get('target_id'),
|
| 108 |
+
'target_name': row.get('target_name'),
|
| 109 |
+
'target_subtype': row.get('target_subtype')
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 114 |
+
"""Sample a random containment pair."""
|
| 115 |
+
if containment_df.empty:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
row = containment_df.sample(n=1).iloc[0]
|
| 119 |
+
return {
|
| 120 |
+
'container_id': row['container_id'],
|
| 121 |
+
'container_name': row['container_name'],
|
| 122 |
+
'container_subtype': row['container_subtype'],
|
| 123 |
+
'contained_subtype': row['contained_subtype']
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def sample_cross_source_anchor(cross_source_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 128 |
+
"""Sample a random cross-source relation."""
|
| 129 |
+
if cross_source_df.empty:
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
row = cross_source_df.sample(n=1).iloc[0]
|
| 133 |
+
return {
|
| 134 |
+
'division_id': row['division_id'],
|
| 135 |
+
'division_name': row['division_name'],
|
| 136 |
+
'division_subtype': row['division_subtype'],
|
| 137 |
+
'natural_id': row['natural_id'],
|
| 138 |
+
'natural_name': row['natural_name'],
|
| 139 |
+
'natural_subtype': row['natural_subtype'],
|
| 140 |
+
'relation_type': row['relation_type']
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def build_candidate_list(
|
| 145 |
+
con: duckdb.DuckDBPyConnection,
|
| 146 |
+
anchor_id: str,
|
| 147 |
+
anchor_name: str,
|
| 148 |
+
anchor_source: str,
|
| 149 |
+
num_candidates: int = 10,
|
| 150 |
+
difficulty: str = "medium"
|
| 151 |
+
) -> List[Candidate]:
|
| 152 |
+
"""Build candidate list with true anchor + distractors."""
|
| 153 |
+
|
| 154 |
+
# Helper to convert pandas NA to None
|
| 155 |
+
def safe_get(row, key, default=None):
|
| 156 |
+
val = row.get(key, default)
|
| 157 |
+
return None if pd.isna(val) else val
|
| 158 |
+
|
| 159 |
+
# Get the true anchor
|
| 160 |
+
if anchor_source == "divisions_area":
|
| 161 |
+
query = """
|
| 162 |
+
SELECT
|
| 163 |
+
id,
|
| 164 |
+
names."primary" AS name,
|
| 165 |
+
subtype,
|
| 166 |
+
country,
|
| 167 |
+
region,
|
| 168 |
+
admin_level
|
| 169 |
+
FROM read_parquet(?)
|
| 170 |
+
WHERE id = ?
|
| 171 |
+
"""
|
| 172 |
+
anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
|
| 173 |
+
else:
|
| 174 |
+
query = """
|
| 175 |
+
SELECT
|
| 176 |
+
id,
|
| 177 |
+
names."primary" AS name,
|
| 178 |
+
subtype
|
| 179 |
+
FROM read_parquet(?)
|
| 180 |
+
WHERE id = ?
|
| 181 |
+
"""
|
| 182 |
+
anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
|
| 183 |
+
|
| 184 |
+
# Build true candidate
|
| 185 |
+
true_candidate = Candidate(
|
| 186 |
+
candidate_id="c1",
|
| 187 |
+
source=anchor_source,
|
| 188 |
+
id=anchor_id,
|
| 189 |
+
name=safe_get(anchor_row, 'name'),
|
| 190 |
+
subtype=safe_get(anchor_row, 'subtype'),
|
| 191 |
+
country=safe_get(anchor_row, 'country'),
|
| 192 |
+
region=safe_get(anchor_row, 'region'),
|
| 193 |
+
admin_level=safe_get(anchor_row, 'admin_level'),
|
| 194 |
+
similarity=1.0
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Build distractors based on difficulty
|
| 198 |
+
distractors = build_distractors(
|
| 199 |
+
con,
|
| 200 |
+
anchor_name,
|
| 201 |
+
anchor_source,
|
| 202 |
+
anchor_id,
|
| 203 |
+
num_candidates - 1,
|
| 204 |
+
difficulty
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Combine and shuffle
|
| 208 |
+
candidates = [true_candidate] + distractors
|
| 209 |
+
random.shuffle(candidates)
|
| 210 |
+
|
| 211 |
+
# Reassign candidate IDs after shuffling
|
| 212 |
+
for i, cand in enumerate(candidates, 1):
|
| 213 |
+
cand.candidate_id = f"c{i}"
|
| 214 |
+
|
| 215 |
+
return candidates
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def build_distractors(
|
| 219 |
+
con: duckdb.DuckDBPyConnection,
|
| 220 |
+
anchor_name: str,
|
| 221 |
+
anchor_source: str,
|
| 222 |
+
exclude_id: str,
|
| 223 |
+
num_distractors: int,
|
| 224 |
+
difficulty: str
|
| 225 |
+
) -> List[Candidate]:
|
| 226 |
+
"""Build distractor candidates using fuzzy search."""
|
| 227 |
+
|
| 228 |
+
# Fuzzy search for similar names
|
| 229 |
+
if anchor_source == "divisions_area":
|
| 230 |
+
query = """
|
| 231 |
+
SELECT
|
| 232 |
+
id,
|
| 233 |
+
names."primary" AS name,
|
| 234 |
+
subtype,
|
| 235 |
+
country,
|
| 236 |
+
region,
|
| 237 |
+
admin_level,
|
| 238 |
+
jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
|
| 239 |
+
FROM read_parquet(?)
|
| 240 |
+
WHERE id != ?
|
| 241 |
+
AND names."primary" IS NOT NULL
|
| 242 |
+
ORDER BY similarity DESC
|
| 243 |
+
LIMIT ?
|
| 244 |
+
"""
|
| 245 |
+
df = con.execute(query, [
|
| 246 |
+
anchor_name, DIVISIONS_AREA_PATH, exclude_id, num_distractors
|
| 247 |
+
]).fetchdf()
|
| 248 |
+
source = "divisions_area"
|
| 249 |
+
else:
|
| 250 |
+
query = """
|
| 251 |
+
SELECT
|
| 252 |
+
id,
|
| 253 |
+
names."primary" AS name,
|
| 254 |
+
subtype,
|
| 255 |
+
jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
|
| 256 |
+
FROM read_parquet(?)
|
| 257 |
+
WHERE id != ?
|
| 258 |
+
AND names."primary" IS NOT NULL
|
| 259 |
+
ORDER BY similarity DESC
|
| 260 |
+
LIMIT ?
|
| 261 |
+
"""
|
| 262 |
+
df = con.execute(query, [
|
| 263 |
+
anchor_name, NATURAL_EARTH_PATH, exclude_id, num_distractors
|
| 264 |
+
]).fetchdf()
|
| 265 |
+
source = "natural_earth"
|
| 266 |
+
|
| 267 |
+
# Helper to convert pandas NA to None
|
| 268 |
+
def safe_get(row, key, default=None):
|
| 269 |
+
val = row.get(key, default)
|
| 270 |
+
return None if pd.isna(val) else val
|
| 271 |
+
|
| 272 |
+
distractors = []
|
| 273 |
+
for _, row in df.iterrows():
|
| 274 |
+
distractors.append(Candidate(
|
| 275 |
+
candidate_id="temp", # Will be reassigned
|
| 276 |
+
source=source,
|
| 277 |
+
id=row['id'],
|
| 278 |
+
name=safe_get(row, 'name'),
|
| 279 |
+
subtype=safe_get(row, 'subtype'),
|
| 280 |
+
country=safe_get(row, 'country'),
|
| 281 |
+
region=safe_get(row, 'region'),
|
| 282 |
+
admin_level=safe_get(row, 'admin_level'),
|
| 283 |
+
similarity=float(row['similarity'])
|
| 284 |
+
))
|
| 285 |
+
|
| 286 |
+
return distractors
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def generate_adjacency_sample(
|
| 290 |
+
con: duckdb.DuckDBPyConnection,
|
| 291 |
+
adjacency_df: pd.DataFrame,
|
| 292 |
+
sample_id: str
|
| 293 |
+
) -> Optional[TrainingSample]:
|
| 294 |
+
"""Generate a sample for adjacency task."""
|
| 295 |
+
|
| 296 |
+
anchor = sample_adjacency_anchor(adjacency_df)
|
| 297 |
+
if not anchor:
|
| 298 |
+
return None
|
| 299 |
+
|
| 300 |
+
# Build SQL
|
| 301 |
+
sql = f"""WITH a AS (
|
| 302 |
+
SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}')
|
| 303 |
+
WHERE id = '{anchor['anchor_id']}'
|
| 304 |
+
)
|
| 305 |
+
SELECT b.id, b.names."primary" AS name, b.geometry
|
| 306 |
+
FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a
|
| 307 |
+
WHERE b.id != '{anchor['anchor_id']}'
|
| 308 |
+
AND b.subtype = '{anchor['target_subtype']}'
|
| 309 |
+
AND ST_Touches(a.geometry, b.geometry)"""
|
| 310 |
+
|
| 311 |
+
# Execute to verify
|
| 312 |
+
try:
|
| 313 |
+
result = con.execute(sql).fetchdf()
|
| 314 |
+
if result.empty:
|
| 315 |
+
return None
|
| 316 |
+
except Exception as e:
|
| 317 |
+
print(f"SQL execution failed: {e}")
|
| 318 |
+
return None
|
| 319 |
+
|
| 320 |
+
# Build candidates
|
| 321 |
+
candidates = build_candidate_list(
|
| 322 |
+
con,
|
| 323 |
+
anchor['anchor_id'],
|
| 324 |
+
anchor['anchor_name'],
|
| 325 |
+
"divisions_area",
|
| 326 |
+
num_candidates=10,
|
| 327 |
+
difficulty="medium"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Find which candidate is the true anchor
|
| 331 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['anchor_id']]
|
| 332 |
+
|
| 333 |
+
# Generate question
|
| 334 |
+
question = f"Which {anchor['target_subtype']}s border {anchor['anchor_name']}?"
|
| 335 |
+
|
| 336 |
+
return TrainingSample(
|
| 337 |
+
id=sample_id,
|
| 338 |
+
question=question,
|
| 339 |
+
candidates=candidates,
|
| 340 |
+
target={
|
| 341 |
+
"selected_candidates": selected_candidate_ids,
|
| 342 |
+
"sql": sql
|
| 343 |
+
},
|
| 344 |
+
metadata={
|
| 345 |
+
"task_family": "adjacency",
|
| 346 |
+
"sql_difficulty": "medium",
|
| 347 |
+
"grounding_difficulty": "medium",
|
| 348 |
+
"template_id": "adj_02",
|
| 349 |
+
"num_candidates": len(candidates),
|
| 350 |
+
"anchor_source": "divisions_area",
|
| 351 |
+
"sql_verified": True
|
| 352 |
+
}
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def generate_containment_sample(
|
| 357 |
+
con: duckdb.DuckDBPyConnection,
|
| 358 |
+
containment_df: pd.DataFrame,
|
| 359 |
+
sample_id: str
|
| 360 |
+
) -> Optional[TrainingSample]:
|
| 361 |
+
"""Generate a sample for containment task."""
|
| 362 |
+
|
| 363 |
+
anchor = sample_containment_anchor(containment_df)
|
| 364 |
+
if not anchor:
|
| 365 |
+
return None
|
| 366 |
+
|
| 367 |
+
# Build SQL
|
| 368 |
+
sql = f"""WITH a AS (
|
| 369 |
+
SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}')
|
| 370 |
+
WHERE id = '{anchor['container_id']}'
|
| 371 |
+
)
|
| 372 |
+
SELECT b.id, b.names."primary" AS name, b.geometry
|
| 373 |
+
FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a
|
| 374 |
+
WHERE b.id != '{anchor['container_id']}'
|
| 375 |
+
AND b.subtype = '{anchor['contained_subtype']}'
|
| 376 |
+
AND ST_Within(b.geometry, a.geometry)"""
|
| 377 |
+
|
| 378 |
+
# Execute to verify
|
| 379 |
+
try:
|
| 380 |
+
result = con.execute(sql).fetchdf()
|
| 381 |
+
if result.empty:
|
| 382 |
+
return None
|
| 383 |
+
except Exception as e:
|
| 384 |
+
print(f"SQL execution failed: {e}")
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
# Build candidates
|
| 388 |
+
candidates = build_candidate_list(
|
| 389 |
+
con,
|
| 390 |
+
anchor['container_id'],
|
| 391 |
+
anchor['container_name'],
|
| 392 |
+
"divisions_area",
|
| 393 |
+
num_candidates=10,
|
| 394 |
+
difficulty="medium"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Find which candidate is the true anchor
|
| 398 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['container_id']]
|
| 399 |
+
|
| 400 |
+
# Generate question
|
| 401 |
+
question = f"What {anchor['contained_subtype']}s are in {anchor['container_name']}?"
|
| 402 |
+
|
| 403 |
+
return TrainingSample(
|
| 404 |
+
id=sample_id,
|
| 405 |
+
question=question,
|
| 406 |
+
candidates=candidates,
|
| 407 |
+
target={
|
| 408 |
+
"selected_candidates": selected_candidate_ids,
|
| 409 |
+
"sql": sql
|
| 410 |
+
},
|
| 411 |
+
metadata={
|
| 412 |
+
"task_family": "containment",
|
| 413 |
+
"sql_difficulty": "medium",
|
| 414 |
+
"grounding_difficulty": "medium",
|
| 415 |
+
"template_id": "contain_01",
|
| 416 |
+
"num_candidates": len(candidates),
|
| 417 |
+
"anchor_source": "divisions_area",
|
| 418 |
+
"sql_verified": True
|
| 419 |
+
}
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def sample_random_entity(
|
| 424 |
+
con: duckdb.DuckDBPyConnection,
|
| 425 |
+
inventory_df: pd.DataFrame,
|
| 426 |
+
source: str
|
| 427 |
+
) -> Optional[Dict[str, Any]]:
|
| 428 |
+
"""Sample a random entity from inventory."""
|
| 429 |
+
if inventory_df.empty:
|
| 430 |
+
return None
|
| 431 |
+
|
| 432 |
+
row = inventory_df.sample(n=1).iloc[0]
|
| 433 |
+
return {
|
| 434 |
+
'id': row['id'],
|
| 435 |
+
'name': row['name'],
|
| 436 |
+
'subtype': row.get('subtype'),
|
| 437 |
+
'country': row.get('country'),
|
| 438 |
+
'source': source
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def generate_template_based_sample(
|
| 443 |
+
con: duckdb.DuckDBPyConnection,
|
| 444 |
+
template: SQLTemplate,
|
| 445 |
+
tables: Dict[str, pd.DataFrame],
|
| 446 |
+
sample_id: str
|
| 447 |
+
) -> Optional[TrainingSample]:
|
| 448 |
+
"""Generate a sample based on a SQL template."""
|
| 449 |
+
|
| 450 |
+
# Sample anchor based on template requirements
|
| 451 |
+
if template.family == "direct_lookup":
|
| 452 |
+
# Just pick a random entity
|
| 453 |
+
if template.anchor_source == "divisions_area":
|
| 454 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 455 |
+
else:
|
| 456 |
+
anchor = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
|
| 457 |
+
|
| 458 |
+
if not anchor:
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
# Render SQL
|
| 462 |
+
sql = template.sql_template.format(
|
| 463 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 464 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 465 |
+
anchor_id=anchor['id']
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# Build candidates
|
| 469 |
+
candidates = build_candidate_list(
|
| 470 |
+
con, anchor['id'], anchor['name'], anchor['source'],
|
| 471 |
+
num_candidates=10, difficulty="easy"
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Question
|
| 475 |
+
question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
|
| 476 |
+
|
| 477 |
+
elif template.family == "adjacency":
|
| 478 |
+
anchor = sample_adjacency_anchor(tables['adjacency_pairs'])
|
| 479 |
+
if not anchor:
|
| 480 |
+
return None
|
| 481 |
+
|
| 482 |
+
sql = template.sql_template.format(
|
| 483 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 484 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 485 |
+
anchor_id=anchor['anchor_id'],
|
| 486 |
+
target_subtype=anchor['target_subtype']
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
candidates = build_candidate_list(
|
| 490 |
+
con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
|
| 491 |
+
num_candidates=10, difficulty="medium"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
question = random.choice(template.question_hints).format(
|
| 495 |
+
anchor_name=anchor['anchor_name'],
|
| 496 |
+
target_subtype=anchor['target_subtype']
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
elif template.family == "containment":
|
| 500 |
+
anchor = sample_containment_anchor(tables['containment_pairs'])
|
| 501 |
+
if not anchor:
|
| 502 |
+
return None
|
| 503 |
+
|
| 504 |
+
sql = template.sql_template.format(
|
| 505 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 506 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 507 |
+
anchor_id=anchor['container_id'],
|
| 508 |
+
target_subtype=anchor['contained_subtype']
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
candidates = build_candidate_list(
|
| 512 |
+
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
|
| 513 |
+
num_candidates=10, difficulty="medium"
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
question = random.choice(template.question_hints).format(
|
| 517 |
+
anchor_name=anchor['container_name'],
|
| 518 |
+
target_subtype=anchor['contained_subtype']
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
elif template.family == "intersection":
|
| 522 |
+
if template.anchor_source == "natural_earth":
|
| 523 |
+
anchor = sample_cross_source_anchor(tables['cross_source_relations'])
|
| 524 |
+
if not anchor:
|
| 525 |
+
return None
|
| 526 |
+
|
| 527 |
+
sql = template.sql_template.format(
|
| 528 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 529 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 530 |
+
anchor_id=anchor['natural_id'],
|
| 531 |
+
target_subtype='country'
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
candidates = build_candidate_list(
|
| 535 |
+
con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
|
| 536 |
+
num_candidates=10, difficulty="medium"
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
question = random.choice(template.question_hints).format(
|
| 540 |
+
anchor_name=anchor['natural_name'],
|
| 541 |
+
target_subtype='country'
|
| 542 |
+
)
|
| 543 |
+
else:
|
| 544 |
+
# Same-source intersection
|
| 545 |
+
anchor = sample_intersection_anchor(tables['intersection_pairs'])
|
| 546 |
+
if not anchor:
|
| 547 |
+
return None
|
| 548 |
+
|
| 549 |
+
# Use a generic subtype if not available
|
| 550 |
+
target_subtype = anchor.get('target_subtype') or 'region'
|
| 551 |
+
|
| 552 |
+
sql = template.sql_template.format(
|
| 553 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 554 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 555 |
+
anchor_id=anchor['anchor_id'],
|
| 556 |
+
target_subtype=target_subtype
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
candidates = build_candidate_list(
|
| 560 |
+
con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
|
| 561 |
+
num_candidates=10, difficulty="medium"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
question = random.choice(template.question_hints).format(
|
| 565 |
+
anchor_name=anchor['anchor_name'],
|
| 566 |
+
target_subtype=target_subtype
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
elif template.family == "set_operations":
|
| 570 |
+
# Union of two entities
|
| 571 |
+
anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 572 |
+
anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 573 |
+
|
| 574 |
+
if not anchor1 or not anchor2:
|
| 575 |
+
return None
|
| 576 |
+
|
| 577 |
+
sql = template.sql_template.format(
|
| 578 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 579 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 580 |
+
anchor_id_1=anchor1['id'],
|
| 581 |
+
anchor_id_2=anchor2['id']
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Build candidates for both anchors
|
| 585 |
+
candidates1 = build_candidate_list(
|
| 586 |
+
con, anchor1['id'], anchor1['name'], 'divisions_area',
|
| 587 |
+
num_candidates=5, difficulty="medium"
|
| 588 |
+
)
|
| 589 |
+
candidates2 = build_candidate_list(
|
| 590 |
+
con, anchor2['id'], anchor2['name'], 'divisions_area',
|
| 591 |
+
num_candidates=5, difficulty="medium"
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
# Combine and deduplicate
|
| 595 |
+
candidates = candidates1 + candidates2
|
| 596 |
+
seen_ids = set()
|
| 597 |
+
unique_candidates = []
|
| 598 |
+
for c in candidates:
|
| 599 |
+
if c.id not in seen_ids:
|
| 600 |
+
unique_candidates.append(c)
|
| 601 |
+
seen_ids.add(c.id)
|
| 602 |
+
candidates = unique_candidates[:10]
|
| 603 |
+
|
| 604 |
+
# Reassign IDs
|
| 605 |
+
for i, c in enumerate(candidates, 1):
|
| 606 |
+
c.candidate_id = f"c{i}"
|
| 607 |
+
|
| 608 |
+
question = f"{anchor1['name']} and {anchor2['name']}"
|
| 609 |
+
|
| 610 |
+
elif template.family == "buffer":
|
| 611 |
+
# Buffer operations
|
| 612 |
+
if template.num_anchors == 1:
|
| 613 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 614 |
+
if not anchor:
|
| 615 |
+
return None
|
| 616 |
+
|
| 617 |
+
buffer_degrees = random.choice([0.1, 0.5, 1.0])
|
| 618 |
+
|
| 619 |
+
sql = template.sql_template.format(
|
| 620 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 621 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 622 |
+
anchor_id=anchor['id'],
|
| 623 |
+
buffer_degrees=buffer_degrees
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
candidates = build_candidate_list(
|
| 627 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 628 |
+
num_candidates=10, difficulty="medium"
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
question = random.choice(template.question_hints).format(
|
| 632 |
+
anchor_name=anchor['name'],
|
| 633 |
+
buffer_degrees=buffer_degrees
|
| 634 |
+
)
|
| 635 |
+
else:
|
| 636 |
+
# Two anchor buffer
|
| 637 |
+
anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 638 |
+
anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 639 |
+
|
| 640 |
+
if not anchor1 or not anchor2:
|
| 641 |
+
return None
|
| 642 |
+
|
| 643 |
+
buffer_degrees = random.choice([0.1, 0.5])
|
| 644 |
+
|
| 645 |
+
sql = template.sql_template.format(
|
| 646 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 647 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 648 |
+
anchor_id_1=anchor1['id'],
|
| 649 |
+
anchor_id_2=anchor2['id'],
|
| 650 |
+
buffer_degrees=buffer_degrees
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
candidates1 = build_candidate_list(
|
| 654 |
+
con, anchor1['id'], anchor1['name'], 'divisions_area',
|
| 655 |
+
num_candidates=5, difficulty="medium"
|
| 656 |
+
)
|
| 657 |
+
candidates2 = build_candidate_list(
|
| 658 |
+
con, anchor2['id'], anchor2['name'], 'divisions_area',
|
| 659 |
+
num_candidates=5, difficulty="medium"
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
candidates = candidates1 + candidates2
|
| 663 |
+
seen_ids = set()
|
| 664 |
+
unique_candidates = []
|
| 665 |
+
for c in candidates:
|
| 666 |
+
if c.id not in seen_ids:
|
| 667 |
+
unique_candidates.append(c)
|
| 668 |
+
seen_ids.add(c.id)
|
| 669 |
+
candidates = unique_candidates[:10]
|
| 670 |
+
|
| 671 |
+
for i, c in enumerate(candidates, 1):
|
| 672 |
+
c.candidate_id = f"c{i}"
|
| 673 |
+
|
| 674 |
+
question = random.choice(template.question_hints).format(
|
| 675 |
+
anchor_1_name=anchor1['name'],
|
| 676 |
+
anchor_2_name=anchor2['name'],
|
| 677 |
+
buffer_degrees=buffer_degrees
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
elif template.family == "partial_selection":
|
| 681 |
+
# Partial selection (northern half, clipping, etc.)
|
| 682 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 683 |
+
if not anchor:
|
| 684 |
+
return None
|
| 685 |
+
|
| 686 |
+
if template.num_anchors == 1:
|
| 687 |
+
sql = template.sql_template.format(
|
| 688 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 689 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 690 |
+
anchor_id=anchor['id']
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
|
| 694 |
+
else:
|
| 695 |
+
# Mixed source clipping
|
| 696 |
+
clip_feature = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
|
| 697 |
+
if not clip_feature:
|
| 698 |
+
return None
|
| 699 |
+
|
| 700 |
+
sql = template.sql_template.format(
|
| 701 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 702 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 703 |
+
anchor_id=anchor['id'],
|
| 704 |
+
clip_feature_id=clip_feature['id']
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
question = random.choice(template.question_hints).format(
|
| 708 |
+
anchor_name=anchor['name'],
|
| 709 |
+
clip_feature_name=clip_feature['name']
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
candidates = build_candidate_list(
|
| 713 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 714 |
+
num_candidates=10, difficulty="hard"
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
elif template.family == "aggregation":
|
| 718 |
+
# Aggregation queries (e.g., largest N localities in a region)
|
| 719 |
+
top_n = random.choice([3, 5, 10])
|
| 720 |
+
|
| 721 |
+
# Check if this is a country-level query (agg_04, agg_05)
|
| 722 |
+
if template.template_id in ['agg_04', 'agg_05']:
|
| 723 |
+
# Country-level aggregation
|
| 724 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 725 |
+
if not anchor:
|
| 726 |
+
return None
|
| 727 |
+
|
| 728 |
+
country = anchor.get('country', 'EC')
|
| 729 |
+
target_subtype = random.choice(['locality', 'region'])
|
| 730 |
+
|
| 731 |
+
sql = template.sql_template.format(
|
| 732 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 733 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 734 |
+
country=country,
|
| 735 |
+
target_subtype=target_subtype,
|
| 736 |
+
top_n=top_n
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
candidates = build_candidate_list(
|
| 740 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 741 |
+
num_candidates=10, difficulty="hard"
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
question = random.choice(template.question_hints).format(
|
| 745 |
+
top_n=top_n,
|
| 746 |
+
target_subtype=target_subtype,
|
| 747 |
+
country=country
|
| 748 |
+
)
|
| 749 |
+
else:
|
| 750 |
+
# Container-based aggregation (within a region)
|
| 751 |
+
anchor = sample_containment_anchor(tables['containment_pairs'])
|
| 752 |
+
if not anchor:
|
| 753 |
+
return None
|
| 754 |
+
|
| 755 |
+
target_subtype = anchor.get('contained_subtype', 'locality')
|
| 756 |
+
|
| 757 |
+
sql = template.sql_template.format(
|
| 758 |
+
DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
|
| 759 |
+
NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
|
| 760 |
+
anchor_id=anchor['container_id'],
|
| 761 |
+
target_subtype=target_subtype,
|
| 762 |
+
top_n=top_n
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
candidates = build_candidate_list(
|
| 766 |
+
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
|
| 767 |
+
num_candidates=10, difficulty="hard"
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
question = random.choice(template.question_hints).format(
|
| 771 |
+
top_n=top_n,
|
| 772 |
+
target_subtype=target_subtype,
|
| 773 |
+
anchor_name=anchor['container_name']
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
else:
|
| 777 |
+
# Skip unsupported families
|
| 778 |
+
return None
|
| 779 |
+
|
| 780 |
+
# Execute SQL to verify
|
| 781 |
+
try:
|
| 782 |
+
result = con.execute(sql).fetchdf()
|
| 783 |
+
if result.empty:
|
| 784 |
+
return None
|
| 785 |
+
except Exception as e:
|
| 786 |
+
# Errors are tracked in worker return, no need to print
|
| 787 |
+
return None
|
| 788 |
+
|
| 789 |
+
# Find selected candidates
|
| 790 |
+
if template.family == "set_operations":
|
| 791 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id in [anchor1['id'], anchor2['id']]]
|
| 792 |
+
else:
|
| 793 |
+
anchor_id_to_find = anchor.get('anchor_id') or anchor.get('container_id') or anchor.get('natural_id') or anchor.get('id')
|
| 794 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor_id_to_find]
|
| 795 |
+
|
| 796 |
+
return TrainingSample(
|
| 797 |
+
id=sample_id,
|
| 798 |
+
question=question,
|
| 799 |
+
candidates=candidates,
|
| 800 |
+
target={
|
| 801 |
+
"selected_candidates": selected_candidate_ids,
|
| 802 |
+
"sql": sql
|
| 803 |
+
},
|
| 804 |
+
metadata={
|
| 805 |
+
"task_family": template.family,
|
| 806 |
+
"sql_difficulty": template.sql_difficulty,
|
| 807 |
+
"grounding_difficulty": "medium",
|
| 808 |
+
"template_id": template.template_id,
|
| 809 |
+
"num_candidates": len(candidates),
|
| 810 |
+
"anchor_source": template.anchor_source,
|
| 811 |
+
"sql_verified": True
|
| 812 |
+
}
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def generate_cross_source_sample(
|
| 817 |
+
con: duckdb.DuckDBPyConnection,
|
| 818 |
+
cross_source_df: pd.DataFrame,
|
| 819 |
+
sample_id: str
|
| 820 |
+
) -> Optional[TrainingSample]:
|
| 821 |
+
"""Generate a sample for cross-source intersection task."""
|
| 822 |
+
|
| 823 |
+
anchor = sample_cross_source_anchor(cross_source_df)
|
| 824 |
+
if not anchor:
|
| 825 |
+
return None
|
| 826 |
+
|
| 827 |
+
# Build SQL (natural feature -> divisions)
|
| 828 |
+
sql = f"""WITH a AS (
|
| 829 |
+
SELECT geometry FROM read_parquet('{NATURAL_EARTH_PATH}')
|
| 830 |
+
WHERE id = '{anchor['natural_id']}'
|
| 831 |
+
)
|
| 832 |
+
SELECT b.id, b.names."primary" AS name, b.geometry
|
| 833 |
+
FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a
|
| 834 |
+
WHERE b.subtype = 'country'
|
| 835 |
+
AND ST_Intersects(b.geometry, a.geometry)"""
|
| 836 |
+
|
| 837 |
+
# Execute to verify
|
| 838 |
+
try:
|
| 839 |
+
result = con.execute(sql).fetchdf()
|
| 840 |
+
if result.empty:
|
| 841 |
+
return None
|
| 842 |
+
except Exception as e:
|
| 843 |
+
print(f"SQL execution failed: {e}")
|
| 844 |
+
return None
|
| 845 |
+
|
| 846 |
+
# Build candidates for natural feature
|
| 847 |
+
candidates = build_candidate_list(
|
| 848 |
+
con,
|
| 849 |
+
anchor['natural_id'],
|
| 850 |
+
anchor['natural_name'],
|
| 851 |
+
"natural_earth",
|
| 852 |
+
num_candidates=10,
|
| 853 |
+
difficulty="medium"
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
# Find which candidate is the true anchor
|
| 857 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['natural_id']]
|
| 858 |
+
|
| 859 |
+
# Generate question
|
| 860 |
+
question = f"Which countries intersect the {anchor['natural_name']}?"
|
| 861 |
+
|
| 862 |
+
return TrainingSample(
|
| 863 |
+
id=sample_id,
|
| 864 |
+
question=question,
|
| 865 |
+
candidates=candidates,
|
| 866 |
+
target={
|
| 867 |
+
"selected_candidates": selected_candidate_ids,
|
| 868 |
+
"sql": sql
|
| 869 |
+
},
|
| 870 |
+
metadata={
|
| 871 |
+
"task_family": "intersection",
|
| 872 |
+
"sql_difficulty": "medium-hard",
|
| 873 |
+
"grounding_difficulty": "medium",
|
| 874 |
+
"template_id": "intersect_02",
|
| 875 |
+
"num_candidates": len(candidates),
|
| 876 |
+
"anchor_source": "natural_earth",
|
| 877 |
+
"sql_verified": True
|
| 878 |
+
}
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def generate_sample_batch_worker(args):
|
| 883 |
+
"""Worker function that processes a batch of work items with a single DuckDB connection.
|
| 884 |
+
|
| 885 |
+
Initializes DuckDB, spatial extension, templates module, and relation tables
|
| 886 |
+
ONCE per batch, then processes all items sequentially.
|
| 887 |
+
"""
|
| 888 |
+
from pathlib import Path
|
| 889 |
+
|
| 890 |
+
work_items, intermediate_dir_str = args
|
| 891 |
+
|
| 892 |
+
# Convert string back to Path
|
| 893 |
+
intermediate_dir = Path(intermediate_dir_str)
|
| 894 |
+
|
| 895 |
+
# Initialize DuckDB ONCE for the entire batch
|
| 896 |
+
con = duckdb.connect()
|
| 897 |
+
con.execute("SET enable_progress_bar=false")
|
| 898 |
+
con.execute("INSTALL spatial")
|
| 899 |
+
con.execute("LOAD spatial")
|
| 900 |
+
|
| 901 |
+
# Load relation tables ONCE
|
| 902 |
+
tables = load_relation_tables(intermediate_dir, quiet=True)
|
| 903 |
+
|
| 904 |
+
# Process all items in batch
|
| 905 |
+
results = []
|
| 906 |
+
for family, template_dict, sample_id, _ in work_items:
|
| 907 |
+
# Reconstruct template from dict (sql_templates is already imported at module level)
|
| 908 |
+
template = sql_templates.SQLTemplate(**template_dict)
|
| 909 |
+
try:
|
| 910 |
+
sample = generate_template_based_sample(con, template, tables, sample_id)
|
| 911 |
+
if sample:
|
| 912 |
+
results.append((sample, family, template.template_id, None))
|
| 913 |
+
else:
|
| 914 |
+
results.append((None, family, template.template_id, "Empty result"))
|
| 915 |
+
except Exception as e:
|
| 916 |
+
results.append((None, family, template_dict.get('template_id', 'unknown'), str(e)))
|
| 917 |
+
|
| 918 |
+
con.close()
|
| 919 |
+
return results
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
def main():
|
| 923 |
+
"""Generate training samples."""
|
| 924 |
+
global TARGET_COUNTS, MAX_WORKERS, RETRY_MULTIPLIER, APPEND_MODE
|
| 925 |
+
|
| 926 |
+
# Setup paths
|
| 927 |
+
script_dir = Path(__file__).parent
|
| 928 |
+
intermediate_dir = script_dir.parent / "intermediate"
|
| 929 |
+
output_dir = script_dir.parent / "output"
|
| 930 |
+
|
| 931 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 932 |
+
|
| 933 |
+
# Load relation tables once to check availability
|
| 934 |
+
print("Loading relation tables...")
|
| 935 |
+
tables = load_relation_tables(intermediate_dir, quiet=False)
|
| 936 |
+
|
| 937 |
+
# Use configured target counts or defaults
|
| 938 |
+
if TARGET_COUNTS is None:
|
| 939 |
+
target_counts = {
|
| 940 |
+
'direct_lookup': 100,
|
| 941 |
+
'adjacency': 200,
|
| 942 |
+
'containment': 100,
|
| 943 |
+
'intersection': 150,
|
| 944 |
+
'buffer': 100,
|
| 945 |
+
'set_operations': 150,
|
| 946 |
+
'partial_selection': 100,
|
| 947 |
+
'aggregation': 100
|
| 948 |
+
}
|
| 949 |
+
else:
|
| 950 |
+
target_counts = TARGET_COUNTS
|
| 951 |
+
|
| 952 |
+
# Load existing samples if in append mode
|
| 953 |
+
existing_samples = []
|
| 954 |
+
existing_sample_ids = set()
|
| 955 |
+
jsonl_file = output_dir / "dataset_raw.jsonl"
|
| 956 |
+
|
| 957 |
+
if APPEND_MODE and jsonl_file.exists():
|
| 958 |
+
print(f"\nAppend mode: Loading existing samples from {jsonl_file}")
|
| 959 |
+
with open(jsonl_file, 'r') as f:
|
| 960 |
+
for line in f:
|
| 961 |
+
if line.strip():
|
| 962 |
+
sample_data = json.loads(line)
|
| 963 |
+
existing_samples.append(sample_data)
|
| 964 |
+
existing_sample_ids.add(sample_data['id'])
|
| 965 |
+
print(f" Found {len(existing_samples)} existing samples")
|
| 966 |
+
|
| 967 |
+
# Determine starting sample counter
|
| 968 |
+
max_existing_id = max([int(s['id'].split('_')[1]) for s in existing_samples if s['id'].startswith('sample_')], default=0)
|
| 969 |
+
sample_counter = max_existing_id + 1
|
| 970 |
+
else:
|
| 971 |
+
sample_counter = 1
|
| 972 |
+
|
| 973 |
+
# Prepare work items for parallel processing
|
| 974 |
+
work_items = []
|
| 975 |
+
starting_sample_counter = sample_counter # Track starting point for logging
|
| 976 |
+
|
| 977 |
+
for family, target_count in target_counts.items():
|
| 978 |
+
if target_count == 0:
|
| 979 |
+
continue
|
| 980 |
+
|
| 981 |
+
# Get templates for this family
|
| 982 |
+
family_templates = [t for t in TEMPLATES if t.family == family]
|
| 983 |
+
if not family_templates:
|
| 984 |
+
print(f"No templates found for {family}, skipping...")
|
| 985 |
+
continue
|
| 986 |
+
|
| 987 |
+
# Create work items (try retry_multiplier * target to account for failures)
|
| 988 |
+
for _ in range(target_count * RETRY_MULTIPLIER):
|
| 989 |
+
template = random.choice(family_templates)
|
| 990 |
+
# Convert template to dict for pickling
|
| 991 |
+
template_dict = {
|
| 992 |
+
'template_id': template.template_id,
|
| 993 |
+
'family': template.family,
|
| 994 |
+
'sql_difficulty': template.sql_difficulty,
|
| 995 |
+
'anchor_source': template.anchor_source,
|
| 996 |
+
'num_anchors': template.num_anchors,
|
| 997 |
+
'sql_template': template.sql_template,
|
| 998 |
+
'question_hints': template.question_hints,
|
| 999 |
+
'target_subtype': template.target_subtype,
|
| 1000 |
+
'requires_buffer': template.requires_buffer,
|
| 1001 |
+
'requires_aggregation': template.requires_aggregation
|
| 1002 |
+
}
|
| 1003 |
+
work_items.append((
|
| 1004 |
+
family,
|
| 1005 |
+
template_dict,
|
| 1006 |
+
f"sample_{sample_counter:03d}",
|
| 1007 |
+
str(intermediate_dir) # Convert Path to string for pickling
|
| 1008 |
+
))
|
| 1009 |
+
sample_counter += 1
|
| 1010 |
+
|
| 1011 |
+
# Shuffle work items for balanced batches across families
|
| 1012 |
+
random.shuffle(work_items)
|
| 1013 |
+
|
| 1014 |
+
# Partition work items into batches (one per worker)
|
| 1015 |
+
num_workers = min(MAX_WORKERS, len(work_items))
|
| 1016 |
+
if num_workers == 0:
|
| 1017 |
+
print("No work items to process")
|
| 1018 |
+
return
|
| 1019 |
+
batch_size = (len(work_items) + num_workers - 1) // num_workers
|
| 1020 |
+
batches = []
|
| 1021 |
+
for i in range(0, len(work_items), batch_size):
|
| 1022 |
+
batch = work_items[i:i + batch_size]
|
| 1023 |
+
batches.append((batch, str(intermediate_dir)))
|
| 1024 |
+
|
| 1025 |
+
# Generate samples in parallel (one batch per worker)
|
| 1026 |
+
active_families = len([f for f in target_counts.values() if f > 0])
|
| 1027 |
+
print(f"\nGenerating {len(work_items)} samples across {active_families} families...")
|
| 1028 |
+
print(f" Split into {len(batches)} batches of ~{batch_size} items (1 DuckDB init per batch)")
|
| 1029 |
+
if APPEND_MODE and existing_samples:
|
| 1030 |
+
print(f"Appending: starting from sample_{starting_sample_counter:03d}")
|
| 1031 |
+
|
| 1032 |
+
all_samples = []
|
| 1033 |
+
family_progress = {f: {'success': 0, 'failed': 0} for f in target_counts.keys() if target_counts[f] > 0}
|
| 1034 |
+
|
| 1035 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
| 1036 |
+
# Submit one batch per worker
|
| 1037 |
+
futures = {executor.submit(generate_sample_batch_worker, batch): i for i, batch in enumerate(batches)}
|
| 1038 |
+
|
| 1039 |
+
# Collect results as batches complete
|
| 1040 |
+
batches_done = 0
|
| 1041 |
+
for future in as_completed(futures):
|
| 1042 |
+
try:
|
| 1043 |
+
batch_results = future.result()
|
| 1044 |
+
for sample, family, template_id, error in batch_results:
|
| 1045 |
+
if sample:
|
| 1046 |
+
all_samples.append(sample)
|
| 1047 |
+
family_progress[family]['success'] += 1
|
| 1048 |
+
else:
|
| 1049 |
+
family_progress[family]['failed'] += 1
|
| 1050 |
+
except Exception as e:
|
| 1051 |
+
print(f"\n Batch failed: {e}")
|
| 1052 |
+
|
| 1053 |
+
batches_done += 1
|
| 1054 |
+
total_done = sum(p['success'] + p['failed'] for p in family_progress.values())
|
| 1055 |
+
print(f"\r Progress: {total_done}/{len(work_items)} samples ({batches_done}/{len(batches)} batches) ", end='', flush=True)
|
| 1056 |
+
|
| 1057 |
+
print() # New line after progress
|
| 1058 |
+
|
| 1059 |
+
# Show distribution (keep all samples, no filtering)
|
| 1060 |
+
print("\nResults by family:")
|
| 1061 |
+
for family in sorted(family_progress.keys()):
|
| 1062 |
+
success = family_progress[family]['success']
|
| 1063 |
+
failed = family_progress[family]['failed']
|
| 1064 |
+
target = target_counts.get(family, 0)
|
| 1065 |
+
total = success + failed
|
| 1066 |
+
success_rate = (success / total * 100) if total > 0 else 0
|
| 1067 |
+
print(f" {family:20s}: {success:3d} success / {failed:3d} failed ({success_rate:5.1f}% success rate, target: {target})")
|
| 1068 |
+
|
| 1069 |
+
# Save combined JSONL (skip individual JSON files for speed at scale)
|
| 1070 |
+
print(f"\nSaving {len(all_samples)} new samples...")
|
| 1071 |
+
if APPEND_MODE and existing_samples:
|
| 1072 |
+
# Append to existing dataset
|
| 1073 |
+
print(f"Appending to existing dataset ({len(existing_samples)} existing samples)")
|
| 1074 |
+
with open(jsonl_file, 'a') as f:
|
| 1075 |
+
for sample in all_samples:
|
| 1076 |
+
f.write(json.dumps(sample.model_dump()) + '\n')
|
| 1077 |
+
total_samples = len(existing_samples) + len(all_samples)
|
| 1078 |
+
else:
|
| 1079 |
+
# Overwrite dataset
|
| 1080 |
+
with open(jsonl_file, 'w') as f:
|
| 1081 |
+
for sample in all_samples:
|
| 1082 |
+
f.write(json.dumps(sample.model_dump()) + '\n')
|
| 1083 |
+
total_samples = len(all_samples)
|
| 1084 |
+
|
| 1085 |
+
print(f"\nGenerated {len(all_samples)} new samples")
|
| 1086 |
+
print(f"Total dataset size: {total_samples} samples")
|
| 1087 |
+
print(f" Dataset: {jsonl_file}")
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
if __name__ == "__main__":
|
| 1091 |
+
main()
|
dataset/scripts/sql_templates.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL template definitions for synthetic data generation.
|
| 3 |
+
|
| 4 |
+
Each template includes:
|
| 5 |
+
- Template ID
|
| 6 |
+
- Task family
|
| 7 |
+
- SQL difficulty level
|
| 8 |
+
- Required anchor types
|
| 9 |
+
- SQL template string with placeholders
|
| 10 |
+
- Question generation hints
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import List, Literal
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class SQLTemplate:
|
| 19 |
+
"""SQL template for synthetic data generation."""
|
| 20 |
+
|
| 21 |
+
template_id: str
|
| 22 |
+
family: str
|
| 23 |
+
sql_difficulty: Literal["easy", "medium", "medium-hard", "hard"]
|
| 24 |
+
anchor_source: Literal["divisions_area", "natural_earth", "mixed"]
|
| 25 |
+
num_anchors: int
|
| 26 |
+
sql_template: str
|
| 27 |
+
question_hints: List[str]
|
| 28 |
+
target_subtype: str = None
|
| 29 |
+
requires_buffer: bool = False
|
| 30 |
+
requires_aggregation: bool = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Template catalog
|
| 34 |
+
TEMPLATES = [
|
| 35 |
+
# DIRECT LOOKUP (10 samples)
|
| 36 |
+
SQLTemplate(
|
| 37 |
+
template_id="lookup_01",
|
| 38 |
+
family="direct_lookup",
|
| 39 |
+
sql_difficulty="easy",
|
| 40 |
+
anchor_source="divisions_area",
|
| 41 |
+
num_anchors=1,
|
| 42 |
+
sql_template="""SELECT geometry, names."primary" AS name, id, subtype FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'""",
|
| 43 |
+
question_hints=["Show me {anchor_name}", "Get the geometry of {anchor_name}", "Find {anchor_name}"]
|
| 44 |
+
),
|
| 45 |
+
|
| 46 |
+
SQLTemplate(
|
| 47 |
+
template_id="lookup_02",
|
| 48 |
+
family="direct_lookup",
|
| 49 |
+
sql_difficulty="easy",
|
| 50 |
+
anchor_source="natural_earth",
|
| 51 |
+
num_anchors=1,
|
| 52 |
+
sql_template="""SELECT geometry, names."primary" AS name, id, subtype FROM read_parquet('{NATURAL_EARTH_PATH}') WHERE id = '{anchor_id}'""",
|
| 53 |
+
question_hints=["Show me the {anchor_name}", "Get {anchor_name}", "Find the {anchor_name}"]
|
| 54 |
+
),
|
| 55 |
+
|
| 56 |
+
# ADJACENCY (20 samples)
|
| 57 |
+
SQLTemplate(
|
| 58 |
+
template_id="adj_01",
|
| 59 |
+
family="adjacency",
|
| 60 |
+
sql_difficulty="medium",
|
| 61 |
+
anchor_source="divisions_area",
|
| 62 |
+
num_anchors=1,
|
| 63 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND ST_Touches(a.geometry, b.geometry)""",
|
| 64 |
+
question_hints=["Which regions border {anchor_name}?", "What borders {anchor_name}?", "List places adjacent to {anchor_name}"]
|
| 65 |
+
),
|
| 66 |
+
|
| 67 |
+
SQLTemplate(
|
| 68 |
+
template_id="adj_02",
|
| 69 |
+
family="adjacency",
|
| 70 |
+
sql_difficulty="medium",
|
| 71 |
+
anchor_source="divisions_area",
|
| 72 |
+
num_anchors=1,
|
| 73 |
+
target_subtype="region",
|
| 74 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND b.subtype = '{target_subtype}' AND ST_Touches(a.geometry, b.geometry)""",
|
| 75 |
+
question_hints=["Which {target_subtype}s border {anchor_name}?", "What {target_subtype}s touch {anchor_name}?"]
|
| 76 |
+
),
|
| 77 |
+
|
| 78 |
+
SQLTemplate(
|
| 79 |
+
template_id="adj_03",
|
| 80 |
+
family="adjacency",
|
| 81 |
+
sql_difficulty="medium",
|
| 82 |
+
anchor_source="divisions_area",
|
| 83 |
+
num_anchors=1,
|
| 84 |
+
target_subtype="sea",
|
| 85 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT n.id, n.names."primary" AS name, n.geometry FROM read_parquet('{NATURAL_EARTH_PATH}') AS n, a WHERE n.subtype = '{target_subtype}' AND ST_Touches(a.geometry, n.geometry)""",
|
| 86 |
+
question_hints=["Which {target_subtype}s touch {anchor_name}?", "What {target_subtype}s border {anchor_name}?"]
|
| 87 |
+
),
|
| 88 |
+
|
| 89 |
+
# CONTAINMENT (15 samples)
|
| 90 |
+
SQLTemplate(
|
| 91 |
+
template_id="contain_01",
|
| 92 |
+
family="containment",
|
| 93 |
+
sql_difficulty="medium",
|
| 94 |
+
anchor_source="divisions_area",
|
| 95 |
+
num_anchors=1,
|
| 96 |
+
target_subtype="locality",
|
| 97 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND b.subtype = '{target_subtype}' AND ST_Within(b.geometry, a.geometry)""",
|
| 98 |
+
question_hints=["What {target_subtype}s are in {anchor_name}?", "Which {target_subtype}s are within {anchor_name}?"]
|
| 99 |
+
),
|
| 100 |
+
|
| 101 |
+
SQLTemplate(
|
| 102 |
+
template_id="contain_02",
|
| 103 |
+
family="containment",
|
| 104 |
+
sql_difficulty="medium",
|
| 105 |
+
anchor_source="divisions_area",
|
| 106 |
+
num_anchors=1,
|
| 107 |
+
target_subtype="country",
|
| 108 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND b.subtype = '{target_subtype}' AND ST_Contains(b.geometry, a.geometry)""",
|
| 109 |
+
question_hints=["What {target_subtype} contains {anchor_name}?", "Which {target_subtype} is {anchor_name} in?"]
|
| 110 |
+
),
|
| 111 |
+
|
| 112 |
+
SQLTemplate(
|
| 113 |
+
template_id="contain_03",
|
| 114 |
+
family="containment",
|
| 115 |
+
sql_difficulty="medium",
|
| 116 |
+
anchor_source="natural_earth",
|
| 117 |
+
num_anchors=1,
|
| 118 |
+
target_subtype="region",
|
| 119 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{NATURAL_EARTH_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.subtype = '{target_subtype}' AND ST_Within(b.geometry, a.geometry)""",
|
| 120 |
+
question_hints=["Which {target_subtype}s are in the {anchor_name}?", "What {target_subtype}s fall within the {anchor_name}?"]
|
| 121 |
+
),
|
| 122 |
+
|
| 123 |
+
# INTERSECTION (15 samples)
|
| 124 |
+
SQLTemplate(
|
| 125 |
+
template_id="intersect_01",
|
| 126 |
+
family="intersection",
|
| 127 |
+
sql_difficulty="medium-hard",
|
| 128 |
+
anchor_source="divisions_area",
|
| 129 |
+
num_anchors=1,
|
| 130 |
+
target_subtype="region",
|
| 131 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND b.subtype = '{target_subtype}' AND ST_Intersects(b.geometry, a.geometry)""",
|
| 132 |
+
question_hints=["Which {target_subtype}s intersect {anchor_name}?", "What {target_subtype}s overlap with {anchor_name}?"]
|
| 133 |
+
),
|
| 134 |
+
|
| 135 |
+
SQLTemplate(
|
| 136 |
+
template_id="intersect_02",
|
| 137 |
+
family="intersection",
|
| 138 |
+
sql_difficulty="medium-hard",
|
| 139 |
+
anchor_source="natural_earth",
|
| 140 |
+
num_anchors=1,
|
| 141 |
+
target_subtype="country",
|
| 142 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{NATURAL_EARTH_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.subtype = '{target_subtype}' AND ST_Intersects(b.geometry, a.geometry)""",
|
| 143 |
+
question_hints=["Which {target_subtype}s intersect the {anchor_name}?", "What {target_subtype}s touch the {anchor_name}?"]
|
| 144 |
+
),
|
| 145 |
+
|
| 146 |
+
# BUFFER OPERATIONS (10 samples)
|
| 147 |
+
SQLTemplate(
|
| 148 |
+
template_id="buffer_01",
|
| 149 |
+
family="buffer",
|
| 150 |
+
sql_difficulty="hard",
|
| 151 |
+
anchor_source="divisions_area",
|
| 152 |
+
num_anchors=1,
|
| 153 |
+
requires_buffer=True,
|
| 154 |
+
sql_template="""WITH a AS (SELECT ST_Buffer(geometry, {buffer_degrees}) AS geom FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND ST_Intersects(b.geometry, a.geom)""",
|
| 155 |
+
question_hints=["A {buffer_degrees} degree buffer around {anchor_name}", "Features within {buffer_degrees} degrees of {anchor_name}"]
|
| 156 |
+
),
|
| 157 |
+
|
| 158 |
+
SQLTemplate(
|
| 159 |
+
template_id="buffer_02",
|
| 160 |
+
family="buffer",
|
| 161 |
+
sql_difficulty="hard",
|
| 162 |
+
anchor_source="divisions_area",
|
| 163 |
+
num_anchors=2,
|
| 164 |
+
requires_buffer=True,
|
| 165 |
+
sql_template="""WITH a AS (SELECT geometry AS g1 FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id_1}'), b AS (SELECT geometry AS g2 FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id_2}'), boundary AS (SELECT ST_Buffer(ST_Intersection(a.g1, b.g2), {buffer_degrees}) AS geom FROM a, b WHERE ST_Touches(a.g1, b.g2)) SELECT geom AS geometry FROM boundary""",
|
| 166 |
+
question_hints=["A {buffer_degrees} degree buffer around the border between {anchor_1_name} and {anchor_2_name}"]
|
| 167 |
+
),
|
| 168 |
+
|
| 169 |
+
# SET OPERATIONS (15 samples)
|
| 170 |
+
SQLTemplate(
|
| 171 |
+
template_id="union_01",
|
| 172 |
+
family="set_operations",
|
| 173 |
+
sql_difficulty="medium-hard",
|
| 174 |
+
anchor_source="divisions_area",
|
| 175 |
+
num_anchors=2,
|
| 176 |
+
sql_template="""SELECT ST_Union_Agg(geometry) AS geometry, array_agg(names."primary") AS names FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id IN ('{anchor_id_1}', '{anchor_id_2}')""",
|
| 177 |
+
question_hints=["{anchor_1_name} and {anchor_2_name}", "The union of {anchor_1_name} and {anchor_2_name}"]
|
| 178 |
+
),
|
| 179 |
+
|
| 180 |
+
# PARTIAL SELECTION (10 samples)
|
| 181 |
+
SQLTemplate(
|
| 182 |
+
template_id="partial_01",
|
| 183 |
+
family="partial_selection",
|
| 184 |
+
sql_difficulty="hard",
|
| 185 |
+
anchor_source="divisions_area",
|
| 186 |
+
num_anchors=1,
|
| 187 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), bbox AS (SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax, ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a), north_half AS (SELECT ST_MakeEnvelope(xmin, (ymin + ymax) / 2, xmax, ymax) AS half_geom FROM bbox) SELECT ST_Intersection(a.geometry, nh.half_geom) AS geometry FROM a, north_half AS nh""",
|
| 188 |
+
question_hints=["The northern half of {anchor_name}", "Northern part of {anchor_name}"]
|
| 189 |
+
),
|
| 190 |
+
|
| 191 |
+
SQLTemplate(
|
| 192 |
+
template_id="partial_02",
|
| 193 |
+
family="partial_selection",
|
| 194 |
+
sql_difficulty="hard",
|
| 195 |
+
anchor_source="divisions_area",
|
| 196 |
+
num_anchors=1,
|
| 197 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), bbox AS (SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax, ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a), south_half AS (SELECT ST_MakeEnvelope(xmin, ymin, xmax, (ymin + ymax) / 2) AS half_geom FROM bbox) SELECT ST_Intersection(a.geometry, sh.half_geom) AS geometry FROM a, south_half AS sh""",
|
| 198 |
+
question_hints=["The southern half of {anchor_name}", "Southern part of {anchor_name}"]
|
| 199 |
+
),
|
| 200 |
+
|
| 201 |
+
SQLTemplate(
|
| 202 |
+
template_id="partial_04",
|
| 203 |
+
family="partial_selection",
|
| 204 |
+
sql_difficulty="hard",
|
| 205 |
+
anchor_source="divisions_area",
|
| 206 |
+
num_anchors=1,
|
| 207 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), bbox AS (SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax, ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a), east_half AS (SELECT ST_MakeEnvelope((xmin + xmax) / 2, ymin, xmax, ymax) AS half_geom FROM bbox) SELECT ST_Intersection(a.geometry, eh.half_geom) AS geometry FROM a, east_half AS eh""",
|
| 208 |
+
question_hints=["The eastern half of {anchor_name}", "Eastern part of {anchor_name}"]
|
| 209 |
+
),
|
| 210 |
+
|
| 211 |
+
SQLTemplate(
|
| 212 |
+
template_id="partial_05",
|
| 213 |
+
family="partial_selection",
|
| 214 |
+
sql_difficulty="hard",
|
| 215 |
+
anchor_source="divisions_area",
|
| 216 |
+
num_anchors=1,
|
| 217 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), bbox AS (SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax, ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a), west_half AS (SELECT ST_MakeEnvelope(xmin, ymin, (xmin + xmax) / 2, ymax) AS half_geom FROM bbox) SELECT ST_Intersection(a.geometry, wh.half_geom) AS geometry FROM a, west_half AS wh""",
|
| 218 |
+
question_hints=["The western half of {anchor_name}", "Western part of {anchor_name}"]
|
| 219 |
+
),
|
| 220 |
+
|
| 221 |
+
SQLTemplate(
|
| 222 |
+
template_id="partial_03",
|
| 223 |
+
family="partial_selection",
|
| 224 |
+
sql_difficulty="hard",
|
| 225 |
+
anchor_source="mixed",
|
| 226 |
+
num_anchors=2,
|
| 227 |
+
sql_template="""WITH a AS (SELECT geometry AS g1 FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), b AS (SELECT geometry AS g2 FROM read_parquet('{NATURAL_EARTH_PATH}') WHERE id = '{clip_feature_id}') SELECT ST_Intersection(a.g1, b.g2) AS geometry FROM a, b WHERE ST_Intersects(a.g1, b.g2)""",
|
| 228 |
+
question_hints=["The part of {anchor_name} that is in the {clip_feature_name}", "{anchor_name} within the {clip_feature_name}"]
|
| 229 |
+
),
|
| 230 |
+
|
| 231 |
+
# AGGREGATION (5 samples)
|
| 232 |
+
SQLTemplate(
|
| 233 |
+
template_id="agg_01",
|
| 234 |
+
family="aggregation",
|
| 235 |
+
sql_difficulty="hard",
|
| 236 |
+
anchor_source="divisions_area",
|
| 237 |
+
num_anchors=1,
|
| 238 |
+
target_subtype="locality",
|
| 239 |
+
requires_aggregation=True,
|
| 240 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry, ST_Area(b.geometry) AS area FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE ST_Within(b.geometry, a.geometry) AND b.subtype = '{target_subtype}' ORDER BY area DESC LIMIT {top_n}""",
|
| 241 |
+
question_hints=["Top {top_n} largest {target_subtype}s in {anchor_name}", "Biggest {target_subtype}s in {anchor_name}", "{top_n} largest {target_subtype}s in {anchor_name}"]
|
| 242 |
+
),
|
| 243 |
+
|
| 244 |
+
SQLTemplate(
|
| 245 |
+
template_id="agg_02",
|
| 246 |
+
family="aggregation",
|
| 247 |
+
sql_difficulty="hard",
|
| 248 |
+
anchor_source="divisions_area",
|
| 249 |
+
num_anchors=1,
|
| 250 |
+
target_subtype="locality",
|
| 251 |
+
requires_aggregation=True,
|
| 252 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry, ST_Area(b.geometry) AS area FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE ST_Within(b.geometry, a.geometry) AND b.subtype = '{target_subtype}' ORDER BY area ASC LIMIT {top_n}""",
|
| 253 |
+
question_hints=["Top {top_n} smallest {target_subtype}s in {anchor_name}", "Smallest {target_subtype}s in {anchor_name}", "{top_n} smallest {target_subtype}s in {anchor_name}"]
|
| 254 |
+
),
|
| 255 |
+
|
| 256 |
+
SQLTemplate(
|
| 257 |
+
template_id="agg_03",
|
| 258 |
+
family="aggregation",
|
| 259 |
+
sql_difficulty="hard",
|
| 260 |
+
anchor_source="divisions_area",
|
| 261 |
+
num_anchors=1,
|
| 262 |
+
target_subtype="region",
|
| 263 |
+
requires_aggregation=True,
|
| 264 |
+
sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE ST_Within(b.geometry, a.geometry) AND b.subtype = '{target_subtype}' ORDER BY RANDOM() LIMIT {top_n}""",
|
| 265 |
+
question_hints=["{top_n} random {target_subtype}s in {anchor_name}", "Any {top_n} {target_subtype}s in {anchor_name}"]
|
| 266 |
+
),
|
| 267 |
+
|
| 268 |
+
SQLTemplate(
|
| 269 |
+
template_id="agg_04",
|
| 270 |
+
family="aggregation",
|
| 271 |
+
sql_difficulty="hard",
|
| 272 |
+
anchor_source="divisions_area",
|
| 273 |
+
num_anchors=1,
|
| 274 |
+
target_subtype="locality",
|
| 275 |
+
requires_aggregation=True,
|
| 276 |
+
sql_template="""SELECT id, names."primary" AS name, geometry, ST_Area(geometry) AS area FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE country = '{country}' AND subtype = '{target_subtype}' ORDER BY area DESC LIMIT {top_n}""",
|
| 277 |
+
question_hints=["Top {top_n} largest {target_subtype}s in {country}", "{top_n} biggest {target_subtype}s in {country}"]
|
| 278 |
+
),
|
| 279 |
+
|
| 280 |
+
SQLTemplate(
|
| 281 |
+
template_id="agg_05",
|
| 282 |
+
family="aggregation",
|
| 283 |
+
sql_difficulty="hard",
|
| 284 |
+
anchor_source="divisions_area",
|
| 285 |
+
num_anchors=1,
|
| 286 |
+
target_subtype="locality",
|
| 287 |
+
requires_aggregation=True,
|
| 288 |
+
sql_template="""SELECT id, names."primary" AS name, geometry, ST_Area(geometry) AS area FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE country = '{country}' AND subtype = '{target_subtype}' ORDER BY area ASC LIMIT {top_n}""",
|
| 289 |
+
question_hints=["Top {top_n} smallest {target_subtype}s in {country}", "{top_n} smallest {target_subtype}s in {country}"]
|
| 290 |
+
),
|
| 291 |
+
]
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def get_templates_by_family(family: str) -> List[SQLTemplate]:
|
| 295 |
+
"""Get all templates for a specific family."""
|
| 296 |
+
return [t for t in TEMPLATES if t.family == family]
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def get_template_by_id(template_id: str) -> SQLTemplate:
|
| 300 |
+
"""Get a specific template by ID."""
|
| 301 |
+
for t in TEMPLATES:
|
| 302 |
+
if t.template_id == template_id:
|
| 303 |
+
return t
|
| 304 |
+
raise ValueError(f"Template {template_id} not found")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
# Print template summary
|
| 309 |
+
families = {}
|
| 310 |
+
for t in TEMPLATES:
|
| 311 |
+
families[t.family] = families.get(t.family, 0) + 1
|
| 312 |
+
|
| 313 |
+
print("SQL Template Catalog")
|
| 314 |
+
print("=" * 60)
|
| 315 |
+
for family, count in sorted(families.items()):
|
| 316 |
+
print(f"{family:20s}: {count:2d} templates")
|
| 317 |
+
print(f"{'TOTAL':20s}: {len(TEMPLATES):2d} templates")
|
dataset/scripts/validate_dataset.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validate and balance the generated dataset.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads all generated samples
|
| 6 |
+
2. Validates SQL executability
|
| 7 |
+
3. Checks candidate list quality
|
| 8 |
+
4. Balances across task families and difficulty
|
| 9 |
+
5. Removes duplicates
|
| 10 |
+
6. Generates dataset statistics
|
| 11 |
+
|
| 12 |
+
Output:
|
| 13 |
+
- output/dataset_validated.jsonl
|
| 14 |
+
- output/dataset_stats.json
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Dict, Any, Tuple
|
| 20 |
+
from collections import Counter
|
| 21 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 22 |
+
|
| 23 |
+
import duckdb
|
| 24 |
+
import pandas as pd
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
|
| 28 |
+
"""Load samples from JSONL file."""
|
| 29 |
+
samples = []
|
| 30 |
+
with open(jsonl_path, 'r') as f:
|
| 31 |
+
for line in f:
|
| 32 |
+
samples.append(json.loads(line))
|
| 33 |
+
return samples
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
|
| 37 |
+
"""Validate that SQL executes without error."""
|
| 38 |
+
try:
|
| 39 |
+
result = con.execute(sql).fetchdf()
|
| 40 |
+
if result.empty:
|
| 41 |
+
return False, "Empty result"
|
| 42 |
+
return True, "OK"
|
| 43 |
+
except Exception as e:
|
| 44 |
+
return False, str(e)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def validate_candidates(sample: Dict[str, Any]) -> tuple[bool, str]:
|
| 48 |
+
"""Validate candidate list quality."""
|
| 49 |
+
candidates = sample['candidates']
|
| 50 |
+
selected = sample['target']['selected_candidates']
|
| 51 |
+
|
| 52 |
+
# Check we have candidates
|
| 53 |
+
if not candidates:
|
| 54 |
+
return False, "No candidates"
|
| 55 |
+
|
| 56 |
+
# Check selected candidates exist
|
| 57 |
+
candidate_ids = {c['candidate_id'] for c in candidates}
|
| 58 |
+
for sel_id in selected:
|
| 59 |
+
if sel_id not in candidate_ids:
|
| 60 |
+
return False, f"Selected candidate {sel_id} not in candidate list"
|
| 61 |
+
|
| 62 |
+
# Check for duplicates
|
| 63 |
+
ids = [c['id'] for c in candidates]
|
| 64 |
+
if len(ids) != len(set(ids)):
|
| 65 |
+
return False, "Duplicate candidates"
|
| 66 |
+
|
| 67 |
+
return True, "OK"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def validate_sample(con: duckdb.DuckDBPyConnection, sample: Dict[str, Any]) -> tuple[bool, List[str]]:
|
| 71 |
+
"""Validate a single sample. Returns (is_valid, list_of_issues)."""
|
| 72 |
+
issues = []
|
| 73 |
+
|
| 74 |
+
# Skip SQL re-execution if already verified during generation
|
| 75 |
+
if not sample.get('metadata', {}).get('sql_verified', False):
|
| 76 |
+
sql_valid, sql_msg = validate_sql(con, sample['target']['sql'])
|
| 77 |
+
if not sql_valid:
|
| 78 |
+
issues.append(f"SQL: {sql_msg}")
|
| 79 |
+
|
| 80 |
+
# Validate candidates
|
| 81 |
+
cand_valid, cand_msg = validate_candidates(sample)
|
| 82 |
+
if not cand_valid:
|
| 83 |
+
issues.append(f"Candidates: {cand_msg}")
|
| 84 |
+
|
| 85 |
+
# Check question exists
|
| 86 |
+
if not sample.get('question') or len(sample['question'].strip()) == 0:
|
| 87 |
+
issues.append("Empty question")
|
| 88 |
+
|
| 89 |
+
return len(issues) == 0, issues
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def validate_sample_worker(sample: Dict[str, Any]) -> Tuple[str, bool, List[str]]:
|
| 93 |
+
"""Worker function for parallel validation. Returns (sample_id, is_valid, issues)."""
|
| 94 |
+
# Each worker creates its own DuckDB connection
|
| 95 |
+
con = duckdb.connect()
|
| 96 |
+
con.execute("SET enable_progress_bar=false")
|
| 97 |
+
con.execute("INSTALL spatial")
|
| 98 |
+
con.execute("LOAD spatial")
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
is_valid, issues = validate_sample(con, sample)
|
| 102 |
+
con.close()
|
| 103 |
+
return (sample['id'], is_valid, issues, sample if is_valid else None)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
con.close()
|
| 106 |
+
return (sample['id'], False, [f"Validation error: {str(e)}"], None)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_statistics(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 110 |
+
"""Compute dataset statistics."""
|
| 111 |
+
|
| 112 |
+
stats = {
|
| 113 |
+
'total_samples': len(samples),
|
| 114 |
+
'task_families': {},
|
| 115 |
+
'sql_difficulty': {},
|
| 116 |
+
'grounding_difficulty': {},
|
| 117 |
+
'anchor_sources': {},
|
| 118 |
+
'avg_candidates_per_sample': 0,
|
| 119 |
+
'avg_question_length': 0,
|
| 120 |
+
'countries_covered': set(),
|
| 121 |
+
'subtypes_covered': set()
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
total_candidates = 0
|
| 125 |
+
total_question_length = 0
|
| 126 |
+
|
| 127 |
+
for sample in samples:
|
| 128 |
+
meta = sample['metadata']
|
| 129 |
+
|
| 130 |
+
# Count by family
|
| 131 |
+
family = meta['task_family']
|
| 132 |
+
stats['task_families'][family] = stats['task_families'].get(family, 0) + 1
|
| 133 |
+
|
| 134 |
+
# Count by SQL difficulty
|
| 135 |
+
sql_diff = meta['sql_difficulty']
|
| 136 |
+
stats['sql_difficulty'][sql_diff] = stats['sql_difficulty'].get(sql_diff, 0) + 1
|
| 137 |
+
|
| 138 |
+
# Count by grounding difficulty
|
| 139 |
+
ground_diff = meta['grounding_difficulty']
|
| 140 |
+
stats['grounding_difficulty'][ground_diff] = stats['grounding_difficulty'].get(ground_diff, 0) + 1
|
| 141 |
+
|
| 142 |
+
# Count by anchor source
|
| 143 |
+
anchor_src = meta['anchor_source']
|
| 144 |
+
stats['anchor_sources'][anchor_src] = stats['anchor_sources'].get(anchor_src, 0) + 1
|
| 145 |
+
|
| 146 |
+
# Candidates
|
| 147 |
+
total_candidates += len(sample['candidates'])
|
| 148 |
+
|
| 149 |
+
# Question length
|
| 150 |
+
total_question_length += len(sample['question'].split())
|
| 151 |
+
|
| 152 |
+
# Countries and subtypes (from selected/answer candidates only)
|
| 153 |
+
selected_ids = set(sample.get('target', {}).get('selected_candidates', []))
|
| 154 |
+
for cand in sample['candidates']:
|
| 155 |
+
if cand['candidate_id'] in selected_ids:
|
| 156 |
+
if cand.get('country'):
|
| 157 |
+
stats['countries_covered'].add(cand['country'])
|
| 158 |
+
if cand.get('subtype'):
|
| 159 |
+
stats['subtypes_covered'].add(cand['subtype'])
|
| 160 |
+
|
| 161 |
+
stats['avg_candidates_per_sample'] = total_candidates / len(samples) if samples else 0
|
| 162 |
+
stats['avg_question_length'] = total_question_length / len(samples) if samples else 0
|
| 163 |
+
stats['countries_covered'] = sorted(list(stats['countries_covered']))
|
| 164 |
+
stats['subtypes_covered'] = sorted(list(stats['subtypes_covered']))
|
| 165 |
+
|
| 166 |
+
return stats
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def main():
|
| 170 |
+
"""Validate and analyze dataset."""
|
| 171 |
+
|
| 172 |
+
script_dir = Path(__file__).parent
|
| 173 |
+
output_dir = script_dir.parent / "output"
|
| 174 |
+
|
| 175 |
+
raw_file = output_dir / "dataset_raw.jsonl"
|
| 176 |
+
validated_file = output_dir / "dataset_validated.jsonl"
|
| 177 |
+
stats_file = output_dir / "dataset_stats.json"
|
| 178 |
+
|
| 179 |
+
if not raw_file.exists():
|
| 180 |
+
print(f"Error: {raw_file} not found. Run generate_samples.py first.")
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Load samples
|
| 184 |
+
print("Loading samples...")
|
| 185 |
+
samples = load_samples(raw_file)
|
| 186 |
+
print(f"Loaded {len(samples)} samples")
|
| 187 |
+
|
| 188 |
+
# Validate samples in parallel
|
| 189 |
+
print("\nValidating samples in parallel...")
|
| 190 |
+
valid_samples = []
|
| 191 |
+
invalid_samples = []
|
| 192 |
+
|
| 193 |
+
with ProcessPoolExecutor(max_workers=8) as executor:
|
| 194 |
+
# Submit all validation tasks
|
| 195 |
+
futures = {executor.submit(validate_sample_worker, sample): sample for sample in samples}
|
| 196 |
+
|
| 197 |
+
# Collect results as they complete
|
| 198 |
+
completed = 0
|
| 199 |
+
for future in as_completed(futures):
|
| 200 |
+
sample_id, is_valid, issues, validated_sample = future.result()
|
| 201 |
+
|
| 202 |
+
if is_valid:
|
| 203 |
+
valid_samples.append(validated_sample)
|
| 204 |
+
else:
|
| 205 |
+
invalid_samples.append((sample_id, issues))
|
| 206 |
+
|
| 207 |
+
completed += 1
|
| 208 |
+
if completed % 50 == 0 or completed == len(samples):
|
| 209 |
+
print(f"\r Progress: {completed}/{len(samples)} ", end='', flush=True)
|
| 210 |
+
|
| 211 |
+
print() # New line after progress
|
| 212 |
+
|
| 213 |
+
print(f"\nValidation results:")
|
| 214 |
+
print(f" Valid: {len(valid_samples)}")
|
| 215 |
+
print(f" Invalid: {len(invalid_samples)}")
|
| 216 |
+
|
| 217 |
+
if invalid_samples and len(invalid_samples) <= 20:
|
| 218 |
+
print("\nInvalid samples:")
|
| 219 |
+
for sample_id, issues in invalid_samples[:20]:
|
| 220 |
+
print(f" {sample_id}: {', '.join(issues)}")
|
| 221 |
+
elif invalid_samples:
|
| 222 |
+
print(f"\n{len(invalid_samples)} invalid samples (showing first 20):")
|
| 223 |
+
for sample_id, issues in invalid_samples[:20]:
|
| 224 |
+
print(f" {sample_id}: {', '.join(issues)}")
|
| 225 |
+
|
| 226 |
+
# Save validated samples
|
| 227 |
+
if valid_samples:
|
| 228 |
+
with open(validated_file, 'w') as f:
|
| 229 |
+
for sample in valid_samples:
|
| 230 |
+
f.write(json.dumps(sample) + '\n')
|
| 231 |
+
print(f"\nSaved {len(valid_samples)} valid samples to {validated_file}")
|
| 232 |
+
|
| 233 |
+
# Compute statistics
|
| 234 |
+
print("\nComputing statistics...")
|
| 235 |
+
stats = compute_statistics(valid_samples)
|
| 236 |
+
|
| 237 |
+
# Save statistics
|
| 238 |
+
# Convert sets to lists for JSON serialization
|
| 239 |
+
stats_json = {k: (list(v) if isinstance(v, set) else v) for k, v in stats.items()}
|
| 240 |
+
with open(stats_file, 'w') as f:
|
| 241 |
+
json.dump(stats_json, f, indent=2)
|
| 242 |
+
print(f"Saved statistics to {stats_file}")
|
| 243 |
+
|
| 244 |
+
# Print summary
|
| 245 |
+
print("\n" + "=" * 60)
|
| 246 |
+
print("DATASET STATISTICS")
|
| 247 |
+
print("=" * 60)
|
| 248 |
+
print(f"\nTotal samples: {stats['total_samples']}")
|
| 249 |
+
|
| 250 |
+
print("\nTask families:")
|
| 251 |
+
for family, count in sorted(stats['task_families'].items()):
|
| 252 |
+
print(f" {family:20s}: {count:3d}")
|
| 253 |
+
|
| 254 |
+
print("\nSQL difficulty:")
|
| 255 |
+
for diff, count in sorted(stats['sql_difficulty'].items()):
|
| 256 |
+
print(f" {diff:20s}: {count:3d}")
|
| 257 |
+
|
| 258 |
+
print("\nGrounding difficulty:")
|
| 259 |
+
for diff, count in sorted(stats['grounding_difficulty'].items()):
|
| 260 |
+
print(f" {diff:20s}: {count:3d}")
|
| 261 |
+
|
| 262 |
+
print("\nAnchor sources:")
|
| 263 |
+
for src, count in sorted(stats['anchor_sources'].items()):
|
| 264 |
+
print(f" {src:20s}: {count:3d}")
|
| 265 |
+
|
| 266 |
+
print(f"\nAverage candidates per sample: {stats['avg_candidates_per_sample']:.1f}")
|
| 267 |
+
print(f"Average question length (words): {stats['avg_question_length']:.1f}")
|
| 268 |
+
print(f"Countries covered: {len(stats['countries_covered'])}")
|
| 269 |
+
print(f"Subtypes covered: {len(stats['subtypes_covered'])}")
|
| 270 |
+
|
| 271 |
+
print("\n✓ Validation complete")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
main()
|
pyproject.toml
CHANGED
|
@@ -19,5 +19,11 @@ dependencies = [
|
|
| 19 |
]
|
| 20 |
optional-dependencies = { demo = ["streamlit", "requests", "pydeck"], dev = ["ruff"] }
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
[tool.hatch.build.targets.wheel]
|
| 23 |
-
packages = ["src/gazet"]
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
]
|
| 20 |
optional-dependencies = { demo = ["streamlit", "requests", "pydeck"], dev = ["ruff"] }
|
| 21 |
|
| 22 |
+
[project.scripts]
|
| 23 |
+
gazet-dataset = "dataset.scripts.cli:main"
|
| 24 |
+
|
| 25 |
[tool.hatch.build.targets.wheel]
|
| 26 |
+
packages = ["src/gazet", "dataset"]
|
| 27 |
+
|
| 28 |
+
[dependency-groups]
|
| 29 |
+
dataset = []
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|