srmsoumya commited on
Commit
d8d4856
·
1 Parent(s): 5623a00

FEAT: Add to create SLM training data

Browse files
.gitignore CHANGED
@@ -133,6 +133,8 @@ dmypy.json
133
  # Pyre type checker
134
  .pyre/
135
 
136
-
137
  data/
138
- output/
 
 
 
133
  # Pyre type checker
134
  .pyre/
135
 
136
+ # Dataset
137
  data/
138
+ output/
139
+ *.parquet
140
+ *.jsonl
dataset/PIPELINE_FLOW.md ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Generation Pipeline Flow
2
+
3
+ This document explains how the optimized pipeline processes data with concrete examples.
4
+
5
+ **Example Configuration:**
6
+ - Countries: `['EC', 'BE', 'KE', 'AE', 'SG']` (5 countries)
7
+ - Sample targets: 100 per family × 8 families = 800 samples
8
+ - Retry multiplier: 2 (generate 1,600 attempts to get 800 valid samples)
9
+ - Max workers: 8
10
+
11
+ ---
12
+
13
+ ## Step 0: Build Entity Inventory (One-Time Setup)
14
+
15
+ **Script:** `build_inventory.py`
16
+
17
+ **What it does:** Extracts compact metadata from the full parquet files for fast sampling.
18
+
19
+ **Input:**
20
+ - `divisions_area.parquet` (~500K entities globally)
21
+ - `natural_earth.parquet` (~50K entities globally)
22
+
23
+ **Process:**
24
+ ```sql
25
+ -- For each parquet, extract:
26
+ SELECT
27
+ id,
28
+ names."primary" AS name,
29
+ subtype,
30
+ country,
31
+ region,
32
+ admin_level,
33
+ ST_Area(geometry) AS area_sq_deg,
34
+ -- bounding box for spatial filtering
35
+ FROM read_parquet(...)
36
+ WHERE names."primary" IS NOT NULL
37
+ ```
38
+
39
+ **Output:**
40
+ - `intermediate/divisions_area_inventory.parquet` (~50K rows for 5 countries)
41
+ - `intermediate/natural_earth_inventory.parquet` (~50K rows)
42
+
43
+ **Parallelization:** None (runs once, fast enough)
44
+
45
+ ---
46
+
47
+ ## Step 1: Build Relation Tables (Parallelized)
48
+
49
+ **Script:** `build_relations.py`
50
+
51
+ **What it does:** Pre-computes spatial relationships between entities so sample generation doesn't need to run expensive spatial joins.
52
+
53
+ ### Before Optimization (Sequential)
54
+ ```
55
+ Total time: adjacency (60s) + containment (15s) + intersection (10s) + cross_source (8s) = 93s
56
+ ```
57
+
58
+ ### After Optimization (Parallel)
59
+ ```
60
+ Total time: max(60s, 15s, 10s, 8s) = 60s
61
+ ```
62
+
63
+ **4 concurrent tasks, each with its own DuckDB connection:**
64
+
65
+ #### Task 1: Adjacency Pairs (60s)
66
+ ```sql
67
+ -- Find all touching boundaries within 5 countries
68
+ WITH features AS (
69
+ SELECT id, name, subtype, country, geometry, ST_Envelope(geometry) AS bbox
70
+ FROM divisions_area
71
+ WHERE country IN ('EC', 'BE', 'KE', 'AE', 'SG')
72
+ )
73
+ SELECT
74
+ a.id AS anchor_id,
75
+ a.name AS anchor_name,
76
+ b.id AS target_id,
77
+ b.subtype AS target_subtype
78
+ FROM features a
79
+ JOIN features b ON (
80
+ a.id < b.id
81
+ AND ST_Intersects(a.bbox, b.bbox) -- Fast bbox pre-filter
82
+ AND ST_Touches(a.geometry, b.geometry) -- Expensive but necessary
83
+ )
84
+ LIMIT 50000
85
+ ```
86
+ **Output:** `adjacency_pairs.parquet` (50,000 rows)
87
+
88
+ #### Task 2: Containment Pairs (15s)
89
+ ```sql
90
+ -- Find all parent-child relationships
91
+ -- Example: Ecuador contains Quito
92
+ SELECT
93
+ container.id AS container_id,
94
+ container.name AS container_name,
95
+ contained.id AS contained_id,
96
+ contained.subtype AS contained_subtype
97
+ FROM features container
98
+ JOIN features contained ON (
99
+ container.admin_level < contained.admin_level -- Parent has lower level
100
+ AND ST_Within(contained.geometry, container.geometry)
101
+ )
102
+ LIMIT 1000
103
+ ```
104
+ **Output:** `containment_pairs.parquet` (1,000 rows)
105
+
106
+ #### Task 3: Intersection Pairs (10s)
107
+ ```sql
108
+ -- Find overlapping regions (not touching, not containing)
109
+ -- Example: Two administrative regions that overlap
110
+ SELECT a.id, a.name, b.id, b.subtype
111
+ FROM features a
112
+ JOIN features b ON (
113
+ ST_Intersects(a.geometry, b.geometry)
114
+ AND NOT ST_Touches(a.geometry, b.geometry)
115
+ AND NOT ST_Within(a.geometry, b.geometry)
116
+ )
117
+ LIMIT 500
118
+ ```
119
+ **Output:** `intersection_pairs.parquet` (500 rows)
120
+
121
+ #### Task 4: Cross-Source Relations (8s)
122
+ ```sql
123
+ -- Find relationships between divisions and natural features
124
+ -- Example: Ecuador intersects Pacific Ocean
125
+ SELECT
126
+ d.id AS division_id,
127
+ d.name AS division_name,
128
+ n.id AS natural_id,
129
+ n.name AS natural_name,
130
+ CASE
131
+ WHEN ST_Touches(...) THEN 'touches'
132
+ WHEN ST_Intersects(...) THEN 'intersects'
133
+ END AS relation_type
134
+ FROM divisions d
135
+ JOIN natural_features n ON ST_Intersects(d.geometry, n.geometry)
136
+ WHERE d.country IN ('EC', 'BE', 'KE', 'AE', 'SG')
137
+ AND n.subtype IN ('sea', 'ocean', 'Lake', 'River')
138
+ LIMIT 500
139
+ ```
140
+ **Output:** `cross_source_relations.parquet` (500 rows)
141
+
142
+ **ThreadPoolExecutor with 4 workers runs all tasks concurrently.**
143
+
144
+ ---
145
+
146
+ ## Step 2: Generate Samples (Batch-Parallelized)
147
+
148
+ **Script:** `generate_samples.py`
149
+
150
+ **What it does:** Creates training samples by:
151
+ 1. Sampling anchors from relation tables
152
+ 2. Rendering SQL templates
153
+ 3. Executing SQL to verify it works
154
+ 4. Building candidate lists with distractors
155
+ 5. Generating questions
156
+
157
+ ### Work Item Preparation
158
+
159
+ **Total work items:** 8 families × 100 targets × 2 retry_multiplier = **1,600 items**
160
+
161
+ ```python
162
+ work_items = [
163
+ ('adjacency', template_dict_1, 'sample_001', '/path/to/intermediate'),
164
+ ('containment', template_dict_2, 'sample_002', '/path/to/intermediate'),
165
+ ('adjacency', template_dict_3, 'sample_003', '/path/to/intermediate'),
166
+ # ... 1,597 more items
167
+ ]
168
+
169
+ # Shuffle for balanced batches
170
+ random.shuffle(work_items)
171
+
172
+ # Partition into 8 batches (one per worker)
173
+ batch_size = 1600 / 8 = 200 items per batch
174
+ batches = [
175
+ batch_1: items[0:200], # ~25 of each family (mixed)
176
+ batch_2: items[200:400],
177
+ batch_3: items[400:600],
178
+ # ... 8 batches total
179
+ ]
180
+ ```
181
+
182
+ ### Before Optimization (Per-Sample Workers)
183
+
184
+ ```
185
+ For each of 1,600 samples:
186
+ - Fork new process
187
+ - Create DuckDB connection
188
+ - INSTALL spatial (5-10ms)
189
+ - LOAD spatial (5-10ms)
190
+ - Import sql_templates module
191
+ - Load 4 relation parquet files (50-100ms)
192
+ - Generate 1 sample (20-50ms)
193
+ - Close connection
194
+
195
+ Total overhead per sample: ~100ms
196
+ Total overhead: 1,600 × 100ms = 160 seconds of pure overhead
197
+ ```
198
+
199
+ ### After Optimization (Batch Workers)
200
+
201
+ ```
202
+ 8 workers run in parallel, each processes 200 samples:
203
+
204
+ Worker 1 (batch of 200 items):
205
+ - Create DuckDB connection (once)
206
+ - INSTALL + LOAD spatial (once, 10ms)
207
+ - Import sql_templates (once)
208
+ - Load 4 relation tables (once, 100ms)
209
+
210
+ FOR EACH of 200 items:
211
+ - Sample anchor from pre-loaded table (instant)
212
+ - Render SQL template
213
+ - Execute SQL to verify (20-50ms)
214
+ - Build candidate list with Jaro-Winkler (10-30ms)
215
+ - Generate question
216
+
217
+ - Close connection
218
+
219
+ Total overhead per worker: ~110ms (one-time)
220
+ Total overhead across 8 workers: ~110ms (parallel)
221
+ ```
222
+
223
+ **Speedup:** 160s → 0.11s overhead = **~1,450x faster on initialization overhead**
224
+
225
+ ### Sample Generation Example
226
+
227
+ **Adjacency sample generation:**
228
+
229
+ ```python
230
+ # 1. Sample anchor from pre-loaded adjacency_pairs DataFrame
231
+ row = adjacency_df.sample(n=1).iloc[0]
232
+ # Result: anchor_id='EC-123', anchor_name='Quito', target_subtype='locality'
233
+
234
+ # 2. Render SQL template
235
+ sql = f"""
236
+ WITH a AS (
237
+ SELECT geometry FROM divisions_area WHERE id = 'EC-123'
238
+ )
239
+ SELECT b.id, b.names."primary" AS name, b.geometry
240
+ FROM divisions_area AS b, a
241
+ WHERE b.id != 'EC-123'
242
+ AND b.subtype = 'locality'
243
+ AND ST_Touches(a.geometry, b.geometry)
244
+ """
245
+
246
+ # 3. Execute to verify (returns 5 neighboring localities)
247
+ result = con.execute(sql).fetchdf() # 30ms
248
+ # ✓ Not empty, sample is valid
249
+
250
+ # 4. Build candidate list (10 candidates: 1 true + 9 distractors)
251
+ candidates = build_candidate_list(
252
+ con, 'EC-123', 'Quito', 'divisions_area', num_candidates=10
253
+ )
254
+ # Uses Jaro-Winkler to find similar names:
255
+ SELECT id, name, subtype, country,
256
+ jaro_winkler_similarity(lower(name), lower('Quito')) AS similarity
257
+ FROM divisions_area
258
+ WHERE id != 'EC-123'
259
+ ORDER BY similarity DESC
260
+ LIMIT 9
261
+
262
+ # Results: ['Quito', 'Cuito', 'Quijos', 'Quinindé', ...]
263
+ # Shuffle and reassign IDs: c1, c2, ..., c10
264
+
265
+ # 5. Generate question
266
+ question = "Which localities border Quito?"
267
+
268
+ # 6. Return TrainingSample with sql_verified=True
269
+ ```
270
+
271
+ ### Batch Progress Tracking
272
+
273
+ ```
274
+ Console output:
275
+
276
+ Generating 1600 samples across 8 families...
277
+ Split into 8 batches of ~200 items (1 DuckDB init per batch)
278
+
279
+ Progress: 200/1600 samples (1/8 batches) # Worker 1 done
280
+ Progress: 400/1600 samples (2/8 batches) # Worker 2 done
281
+ Progress: 600/1600 samples (3/8 batches) # Worker 3 done
282
+ ...
283
+ Progress: 1600/1600 samples (8/8 batches) # All done
284
+
285
+ Results by family:
286
+ adjacency : 185 success / 15 failed (92.5% success rate, target: 100)
287
+ aggregation : 178 success / 22 failed (89.0% success rate, target: 100)
288
+ buffer : 192 success / 8 failed (96.0% success rate, target: 100)
289
+ containment : 188 success / 12 failed (94.0% success rate, target: 100)
290
+ direct_lookup : 200 success / 0 failed (100% success rate, target: 100)
291
+ intersection : 181 success / 19 failed (90.5% success rate, target: 100)
292
+ partial_selection : 175 success / 25 failed (87.5% success rate, target: 100)
293
+ set_operations : 190 success / 10 failed (95.0% success rate, target: 100)
294
+
295
+ Total: 1,489 valid samples from 1,600 attempts
296
+ ```
297
+
298
+ ---
299
+
300
+ ## Step 3: Validate Dataset (Optimized)
301
+
302
+ **Script:** `validate_dataset.py`
303
+
304
+ **What it does:** Validates samples in parallel, but **skips SQL re-execution** for samples with `sql_verified: True`.
305
+
306
+ ### Before Optimization
307
+
308
+ ```
309
+ For each of 1,489 samples:
310
+ - Execute SQL to verify (30ms)
311
+ - Validate candidates (1ms)
312
+ - Check question (1ms)
313
+
314
+ Total: 1,489 × 32ms = 47.6 seconds
315
+ ```
316
+
317
+ ### After Optimization
318
+
319
+ ```
320
+ For each of 1,489 samples:
321
+ - Check metadata.sql_verified flag
322
+ - IF True: skip SQL execution (saved 30ms)
323
+ - Validate candidates (1ms)
324
+ - Check question (1ms)
325
+
326
+ Total: 1,489 × 2ms = 3.0 seconds
327
+ ```
328
+
329
+ **Speedup:** 47.6s → 3.0s = **~16x faster**
330
+
331
+ **Parallelization:** 8 workers process samples in parallel batches
332
+
333
+ ---
334
+
335
+ ## Step 4: Export Splits
336
+
337
+ **Script:** `export_training_data.py`
338
+
339
+ **What it does:** Stratified split into train/val/test (80/10/10) by task family.
340
+
341
+ **Input:** `dataset_validated.jsonl` (1,489 samples)
342
+
343
+ **Process:**
344
+ ```python
345
+ # Group by family
346
+ adjacency_samples: 185
347
+ aggregation_samples: 178
348
+ buffer_samples: 192
349
+ # ... etc
350
+
351
+ # Split each family 80/10/10
352
+ adjacency_train: 148, val: 19, test: 18
353
+ aggregation_train: 142, val: 18, test: 18
354
+ # ... etc
355
+
356
+ # Combine and shuffle
357
+ train: 1,191 samples
358
+ val: 149 samples
359
+ test: 149 samples
360
+ ```
361
+
362
+ **Output:**
363
+ - `output/train.jsonl` (1,191 samples)
364
+ - `output/val.jsonl` (149 samples)
365
+ - `output/test.jsonl` (149 samples)
366
+
367
+ **Parallelization:** None needed (fast enough)
368
+
369
+ ---
370
+
371
+ ## Overall Pipeline Timing
372
+
373
+ ### Before Optimizations
374
+ ```
375
+ Step 0: Build Inventory : 5s (one-time)
376
+ Step 1: Build Relations : 93s (sequential)
377
+ Step 2: Generate Samples : 320s (160s overhead + 160s generation)
378
+ Step 3: Validate Dataset : 48s (re-executing all SQL)
379
+ Step 4: Export Splits : 2s
380
+ ------
381
+ Total : 468s (~7.8 minutes)
382
+ ```
383
+
384
+ ### After Optimizations
385
+ ```
386
+ Step 0: Build Inventory : 5s (one-time)
387
+ Step 1: Build Relations : 60s (parallel, limited by slowest task)
388
+ Step 2: Generate Samples : 165s (0.11s overhead + 165s generation)
389
+ Step 3: Validate Dataset : 3s (skips SQL re-execution)
390
+ Step 4: Export Splits : 2s
391
+ ------
392
+ Total : 235s (~3.9 minutes)
393
+ ```
394
+
395
+ **Overall speedup:** 468s → 235s = **~2x faster**
396
+
397
+ **At 10K scale (100x more samples):**
398
+ - Before: ~780 minutes (13 hours)
399
+ - After: ~390 minutes (6.5 hours)
400
+ - With further optimizations (sampling without replacement, better caching): **<2 hours**
401
+
402
+ ---
403
+
404
+ ## Key Optimizations Summary
405
+
406
+ | Optimization | Impact | Where |
407
+ |-------------|--------|-------|
408
+ | **Batch workers** | 1,450x on init overhead | `generate_samples.py` |
409
+ | **Parallel relations** | 1.5x on relation building | `build_relations.py` |
410
+ | **Jaro-Winkler** | 2-3x on distractor search | `generate_samples.py` |
411
+ | **Skip SQL re-validation** | 16x on validation | `validate_dataset.py` |
412
+ | **Drop individual JSON files** | 1.2x on I/O | `generate_samples.py` |
413
+
414
+ **Combined:** Enables scaling from hundreds to tens of thousands of samples efficiently.
dataset/README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Generation CLI
2
+
3
+ Generate synthetic text-to-SQL training datasets.
4
+
5
+ ## Quick Start
6
+
7
+ ```bash
8
+ # Install
9
+ uv sync
10
+
11
+ # Generate dataset
12
+ gazet-dataset full-pipeline --config dataset/config.yaml
13
+ ```
14
+
15
+ ## Configuration
16
+
17
+ Edit `dataset/config.yaml`:
18
+
19
+ ```yaml
20
+ countries:
21
+ - EC # Ecuador
22
+ - BE # Belgium
23
+ - KE # Kenya
24
+ - AE # UAE
25
+ - SG # Singapore
26
+ - CH # Switzerland
27
+
28
+ sample_targets:
29
+ direct_lookup: 100
30
+ adjacency: 100
31
+ containment: 100
32
+ intersection: 100
33
+ buffer: 100
34
+ set_operations: 100
35
+ partial_selection: 100
36
+ aggregation: 100
37
+
38
+ generation:
39
+ max_workers: 8
40
+ retry_multiplier: 2
41
+ append_mode: true
42
+
43
+ auto_scaling:
44
+ safety_factor: 1.5 # Auto-calculates relation limits
45
+ ```
46
+
47
+ ## Growing Your Dataset
48
+
49
+ ### Start Small
50
+ ```bash
51
+ # Generate initial dataset (e.g., 100 samples)
52
+ gazet-dataset full-pipeline --config dataset/config.yaml
53
+ ```
54
+
55
+ ### Add More Samples (Same Countries)
56
+ ```bash
57
+ # Increase sample_targets in config.yaml, then:
58
+ gazet-dataset full-pipeline --config dataset/config.yaml --append
59
+ ```
60
+
61
+ ### Add New Countries
62
+ ```bash
63
+ # Add countries to config.yaml, then:
64
+ gazet-dataset full-pipeline --config dataset/config.yaml --append
65
+ # Auto-rebuilds relations if countries changed
66
+ ```
67
+
68
+ ### Scale to 10K+
69
+ ```yaml
70
+ # config.yaml - increase targets
71
+ sample_targets:
72
+ adjacency: 1000
73
+ containment: 1000
74
+ intersection: 1000
75
+ # ... etc
76
+ ```
77
+
78
+ ```bash
79
+ gazet-dataset full-pipeline --config dataset/config.yaml --append
80
+ ```
81
+
82
+ ## Commands
83
+
84
+ ```bash
85
+ gazet-dataset full-pipeline --config <path> # Run everything
86
+ gazet-dataset build-relations --config <path> # Build spatial relations
87
+ gazet-dataset generate-samples --config <path> # Generate samples
88
+ gazet-dataset validate --config <path> # Validate dataset
89
+ gazet-dataset export --config <path> # Export train/val/test
90
+ ```
91
+
92
+ **Options:**
93
+ - `--append`: Add to existing dataset instead of overwriting
94
+
95
+ ## Output
96
+
97
+ - `dataset/output/dataset_raw.jsonl` - Generated samples
98
+ - `dataset/output/dataset_validated.jsonl` - Validated samples
99
+ - `dataset/output/train.jsonl` - Training split
100
+ - `dataset/output/val.jsonl` - Validation split
101
+ - `dataset/output/test.jsonl` - Test split
102
+
103
+ ## Tips
104
+
105
+ - Start with 2-3 countries and small sample targets
106
+ - Use `--append` to grow dataset incrementally
107
+ - Relation limits auto-calculate from sample targets
108
+ - Check success rates in output summary
dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Synthetic dataset generation package."""
dataset/config.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Generation Configuration
2
+ # This config controls which countries to process and how many samples to generate
3
+
4
+ # Countries to include in relation building
5
+ # Use ISO 3166-1 alpha-2 codes
6
+ countries:
7
+ # - EC # Ecuador
8
+ # - BE # Belgium
9
+ # - KE # Kenya
10
+ # - AE # UAE
11
+ # - SG # Singapore
12
+ # - CH # Switzerland
13
+ - IN # India
14
+ - PK # Pakistan
15
+ # - SL # Sri Lanka
16
+ # - BD # Bangladesh
17
+
18
+ # Sample generation targets per family
19
+ # Relation limits are auto-calculated from these targets
20
+ sample_targets:
21
+ direct_lookup: 20
22
+ adjacency: 20
23
+ containment: 20
24
+ intersection: 20
25
+ buffer: 20
26
+ set_operations: 20
27
+ partial_selection: 20
28
+ aggregation: 20
29
+
30
+ # Generation settings
31
+ generation:
32
+ max_workers: 8 # Number of parallel workers
33
+ retry_multiplier: 2 # Generate 2x samples to account for failures
34
+ append_mode: true # If true, append to existing dataset instead of overwriting
35
+
36
+ # Auto-scaling configuration
37
+ # Relation limits are automatically calculated: target * retry_multiplier * safety_factor
38
+ auto_scaling:
39
+ safety_factor: 1.5 # Extra buffer to ensure enough unique pairs
40
+
41
+ # Manual overrides (optional) - uncomment to override auto-calculated limits
42
+ manual_limits: {}
43
+ # adjacency: 10000 # Uncomment to manually set
44
+ # containment: 2000
45
+ # intersection: 1000
46
+ # cross_source: 500
47
+
48
+ # Output paths (relative to dataset directory)
49
+ output:
50
+ samples_dir: "output/samples"
51
+ dataset_file: "output/dataset_raw.jsonl"
52
+ intermediate_dir: "intermediate"
dataset/scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Dataset generation scripts package."""
dataset/scripts/build_inventory.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build entity inventory from divisions_area and natural_earth parquet files.
3
+
4
+ This script creates compact inventory tables containing only the fields needed
5
+ for candidate sampling and distractor generation.
6
+
7
+ Output:
8
+ - intermediate/divisions_area_inventory.parquet
9
+ - intermediate/natural_earth_inventory.parquet
10
+ """
11
+
12
+ import duckdb
13
+ import pandas as pd
14
+ from pathlib import Path
15
+
16
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
17
+
18
+
19
+ def build_divisions_area_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFrame:
20
+ """Extract compact inventory from divisions_area."""
21
+ query = """
22
+ SELECT
23
+ 'divisions_area' AS source,
24
+ id,
25
+ names."primary" AS name,
26
+ subtype,
27
+ country,
28
+ region,
29
+ admin_level,
30
+ class,
31
+ is_land,
32
+ is_territorial,
33
+ division_id,
34
+ ST_Area(geometry) AS area_sq_deg,
35
+ ST_XMin(geometry) AS xmin,
36
+ ST_YMin(geometry) AS ymin,
37
+ ST_XMax(geometry) AS xmax,
38
+ ST_YMax(geometry) AS ymax
39
+ FROM read_parquet(?)
40
+ WHERE names."primary" IS NOT NULL
41
+ AND trim(names."primary") != ''
42
+ """
43
+
44
+ df = con.execute(query, [DIVISIONS_AREA_PATH]).fetchdf()
45
+ print(f"Divisions area inventory: {len(df)} entities")
46
+ print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
47
+ print(f"Countries: {df['country'].nunique()} unique")
48
+
49
+ return df
50
+
51
+
52
+ def build_natural_earth_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFrame:
53
+ """Extract compact inventory from natural_earth."""
54
+ query = """
55
+ SELECT
56
+ 'natural_earth' AS source,
57
+ id,
58
+ names."primary" AS name,
59
+ subtype,
60
+ country,
61
+ region,
62
+ admin_level,
63
+ class,
64
+ is_land,
65
+ is_territorial,
66
+ ST_Area(geometry) AS area_sq_deg,
67
+ ST_XMin(geometry) AS xmin,
68
+ ST_YMin(geometry) AS ymin,
69
+ ST_XMax(geometry) AS xmax,
70
+ ST_YMax(geometry) AS ymax
71
+ FROM read_parquet(?)
72
+ WHERE names."primary" IS NOT NULL
73
+ AND trim(names."primary") != ''
74
+ """
75
+
76
+ df = con.execute(query, [NATURAL_EARTH_PATH]).fetchdf()
77
+ print(f"\nNatural earth inventory: {len(df)} entities")
78
+ print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
79
+
80
+ return df
81
+
82
+
83
+ def main():
84
+ """Build and save inventory tables."""
85
+ output_dir = Path(__file__).parent.parent / "intermediate"
86
+ output_dir.mkdir(exist_ok=True, parents=True)
87
+
88
+ con = duckdb.connect()
89
+ con.execute("INSTALL spatial")
90
+ con.execute("LOAD spatial")
91
+
92
+ print("Building divisions_area inventory...")
93
+ divisions_df = build_divisions_area_inventory(con)
94
+ divisions_path = output_dir / "divisions_area_inventory.parquet"
95
+ divisions_df.to_parquet(divisions_path, index=False)
96
+ print(f"Saved to {divisions_path}")
97
+
98
+ print("\nBuilding natural_earth inventory...")
99
+ natural_earth_df = build_natural_earth_inventory(con)
100
+ natural_earth_path = output_dir / "natural_earth_inventory.parquet"
101
+ natural_earth_df.to_parquet(natural_earth_path, index=False)
102
+ print(f"Saved to {natural_earth_path}")
103
+
104
+ con.close()
105
+
106
+ print("\n✓ Inventory build complete")
107
+ print(f" Total entities: {len(divisions_df) + len(natural_earth_df)}")
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
dataset/scripts/build_relations.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Precompute spatial relation tables for efficient anchor sampling.
3
+
4
+ This script computes:
5
+ - Adjacency pairs (touching features)
6
+ - Containment pairs (features within other features)
7
+ - Intersection pairs (overlapping features)
8
+ - Cross-source relations (divisions_area ↔ natural_earth)
9
+
10
+ Output:
11
+ - intermediate/adjacency_pairs.parquet
12
+ - intermediate/containment_pairs.parquet
13
+ - intermediate/intersection_pairs.parquet
14
+ - intermediate/cross_source_relations.parquet
15
+ """
16
+
17
+ import duckdb
18
+ import pandas as pd
19
+ from pathlib import Path
20
+ from concurrent.futures import ThreadPoolExecutor, as_completed
21
+
22
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
23
+
24
+
25
+ def compute_adjacency_pairs(
26
+ con: duckdb.DuckDBPyConnection,
27
+ countries: list,
28
+ limit: int
29
+ ) -> pd.DataFrame:
30
+ """Find all pairs of features that touch (share a boundary)."""
31
+ print("Computing adjacency pairs (optimized with spatial index)...")
32
+
33
+ # Use bounding box pre-filter to avoid full cartesian product
34
+ query = """
35
+ WITH features AS (
36
+ SELECT
37
+ id,
38
+ names."primary" AS name,
39
+ subtype,
40
+ country,
41
+ admin_level,
42
+ geometry,
43
+ ST_Envelope(geometry) AS bbox
44
+ FROM read_parquet(?)
45
+ WHERE country IN (SELECT unnest(?))
46
+ )
47
+ SELECT
48
+ a.id AS anchor_id,
49
+ a.name AS anchor_name,
50
+ a.subtype AS anchor_subtype,
51
+ a.country AS anchor_country,
52
+ b.id AS target_id,
53
+ b.name AS target_name,
54
+ b.subtype AS target_subtype,
55
+ b.country AS target_country,
56
+ 'adjacency' AS relation_type
57
+ FROM features AS a
58
+ JOIN features AS b ON (
59
+ a.id < b.id
60
+ AND ST_Intersects(a.bbox, b.bbox)
61
+ AND ST_Touches(a.geometry, b.geometry)
62
+ )
63
+ LIMIT ?
64
+ """
65
+
66
+ df = con.execute(query, [DIVISIONS_AREA_PATH, countries, limit]).fetchdf()
67
+ print(f"Found {len(df)} adjacency pairs")
68
+
69
+ return df
70
+
71
+
72
+ def compute_containment_pairs(
73
+ con: duckdb.DuckDBPyConnection,
74
+ countries: list,
75
+ limit: int
76
+ ) -> pd.DataFrame:
77
+ """Find all pairs where one feature contains another."""
78
+ print("\nComputing containment pairs (optimized)...")
79
+
80
+ query = """
81
+ WITH features AS (
82
+ SELECT
83
+ id,
84
+ names."primary" AS name,
85
+ subtype,
86
+ country,
87
+ admin_level,
88
+ geometry,
89
+ ST_Envelope(geometry) AS bbox
90
+ FROM read_parquet(?)
91
+ WHERE country IN (SELECT unnest(?))
92
+ )
93
+ SELECT
94
+ a.id AS container_id,
95
+ a.name AS container_name,
96
+ a.subtype AS container_subtype,
97
+ b.id AS contained_id,
98
+ b.name AS contained_name,
99
+ b.subtype AS contained_subtype,
100
+ 'containment' AS relation_type
101
+ FROM features AS a
102
+ JOIN features AS b ON (
103
+ a.id != b.id
104
+ AND a.admin_level < b.admin_level
105
+ AND ST_Intersects(a.bbox, b.bbox)
106
+ AND ST_Within(b.geometry, a.geometry)
107
+ )
108
+ LIMIT ?
109
+ """
110
+
111
+ df = con.execute(query, [DIVISIONS_AREA_PATH, countries, limit]).fetchdf()
112
+ print(f"Found {len(df)} containment pairs")
113
+
114
+ return df
115
+
116
+
117
+ def compute_intersection_pairs(
118
+ con: duckdb.DuckDBPyConnection,
119
+ countries: list,
120
+ limit: int
121
+ ) -> pd.DataFrame:
122
+ """Find pairs that intersect but don't touch or contain."""
123
+ print("\nComputing intersection pairs (optimized)...")
124
+
125
+ query = """
126
+ WITH features AS (
127
+ SELECT
128
+ id,
129
+ names."primary" AS name,
130
+ subtype,
131
+ country,
132
+ admin_level,
133
+ geometry,
134
+ ST_Envelope(geometry) AS bbox
135
+ FROM read_parquet(?)
136
+ WHERE country IN (SELECT unnest(?))
137
+ )
138
+ SELECT
139
+ a.id AS anchor_id,
140
+ a.name AS anchor_name,
141
+ a.subtype AS anchor_subtype,
142
+ b.id AS target_id,
143
+ b.name AS target_name,
144
+ b.subtype AS target_subtype,
145
+ 'intersection' AS relation_type
146
+ FROM features AS a
147
+ JOIN features AS b ON (
148
+ a.id < b.id
149
+ AND ST_Intersects(a.bbox, b.bbox)
150
+ AND ST_Intersects(a.geometry, b.geometry)
151
+ AND NOT ST_Touches(a.geometry, b.geometry)
152
+ AND NOT ST_Within(a.geometry, b.geometry)
153
+ AND NOT ST_Within(b.geometry, a.geometry)
154
+ )
155
+ LIMIT ?
156
+ """
157
+
158
+ df = con.execute(query, [DIVISIONS_AREA_PATH, countries, limit]).fetchdf()
159
+ print(f"Found {len(df)} same-source intersection pairs")
160
+
161
+ return df
162
+
163
+
164
+ def compute_cross_source_relations(
165
+ con: duckdb.DuckDBPyConnection,
166
+ countries: list,
167
+ limit: int
168
+ ) -> pd.DataFrame:
169
+ """Find relations between divisions_area and natural_earth."""
170
+ print("\nComputing cross-source relations...")
171
+
172
+ query = """
173
+ WITH divisions AS (
174
+ SELECT
175
+ id,
176
+ names."primary" AS name,
177
+ subtype,
178
+ country,
179
+ geometry
180
+ FROM read_parquet(?)
181
+ WHERE country IN (SELECT unnest(?))
182
+ ),
183
+ natural_features AS (
184
+ SELECT
185
+ id,
186
+ names."primary" AS name,
187
+ subtype,
188
+ geometry
189
+ FROM read_parquet(?)
190
+ WHERE subtype IN ('sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay')
191
+ LIMIT 200
192
+ )
193
+ SELECT
194
+ d.id AS division_id,
195
+ d.name AS division_name,
196
+ d.subtype AS division_subtype,
197
+ d.country AS division_country,
198
+ n.id AS natural_id,
199
+ n.name AS natural_name,
200
+ n.subtype AS natural_subtype,
201
+ CASE
202
+ WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches'
203
+ WHEN ST_Within(d.geometry, n.geometry) THEN 'within'
204
+ WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains'
205
+ WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects'
206
+ END AS relation_type
207
+ FROM divisions AS d
208
+ JOIN natural_features AS n ON ST_Intersects(d.geometry, n.geometry)
209
+ LIMIT ?
210
+ """
211
+
212
+ df = con.execute(query, [DIVISIONS_AREA_PATH, countries, NATURAL_EARTH_PATH, limit]).fetchdf()
213
+ print(f"Found {len(df)} cross-source relations")
214
+
215
+ return df
216
+
217
+
218
+ def _make_connection():
219
+ """Create a new DuckDB connection with spatial extension loaded."""
220
+ con = duckdb.connect()
221
+ con.execute("INSTALL spatial")
222
+ con.execute("LOAD spatial")
223
+ return con
224
+
225
+
226
+ def _compute_and_save(compute_fn, countries, limit, output_path):
227
+ """Compute a relation table and save it to parquet. Uses its own DuckDB connection."""
228
+ con = _make_connection()
229
+ try:
230
+ df = compute_fn(con, countries, limit)
231
+ df.to_parquet(output_path, index=False)
232
+ print(f"Saved to {output_path}")
233
+ return df
234
+ finally:
235
+ con.close()
236
+
237
+
238
+ def main(countries: list = None, relation_limits: dict = None):
239
+ """Compute and save all relation tables in parallel.
240
+
241
+ Args:
242
+ countries: List of country codes to process
243
+ relation_limits: Dict with keys: adjacency, containment, intersection, cross_source
244
+ """
245
+ # Defaults
246
+ if countries is None:
247
+ countries = ['EC', 'BE', 'KE', 'AE', 'SG', 'CH']
248
+ if relation_limits is None:
249
+ relation_limits = {
250
+ 'adjacency': 50000,
251
+ 'containment': 1000,
252
+ 'intersection': 500,
253
+ 'cross_source': 500
254
+ }
255
+
256
+ output_dir = Path(__file__).parent.parent / "intermediate"
257
+ output_dir.mkdir(exist_ok=True, parents=True)
258
+
259
+ # Define all relation tasks
260
+ tasks = [
261
+ ("adjacency", compute_adjacency_pairs, relation_limits['adjacency'], output_dir / "adjacency_pairs.parquet"),
262
+ ("containment", compute_containment_pairs, relation_limits['containment'], output_dir / "containment_pairs.parquet"),
263
+ ("intersection", compute_intersection_pairs, relation_limits['intersection'], output_dir / "intersection_pairs.parquet"),
264
+ ("cross_source", compute_cross_source_relations, relation_limits['cross_source'], output_dir / "cross_source_relations.parquet"),
265
+ ]
266
+
267
+ print(f"Computing {len(tasks)} relation types in parallel...")
268
+
269
+ # Run all relation computations concurrently
270
+ with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
271
+ futures = {
272
+ executor.submit(_compute_and_save, compute_fn, countries, limit, path): name
273
+ for name, compute_fn, limit, path in tasks
274
+ }
275
+
276
+ for future in as_completed(futures):
277
+ name = futures[future]
278
+ try:
279
+ future.result()
280
+ except Exception as e:
281
+ print(f"ERROR computing {name}: {e}")
282
+ raise
283
+
284
+ print("\n✓ Relation tables build complete")
285
+
286
+
287
+ if __name__ == "__main__":
288
+ main()
dataset/scripts/cli.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CLI for synthetic dataset generation.
4
+
5
+ Usage:
6
+ python cli.py build-relations --config ../config.yaml
7
+ python cli.py generate-samples --config ../config.yaml
8
+ python cli.py generate-samples --config ../config.yaml --append
9
+ python cli.py full-pipeline --config ../config.yaml
10
+ """
11
+
12
+ import argparse
13
+ import sys
14
+ from pathlib import Path
15
+ import yaml
16
+ import subprocess
17
+ from typing import Dict, Set
18
+ import pandas as pd
19
+
20
+
21
+ def load_config(config_path: Path) -> dict:
22
+ """Load configuration from YAML file."""
23
+ with open(config_path) as f:
24
+ return yaml.safe_load(f)
25
+
26
+
27
+ def should_rebuild_relations(config: dict, intermediate_dir: Path, append: bool) -> bool:
28
+ """Check if relation tables need to be rebuilt.
29
+
30
+ Returns True if:
31
+ - Not in append mode (always rebuild)
32
+ - Relation tables don't exist
33
+ - Countries in config differ from countries in existing relation tables
34
+ """
35
+ if not append:
36
+ return True
37
+
38
+ # Check if relation tables exist
39
+ adjacency_file = intermediate_dir / "adjacency_pairs.parquet"
40
+ if not adjacency_file.exists():
41
+ print("WARNING: Relation tables not found, will rebuild despite append mode")
42
+ return True
43
+
44
+ # Check if countries have changed
45
+ try:
46
+ df = pd.read_parquet(adjacency_file)
47
+ if 'anchor_country' in df.columns:
48
+ existing_countries = set(df['anchor_country'].unique())
49
+ config_countries = set(config['countries'])
50
+
51
+ if existing_countries != config_countries:
52
+ print(f"WARNING: Countries changed:")
53
+ print(f" Previous: {sorted(existing_countries)}")
54
+ print(f" New: {sorted(config_countries)}")
55
+ print(f" Will rebuild relation tables to include new countries")
56
+ return True
57
+ else:
58
+ print(f"Countries unchanged: {sorted(config_countries)}")
59
+ return False
60
+ else:
61
+ # Can't determine countries, rebuild to be safe
62
+ print("WARNING: Cannot determine countries from existing tables, will rebuild")
63
+ return True
64
+ except Exception as e:
65
+ print(f"WARNING: Error reading existing relation tables: {e}")
66
+ print(" Will rebuild to be safe")
67
+ return True
68
+
69
+
70
+ def calculate_relation_limits(config: dict) -> Dict[str, int]:
71
+ """Auto-calculate relation limits based on sample targets."""
72
+ sample_targets = config['sample_targets']
73
+ retry_mult = config['generation']['retry_multiplier']
74
+ safety = config.get('auto_scaling', {}).get('safety_factor', 1.5)
75
+
76
+ # Map task families to relation types they need
77
+ family_to_relation = {
78
+ 'adjacency': 'adjacency',
79
+ 'containment': 'containment',
80
+ 'intersection': 'intersection',
81
+ 'buffer': 'adjacency', # Buffer uses adjacency pairs
82
+ 'set_operations': 'intersection', # Set ops use intersection pairs
83
+ 'partial_selection': 'containment', # Partial uses containment
84
+ 'aggregation': 'containment', # Aggregation uses containment
85
+ 'direct_lookup': None, # Uses inventory only
86
+ }
87
+
88
+ # Calculate required limits by summing needs per relation type
89
+ relation_needs = {}
90
+ for family, target in sample_targets.items():
91
+ relation_type = family_to_relation.get(family)
92
+ if relation_type:
93
+ needed = int(target * retry_mult * safety)
94
+ relation_needs[relation_type] = relation_needs.get(relation_type, 0) + needed
95
+
96
+ # Add cross-source (used by mixed-source partial selection)
97
+ partial_target = sample_targets.get('partial_selection', 0)
98
+ relation_needs['cross_source'] = int(partial_target * retry_mult * safety * 0.3)
99
+
100
+ # Apply manual overrides if specified
101
+ manual = config.get('auto_scaling', {}).get('manual_limits', {})
102
+ relation_needs.update(manual)
103
+
104
+ return relation_needs
105
+
106
+
107
+ def build_relations(config_path: Path):
108
+ """Run relation building with config."""
109
+ config = load_config(config_path)
110
+
111
+ # Auto-calculate relation limits
112
+ relation_limits = calculate_relation_limits(config)
113
+
114
+ print("=" * 60)
115
+ print("STEP 1: Building Relation Tables")
116
+ print("=" * 60)
117
+ print(f"Countries: {', '.join(config['countries'])}")
118
+ print(f"\nAuto-calculated relation limits:")
119
+ for rel_type, limit in relation_limits.items():
120
+ print(f" {rel_type:20s}: {limit:,}")
121
+ print()
122
+
123
+ # Import and run the relation builder
124
+ from dataset.scripts import build_relations
125
+
126
+ # Run with config parameters
127
+ build_relations.main(
128
+ countries=config['countries'],
129
+ relation_limits=relation_limits
130
+ )
131
+
132
+ print("\n✓ Relation tables built successfully")
133
+
134
+
135
+ def generate_samples(config_path: Path, append: bool = False):
136
+ """Run sample generation with config."""
137
+ config = load_config(config_path)
138
+
139
+ print("=" * 60)
140
+ print("STEP 2: Generating Samples")
141
+ print("=" * 60)
142
+ print(f"Targets: {config['sample_targets']}")
143
+ print(f"Workers: {config['generation']['max_workers']}")
144
+ print(f"Append mode: {append or config['generation']['append_mode']}")
145
+ print()
146
+
147
+ # Simple import - no number prefixes needed
148
+ from dataset.scripts import generate_samples as gs_module
149
+
150
+ # Override config values
151
+ gs_module.TARGET_COUNTS = config['sample_targets']
152
+ gs_module.MAX_WORKERS = config['generation']['max_workers']
153
+ gs_module.RETRY_MULTIPLIER = config['generation']['retry_multiplier']
154
+ gs_module.APPEND_MODE = append or config['generation']['append_mode']
155
+
156
+ # Run the main function
157
+ gs_module.main()
158
+
159
+ print("\n✓ Samples generated successfully")
160
+
161
+
162
+ def validate_dataset(config_path: Path):
163
+ """Run dataset validation."""
164
+ print("=" * 60)
165
+ print("STEP 3: Validating Dataset")
166
+ print("=" * 60)
167
+
168
+ script_dir = Path(__file__).parent
169
+ result = subprocess.run(
170
+ [sys.executable, str(script_dir / "validate_dataset.py")],
171
+ check=True
172
+ )
173
+
174
+ print("\n✓ Dataset validated successfully")
175
+
176
+
177
+ def export_dataset(config_path: Path):
178
+ """Run dataset export."""
179
+ print("=" * 60)
180
+ print("STEP 4: Exporting Dataset")
181
+ print("=" * 60)
182
+
183
+ script_dir = Path(__file__).parent
184
+ result = subprocess.run(
185
+ [sys.executable, str(script_dir / "export_training_data.py")],
186
+ check=True
187
+ )
188
+
189
+ print("\n✓ Dataset exported successfully")
190
+
191
+
192
+ def full_pipeline(config_path: Path, append: bool = False):
193
+ """Run the full pipeline."""
194
+ print("\n" + "=" * 60)
195
+ print("RUNNING FULL DATASET GENERATION PIPELINE")
196
+ print("=" * 60 + "\n")
197
+
198
+ config = load_config(config_path)
199
+
200
+ # Check if inventory exists, create if not
201
+ script_dir = Path(__file__).parent
202
+ intermediate_dir = script_dir.parent / "intermediate"
203
+ inventory_files = [
204
+ intermediate_dir / "divisions_area_inventory.parquet",
205
+ intermediate_dir / "natural_earth_inventory.parquet"
206
+ ]
207
+
208
+ inventory_missing = any(not f.exists() for f in inventory_files)
209
+
210
+ if inventory_missing:
211
+ print("=" * 60)
212
+ print("STEP 0: Building Entity Inventory")
213
+ print("=" * 60)
214
+ print("Inventory files not found. Building inventory...\n")
215
+
216
+ from dataset.scripts import build_inventory
217
+ build_inventory.main()
218
+
219
+ print("\n✓ Inventory built successfully\n")
220
+
221
+ # Check if we need to rebuild relations
222
+ need_rebuild = should_rebuild_relations(config, intermediate_dir, append)
223
+
224
+ if need_rebuild:
225
+ build_relations(config_path)
226
+ else:
227
+ print("Using existing relation tables (append mode, same countries)")
228
+
229
+ generate_samples(config_path, append=append)
230
+ validate_dataset(config_path)
231
+ export_dataset(config_path)
232
+
233
+ print("\n" + "=" * 60)
234
+ print("✓ PIPELINE COMPLETED SUCCESSFULLY")
235
+ print("=" * 60)
236
+
237
+
238
+ def main():
239
+ parser = argparse.ArgumentParser(
240
+ description="Synthetic dataset generation CLI",
241
+ formatter_class=argparse.RawDescriptionHelpFormatter,
242
+ epilog="""
243
+ Examples:
244
+ # Build relation tables only
245
+ python cli.py build-relations --config ../config.yaml
246
+
247
+ # Generate samples only
248
+ python cli.py generate-samples --config ../config.yaml
249
+
250
+ # Generate and append to existing dataset
251
+ python cli.py generate-samples --config ../config.yaml --append
252
+
253
+ # Run full pipeline
254
+ python cli.py full-pipeline --config ../config.yaml
255
+
256
+ # Run full pipeline in append mode (skip relation building)
257
+ python cli.py full-pipeline --config ../config.yaml --append
258
+ """
259
+ )
260
+
261
+ parser.add_argument(
262
+ 'command',
263
+ choices=['build-relations', 'generate-samples', 'validate', 'export', 'full-pipeline'],
264
+ help='Command to run'
265
+ )
266
+
267
+ parser.add_argument(
268
+ '--config',
269
+ type=Path,
270
+ required=True,
271
+ help='Path to config YAML file'
272
+ )
273
+
274
+ parser.add_argument(
275
+ '--append',
276
+ action='store_true',
277
+ help='Append to existing dataset instead of overwriting'
278
+ )
279
+
280
+ args = parser.parse_args()
281
+
282
+ # Validate config file exists
283
+ if not args.config.exists():
284
+ print(f"Error: Config file not found: {args.config}")
285
+ sys.exit(1)
286
+
287
+ # Run the appropriate command
288
+ try:
289
+ if args.command == 'build-relations':
290
+ build_relations(args.config)
291
+ elif args.command == 'generate-samples':
292
+ generate_samples(args.config, args.append)
293
+ elif args.command == 'validate':
294
+ validate_dataset(args.config)
295
+ elif args.command == 'export':
296
+ export_dataset(args.config)
297
+ elif args.command == 'full-pipeline':
298
+ full_pipeline(args.config, args.append)
299
+ except Exception as e:
300
+ print(f"\n✗ Error: {e}")
301
+ sys.exit(1)
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()
dataset/scripts/export_training_data.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Export validated dataset to train/val/test splits.
3
+
4
+ This script:
5
+ 1. Loads validated samples
6
+ 2. Splits into train (80%), val (10%), test (10%)
7
+ 3. Ensures balanced splits across task families
8
+ 4. Exports to JSONL format
9
+
10
+ Output:
11
+ - output/train.jsonl (80% of samples)
12
+ - output/val.jsonl (10% of samples)
13
+ - output/test.jsonl (10% of samples)
14
+ """
15
+
16
+ import json
17
+ import random
18
+ from pathlib import Path
19
+ from typing import List, Dict, Any
20
+ from collections import defaultdict
21
+
22
+
23
+ def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
24
+ """Load samples from JSONL file."""
25
+ samples = []
26
+ with open(jsonl_path, 'r') as f:
27
+ for line in f:
28
+ samples.append(json.loads(line))
29
+ return samples
30
+
31
+
32
+ def stratified_split(
33
+ samples: List[Dict[str, Any]],
34
+ train_ratio: float = 0.8,
35
+ val_ratio: float = 0.1,
36
+ test_ratio: float = 0.1,
37
+ random_seed: int = 42
38
+ ) -> tuple[List[Dict], List[Dict], List[Dict]]:
39
+ """Split samples by task family to ensure balanced distribution."""
40
+
41
+ random.seed(random_seed)
42
+
43
+ # Group by task family
44
+ by_family = defaultdict(list)
45
+ for sample in samples:
46
+ family = sample['metadata']['task_family']
47
+ by_family[family].append(sample)
48
+
49
+ train_samples = []
50
+ val_samples = []
51
+ test_samples = []
52
+
53
+ # Split each family
54
+ for family, family_samples in by_family.items():
55
+ # Shuffle
56
+ random.shuffle(family_samples)
57
+
58
+ n = len(family_samples)
59
+ n_train = int(n * train_ratio)
60
+ n_val = int(n * val_ratio)
61
+
62
+ train_samples.extend(family_samples[:n_train])
63
+ val_samples.extend(family_samples[n_train:n_train + n_val])
64
+ test_samples.extend(family_samples[n_train + n_val:])
65
+
66
+ # Shuffle final splits
67
+ random.shuffle(train_samples)
68
+ random.shuffle(val_samples)
69
+ random.shuffle(test_samples)
70
+
71
+ return train_samples, val_samples, test_samples
72
+
73
+
74
+ def save_split(samples: List[Dict[str, Any]], output_path: Path):
75
+ """Save samples to JSONL file."""
76
+ with open(output_path, 'w') as f:
77
+ for sample in samples:
78
+ f.write(json.dumps(sample) + '\n')
79
+
80
+
81
+ def print_split_stats(split_name: str, samples: List[Dict[str, Any]]):
82
+ """Print statistics for a split."""
83
+ families = defaultdict(int)
84
+ for sample in samples:
85
+ family = sample['metadata']['task_family']
86
+ families[family] += 1
87
+
88
+ print(f"\n{split_name}:")
89
+ print(f" Total: {len(samples)}")
90
+ for family, count in sorted(families.items()):
91
+ print(f" {family:20s}: {count:3d}")
92
+
93
+
94
+ def print_country_stats(samples: List[Dict[str, Any]]):
95
+ """Print country distribution statistics."""
96
+ country_counts = defaultdict(int)
97
+
98
+ # Extract countries from selected/answer candidates only
99
+ for sample in samples:
100
+ selected_ids = set(sample.get('target', {}).get('selected_candidates', []))
101
+ countries_in_sample = set()
102
+ for candidate in sample.get('candidates', []):
103
+ if candidate.get('candidate_id') in selected_ids:
104
+ country = candidate.get('country')
105
+ if country:
106
+ countries_in_sample.add(country)
107
+
108
+ # Count each unique country once per sample
109
+ for country in countries_in_sample:
110
+ country_counts[country] += 1
111
+
112
+ if not country_counts:
113
+ print("\nNo country information found in samples")
114
+ return
115
+
116
+ print(f"\nCOUNTRY DISTRIBUTION:")
117
+ print(f" Total unique countries: {len(country_counts)}")
118
+ print(f"\n Top countries by sample count:")
119
+
120
+ # Sort by count descending
121
+ sorted_countries = sorted(country_counts.items(), key=lambda x: x[1], reverse=True)
122
+
123
+ # Show top 20
124
+ for country, count in sorted_countries[:20]:
125
+ percentage = (count / len(samples)) * 100
126
+ print(f" {country:3s}: {count:4d} samples ({percentage:5.1f}%)")
127
+
128
+ if len(sorted_countries) > 20:
129
+ remaining = len(sorted_countries) - 20
130
+ remaining_count = sum(c for _, c in sorted_countries[20:])
131
+ print(f" ... and {remaining} more countries ({remaining_count} samples)")
132
+
133
+
134
+ def main():
135
+ """Export dataset splits."""
136
+
137
+ script_dir = Path(__file__).parent
138
+ output_dir = script_dir.parent / "output"
139
+
140
+ validated_file = output_dir / "dataset_validated.jsonl"
141
+ train_file = output_dir / "train.jsonl"
142
+ val_file = output_dir / "val.jsonl"
143
+ test_file = output_dir / "test.jsonl"
144
+
145
+ if not validated_file.exists():
146
+ print(f"Error: {validated_file} not found. Run 05_validate_dataset.py first.")
147
+ return
148
+
149
+ # Load validated samples
150
+ print("Loading validated samples...")
151
+ samples = load_samples(validated_file)
152
+ print(f"Loaded {len(samples)} samples")
153
+
154
+ # Split
155
+ print("\nSplitting dataset (80/10/10)...")
156
+ train_samples, val_samples, test_samples = stratified_split(samples)
157
+
158
+ # Save splits
159
+ print("\nSaving splits...")
160
+ save_split(train_samples, train_file)
161
+ save_split(val_samples, val_file)
162
+ save_split(test_samples, test_file)
163
+
164
+ print(f" Train: {train_file} ({len(train_samples)} samples)")
165
+ print(f" Val: {val_file} ({len(val_samples)} samples)")
166
+ print(f" Test: {test_file} ({len(test_samples)} samples)")
167
+
168
+ # Print statistics
169
+ print("\n" + "=" * 60)
170
+ print("SPLIT STATISTICS")
171
+ print("=" * 60)
172
+
173
+ print_split_stats("TRAIN", train_samples)
174
+ print_split_stats("VAL", val_samples)
175
+ print_split_stats("TEST", test_samples)
176
+
177
+ # Print country distribution
178
+ print("\n" + "=" * 60)
179
+ print("GEOGRAPHIC DISTRIBUTION")
180
+ print("=" * 60)
181
+ print_country_stats(samples)
182
+
183
+ print("\n✓ Export complete")
184
+ print(f"\nReady for training!")
185
+ print(f" Training data: {train_file}")
186
+ print(f" Validation data: {val_file}")
187
+ print(f" Test data: {test_file}")
188
+
189
+
190
+ if __name__ == "__main__":
191
+ main()
dataset/scripts/generate_samples.py ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate synthetic training samples for text-to-SQL task.
3
+
4
+ This script:
5
+ 1. Loads relation tables and entity inventories
6
+ 2. For each SQL template, samples valid anchors
7
+ 3. Renders and executes SQL to verify it works
8
+ 4. Builds candidate lists with controlled distractors
9
+ 5. Generates natural language questions using LLM
10
+ 6. Saves complete training samples
11
+
12
+ Output:
13
+ - output/samples/sample_*.json (individual samples)
14
+ - output/dataset_raw.jsonl (all samples)
15
+ """
16
+
17
+ import json
18
+ import random
19
+ import warnings
20
+ from pathlib import Path
21
+ from typing import List, Dict, Any, Optional
22
+ from concurrent.futures import ProcessPoolExecutor, as_completed
23
+ from functools import partial
24
+
25
+ import duckdb
26
+ import pandas as pd
27
+ from pydantic import BaseModel
28
+
29
+ # Suppress warnings
30
+ warnings.filterwarnings('ignore')
31
+
32
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
33
+
34
+ # Configurable parameters (can be overridden by CLI)
35
+ TARGET_COUNTS = None # Will be set in main() or by CLI
36
+ MAX_WORKERS = 8
37
+ RETRY_MULTIPLIER = 2
38
+ APPEND_MODE = False
39
+
40
+ # Import templates from same directory
41
+ from . import sql_templates
42
+ TEMPLATES = sql_templates.TEMPLATES
43
+ SQLTemplate = sql_templates.SQLTemplate
44
+ get_templates_by_family = sql_templates.get_templates_by_family
45
+
46
+
47
+ class Candidate(BaseModel):
48
+ """Candidate entity for grounding."""
49
+ candidate_id: str
50
+ source: str
51
+ id: str
52
+ name: str
53
+ subtype: Optional[str] = None
54
+ country: Optional[str] = None
55
+ region: Optional[str] = None
56
+ admin_level: Optional[int] = None
57
+ similarity: float = 0.0
58
+
59
+
60
+ class TrainingSample(BaseModel):
61
+ """Complete training sample."""
62
+ id: str
63
+ question: str
64
+ candidates: List[Candidate]
65
+ target: Dict[str, Any]
66
+ metadata: Dict[str, Any]
67
+
68
+
69
+ def load_relation_tables(intermediate_dir: Path, quiet: bool = False) -> Dict[str, pd.DataFrame]:
70
+ """Load all precomputed relation tables."""
71
+ tables = {}
72
+
73
+ for file in intermediate_dir.glob("*.parquet"):
74
+ name = file.stem
75
+ tables[name] = pd.read_parquet(file)
76
+ if not quiet:
77
+ print(f" {name}: {len(tables[name])} rows")
78
+
79
+ return tables
80
+
81
+
82
+ def sample_adjacency_anchor(adjacency_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
83
+ """Sample a random adjacency pair."""
84
+ if adjacency_df.empty:
85
+ return None
86
+
87
+ row = adjacency_df.sample(n=1).iloc[0]
88
+ return {
89
+ 'anchor_id': row['anchor_id'],
90
+ 'anchor_name': row['anchor_name'],
91
+ 'anchor_subtype': row['anchor_subtype'],
92
+ 'anchor_country': row.get('anchor_country'), # May not exist in all tables
93
+ 'target_subtype': row.get('target_subtype')
94
+ }
95
+
96
+
97
+ def sample_intersection_anchor(intersection_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
98
+ """Sample a random intersection pair."""
99
+ if intersection_df.empty:
100
+ return None
101
+
102
+ row = intersection_df.sample(n=1).iloc[0]
103
+ return {
104
+ 'anchor_id': row['anchor_id'],
105
+ 'anchor_name': row['anchor_name'],
106
+ 'anchor_subtype': row['anchor_subtype'],
107
+ 'target_id': row.get('target_id'),
108
+ 'target_name': row.get('target_name'),
109
+ 'target_subtype': row.get('target_subtype')
110
+ }
111
+
112
+
113
+ def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
114
+ """Sample a random containment pair."""
115
+ if containment_df.empty:
116
+ return None
117
+
118
+ row = containment_df.sample(n=1).iloc[0]
119
+ return {
120
+ 'container_id': row['container_id'],
121
+ 'container_name': row['container_name'],
122
+ 'container_subtype': row['container_subtype'],
123
+ 'contained_subtype': row['contained_subtype']
124
+ }
125
+
126
+
127
+ def sample_cross_source_anchor(cross_source_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
128
+ """Sample a random cross-source relation."""
129
+ if cross_source_df.empty:
130
+ return None
131
+
132
+ row = cross_source_df.sample(n=1).iloc[0]
133
+ return {
134
+ 'division_id': row['division_id'],
135
+ 'division_name': row['division_name'],
136
+ 'division_subtype': row['division_subtype'],
137
+ 'natural_id': row['natural_id'],
138
+ 'natural_name': row['natural_name'],
139
+ 'natural_subtype': row['natural_subtype'],
140
+ 'relation_type': row['relation_type']
141
+ }
142
+
143
+
144
+ def build_candidate_list(
145
+ con: duckdb.DuckDBPyConnection,
146
+ anchor_id: str,
147
+ anchor_name: str,
148
+ anchor_source: str,
149
+ num_candidates: int = 10,
150
+ difficulty: str = "medium"
151
+ ) -> List[Candidate]:
152
+ """Build candidate list with true anchor + distractors."""
153
+
154
+ # Helper to convert pandas NA to None
155
+ def safe_get(row, key, default=None):
156
+ val = row.get(key, default)
157
+ return None if pd.isna(val) else val
158
+
159
+ # Get the true anchor
160
+ if anchor_source == "divisions_area":
161
+ query = """
162
+ SELECT
163
+ id,
164
+ names."primary" AS name,
165
+ subtype,
166
+ country,
167
+ region,
168
+ admin_level
169
+ FROM read_parquet(?)
170
+ WHERE id = ?
171
+ """
172
+ anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
173
+ else:
174
+ query = """
175
+ SELECT
176
+ id,
177
+ names."primary" AS name,
178
+ subtype
179
+ FROM read_parquet(?)
180
+ WHERE id = ?
181
+ """
182
+ anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
183
+
184
+ # Build true candidate
185
+ true_candidate = Candidate(
186
+ candidate_id="c1",
187
+ source=anchor_source,
188
+ id=anchor_id,
189
+ name=safe_get(anchor_row, 'name'),
190
+ subtype=safe_get(anchor_row, 'subtype'),
191
+ country=safe_get(anchor_row, 'country'),
192
+ region=safe_get(anchor_row, 'region'),
193
+ admin_level=safe_get(anchor_row, 'admin_level'),
194
+ similarity=1.0
195
+ )
196
+
197
+ # Build distractors based on difficulty
198
+ distractors = build_distractors(
199
+ con,
200
+ anchor_name,
201
+ anchor_source,
202
+ anchor_id,
203
+ num_candidates - 1,
204
+ difficulty
205
+ )
206
+
207
+ # Combine and shuffle
208
+ candidates = [true_candidate] + distractors
209
+ random.shuffle(candidates)
210
+
211
+ # Reassign candidate IDs after shuffling
212
+ for i, cand in enumerate(candidates, 1):
213
+ cand.candidate_id = f"c{i}"
214
+
215
+ return candidates
216
+
217
+
218
+ def build_distractors(
219
+ con: duckdb.DuckDBPyConnection,
220
+ anchor_name: str,
221
+ anchor_source: str,
222
+ exclude_id: str,
223
+ num_distractors: int,
224
+ difficulty: str
225
+ ) -> List[Candidate]:
226
+ """Build distractor candidates using fuzzy search."""
227
+
228
+ # Fuzzy search for similar names
229
+ if anchor_source == "divisions_area":
230
+ query = """
231
+ SELECT
232
+ id,
233
+ names."primary" AS name,
234
+ subtype,
235
+ country,
236
+ region,
237
+ admin_level,
238
+ jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
239
+ FROM read_parquet(?)
240
+ WHERE id != ?
241
+ AND names."primary" IS NOT NULL
242
+ ORDER BY similarity DESC
243
+ LIMIT ?
244
+ """
245
+ df = con.execute(query, [
246
+ anchor_name, DIVISIONS_AREA_PATH, exclude_id, num_distractors
247
+ ]).fetchdf()
248
+ source = "divisions_area"
249
+ else:
250
+ query = """
251
+ SELECT
252
+ id,
253
+ names."primary" AS name,
254
+ subtype,
255
+ jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
256
+ FROM read_parquet(?)
257
+ WHERE id != ?
258
+ AND names."primary" IS NOT NULL
259
+ ORDER BY similarity DESC
260
+ LIMIT ?
261
+ """
262
+ df = con.execute(query, [
263
+ anchor_name, NATURAL_EARTH_PATH, exclude_id, num_distractors
264
+ ]).fetchdf()
265
+ source = "natural_earth"
266
+
267
+ # Helper to convert pandas NA to None
268
+ def safe_get(row, key, default=None):
269
+ val = row.get(key, default)
270
+ return None if pd.isna(val) else val
271
+
272
+ distractors = []
273
+ for _, row in df.iterrows():
274
+ distractors.append(Candidate(
275
+ candidate_id="temp", # Will be reassigned
276
+ source=source,
277
+ id=row['id'],
278
+ name=safe_get(row, 'name'),
279
+ subtype=safe_get(row, 'subtype'),
280
+ country=safe_get(row, 'country'),
281
+ region=safe_get(row, 'region'),
282
+ admin_level=safe_get(row, 'admin_level'),
283
+ similarity=float(row['similarity'])
284
+ ))
285
+
286
+ return distractors
287
+
288
+
289
+ def generate_adjacency_sample(
290
+ con: duckdb.DuckDBPyConnection,
291
+ adjacency_df: pd.DataFrame,
292
+ sample_id: str
293
+ ) -> Optional[TrainingSample]:
294
+ """Generate a sample for adjacency task."""
295
+
296
+ anchor = sample_adjacency_anchor(adjacency_df)
297
+ if not anchor:
298
+ return None
299
+
300
+ # Build SQL
301
+ sql = f"""WITH a AS (
302
+ SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}')
303
+ WHERE id = '{anchor['anchor_id']}'
304
+ )
305
+ SELECT b.id, b.names."primary" AS name, b.geometry
306
+ FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a
307
+ WHERE b.id != '{anchor['anchor_id']}'
308
+ AND b.subtype = '{anchor['target_subtype']}'
309
+ AND ST_Touches(a.geometry, b.geometry)"""
310
+
311
+ # Execute to verify
312
+ try:
313
+ result = con.execute(sql).fetchdf()
314
+ if result.empty:
315
+ return None
316
+ except Exception as e:
317
+ print(f"SQL execution failed: {e}")
318
+ return None
319
+
320
+ # Build candidates
321
+ candidates = build_candidate_list(
322
+ con,
323
+ anchor['anchor_id'],
324
+ anchor['anchor_name'],
325
+ "divisions_area",
326
+ num_candidates=10,
327
+ difficulty="medium"
328
+ )
329
+
330
+ # Find which candidate is the true anchor
331
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['anchor_id']]
332
+
333
+ # Generate question
334
+ question = f"Which {anchor['target_subtype']}s border {anchor['anchor_name']}?"
335
+
336
+ return TrainingSample(
337
+ id=sample_id,
338
+ question=question,
339
+ candidates=candidates,
340
+ target={
341
+ "selected_candidates": selected_candidate_ids,
342
+ "sql": sql
343
+ },
344
+ metadata={
345
+ "task_family": "adjacency",
346
+ "sql_difficulty": "medium",
347
+ "grounding_difficulty": "medium",
348
+ "template_id": "adj_02",
349
+ "num_candidates": len(candidates),
350
+ "anchor_source": "divisions_area",
351
+ "sql_verified": True
352
+ }
353
+ )
354
+
355
+
356
+ def generate_containment_sample(
357
+ con: duckdb.DuckDBPyConnection,
358
+ containment_df: pd.DataFrame,
359
+ sample_id: str
360
+ ) -> Optional[TrainingSample]:
361
+ """Generate a sample for containment task."""
362
+
363
+ anchor = sample_containment_anchor(containment_df)
364
+ if not anchor:
365
+ return None
366
+
367
+ # Build SQL
368
+ sql = f"""WITH a AS (
369
+ SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}')
370
+ WHERE id = '{anchor['container_id']}'
371
+ )
372
+ SELECT b.id, b.names."primary" AS name, b.geometry
373
+ FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a
374
+ WHERE b.id != '{anchor['container_id']}'
375
+ AND b.subtype = '{anchor['contained_subtype']}'
376
+ AND ST_Within(b.geometry, a.geometry)"""
377
+
378
+ # Execute to verify
379
+ try:
380
+ result = con.execute(sql).fetchdf()
381
+ if result.empty:
382
+ return None
383
+ except Exception as e:
384
+ print(f"SQL execution failed: {e}")
385
+ return None
386
+
387
+ # Build candidates
388
+ candidates = build_candidate_list(
389
+ con,
390
+ anchor['container_id'],
391
+ anchor['container_name'],
392
+ "divisions_area",
393
+ num_candidates=10,
394
+ difficulty="medium"
395
+ )
396
+
397
+ # Find which candidate is the true anchor
398
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['container_id']]
399
+
400
+ # Generate question
401
+ question = f"What {anchor['contained_subtype']}s are in {anchor['container_name']}?"
402
+
403
+ return TrainingSample(
404
+ id=sample_id,
405
+ question=question,
406
+ candidates=candidates,
407
+ target={
408
+ "selected_candidates": selected_candidate_ids,
409
+ "sql": sql
410
+ },
411
+ metadata={
412
+ "task_family": "containment",
413
+ "sql_difficulty": "medium",
414
+ "grounding_difficulty": "medium",
415
+ "template_id": "contain_01",
416
+ "num_candidates": len(candidates),
417
+ "anchor_source": "divisions_area",
418
+ "sql_verified": True
419
+ }
420
+ )
421
+
422
+
423
+ def sample_random_entity(
424
+ con: duckdb.DuckDBPyConnection,
425
+ inventory_df: pd.DataFrame,
426
+ source: str
427
+ ) -> Optional[Dict[str, Any]]:
428
+ """Sample a random entity from inventory."""
429
+ if inventory_df.empty:
430
+ return None
431
+
432
+ row = inventory_df.sample(n=1).iloc[0]
433
+ return {
434
+ 'id': row['id'],
435
+ 'name': row['name'],
436
+ 'subtype': row.get('subtype'),
437
+ 'country': row.get('country'),
438
+ 'source': source
439
+ }
440
+
441
+
442
+ def generate_template_based_sample(
443
+ con: duckdb.DuckDBPyConnection,
444
+ template: SQLTemplate,
445
+ tables: Dict[str, pd.DataFrame],
446
+ sample_id: str
447
+ ) -> Optional[TrainingSample]:
448
+ """Generate a sample based on a SQL template."""
449
+
450
+ # Sample anchor based on template requirements
451
+ if template.family == "direct_lookup":
452
+ # Just pick a random entity
453
+ if template.anchor_source == "divisions_area":
454
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
455
+ else:
456
+ anchor = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
457
+
458
+ if not anchor:
459
+ return None
460
+
461
+ # Render SQL
462
+ sql = template.sql_template.format(
463
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
464
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
465
+ anchor_id=anchor['id']
466
+ )
467
+
468
+ # Build candidates
469
+ candidates = build_candidate_list(
470
+ con, anchor['id'], anchor['name'], anchor['source'],
471
+ num_candidates=10, difficulty="easy"
472
+ )
473
+
474
+ # Question
475
+ question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
476
+
477
+ elif template.family == "adjacency":
478
+ anchor = sample_adjacency_anchor(tables['adjacency_pairs'])
479
+ if not anchor:
480
+ return None
481
+
482
+ sql = template.sql_template.format(
483
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
484
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
485
+ anchor_id=anchor['anchor_id'],
486
+ target_subtype=anchor['target_subtype']
487
+ )
488
+
489
+ candidates = build_candidate_list(
490
+ con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
491
+ num_candidates=10, difficulty="medium"
492
+ )
493
+
494
+ question = random.choice(template.question_hints).format(
495
+ anchor_name=anchor['anchor_name'],
496
+ target_subtype=anchor['target_subtype']
497
+ )
498
+
499
+ elif template.family == "containment":
500
+ anchor = sample_containment_anchor(tables['containment_pairs'])
501
+ if not anchor:
502
+ return None
503
+
504
+ sql = template.sql_template.format(
505
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
506
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
507
+ anchor_id=anchor['container_id'],
508
+ target_subtype=anchor['contained_subtype']
509
+ )
510
+
511
+ candidates = build_candidate_list(
512
+ con, anchor['container_id'], anchor['container_name'], 'divisions_area',
513
+ num_candidates=10, difficulty="medium"
514
+ )
515
+
516
+ question = random.choice(template.question_hints).format(
517
+ anchor_name=anchor['container_name'],
518
+ target_subtype=anchor['contained_subtype']
519
+ )
520
+
521
+ elif template.family == "intersection":
522
+ if template.anchor_source == "natural_earth":
523
+ anchor = sample_cross_source_anchor(tables['cross_source_relations'])
524
+ if not anchor:
525
+ return None
526
+
527
+ sql = template.sql_template.format(
528
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
529
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
530
+ anchor_id=anchor['natural_id'],
531
+ target_subtype='country'
532
+ )
533
+
534
+ candidates = build_candidate_list(
535
+ con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
536
+ num_candidates=10, difficulty="medium"
537
+ )
538
+
539
+ question = random.choice(template.question_hints).format(
540
+ anchor_name=anchor['natural_name'],
541
+ target_subtype='country'
542
+ )
543
+ else:
544
+ # Same-source intersection
545
+ anchor = sample_intersection_anchor(tables['intersection_pairs'])
546
+ if not anchor:
547
+ return None
548
+
549
+ # Use a generic subtype if not available
550
+ target_subtype = anchor.get('target_subtype') or 'region'
551
+
552
+ sql = template.sql_template.format(
553
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
554
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
555
+ anchor_id=anchor['anchor_id'],
556
+ target_subtype=target_subtype
557
+ )
558
+
559
+ candidates = build_candidate_list(
560
+ con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
561
+ num_candidates=10, difficulty="medium"
562
+ )
563
+
564
+ question = random.choice(template.question_hints).format(
565
+ anchor_name=anchor['anchor_name'],
566
+ target_subtype=target_subtype
567
+ )
568
+
569
+ elif template.family == "set_operations":
570
+ # Union of two entities
571
+ anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
572
+ anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
573
+
574
+ if not anchor1 or not anchor2:
575
+ return None
576
+
577
+ sql = template.sql_template.format(
578
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
579
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
580
+ anchor_id_1=anchor1['id'],
581
+ anchor_id_2=anchor2['id']
582
+ )
583
+
584
+ # Build candidates for both anchors
585
+ candidates1 = build_candidate_list(
586
+ con, anchor1['id'], anchor1['name'], 'divisions_area',
587
+ num_candidates=5, difficulty="medium"
588
+ )
589
+ candidates2 = build_candidate_list(
590
+ con, anchor2['id'], anchor2['name'], 'divisions_area',
591
+ num_candidates=5, difficulty="medium"
592
+ )
593
+
594
+ # Combine and deduplicate
595
+ candidates = candidates1 + candidates2
596
+ seen_ids = set()
597
+ unique_candidates = []
598
+ for c in candidates:
599
+ if c.id not in seen_ids:
600
+ unique_candidates.append(c)
601
+ seen_ids.add(c.id)
602
+ candidates = unique_candidates[:10]
603
+
604
+ # Reassign IDs
605
+ for i, c in enumerate(candidates, 1):
606
+ c.candidate_id = f"c{i}"
607
+
608
+ question = f"{anchor1['name']} and {anchor2['name']}"
609
+
610
+ elif template.family == "buffer":
611
+ # Buffer operations
612
+ if template.num_anchors == 1:
613
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
614
+ if not anchor:
615
+ return None
616
+
617
+ buffer_degrees = random.choice([0.1, 0.5, 1.0])
618
+
619
+ sql = template.sql_template.format(
620
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
621
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
622
+ anchor_id=anchor['id'],
623
+ buffer_degrees=buffer_degrees
624
+ )
625
+
626
+ candidates = build_candidate_list(
627
+ con, anchor['id'], anchor['name'], 'divisions_area',
628
+ num_candidates=10, difficulty="medium"
629
+ )
630
+
631
+ question = random.choice(template.question_hints).format(
632
+ anchor_name=anchor['name'],
633
+ buffer_degrees=buffer_degrees
634
+ )
635
+ else:
636
+ # Two anchor buffer
637
+ anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
638
+ anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
639
+
640
+ if not anchor1 or not anchor2:
641
+ return None
642
+
643
+ buffer_degrees = random.choice([0.1, 0.5])
644
+
645
+ sql = template.sql_template.format(
646
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
647
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
648
+ anchor_id_1=anchor1['id'],
649
+ anchor_id_2=anchor2['id'],
650
+ buffer_degrees=buffer_degrees
651
+ )
652
+
653
+ candidates1 = build_candidate_list(
654
+ con, anchor1['id'], anchor1['name'], 'divisions_area',
655
+ num_candidates=5, difficulty="medium"
656
+ )
657
+ candidates2 = build_candidate_list(
658
+ con, anchor2['id'], anchor2['name'], 'divisions_area',
659
+ num_candidates=5, difficulty="medium"
660
+ )
661
+
662
+ candidates = candidates1 + candidates2
663
+ seen_ids = set()
664
+ unique_candidates = []
665
+ for c in candidates:
666
+ if c.id not in seen_ids:
667
+ unique_candidates.append(c)
668
+ seen_ids.add(c.id)
669
+ candidates = unique_candidates[:10]
670
+
671
+ for i, c in enumerate(candidates, 1):
672
+ c.candidate_id = f"c{i}"
673
+
674
+ question = random.choice(template.question_hints).format(
675
+ anchor_1_name=anchor1['name'],
676
+ anchor_2_name=anchor2['name'],
677
+ buffer_degrees=buffer_degrees
678
+ )
679
+
680
+ elif template.family == "partial_selection":
681
+ # Partial selection (northern half, clipping, etc.)
682
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
683
+ if not anchor:
684
+ return None
685
+
686
+ if template.num_anchors == 1:
687
+ sql = template.sql_template.format(
688
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
689
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
690
+ anchor_id=anchor['id']
691
+ )
692
+
693
+ question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
694
+ else:
695
+ # Mixed source clipping
696
+ clip_feature = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
697
+ if not clip_feature:
698
+ return None
699
+
700
+ sql = template.sql_template.format(
701
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
702
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
703
+ anchor_id=anchor['id'],
704
+ clip_feature_id=clip_feature['id']
705
+ )
706
+
707
+ question = random.choice(template.question_hints).format(
708
+ anchor_name=anchor['name'],
709
+ clip_feature_name=clip_feature['name']
710
+ )
711
+
712
+ candidates = build_candidate_list(
713
+ con, anchor['id'], anchor['name'], 'divisions_area',
714
+ num_candidates=10, difficulty="hard"
715
+ )
716
+
717
+ elif template.family == "aggregation":
718
+ # Aggregation queries (e.g., largest N localities in a region)
719
+ top_n = random.choice([3, 5, 10])
720
+
721
+ # Check if this is a country-level query (agg_04, agg_05)
722
+ if template.template_id in ['agg_04', 'agg_05']:
723
+ # Country-level aggregation
724
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
725
+ if not anchor:
726
+ return None
727
+
728
+ country = anchor.get('country', 'EC')
729
+ target_subtype = random.choice(['locality', 'region'])
730
+
731
+ sql = template.sql_template.format(
732
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
733
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
734
+ country=country,
735
+ target_subtype=target_subtype,
736
+ top_n=top_n
737
+ )
738
+
739
+ candidates = build_candidate_list(
740
+ con, anchor['id'], anchor['name'], 'divisions_area',
741
+ num_candidates=10, difficulty="hard"
742
+ )
743
+
744
+ question = random.choice(template.question_hints).format(
745
+ top_n=top_n,
746
+ target_subtype=target_subtype,
747
+ country=country
748
+ )
749
+ else:
750
+ # Container-based aggregation (within a region)
751
+ anchor = sample_containment_anchor(tables['containment_pairs'])
752
+ if not anchor:
753
+ return None
754
+
755
+ target_subtype = anchor.get('contained_subtype', 'locality')
756
+
757
+ sql = template.sql_template.format(
758
+ DIVISIONS_AREA_PATH=DIVISIONS_AREA_PATH,
759
+ NATURAL_EARTH_PATH=NATURAL_EARTH_PATH,
760
+ anchor_id=anchor['container_id'],
761
+ target_subtype=target_subtype,
762
+ top_n=top_n
763
+ )
764
+
765
+ candidates = build_candidate_list(
766
+ con, anchor['container_id'], anchor['container_name'], 'divisions_area',
767
+ num_candidates=10, difficulty="hard"
768
+ )
769
+
770
+ question = random.choice(template.question_hints).format(
771
+ top_n=top_n,
772
+ target_subtype=target_subtype,
773
+ anchor_name=anchor['container_name']
774
+ )
775
+
776
+ else:
777
+ # Skip unsupported families
778
+ return None
779
+
780
+ # Execute SQL to verify
781
+ try:
782
+ result = con.execute(sql).fetchdf()
783
+ if result.empty:
784
+ return None
785
+ except Exception as e:
786
+ # Errors are tracked in worker return, no need to print
787
+ return None
788
+
789
+ # Find selected candidates
790
+ if template.family == "set_operations":
791
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id in [anchor1['id'], anchor2['id']]]
792
+ else:
793
+ anchor_id_to_find = anchor.get('anchor_id') or anchor.get('container_id') or anchor.get('natural_id') or anchor.get('id')
794
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor_id_to_find]
795
+
796
+ return TrainingSample(
797
+ id=sample_id,
798
+ question=question,
799
+ candidates=candidates,
800
+ target={
801
+ "selected_candidates": selected_candidate_ids,
802
+ "sql": sql
803
+ },
804
+ metadata={
805
+ "task_family": template.family,
806
+ "sql_difficulty": template.sql_difficulty,
807
+ "grounding_difficulty": "medium",
808
+ "template_id": template.template_id,
809
+ "num_candidates": len(candidates),
810
+ "anchor_source": template.anchor_source,
811
+ "sql_verified": True
812
+ }
813
+ )
814
+
815
+
816
+ def generate_cross_source_sample(
817
+ con: duckdb.DuckDBPyConnection,
818
+ cross_source_df: pd.DataFrame,
819
+ sample_id: str
820
+ ) -> Optional[TrainingSample]:
821
+ """Generate a sample for cross-source intersection task."""
822
+
823
+ anchor = sample_cross_source_anchor(cross_source_df)
824
+ if not anchor:
825
+ return None
826
+
827
+ # Build SQL (natural feature -> divisions)
828
+ sql = f"""WITH a AS (
829
+ SELECT geometry FROM read_parquet('{NATURAL_EARTH_PATH}')
830
+ WHERE id = '{anchor['natural_id']}'
831
+ )
832
+ SELECT b.id, b.names."primary" AS name, b.geometry
833
+ FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a
834
+ WHERE b.subtype = 'country'
835
+ AND ST_Intersects(b.geometry, a.geometry)"""
836
+
837
+ # Execute to verify
838
+ try:
839
+ result = con.execute(sql).fetchdf()
840
+ if result.empty:
841
+ return None
842
+ except Exception as e:
843
+ print(f"SQL execution failed: {e}")
844
+ return None
845
+
846
+ # Build candidates for natural feature
847
+ candidates = build_candidate_list(
848
+ con,
849
+ anchor['natural_id'],
850
+ anchor['natural_name'],
851
+ "natural_earth",
852
+ num_candidates=10,
853
+ difficulty="medium"
854
+ )
855
+
856
+ # Find which candidate is the true anchor
857
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['natural_id']]
858
+
859
+ # Generate question
860
+ question = f"Which countries intersect the {anchor['natural_name']}?"
861
+
862
+ return TrainingSample(
863
+ id=sample_id,
864
+ question=question,
865
+ candidates=candidates,
866
+ target={
867
+ "selected_candidates": selected_candidate_ids,
868
+ "sql": sql
869
+ },
870
+ metadata={
871
+ "task_family": "intersection",
872
+ "sql_difficulty": "medium-hard",
873
+ "grounding_difficulty": "medium",
874
+ "template_id": "intersect_02",
875
+ "num_candidates": len(candidates),
876
+ "anchor_source": "natural_earth",
877
+ "sql_verified": True
878
+ }
879
+ )
880
+
881
+
882
+ def generate_sample_batch_worker(args):
883
+ """Worker function that processes a batch of work items with a single DuckDB connection.
884
+
885
+ Initializes DuckDB, spatial extension, templates module, and relation tables
886
+ ONCE per batch, then processes all items sequentially.
887
+ """
888
+ from pathlib import Path
889
+
890
+ work_items, intermediate_dir_str = args
891
+
892
+ # Convert string back to Path
893
+ intermediate_dir = Path(intermediate_dir_str)
894
+
895
+ # Initialize DuckDB ONCE for the entire batch
896
+ con = duckdb.connect()
897
+ con.execute("SET enable_progress_bar=false")
898
+ con.execute("INSTALL spatial")
899
+ con.execute("LOAD spatial")
900
+
901
+ # Load relation tables ONCE
902
+ tables = load_relation_tables(intermediate_dir, quiet=True)
903
+
904
+ # Process all items in batch
905
+ results = []
906
+ for family, template_dict, sample_id, _ in work_items:
907
+ # Reconstruct template from dict (sql_templates is already imported at module level)
908
+ template = sql_templates.SQLTemplate(**template_dict)
909
+ try:
910
+ sample = generate_template_based_sample(con, template, tables, sample_id)
911
+ if sample:
912
+ results.append((sample, family, template.template_id, None))
913
+ else:
914
+ results.append((None, family, template.template_id, "Empty result"))
915
+ except Exception as e:
916
+ results.append((None, family, template_dict.get('template_id', 'unknown'), str(e)))
917
+
918
+ con.close()
919
+ return results
920
+
921
+
922
+ def main():
923
+ """Generate training samples."""
924
+ global TARGET_COUNTS, MAX_WORKERS, RETRY_MULTIPLIER, APPEND_MODE
925
+
926
+ # Setup paths
927
+ script_dir = Path(__file__).parent
928
+ intermediate_dir = script_dir.parent / "intermediate"
929
+ output_dir = script_dir.parent / "output"
930
+
931
+ output_dir.mkdir(exist_ok=True, parents=True)
932
+
933
+ # Load relation tables once to check availability
934
+ print("Loading relation tables...")
935
+ tables = load_relation_tables(intermediate_dir, quiet=False)
936
+
937
+ # Use configured target counts or defaults
938
+ if TARGET_COUNTS is None:
939
+ target_counts = {
940
+ 'direct_lookup': 100,
941
+ 'adjacency': 200,
942
+ 'containment': 100,
943
+ 'intersection': 150,
944
+ 'buffer': 100,
945
+ 'set_operations': 150,
946
+ 'partial_selection': 100,
947
+ 'aggregation': 100
948
+ }
949
+ else:
950
+ target_counts = TARGET_COUNTS
951
+
952
+ # Load existing samples if in append mode
953
+ existing_samples = []
954
+ existing_sample_ids = set()
955
+ jsonl_file = output_dir / "dataset_raw.jsonl"
956
+
957
+ if APPEND_MODE and jsonl_file.exists():
958
+ print(f"\nAppend mode: Loading existing samples from {jsonl_file}")
959
+ with open(jsonl_file, 'r') as f:
960
+ for line in f:
961
+ if line.strip():
962
+ sample_data = json.loads(line)
963
+ existing_samples.append(sample_data)
964
+ existing_sample_ids.add(sample_data['id'])
965
+ print(f" Found {len(existing_samples)} existing samples")
966
+
967
+ # Determine starting sample counter
968
+ max_existing_id = max([int(s['id'].split('_')[1]) for s in existing_samples if s['id'].startswith('sample_')], default=0)
969
+ sample_counter = max_existing_id + 1
970
+ else:
971
+ sample_counter = 1
972
+
973
+ # Prepare work items for parallel processing
974
+ work_items = []
975
+ starting_sample_counter = sample_counter # Track starting point for logging
976
+
977
+ for family, target_count in target_counts.items():
978
+ if target_count == 0:
979
+ continue
980
+
981
+ # Get templates for this family
982
+ family_templates = [t for t in TEMPLATES if t.family == family]
983
+ if not family_templates:
984
+ print(f"No templates found for {family}, skipping...")
985
+ continue
986
+
987
+ # Create work items (try retry_multiplier * target to account for failures)
988
+ for _ in range(target_count * RETRY_MULTIPLIER):
989
+ template = random.choice(family_templates)
990
+ # Convert template to dict for pickling
991
+ template_dict = {
992
+ 'template_id': template.template_id,
993
+ 'family': template.family,
994
+ 'sql_difficulty': template.sql_difficulty,
995
+ 'anchor_source': template.anchor_source,
996
+ 'num_anchors': template.num_anchors,
997
+ 'sql_template': template.sql_template,
998
+ 'question_hints': template.question_hints,
999
+ 'target_subtype': template.target_subtype,
1000
+ 'requires_buffer': template.requires_buffer,
1001
+ 'requires_aggregation': template.requires_aggregation
1002
+ }
1003
+ work_items.append((
1004
+ family,
1005
+ template_dict,
1006
+ f"sample_{sample_counter:03d}",
1007
+ str(intermediate_dir) # Convert Path to string for pickling
1008
+ ))
1009
+ sample_counter += 1
1010
+
1011
+ # Shuffle work items for balanced batches across families
1012
+ random.shuffle(work_items)
1013
+
1014
+ # Partition work items into batches (one per worker)
1015
+ num_workers = min(MAX_WORKERS, len(work_items))
1016
+ if num_workers == 0:
1017
+ print("No work items to process")
1018
+ return
1019
+ batch_size = (len(work_items) + num_workers - 1) // num_workers
1020
+ batches = []
1021
+ for i in range(0, len(work_items), batch_size):
1022
+ batch = work_items[i:i + batch_size]
1023
+ batches.append((batch, str(intermediate_dir)))
1024
+
1025
+ # Generate samples in parallel (one batch per worker)
1026
+ active_families = len([f for f in target_counts.values() if f > 0])
1027
+ print(f"\nGenerating {len(work_items)} samples across {active_families} families...")
1028
+ print(f" Split into {len(batches)} batches of ~{batch_size} items (1 DuckDB init per batch)")
1029
+ if APPEND_MODE and existing_samples:
1030
+ print(f"Appending: starting from sample_{starting_sample_counter:03d}")
1031
+
1032
+ all_samples = []
1033
+ family_progress = {f: {'success': 0, 'failed': 0} for f in target_counts.keys() if target_counts[f] > 0}
1034
+
1035
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
1036
+ # Submit one batch per worker
1037
+ futures = {executor.submit(generate_sample_batch_worker, batch): i for i, batch in enumerate(batches)}
1038
+
1039
+ # Collect results as batches complete
1040
+ batches_done = 0
1041
+ for future in as_completed(futures):
1042
+ try:
1043
+ batch_results = future.result()
1044
+ for sample, family, template_id, error in batch_results:
1045
+ if sample:
1046
+ all_samples.append(sample)
1047
+ family_progress[family]['success'] += 1
1048
+ else:
1049
+ family_progress[family]['failed'] += 1
1050
+ except Exception as e:
1051
+ print(f"\n Batch failed: {e}")
1052
+
1053
+ batches_done += 1
1054
+ total_done = sum(p['success'] + p['failed'] for p in family_progress.values())
1055
+ print(f"\r Progress: {total_done}/{len(work_items)} samples ({batches_done}/{len(batches)} batches) ", end='', flush=True)
1056
+
1057
+ print() # New line after progress
1058
+
1059
+ # Show distribution (keep all samples, no filtering)
1060
+ print("\nResults by family:")
1061
+ for family in sorted(family_progress.keys()):
1062
+ success = family_progress[family]['success']
1063
+ failed = family_progress[family]['failed']
1064
+ target = target_counts.get(family, 0)
1065
+ total = success + failed
1066
+ success_rate = (success / total * 100) if total > 0 else 0
1067
+ print(f" {family:20s}: {success:3d} success / {failed:3d} failed ({success_rate:5.1f}% success rate, target: {target})")
1068
+
1069
+ # Save combined JSONL (skip individual JSON files for speed at scale)
1070
+ print(f"\nSaving {len(all_samples)} new samples...")
1071
+ if APPEND_MODE and existing_samples:
1072
+ # Append to existing dataset
1073
+ print(f"Appending to existing dataset ({len(existing_samples)} existing samples)")
1074
+ with open(jsonl_file, 'a') as f:
1075
+ for sample in all_samples:
1076
+ f.write(json.dumps(sample.model_dump()) + '\n')
1077
+ total_samples = len(existing_samples) + len(all_samples)
1078
+ else:
1079
+ # Overwrite dataset
1080
+ with open(jsonl_file, 'w') as f:
1081
+ for sample in all_samples:
1082
+ f.write(json.dumps(sample.model_dump()) + '\n')
1083
+ total_samples = len(all_samples)
1084
+
1085
+ print(f"\nGenerated {len(all_samples)} new samples")
1086
+ print(f"Total dataset size: {total_samples} samples")
1087
+ print(f" Dataset: {jsonl_file}")
1088
+
1089
+
1090
+ if __name__ == "__main__":
1091
+ main()
dataset/scripts/sql_templates.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL template definitions for synthetic data generation.
3
+
4
+ Each template includes:
5
+ - Template ID
6
+ - Task family
7
+ - SQL difficulty level
8
+ - Required anchor types
9
+ - SQL template string with placeholders
10
+ - Question generation hints
11
+ """
12
+
13
+ from dataclasses import dataclass
14
+ from typing import List, Literal
15
+
16
+
17
+ @dataclass
18
+ class SQLTemplate:
19
+ """SQL template for synthetic data generation."""
20
+
21
+ template_id: str
22
+ family: str
23
+ sql_difficulty: Literal["easy", "medium", "medium-hard", "hard"]
24
+ anchor_source: Literal["divisions_area", "natural_earth", "mixed"]
25
+ num_anchors: int
26
+ sql_template: str
27
+ question_hints: List[str]
28
+ target_subtype: str = None
29
+ requires_buffer: bool = False
30
+ requires_aggregation: bool = False
31
+
32
+
33
+ # Template catalog
34
+ TEMPLATES = [
35
+ # DIRECT LOOKUP (10 samples)
36
+ SQLTemplate(
37
+ template_id="lookup_01",
38
+ family="direct_lookup",
39
+ sql_difficulty="easy",
40
+ anchor_source="divisions_area",
41
+ num_anchors=1,
42
+ sql_template="""SELECT geometry, names."primary" AS name, id, subtype FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'""",
43
+ question_hints=["Show me {anchor_name}", "Get the geometry of {anchor_name}", "Find {anchor_name}"]
44
+ ),
45
+
46
+ SQLTemplate(
47
+ template_id="lookup_02",
48
+ family="direct_lookup",
49
+ sql_difficulty="easy",
50
+ anchor_source="natural_earth",
51
+ num_anchors=1,
52
+ sql_template="""SELECT geometry, names."primary" AS name, id, subtype FROM read_parquet('{NATURAL_EARTH_PATH}') WHERE id = '{anchor_id}'""",
53
+ question_hints=["Show me the {anchor_name}", "Get {anchor_name}", "Find the {anchor_name}"]
54
+ ),
55
+
56
+ # ADJACENCY (20 samples)
57
+ SQLTemplate(
58
+ template_id="adj_01",
59
+ family="adjacency",
60
+ sql_difficulty="medium",
61
+ anchor_source="divisions_area",
62
+ num_anchors=1,
63
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND ST_Touches(a.geometry, b.geometry)""",
64
+ question_hints=["Which regions border {anchor_name}?", "What borders {anchor_name}?", "List places adjacent to {anchor_name}"]
65
+ ),
66
+
67
+ SQLTemplate(
68
+ template_id="adj_02",
69
+ family="adjacency",
70
+ sql_difficulty="medium",
71
+ anchor_source="divisions_area",
72
+ num_anchors=1,
73
+ target_subtype="region",
74
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND b.subtype = '{target_subtype}' AND ST_Touches(a.geometry, b.geometry)""",
75
+ question_hints=["Which {target_subtype}s border {anchor_name}?", "What {target_subtype}s touch {anchor_name}?"]
76
+ ),
77
+
78
+ SQLTemplate(
79
+ template_id="adj_03",
80
+ family="adjacency",
81
+ sql_difficulty="medium",
82
+ anchor_source="divisions_area",
83
+ num_anchors=1,
84
+ target_subtype="sea",
85
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT n.id, n.names."primary" AS name, n.geometry FROM read_parquet('{NATURAL_EARTH_PATH}') AS n, a WHERE n.subtype = '{target_subtype}' AND ST_Touches(a.geometry, n.geometry)""",
86
+ question_hints=["Which {target_subtype}s touch {anchor_name}?", "What {target_subtype}s border {anchor_name}?"]
87
+ ),
88
+
89
+ # CONTAINMENT (15 samples)
90
+ SQLTemplate(
91
+ template_id="contain_01",
92
+ family="containment",
93
+ sql_difficulty="medium",
94
+ anchor_source="divisions_area",
95
+ num_anchors=1,
96
+ target_subtype="locality",
97
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND b.subtype = '{target_subtype}' AND ST_Within(b.geometry, a.geometry)""",
98
+ question_hints=["What {target_subtype}s are in {anchor_name}?", "Which {target_subtype}s are within {anchor_name}?"]
99
+ ),
100
+
101
+ SQLTemplate(
102
+ template_id="contain_02",
103
+ family="containment",
104
+ sql_difficulty="medium",
105
+ anchor_source="divisions_area",
106
+ num_anchors=1,
107
+ target_subtype="country",
108
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND b.subtype = '{target_subtype}' AND ST_Contains(b.geometry, a.geometry)""",
109
+ question_hints=["What {target_subtype} contains {anchor_name}?", "Which {target_subtype} is {anchor_name} in?"]
110
+ ),
111
+
112
+ SQLTemplate(
113
+ template_id="contain_03",
114
+ family="containment",
115
+ sql_difficulty="medium",
116
+ anchor_source="natural_earth",
117
+ num_anchors=1,
118
+ target_subtype="region",
119
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{NATURAL_EARTH_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.subtype = '{target_subtype}' AND ST_Within(b.geometry, a.geometry)""",
120
+ question_hints=["Which {target_subtype}s are in the {anchor_name}?", "What {target_subtype}s fall within the {anchor_name}?"]
121
+ ),
122
+
123
+ # INTERSECTION (15 samples)
124
+ SQLTemplate(
125
+ template_id="intersect_01",
126
+ family="intersection",
127
+ sql_difficulty="medium-hard",
128
+ anchor_source="divisions_area",
129
+ num_anchors=1,
130
+ target_subtype="region",
131
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND b.subtype = '{target_subtype}' AND ST_Intersects(b.geometry, a.geometry)""",
132
+ question_hints=["Which {target_subtype}s intersect {anchor_name}?", "What {target_subtype}s overlap with {anchor_name}?"]
133
+ ),
134
+
135
+ SQLTemplate(
136
+ template_id="intersect_02",
137
+ family="intersection",
138
+ sql_difficulty="medium-hard",
139
+ anchor_source="natural_earth",
140
+ num_anchors=1,
141
+ target_subtype="country",
142
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{NATURAL_EARTH_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.subtype = '{target_subtype}' AND ST_Intersects(b.geometry, a.geometry)""",
143
+ question_hints=["Which {target_subtype}s intersect the {anchor_name}?", "What {target_subtype}s touch the {anchor_name}?"]
144
+ ),
145
+
146
+ # BUFFER OPERATIONS (10 samples)
147
+ SQLTemplate(
148
+ template_id="buffer_01",
149
+ family="buffer",
150
+ sql_difficulty="hard",
151
+ anchor_source="divisions_area",
152
+ num_anchors=1,
153
+ requires_buffer=True,
154
+ sql_template="""WITH a AS (SELECT ST_Buffer(geometry, {buffer_degrees}) AS geom FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE b.id != '{anchor_id}' AND ST_Intersects(b.geometry, a.geom)""",
155
+ question_hints=["A {buffer_degrees} degree buffer around {anchor_name}", "Features within {buffer_degrees} degrees of {anchor_name}"]
156
+ ),
157
+
158
+ SQLTemplate(
159
+ template_id="buffer_02",
160
+ family="buffer",
161
+ sql_difficulty="hard",
162
+ anchor_source="divisions_area",
163
+ num_anchors=2,
164
+ requires_buffer=True,
165
+ sql_template="""WITH a AS (SELECT geometry AS g1 FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id_1}'), b AS (SELECT geometry AS g2 FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id_2}'), boundary AS (SELECT ST_Buffer(ST_Intersection(a.g1, b.g2), {buffer_degrees}) AS geom FROM a, b WHERE ST_Touches(a.g1, b.g2)) SELECT geom AS geometry FROM boundary""",
166
+ question_hints=["A {buffer_degrees} degree buffer around the border between {anchor_1_name} and {anchor_2_name}"]
167
+ ),
168
+
169
+ # SET OPERATIONS (15 samples)
170
+ SQLTemplate(
171
+ template_id="union_01",
172
+ family="set_operations",
173
+ sql_difficulty="medium-hard",
174
+ anchor_source="divisions_area",
175
+ num_anchors=2,
176
+ sql_template="""SELECT ST_Union_Agg(geometry) AS geometry, array_agg(names."primary") AS names FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id IN ('{anchor_id_1}', '{anchor_id_2}')""",
177
+ question_hints=["{anchor_1_name} and {anchor_2_name}", "The union of {anchor_1_name} and {anchor_2_name}"]
178
+ ),
179
+
180
+ # PARTIAL SELECTION (10 samples)
181
+ SQLTemplate(
182
+ template_id="partial_01",
183
+ family="partial_selection",
184
+ sql_difficulty="hard",
185
+ anchor_source="divisions_area",
186
+ num_anchors=1,
187
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), bbox AS (SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax, ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a), north_half AS (SELECT ST_MakeEnvelope(xmin, (ymin + ymax) / 2, xmax, ymax) AS half_geom FROM bbox) SELECT ST_Intersection(a.geometry, nh.half_geom) AS geometry FROM a, north_half AS nh""",
188
+ question_hints=["The northern half of {anchor_name}", "Northern part of {anchor_name}"]
189
+ ),
190
+
191
+ SQLTemplate(
192
+ template_id="partial_02",
193
+ family="partial_selection",
194
+ sql_difficulty="hard",
195
+ anchor_source="divisions_area",
196
+ num_anchors=1,
197
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), bbox AS (SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax, ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a), south_half AS (SELECT ST_MakeEnvelope(xmin, ymin, xmax, (ymin + ymax) / 2) AS half_geom FROM bbox) SELECT ST_Intersection(a.geometry, sh.half_geom) AS geometry FROM a, south_half AS sh""",
198
+ question_hints=["The southern half of {anchor_name}", "Southern part of {anchor_name}"]
199
+ ),
200
+
201
+ SQLTemplate(
202
+ template_id="partial_04",
203
+ family="partial_selection",
204
+ sql_difficulty="hard",
205
+ anchor_source="divisions_area",
206
+ num_anchors=1,
207
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), bbox AS (SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax, ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a), east_half AS (SELECT ST_MakeEnvelope((xmin + xmax) / 2, ymin, xmax, ymax) AS half_geom FROM bbox) SELECT ST_Intersection(a.geometry, eh.half_geom) AS geometry FROM a, east_half AS eh""",
208
+ question_hints=["The eastern half of {anchor_name}", "Eastern part of {anchor_name}"]
209
+ ),
210
+
211
+ SQLTemplate(
212
+ template_id="partial_05",
213
+ family="partial_selection",
214
+ sql_difficulty="hard",
215
+ anchor_source="divisions_area",
216
+ num_anchors=1,
217
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), bbox AS (SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax, ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a), west_half AS (SELECT ST_MakeEnvelope(xmin, ymin, (xmin + xmax) / 2, ymax) AS half_geom FROM bbox) SELECT ST_Intersection(a.geometry, wh.half_geom) AS geometry FROM a, west_half AS wh""",
218
+ question_hints=["The western half of {anchor_name}", "Western part of {anchor_name}"]
219
+ ),
220
+
221
+ SQLTemplate(
222
+ template_id="partial_03",
223
+ family="partial_selection",
224
+ sql_difficulty="hard",
225
+ anchor_source="mixed",
226
+ num_anchors=2,
227
+ sql_template="""WITH a AS (SELECT geometry AS g1 FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}'), b AS (SELECT geometry AS g2 FROM read_parquet('{NATURAL_EARTH_PATH}') WHERE id = '{clip_feature_id}') SELECT ST_Intersection(a.g1, b.g2) AS geometry FROM a, b WHERE ST_Intersects(a.g1, b.g2)""",
228
+ question_hints=["The part of {anchor_name} that is in the {clip_feature_name}", "{anchor_name} within the {clip_feature_name}"]
229
+ ),
230
+
231
+ # AGGREGATION (5 samples)
232
+ SQLTemplate(
233
+ template_id="agg_01",
234
+ family="aggregation",
235
+ sql_difficulty="hard",
236
+ anchor_source="divisions_area",
237
+ num_anchors=1,
238
+ target_subtype="locality",
239
+ requires_aggregation=True,
240
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry, ST_Area(b.geometry) AS area FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE ST_Within(b.geometry, a.geometry) AND b.subtype = '{target_subtype}' ORDER BY area DESC LIMIT {top_n}""",
241
+ question_hints=["Top {top_n} largest {target_subtype}s in {anchor_name}", "Biggest {target_subtype}s in {anchor_name}", "{top_n} largest {target_subtype}s in {anchor_name}"]
242
+ ),
243
+
244
+ SQLTemplate(
245
+ template_id="agg_02",
246
+ family="aggregation",
247
+ sql_difficulty="hard",
248
+ anchor_source="divisions_area",
249
+ num_anchors=1,
250
+ target_subtype="locality",
251
+ requires_aggregation=True,
252
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry, ST_Area(b.geometry) AS area FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE ST_Within(b.geometry, a.geometry) AND b.subtype = '{target_subtype}' ORDER BY area ASC LIMIT {top_n}""",
253
+ question_hints=["Top {top_n} smallest {target_subtype}s in {anchor_name}", "Smallest {target_subtype}s in {anchor_name}", "{top_n} smallest {target_subtype}s in {anchor_name}"]
254
+ ),
255
+
256
+ SQLTemplate(
257
+ template_id="agg_03",
258
+ family="aggregation",
259
+ sql_difficulty="hard",
260
+ anchor_source="divisions_area",
261
+ num_anchors=1,
262
+ target_subtype="region",
263
+ requires_aggregation=True,
264
+ sql_template="""WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '{anchor_id}') SELECT b.id, b.names."primary" AS name, b.geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') AS b, a WHERE ST_Within(b.geometry, a.geometry) AND b.subtype = '{target_subtype}' ORDER BY RANDOM() LIMIT {top_n}""",
265
+ question_hints=["{top_n} random {target_subtype}s in {anchor_name}", "Any {top_n} {target_subtype}s in {anchor_name}"]
266
+ ),
267
+
268
+ SQLTemplate(
269
+ template_id="agg_04",
270
+ family="aggregation",
271
+ sql_difficulty="hard",
272
+ anchor_source="divisions_area",
273
+ num_anchors=1,
274
+ target_subtype="locality",
275
+ requires_aggregation=True,
276
+ sql_template="""SELECT id, names."primary" AS name, geometry, ST_Area(geometry) AS area FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE country = '{country}' AND subtype = '{target_subtype}' ORDER BY area DESC LIMIT {top_n}""",
277
+ question_hints=["Top {top_n} largest {target_subtype}s in {country}", "{top_n} biggest {target_subtype}s in {country}"]
278
+ ),
279
+
280
+ SQLTemplate(
281
+ template_id="agg_05",
282
+ family="aggregation",
283
+ sql_difficulty="hard",
284
+ anchor_source="divisions_area",
285
+ num_anchors=1,
286
+ target_subtype="locality",
287
+ requires_aggregation=True,
288
+ sql_template="""SELECT id, names."primary" AS name, geometry, ST_Area(geometry) AS area FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE country = '{country}' AND subtype = '{target_subtype}' ORDER BY area ASC LIMIT {top_n}""",
289
+ question_hints=["Top {top_n} smallest {target_subtype}s in {country}", "{top_n} smallest {target_subtype}s in {country}"]
290
+ ),
291
+ ]
292
+
293
+
294
+ def get_templates_by_family(family: str) -> List[SQLTemplate]:
295
+ """Get all templates for a specific family."""
296
+ return [t for t in TEMPLATES if t.family == family]
297
+
298
+
299
+ def get_template_by_id(template_id: str) -> SQLTemplate:
300
+ """Get a specific template by ID."""
301
+ for t in TEMPLATES:
302
+ if t.template_id == template_id:
303
+ return t
304
+ raise ValueError(f"Template {template_id} not found")
305
+
306
+
307
+ if __name__ == "__main__":
308
+ # Print template summary
309
+ families = {}
310
+ for t in TEMPLATES:
311
+ families[t.family] = families.get(t.family, 0) + 1
312
+
313
+ print("SQL Template Catalog")
314
+ print("=" * 60)
315
+ for family, count in sorted(families.items()):
316
+ print(f"{family:20s}: {count:2d} templates")
317
+ print(f"{'TOTAL':20s}: {len(TEMPLATES):2d} templates")
dataset/scripts/validate_dataset.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validate and balance the generated dataset.
3
+
4
+ This script:
5
+ 1. Loads all generated samples
6
+ 2. Validates SQL executability
7
+ 3. Checks candidate list quality
8
+ 4. Balances across task families and difficulty
9
+ 5. Removes duplicates
10
+ 6. Generates dataset statistics
11
+
12
+ Output:
13
+ - output/dataset_validated.jsonl
14
+ - output/dataset_stats.json
15
+ """
16
+
17
+ import json
18
+ from pathlib import Path
19
+ from typing import List, Dict, Any, Tuple
20
+ from collections import Counter
21
+ from concurrent.futures import ProcessPoolExecutor, as_completed
22
+
23
+ import duckdb
24
+ import pandas as pd
25
+
26
+
27
+ def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
28
+ """Load samples from JSONL file."""
29
+ samples = []
30
+ with open(jsonl_path, 'r') as f:
31
+ for line in f:
32
+ samples.append(json.loads(line))
33
+ return samples
34
+
35
+
36
+ def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
37
+ """Validate that SQL executes without error."""
38
+ try:
39
+ result = con.execute(sql).fetchdf()
40
+ if result.empty:
41
+ return False, "Empty result"
42
+ return True, "OK"
43
+ except Exception as e:
44
+ return False, str(e)
45
+
46
+
47
+ def validate_candidates(sample: Dict[str, Any]) -> tuple[bool, str]:
48
+ """Validate candidate list quality."""
49
+ candidates = sample['candidates']
50
+ selected = sample['target']['selected_candidates']
51
+
52
+ # Check we have candidates
53
+ if not candidates:
54
+ return False, "No candidates"
55
+
56
+ # Check selected candidates exist
57
+ candidate_ids = {c['candidate_id'] for c in candidates}
58
+ for sel_id in selected:
59
+ if sel_id not in candidate_ids:
60
+ return False, f"Selected candidate {sel_id} not in candidate list"
61
+
62
+ # Check for duplicates
63
+ ids = [c['id'] for c in candidates]
64
+ if len(ids) != len(set(ids)):
65
+ return False, "Duplicate candidates"
66
+
67
+ return True, "OK"
68
+
69
+
70
+ def validate_sample(con: duckdb.DuckDBPyConnection, sample: Dict[str, Any]) -> tuple[bool, List[str]]:
71
+ """Validate a single sample. Returns (is_valid, list_of_issues)."""
72
+ issues = []
73
+
74
+ # Skip SQL re-execution if already verified during generation
75
+ if not sample.get('metadata', {}).get('sql_verified', False):
76
+ sql_valid, sql_msg = validate_sql(con, sample['target']['sql'])
77
+ if not sql_valid:
78
+ issues.append(f"SQL: {sql_msg}")
79
+
80
+ # Validate candidates
81
+ cand_valid, cand_msg = validate_candidates(sample)
82
+ if not cand_valid:
83
+ issues.append(f"Candidates: {cand_msg}")
84
+
85
+ # Check question exists
86
+ if not sample.get('question') or len(sample['question'].strip()) == 0:
87
+ issues.append("Empty question")
88
+
89
+ return len(issues) == 0, issues
90
+
91
+
92
+ def validate_sample_worker(sample: Dict[str, Any]) -> Tuple[str, bool, List[str]]:
93
+ """Worker function for parallel validation. Returns (sample_id, is_valid, issues)."""
94
+ # Each worker creates its own DuckDB connection
95
+ con = duckdb.connect()
96
+ con.execute("SET enable_progress_bar=false")
97
+ con.execute("INSTALL spatial")
98
+ con.execute("LOAD spatial")
99
+
100
+ try:
101
+ is_valid, issues = validate_sample(con, sample)
102
+ con.close()
103
+ return (sample['id'], is_valid, issues, sample if is_valid else None)
104
+ except Exception as e:
105
+ con.close()
106
+ return (sample['id'], False, [f"Validation error: {str(e)}"], None)
107
+
108
+
109
+ def compute_statistics(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
110
+ """Compute dataset statistics."""
111
+
112
+ stats = {
113
+ 'total_samples': len(samples),
114
+ 'task_families': {},
115
+ 'sql_difficulty': {},
116
+ 'grounding_difficulty': {},
117
+ 'anchor_sources': {},
118
+ 'avg_candidates_per_sample': 0,
119
+ 'avg_question_length': 0,
120
+ 'countries_covered': set(),
121
+ 'subtypes_covered': set()
122
+ }
123
+
124
+ total_candidates = 0
125
+ total_question_length = 0
126
+
127
+ for sample in samples:
128
+ meta = sample['metadata']
129
+
130
+ # Count by family
131
+ family = meta['task_family']
132
+ stats['task_families'][family] = stats['task_families'].get(family, 0) + 1
133
+
134
+ # Count by SQL difficulty
135
+ sql_diff = meta['sql_difficulty']
136
+ stats['sql_difficulty'][sql_diff] = stats['sql_difficulty'].get(sql_diff, 0) + 1
137
+
138
+ # Count by grounding difficulty
139
+ ground_diff = meta['grounding_difficulty']
140
+ stats['grounding_difficulty'][ground_diff] = stats['grounding_difficulty'].get(ground_diff, 0) + 1
141
+
142
+ # Count by anchor source
143
+ anchor_src = meta['anchor_source']
144
+ stats['anchor_sources'][anchor_src] = stats['anchor_sources'].get(anchor_src, 0) + 1
145
+
146
+ # Candidates
147
+ total_candidates += len(sample['candidates'])
148
+
149
+ # Question length
150
+ total_question_length += len(sample['question'].split())
151
+
152
+ # Countries and subtypes (from selected/answer candidates only)
153
+ selected_ids = set(sample.get('target', {}).get('selected_candidates', []))
154
+ for cand in sample['candidates']:
155
+ if cand['candidate_id'] in selected_ids:
156
+ if cand.get('country'):
157
+ stats['countries_covered'].add(cand['country'])
158
+ if cand.get('subtype'):
159
+ stats['subtypes_covered'].add(cand['subtype'])
160
+
161
+ stats['avg_candidates_per_sample'] = total_candidates / len(samples) if samples else 0
162
+ stats['avg_question_length'] = total_question_length / len(samples) if samples else 0
163
+ stats['countries_covered'] = sorted(list(stats['countries_covered']))
164
+ stats['subtypes_covered'] = sorted(list(stats['subtypes_covered']))
165
+
166
+ return stats
167
+
168
+
169
+ def main():
170
+ """Validate and analyze dataset."""
171
+
172
+ script_dir = Path(__file__).parent
173
+ output_dir = script_dir.parent / "output"
174
+
175
+ raw_file = output_dir / "dataset_raw.jsonl"
176
+ validated_file = output_dir / "dataset_validated.jsonl"
177
+ stats_file = output_dir / "dataset_stats.json"
178
+
179
+ if not raw_file.exists():
180
+ print(f"Error: {raw_file} not found. Run generate_samples.py first.")
181
+ return
182
+
183
+ # Load samples
184
+ print("Loading samples...")
185
+ samples = load_samples(raw_file)
186
+ print(f"Loaded {len(samples)} samples")
187
+
188
+ # Validate samples in parallel
189
+ print("\nValidating samples in parallel...")
190
+ valid_samples = []
191
+ invalid_samples = []
192
+
193
+ with ProcessPoolExecutor(max_workers=8) as executor:
194
+ # Submit all validation tasks
195
+ futures = {executor.submit(validate_sample_worker, sample): sample for sample in samples}
196
+
197
+ # Collect results as they complete
198
+ completed = 0
199
+ for future in as_completed(futures):
200
+ sample_id, is_valid, issues, validated_sample = future.result()
201
+
202
+ if is_valid:
203
+ valid_samples.append(validated_sample)
204
+ else:
205
+ invalid_samples.append((sample_id, issues))
206
+
207
+ completed += 1
208
+ if completed % 50 == 0 or completed == len(samples):
209
+ print(f"\r Progress: {completed}/{len(samples)} ", end='', flush=True)
210
+
211
+ print() # New line after progress
212
+
213
+ print(f"\nValidation results:")
214
+ print(f" Valid: {len(valid_samples)}")
215
+ print(f" Invalid: {len(invalid_samples)}")
216
+
217
+ if invalid_samples and len(invalid_samples) <= 20:
218
+ print("\nInvalid samples:")
219
+ for sample_id, issues in invalid_samples[:20]:
220
+ print(f" {sample_id}: {', '.join(issues)}")
221
+ elif invalid_samples:
222
+ print(f"\n{len(invalid_samples)} invalid samples (showing first 20):")
223
+ for sample_id, issues in invalid_samples[:20]:
224
+ print(f" {sample_id}: {', '.join(issues)}")
225
+
226
+ # Save validated samples
227
+ if valid_samples:
228
+ with open(validated_file, 'w') as f:
229
+ for sample in valid_samples:
230
+ f.write(json.dumps(sample) + '\n')
231
+ print(f"\nSaved {len(valid_samples)} valid samples to {validated_file}")
232
+
233
+ # Compute statistics
234
+ print("\nComputing statistics...")
235
+ stats = compute_statistics(valid_samples)
236
+
237
+ # Save statistics
238
+ # Convert sets to lists for JSON serialization
239
+ stats_json = {k: (list(v) if isinstance(v, set) else v) for k, v in stats.items()}
240
+ with open(stats_file, 'w') as f:
241
+ json.dump(stats_json, f, indent=2)
242
+ print(f"Saved statistics to {stats_file}")
243
+
244
+ # Print summary
245
+ print("\n" + "=" * 60)
246
+ print("DATASET STATISTICS")
247
+ print("=" * 60)
248
+ print(f"\nTotal samples: {stats['total_samples']}")
249
+
250
+ print("\nTask families:")
251
+ for family, count in sorted(stats['task_families'].items()):
252
+ print(f" {family:20s}: {count:3d}")
253
+
254
+ print("\nSQL difficulty:")
255
+ for diff, count in sorted(stats['sql_difficulty'].items()):
256
+ print(f" {diff:20s}: {count:3d}")
257
+
258
+ print("\nGrounding difficulty:")
259
+ for diff, count in sorted(stats['grounding_difficulty'].items()):
260
+ print(f" {diff:20s}: {count:3d}")
261
+
262
+ print("\nAnchor sources:")
263
+ for src, count in sorted(stats['anchor_sources'].items()):
264
+ print(f" {src:20s}: {count:3d}")
265
+
266
+ print(f"\nAverage candidates per sample: {stats['avg_candidates_per_sample']:.1f}")
267
+ print(f"Average question length (words): {stats['avg_question_length']:.1f}")
268
+ print(f"Countries covered: {len(stats['countries_covered'])}")
269
+ print(f"Subtypes covered: {len(stats['subtypes_covered'])}")
270
+
271
+ print("\n✓ Validation complete")
272
+
273
+
274
+ if __name__ == "__main__":
275
+ main()
pyproject.toml CHANGED
@@ -19,5 +19,11 @@ dependencies = [
19
  ]
20
  optional-dependencies = { demo = ["streamlit", "requests", "pydeck"], dev = ["ruff"] }
21
 
 
 
 
22
  [tool.hatch.build.targets.wheel]
23
- packages = ["src/gazet"]
 
 
 
 
19
  ]
20
  optional-dependencies = { demo = ["streamlit", "requests", "pydeck"], dev = ["ruff"] }
21
 
22
+ [project.scripts]
23
+ gazet-dataset = "dataset.scripts.cli:main"
24
+
25
  [tool.hatch.build.targets.wheel]
26
+ packages = ["src/gazet", "dataset"]
27
+
28
+ [dependency-groups]
29
+ dataset = []
uv.lock CHANGED
The diff for this file is too large to render. See raw diff