webVishnu commited on
Commit
77445cb
·
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ testing/
2
+ training/
3
+ checkpoints/
4
+ eval_outputs/
5
+ infer_outputs/
6
+ logs/
7
+ __pycache__/
8
+ *.pyc
9
+ *.pyo
10
+ *.pyd
11
+ *.pyw
12
+ *.pyz
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Desert Semantic Segmentation
2
+
3
+ End-to-end **semantic segmentation** for **off-road / desert** scenes: every pixel is classified into one of several terrain / object categories. The pipeline is built for **synthetic RGB + mask** data, **PyTorch**, **[segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch)** (SMP), **Albumentations**, and hackathon-style iteration (strong baselines, IoU-driven checkpoints, optional EMA / TTA / ONNX).
4
+
5
+ ---
6
+
7
+ ## Table of contents
8
+
9
+ 1. [What this project does](#1-what-this-project-does)
10
+ 2. [Problem statement and goals](#2-problem-statement-and-goals)
11
+ 3. [Dataset layout and assumptions](#3-dataset-layout-and-assumptions)
12
+ 4. [Label format (critical)](#4-label-format-critical)
13
+ 5. [Repository structure](#5-repository-structure)
14
+ 6. [Configuration (`default.yaml`)](#6-configuration-defaultyaml)
15
+ 7. [High-level architecture](#7-high-level-architecture)
16
+ 8. [Data pipeline (detailed)](#8-data-pipeline-detailed)
17
+ 9. [Model](#9-model)
18
+ 10. [Loss functions](#10-loss-functions)
19
+ 11. [Metrics](#11-metrics)
20
+ 12. [Training loop](#12-training-loop)
21
+ 13. [Validation and evaluation scripts](#13-validation-and-evaluation-scripts)
22
+ 14. [Inference (testing folder, sliding window, TTA, ONNX)](#14-inference-testing-folder-sliding-window-tta-onnx)
23
+ 15. [Checkpoints and artifacts](#15-checkpoints-and-artifacts)
24
+ 16. [How to run (commands)](#16-how-to-run-commands)
25
+ 17. [Interactive demo (Gradio)](#17-interactive-demo-gradio)
26
+ 18. [Tests](#18-tests)
27
+ 19. [Dependencies and environment notes](#19-dependencies-and-environment-notes)
28
+ 20. [Design decisions and limitations](#20-design-decisions-and-limitations)
29
+ 21. [Extending the project](#21-extending-the-project)
30
+ 22. [Flowcharts](#22-flowcharts)
31
+
32
+ ---
33
+
34
+ ## 1. What this project does
35
+
36
+ - **Input:** RGB color images (`Color_Images`).
37
+ - **Supervision:** Per-pixel class masks (`Segmentation`) aligned by **filename** with the RGB image.
38
+ - **Output:** A trained neural network that predicts a **class index per pixel** on validation, held-out **testing** images (no labels in repo), or any folder of images you point inference at.
39
+ - **Primary quality metric:** **mean Intersection-over-Union (mIoU)** on the validation set, plus **per-class IoU**, **frequency-weighted IoU (fwIoU)**, and a **confusion matrix**.
40
+
41
+ ---
42
+
43
+ ## 2. Problem statement and goals
44
+
45
+ | Goal | How we address it |
46
+ |------|-------------------|
47
+ | Accurate pixel-wise classification | DeepLabV3+ with ImageNet-pretrained encoder; CE + Dice loss; class-frequency weights |
48
+ | Robustness (synthetic → harder real domains) | Strong photometric + mild “desert-like” augmentations (sun flare, shadow, blur, noise, JPEG) |
49
+ | Class imbalance | Inverse log-frequency weights with a **cap**; rare-class-biased random crops |
50
+ | Stable training | AdamW, cosine decay with **warmup**, gradient clipping, optional **EMA** |
51
+ | Fast iteration | YAML-driven config; SMP for one-line model construction; scripts for train / eval / infer |
52
+ | Deployment story | Optional **ONNX** export; inference timing written to `latency.txt` |
53
+
54
+ **Note:** The original hackathon plan also mentioned **SegFormer-B2** as a balanced option. This codebase’s **default** is **DeepLabV3+ + ResNet-50**. UNet and FPN are supported in code; SegFormer is **not** implemented as a separate architecture in `models/factory.py` (you can experiment with **MiT** encoders under DeepLabV3+ if SMP supports your chosen encoder name).
55
+
56
+ ---
57
+
58
+ ## 3. Dataset layout and assumptions
59
+
60
+ All paths in config are **relative to the workspace root** (`--root` on the CLI, or the repo root by default).
61
+
62
+ ```text
63
+ <root>/
64
+ training/
65
+ train/
66
+ Color_Images/ # RGB training inputs
67
+ Segmentation/ # Training masks (same filenames as Color_Images)
68
+ val/
69
+ Color_Images/ # RGB validation inputs
70
+ Segmentation/ # Validation masks
71
+ testing/
72
+ Color_Images/ # Unlabeled images for final inference / demo
73
+ ```
74
+
75
+ **Pairing rule:** For each split, every file in `Color_Images` must have a mask with the **same basename** in `Segmentation`. The dataset constructor raises if a mask is missing.
76
+
77
+ **Typical image size in this workspace:** RGB and masks are often **960×540** (masks are single-channel uint16 PNGs). Training uses **512×512** crops; validation pads to a **512×512** canvas for batching.
78
+
79
+ ---
80
+
81
+ ## 4. Label format (critical)
82
+
83
+ ### 4.1 What the masks are
84
+
85
+ - Masks are read as **2D arrays** (single channel).
86
+ - In this dataset they behave as **`I;16` (16-bit unsigned)** semantic IDs: pixel values are **not** 0, 1, 2, …
87
+ They are **dataset-specific raw IDs**, e.g. `100, 200, 300, 500, 550, 600, 700, 800, 7100, 10000`.
88
+
89
+ ### 4.2 Mapping raw IDs → training indices
90
+
91
+ The class `RawMaskCodec` in `desert_segmentation/data/mask_encoding.py`:
92
+
93
+ 1. Builds a **lookup table (LUT)** from `max(raw_ids)` down to 0.
94
+ 2. Maps each legal raw ID to a contiguous index **`0 … num_classes-1`** (uint8 for Albumentations compatibility).
95
+ 3. **Raises** if any pixel is not in the configured `raw_ids` list (unknown pixels would map to sentinel `255` in the LUT and trigger an error).
96
+
97
+ **Why this matters:** Using the wrong mapping (or treating masks as 8-bit class indices) silently destroys learning.
98
+
99
+ ### 4.3 Ignore index (255)
100
+
101
+ - **Training:** `ShiftScaleRotate` can introduce border pixels on the mask; those are filled with **`ignore_index` (255)**. Cross-entropy and Dice **ignore** those pixels.
102
+ - **Validation:** `PadIfNeeded` pads the mask with **255** so square tensors align; metrics and loss skip those pixels.
103
+
104
+ ### 4.4 Class names
105
+
106
+ `class_names` in YAML are **display labels** (e.g. `id_100`, …). Replace them with semantic names (e.g. `sky`, `sand`) when you have official ontology from the dataset provider.
107
+
108
+ ---
109
+
110
+ ## 5. Repository structure
111
+
112
+ ```text
113
+ codewizard 2.0/
114
+ README.md # This file
115
+ requirements.txt # Python dependencies
116
+ requirements-demo.txt # Optional: Gradio demo
117
+ desert_segmentation/ # Importable package
118
+ __init__.py
119
+ configs/
120
+ default.yaml # Single source of truth for paths & hyperparameters
121
+ data/
122
+ dataset.py # SegmentationDataset (pairing, crop, rare bias)
123
+ transforms.py # Albumentations train/val pipelines
124
+ mask_encoding.py # RawMaskCodec + build_codec_from_config
125
+ models/
126
+ factory.py # SMP: DeepLabV3+, UNet, FPN
127
+ losses/
128
+ combined.py # CE, weighted CE, focal, CE+Dice + weight helper
129
+ metrics/
130
+ iou.py # Confusion matrix, IoU, mIoU, fwIoU
131
+ train/
132
+ trainer.py # Main training loop (AMP, EMA, scheduler, checkpoints)
133
+ evaluate.py # Batched validation metric pass
134
+ infer/
135
+ predict.py # Sliding window, TTA, folder inference, ONNX export
136
+ utils/
137
+ config.py # YAML load + path resolution
138
+ seed.py # Reproducibility
139
+ logging_utils.py # Logging setup
140
+ freq.py # Scan mask folders for class frequencies
141
+ viz.py # Colorization + overlay + triplet PNG export
142
+ demo/
143
+ inference_ui.py # Gradio helpers: legend HTML, validation, composites
144
+ scripts/
145
+ train.py # CLI: train from config
146
+ eval.py # CLI: val metrics + confusion + visualization PNGs
147
+ eval_summary.py # CLI: mIoU (all + valid-GT), fwIoU, accuracies, GT counts, per-class table (+ JSON)
148
+ infer.py # CLI: run on testing/ or export ONNX
149
+ demo_gradio.py # CLI: browser upload demo (Gradio)
150
+ tests/
151
+ test_mask_encoding.py # Unit tests for codec / unknown pixels
152
+ ```
153
+
154
+ **Scripts** add the repo root to `sys.path` so you can run them without installing the package as a wheel.
155
+
156
+ ---
157
+
158
+ ## 6. Configuration (`default.yaml`)
159
+
160
+ Key sections (see `desert_segmentation/configs/default.yaml` for the full file):
161
+
162
+ | Section | Purpose |
163
+ |---------|---------|
164
+ | `root` | Base path for resolving relative data paths (overridden by `--root` in scripts) |
165
+ | `data.*` | Relative dirs for train/val images and masks, test images, `raw_ids`, `class_names`, `crop_size`, `rare_class_crop_prob`, `weighted_sampler`, `weighted_sampler_eps`, `ignore_index` |
166
+ | `model.*` | `architecture` (`deeplabv3plus` \| `unet` \| `fpn`), `encoder_name`, `encoder_weights` |
167
+ | `train.*` | `batch_size`, `epochs`, `lr`, `weight_decay`, `warmup_ratio`, `amp`, `gradient_clip`, `seed`, `checkpoint_dir`, `log_interval`, `early_stop_patience` |
168
+ | `loss.*` | `name` (`ce` \| `weighted_ce` \| `ce_dice` \| `focal_ce` \| `focal_ce_dice`), `dice_weight`, `label_smoothing` (CE modes only), `class_weight_cap`, `focal_gamma` |
169
+ | `augmentation.strong` | Enables extra sun flare + shadow blocks in training |
170
+ | `ema.*` | Optional exponential moving average of weights for evaluation |
171
+ | `inference.*` | `tile_size`, `overlap` (for sliding window), `tta_flip`, `batch_size` (reserved for future batching) |
172
+
173
+ ---
174
+
175
+ ## 7. High-level architecture
176
+
177
+ ```mermaid
178
+ flowchart TB
179
+ subgraph inputs [Inputs]
180
+ RGB[RGB images]
181
+ GT[Ground truth masks]
182
+ end
183
+ subgraph prep [Preprocessing]
184
+ Codec[RawMaskCodec LUT]
185
+ Crop[Train: random 512 crop with rare bias]
186
+ ValPad[Val: resize longest side then pad to 512]
187
+ Aug[Albumentations geom plus color]
188
+ end
189
+ subgraph model [Model SMP]
190
+ DL[DeepLabV3Plus default]
191
+ end
192
+ subgraph train [Training]
193
+ Loss[CE plus Dice with class weights]
194
+ Opt[AdamW plus cosine warmup LR]
195
+ AMP[AMP if CUDA]
196
+ EMA[EMA optional]
197
+ CKPT[Best mIoU checkpoint]
198
+ end
199
+ subgraph out [Outputs]
200
+ Metrics[mIoU per class IoU fwIoU confusion]
201
+ Viz[Overlays triplets]
202
+ ONNX[Optional ONNX]
203
+ end
204
+ RGB --> Codec
205
+ GT --> Codec
206
+ Codec --> Crop
207
+ Codec --> ValPad
208
+ Crop --> Aug
209
+ Aug --> DL
210
+ ValPad --> DL
211
+ DL --> Loss
212
+ Loss --> Opt
213
+ Opt --> AMP
214
+ Opt --> EMA
215
+ DL --> Metrics
216
+ Metrics --> CKPT
217
+ Metrics --> Viz
218
+ DL --> ONNX
219
+ ```
220
+
221
+ ---
222
+
223
+ ## 8. Data pipeline (detailed)
224
+
225
+ ### 8.1 `SegmentationDataset` (`data/dataset.py`)
226
+
227
+ 1. **List images** in `images_dir` with extensions: `.png`, `.jpg`, `.jpeg`, `.bmp`, `.tif`, `.tiff`.
228
+ 2. **Verify** each image has a mask with the same filename in `masks_dir`.
229
+ 3. **Load RGB** with Pillow → `HxWx3` uint8.
230
+ 4. **Load mask** as numpy 2D → cast to `uint16` → **`codec.encode_mask`** → `HxW` uint8 with values `0 … C-1` (or padded 255 later in transforms).
231
+
232
+ **Train mode (`mode="train"`):**
233
+
234
+ - **`_random_crop_bias_rare`:** Extract a **`crop_size × crop_size`** patch.
235
+ - With probability `rare_class_crop_prob` (default **0.35**), pick the **rarest class** in that image (by histogram) and center the crop on a random pixel of that class (if any exist).
236
+ - Otherwise pick a uniformly random center.
237
+ - If the image is smaller than the crop, **zero-pad** the image and **255-pad** the mask (ignore regions).
238
+
239
+ **Val mode (`mode="val"`):**
240
+
241
+ - No random crop in the dataset; the **full** image goes to Albumentations.
242
+
243
+ ### 8.2 Transforms (`data/transforms.py`)
244
+
245
+ **Train (`build_train_transforms`):**
246
+
247
+ - **Geometric:** `HorizontalFlip`, `ShiftScaleRotate` (shift, scale, ±10° rotation) with `mask_value=ignore_index` on borders.
248
+ - **Photometric:** brightness/contrast, hue/sat/value, Gaussian blur, Gaussian noise, JPEG compression simulation, RGB shift.
249
+ - **If `augmentation.strong`:** `RandomSunFlare`, `RandomShadow` (desert-relevant appearance stress).
250
+ - **Normalize:** ImageNet mean/std.
251
+ - **`ToTensorV2`:** Image → `float` tensor `CHW`; mask handled so downstream converts to `long` in `__getitem__`.
252
+
253
+ **Val (`build_val_transforms`):**
254
+
255
+ - `LongestMaxSize(crop_size)` then `PadIfNeeded(crop_size, crop_size)` with **mask pad = 255** (ignored in loss/metrics).
256
+
257
+ ### 8.3 Class frequency estimation (`utils/freq.py`)
258
+
259
+ Before training, `scripts/train.py` calls **`estimate_pixel_frequencies`** over **all** training mask files (configurable `max_files` in code; train script uses full corpus). This yields a normalized frequency vector per class → used to build **class weights**.
260
+
261
+ ---
262
+
263
+ ## 9. Model
264
+
265
+ **Factory:** `desert_segmentation/models/factory.py`
266
+
267
+ | `architecture` | SMP class | Notes |
268
+ |----------------|-----------|--------|
269
+ | `deeplabv3plus` (default) | `smp.DeepLabV3Plus` | Mainline; strong decoder + atrous spatial pyramid |
270
+ | `unet` | `smp.Unet` | Classic encoder–decoder skips |
271
+ | `fpn` | `smp.FPN` | Feature pyramid neck |
272
+
273
+ **Default encoder:** `resnet50` with `encoder_weights: imagenet`.
274
+
275
+ **Forward:** Input batch `N×3×H×W` → logits `N×C×H×W` where `C = num_classes`.
276
+
277
+ ---
278
+
279
+ ## 10. Loss functions
280
+
281
+ **File:** `desert_segmentation/losses/combined.py`
282
+
283
+ **Modes (`loss.name`):**
284
+
285
+ | Mode | Description |
286
+ |------|-------------|
287
+ | `ce` | Plain cross-entropy, unweighted |
288
+ | `weighted_ce` | Cross-entropy with per-class `weight` tensor |
289
+ | `ce_dice` (default) | `CE(weighted) + dice_weight * multiclass_Dice_loss` |
290
+ | `focal_ce` | Focal modulated CE; optional class weights on pixels |
291
+ | `focal_ce_dice` | `focal_ce` + `dice_weight * multiclass_Dice_loss` (same class weights in focal term) |
292
+
293
+ **Shared options:**
294
+
295
+ - **`ignore_index`:** Pixels with label 255 are masked out of CE / focal / dice.
296
+ - **`label_smoothing`:** Applied to **CE-based** modes (`ce`, `weighted_ce`, `ce_dice`) only; not used in `focal_ce` / `focal_ce_dice`.
297
+
298
+ **Class weights (`compute_class_weights_from_freq`):**
299
+
300
+ 1. Start from per-class pixel frequency `freq` on the training masks.
301
+ 2. `w ∝ 1 / log(freq + ε)`, normalize by mean.
302
+ 3. Clamp the ratio `w / median(w)` to **`class_weight_cap`** (default **15**) so rare classes do not explode the loss.
303
+
304
+ ---
305
+
306
+ ## 11. Metrics
307
+
308
+ **File:** `desert_segmentation/metrics/iou.py`
309
+
310
+ 1. **Confusion matrix** `C×C` (implementation uses `idx = tgt * C + pred` then `bincount`; rows correspond to **ground-truth class**, columns to **predicted class**).
311
+ 2. **Per-class IoU:**
312
+ \(\text{IoU}_k = \frac{TP_k}{TP_k + FP_k + FN_k}\)
313
+ with `TP_k = CM[k,k]`, row/col sums for FP/FN.
314
+ 3. **mIoU:** Mean of per-class IoU over finite entries.
315
+ 4. **fwIoU (frequency-weighted IoU):** \(\sum_k \text{IoU}_k \cdot p_k\) where \(p_k\) is the empirical frequency of class \(k\) in the ground-truth pixels (column marginals).
316
+
317
+ **Note:** The docstring in `compute_confusion` mentions “pred rows, target columns”; the actual indexing follows **`tgt` (row) × `C` + `pred` (col)`** after reshape.
318
+
319
+ ---
320
+
321
+ ## 12. Training loop
322
+
323
+ **File:** `desert_segmentation/train/trainer.py`
324
+
325
+ **Optimizer:** AdamW on all parameters.
326
+
327
+ **Learning rate:** `LambdaLR` with:
328
+
329
+ - **Linear warmup** for `warmup_ratio` of total optimizer steps (default **8%**).
330
+ - **Cosine** decay from 1.0 down to `min_ratio` **0.01** (implemented in `_warmup_cosine_lambda`).
331
+
332
+ **AMP (mixed precision):**
333
+
334
+ - Enabled only if `train.amp` is true **and** `torch.cuda.is_available()`.
335
+ - Uses `torch.cuda.amp.autocast` + `GradScaler` when on CUDA.
336
+ - On **CPU**, AMP is off; training uses standard FP32 backward (no scaler).
337
+
338
+ **Gradient clipping:** Global norm clip when `gradient_clip > 0` (default **1.0**).
339
+
340
+ **EMA (optional):**
341
+
342
+ - If `ema.enabled`, after each optimizer step the code maintains a **shadow weight** copy per trainable parameter: exponential decay **0.999** by default.
343
+ - **Each epoch:** Training weights are **deep-copied**; **EMA weights are copied into the model** for validation only; then the training snapshot is **restored** so optimization continues from the non-EMA weights.
344
+
345
+ **Checkpointing:**
346
+
347
+ - Every epoch: `checkpoints/last.pt` (model, optional EMA dict, optimizer, full config, class names).
348
+ - **Best validation mIoU:** `checkpoints/best.pt` (adds `miou`, `per_class_iou`).
349
+
350
+ **Early stopping:** If validation mIoU does not improve for `early_stop_patience` epochs (default **12**), training stops.
351
+
352
+ **Optional smoke flags (`scripts/train.py`):**
353
+
354
+ - `--epochs N` — override epoch count.
355
+ - `--max_train_batches K` — stop each training epoch after `K` batches (debug only; scheduler still advances per batch).
356
+
357
+ **Logging:** `checkpoints/history.json` lists per-epoch `miou` and `fw_iou`.
358
+
359
+ ---
360
+
361
+ ## 13. Validation and evaluation scripts
362
+
363
+ **Core loop:** `desert_segmentation/train/evaluate.py` runs the model in `eval()` mode, accumulates confusion via `IoUMetrics`, returns a dict.
364
+
365
+ **CLI:** `scripts/eval.py`
366
+
367
+ 1. Loads config + builds validation dataset (same codec and val transforms as training).
368
+ 2. Loads checkpoint from `--checkpoint`.
369
+ 3. **Weight loading priority:** If `ema` dict exists in checkpoint, **EMA tensors are copied into parameters** for evaluation; else `state_dict` from `model` key.
370
+ 4. Runs full val loader → logs **mIoU**, **fwIoU**, per-class IoU.
371
+ 5. Writes:
372
+ - `eval_outputs/metrics.json` (or `--out_dir`)
373
+ - `confusion.npy`
374
+ - Up to `--max_viz` side-by-side **RGB | GT | Pred** PNGs (`save_triplet` in `utils/viz.py`), with ImageNet denormalization for RGB panels.
375
+
376
+ ---
377
+
378
+ ## 14. Inference (testing folder, sliding window, TTA, ONNX)
379
+
380
+ **CLI:** `scripts/infer.py`
381
+
382
+ ### 14.1 Folder inference
383
+
384
+ - Reads `testing/Color_Images` (or whatever `data.test_images` points to).
385
+ - Loads checkpoint with the same **EMA-first** rule as eval.
386
+ - For each image:
387
+ - If **both** height and width ≤ `tile_size` (512): single forward pass.
388
+ - Else: **sliding window** with stride `tile_size * (1 - overlap)` (default overlap **0.25** → stride **384**).
389
+ - Pads the image with **reflect** padding so tile grid covers corners; crops back to original size.
390
+ - Accumulates **per-class logits** weighted by a **2D Gaussian** (`sigma ∝ tile/3`) so tile borders blend smoothly; final prediction is **`argmax` over classes** per pixel.
391
+
392
+ ### 14.2 Test-time augmentation (TTA)
393
+
394
+ If `inference.tta_flip` is true: logits = **0.5 × (logits(x) + unflip(logits(flip(x))))** horizontally.
395
+
396
+ ### 14.3 Outputs
397
+
398
+ Under `--out_dir` (default `infer_outputs/`):
399
+
400
+ - `pred_<filename>` — color overlay (prediction tinted on RGB).
401
+ - `triplet_<filename>` — **RGB | blank or GT | Pred** strip (test set has no GT, so middle panel is zeros in current `save_triplet` usage).
402
+ - `latency.txt` — mean milliseconds per image and device string.
403
+
404
+ ### 14.4 ONNX
405
+
406
+ `python scripts/infer.py --checkpoint ... --onnx model.onnx` calls `export_onnx`: builds model on **CPU**, dummy input `1×3×512×512`, `torch.onnx.export` with dynamic axes for batch and spatial size.
407
+
408
+ ---
409
+
410
+ ## 15. Checkpoints and artifacts
411
+
412
+ | Artifact | Contents |
413
+ |----------|----------|
414
+ | `checkpoints/best.pt` | `model`, `ema` (optional), `miou`, `per_class_iou`, `config`, `class_names` |
415
+ | `checkpoints/last.pt` | Latest epoch snapshot + optimizer |
416
+ | `checkpoints/history.json` | List of `{epoch, miou, fw_iou}` |
417
+ | `eval_outputs/*` | `metrics.json`, `confusion.npy`, visualization PNGs |
418
+ | `infer_outputs/*` | Overlays, triplets, `latency.txt` |
419
+
420
+ ---
421
+
422
+ ## 16. How to run (commands)
423
+
424
+ From the repository root (adjust paths if yours differ).
425
+
426
+ ### 16.1 Install
427
+
428
+ ```powershell
429
+ python -m pip install -r requirements.txt
430
+ ```
431
+
432
+ ### 16.2 Train
433
+
434
+ ```powershell
435
+ $env:PYTHONPATH="."
436
+ python scripts\train.py --root "d:\codewizard 2.0"
437
+ ```
438
+
439
+ Optional:
440
+
441
+ ```powershell
442
+ python scripts\train.py --root "d:\codewizard 2.0" --config desert_segmentation\configs\default.yaml --epochs 5 --max_train_batches 50
443
+ ```
444
+
445
+ **Imbalanced classes (optional YAML):** set `loss.name` to `focal_ce_dice` for focal + Dice; tune `class_weight_cap`, `rare_class_crop_prob`, and/or `data.weighted_sampler: true` to oversample train images that contain rare classes (scans all train masks once at startup—can take a minute on large sets).
446
+
447
+ ### 16.3 Evaluate (validation)
448
+
449
+ ```powershell
450
+ python scripts\eval.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --out_dir eval_outputs
451
+ ```
452
+
453
+ **Metric summary (no PNGs):** prints **mIoU (all classes)** and **mIoU (classes with val GT)** (the latter ignores absent classes so it is easier to interpret on sparse val labels), **fwIoU**, **global / mean class accuracy**, **val GT pixel counts per class**, and a **per-class IoU / recall** table. Same validation forward pass as `eval.py`. Optional: `--json-out eval_summary.json` (includes `miou_valid_gt_classes`, `val_gt_pixel_counts`).
454
+
455
+ ```powershell
456
+ python scripts\eval_summary.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --json-out eval_summary.json
457
+ ```
458
+
459
+ To print only **mIoU** and **per-class IoU** stored inside the checkpoint (no GPU eval): `python scripts\eval_summary.py --from-checkpoint-only --checkpoint checkpoints\best.pt`
460
+
461
+ ### 16.4 Infer on `testing/Color_Images`
462
+
463
+ ```powershell
464
+ python scripts\infer.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --out_dir infer_outputs --limit 20
465
+ ```
466
+
467
+ ### 16.5 Export ONNX
468
+
469
+ ```powershell
470
+ python scripts\infer.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --onnx model.onnx
471
+ ```
472
+
473
+ ---
474
+
475
+ ## 17. Interactive demo (Gradio)
476
+
477
+ Upload an RGB image in the browser and get a **colored class mask**, **overlay**, a **side-by-side strip** (RGB | mask | overlay), a **fixed legend** (colors match `palette()` in training), **inference time**, and **dominant classes** (pixel histogram). Uses the same path as CLI inference: [`_load_model_for_inference`](d:\codewizard 2.0\desert_segmentation\infer\predict.py) and [`predict_image`](d:\codewizard 2.0\desert_segmentation\infer\predict.py) (EMA weights preferred when present in the checkpoint).
478
+
479
+ **Install** (base + demo extras):
480
+
481
+ ```powershell
482
+ python -m pip install -r requirements.txt -r requirements-demo.txt
483
+ ```
484
+
485
+ **Run** (from repo root; model loads **once** at startup — look for a log line `Model ready`):
486
+
487
+ ```powershell
488
+ $env:PYTHONPATH="."
489
+ python scripts\demo_gradio.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt
490
+ ```
491
+
492
+ **CLI flags:** `--host` (default `127.0.0.1`), `--port` (default `7860`), `--share` (temporary public Gradio link), `--max-side`, `--max-megapixels` (reject huge uploads before inference).
493
+
494
+ **Environment variables** (optional defaults if flags omitted):
495
+
496
+ | Variable | Purpose |
497
+ |----------|---------|
498
+ | `ROOT` | Workspace root (same as `--root`) |
499
+ | `CHECKPOINT_PATH` | Path to `best.pt` (relative paths resolve under `ROOT`) |
500
+
501
+ **Advanced panel:** TTA on/off, tile overlap slider, tile size slider (256–2048, step 64). Overrides are passed into `predict_image` only; the checkpoint file is not modified.
502
+
503
+ **v1 limitations:** No per-pixel **confidence heatmap** for full sliding-window runs (only `argmax` is returned from `predict_image`). See plan follow-up to add logits fusion if needed.
504
+
505
+ **Windows:** Use backslashes or quoted paths as above; first launch may be slow while dependencies initialize.
506
+
507
+ **Follow-ups (not in v1):** full-resolution **confidence** heatmap (needs logits path in `predict.py`); **ZIP** batch upload; **two-checkpoint** comparison UI; client-side **ONNX** inference.
508
+
509
+ ---
510
+
511
+ ## 18. Tests
512
+
513
+ ```powershell
514
+ python -m pytest tests\test_mask_encoding.py -q
515
+ ```
516
+
517
+ Covers:
518
+
519
+ - Round-trip **raw mask ↔ class indices** for known IDs.
520
+ - **Unknown raw pixel** raises `ValueError`.
521
+ - LUT correctness for each configured raw id.
522
+
523
+ ---
524
+
525
+ ## 19. Dependencies and environment notes
526
+
527
+ **`requirements.txt`:**
528
+
529
+ - `torch`, `torchvision`, `numpy`, `Pillow`, `PyYAML`
530
+ - `albumentations` pinned to `<1.5` to reduce optional native build issues on some Windows setups
531
+ - `segmentation-models-pytorch` (SMP)
532
+ - `tqdm`, `pytest`
533
+ - Optional demo: `requirements-demo.txt` adds **Gradio**
534
+
535
+ **Windows:** `scripts/train.py` and `scripts/eval.py` set `num_workers=0` for `DataLoader` on NT to avoid multiprocessing friction.
536
+
537
+ **SMP pretrained weights:** First run may download encoder weights (e.g. ResNet-50 ImageNet) via SMP / Hugging Face hubs depending on SMP version.
538
+
539
+ ---
540
+
541
+ ## 20. Design decisions and limitations
542
+
543
+ | Topic | Decision / limitation |
544
+ |-------|------------------------|
545
+ | Mask modes | **16-bit raw IDs** supported via LUT; **P-mode palette** and **RGB color masks** are *not* auto-detected in this codebase—extend `mask_encoding.py` if your dataset uses them |
546
+ | SegFormer | **Not** a separate `architecture` enum; plan mentioned SegFormer-B2 as an alternative—would require additional factory code or using a supported SMP encoder |
547
+ | Val resolution | Images are **letterboxed** to 512×512 for batching; mIoU is on padded regions with ignore—fine for hackathon; for publication-grade eval consider sliding-window val too |
548
+ | Inference fusion | Overlapping tiles add **Gaussian-weighted logits** per class into an accumulator; the final label is **`argmax` over the accumulated logits** (feathered overlap fusion). A per-pixel `weight` tensor is also accumulated in code for possible future normalization extensions |
549
+ | Poly LR / sync BN | **Not** implemented (cosine+warmup only) |
550
+ | Ensemble | **Not** implemented (single model + optional EMA) |
551
+
552
+ ---
553
+
554
+ ## 21. Extending the project
555
+
556
+ 1. **New classes / raw IDs:** Edit `data.raw_ids` and `data.class_names` in YAML; rerun frequency scan is automatic in `train.py`.
557
+ 2. **UNet / FPN:** Set `model.architecture` to `unet` or `fpn`; pick a valid `encoder_name` for SMP.
558
+ 3. **Larger encoder:** e.g. `encoder_name: resnet101` for DeepLabV3+.
559
+ 4. **Loss ablation:** Set `loss.name` to `ce`, `weighted_ce`, `focal_ce`, or `focal_ce_dice`; tune `dice_weight`, `label_smoothing`, `class_weight_cap`.
560
+ 5. **Stronger aug:** Add Albumentations ops in `transforms.py` (keep `additional_targets={"mask":"mask"}` for paired geometry).
561
+
562
+ ---
563
+
564
+ ## 22. Flowcharts
565
+
566
+ ### 22.1 Training epoch (simplified)
567
+
568
+ ```mermaid
569
+ flowchart TD
570
+ start[Start epoch]
571
+ trainLoop[For each batch]
572
+ fwd[Forward logits]
573
+ lossStep[Compute loss CE plus Dice]
574
+ backward[Backward plus clip]
575
+ stepOpt[Optimizer step plus scheduler step]
576
+ emaUp[Update EMA if enabled]
577
+ endTrain[End train batches]
578
+ snap[Snapshot model weights]
579
+ applyEMA[Copy EMA into model if enabled]
580
+ valRun[Run validation mIoU]
581
+ restore[Restore snapshot weights]
582
+ better{New best mIoU?}
583
+ saveBest[Save best.pt]
584
+ early{Patience exceeded?}
585
+ stop[Stop training]
586
+ start --> trainLoop
587
+ trainLoop --> fwd --> lossStep --> backward --> stepOpt --> emaUp
588
+ emaUp --> trainLoop
589
+ trainLoop --> endTrain
590
+ endTrain --> snap --> applyEMA --> valRun --> restore --> better
591
+ better -->|yes| saveBest --> early
592
+ better -->|no| early
593
+ early -->|yes| stop
594
+ early -->|no| start
595
+ ```
596
+
597
+ ### 22.2 Inference on large images
598
+
599
+ ```mermaid
600
+ flowchart LR
601
+ img[Input RGB HxW]
602
+ pad[Reflect pad to tile grid]
603
+ tiles[For each tile]
604
+ fwdT[Forward logits optional TTA]
605
+ g[Multiply by Gaussian feather]
606
+ acc[Accumulate class logits maps]
607
+ argmax[Argmax over classes]
608
+ cropBack[Crop to original HxW]
609
+ img --> pad --> tiles --> fwdT --> g --> acc --> argmax --> cropBack
610
+ ```
611
+
612
+ ---
613
+
614
+ ## Acknowledgments
615
+
616
+ - **segmentation_models_pytorch** (Pavel Iakubovskii and contributors) for modular segmentation architectures.
617
+ - **Albumentations** for fast, paired image–mask augmentations.
618
+
619
+ ---
620
+
621
+ *Generated to document the implementation in this repository as of the README authoring date. For the original hackathon planning narrative, see your separate plan document (not stored in this repo’s `README`).*
desert_segmentation/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Desert semantic segmentation training and inference package."""
2
+
3
+ __version__ = "0.1.0"
desert_segmentation/configs/default.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Paths are relative to `root` unless absolute.
2
+ root: "."
3
+
4
+ data:
5
+ train_images: "training/train/Color_Images"
6
+ train_masks: "training/train/Segmentation"
7
+ val_images: "training/val/Color_Images"
8
+ val_masks: "training/val/Segmentation"
9
+ test_images: "testing/Color_Images"
10
+
11
+ # Raw uint16 label IDs (must match PNG values) and display names
12
+ raw_ids: [100, 200, 300, 500, 550, 600, 700, 800, 7100, 10000]
13
+ class_names:
14
+ - "id_100"
15
+ - "id_200"
16
+ - "id_300"
17
+ - "id_500"
18
+ - "id_550"
19
+ - "id_600"
20
+ - "id_700"
21
+ - "id_800"
22
+ - "id_7100"
23
+ - "id_10000"
24
+
25
+ crop_size: 512
26
+ num_workers: 4
27
+ # Prefer crops containing underrepresented classes (probability 0–1). Higher = more
28
+ # training crops centered on rare-class pixels (see SegmentationDataset).
29
+ rare_class_crop_prob: 0.35
30
+ # Oversample images that contain rare classes (scans train masks at startup).
31
+ weighted_sampler: false
32
+ weighted_sampler_eps: 1.0e-6
33
+ ignore_index: 255
34
+
35
+ model:
36
+ architecture: "deeplabv3plus"
37
+ encoder_name: "resnet50"
38
+ encoder_weights: "imagenet"
39
+ # Alternative: mit_b2 with deeplabv3plus if supported by SMP
40
+ # encoder_name: "mit_b2"
41
+
42
+ train:
43
+ batch_size: 4
44
+ epochs: 40
45
+ lr: 0.0003
46
+ weight_decay: 0.0005
47
+ warmup_ratio: 0.08
48
+ amp: true
49
+ gradient_clip: 1.0
50
+ seed: 42
51
+ checkpoint_dir: "checkpoints"
52
+ log_interval: 20
53
+ early_stop_patience: 12
54
+
55
+ loss:
56
+ # ce | weighted_ce | ce_dice | focal_ce | focal_ce_dice
57
+ name: "focal_ce_dice"
58
+ dice_weight: 0.5
59
+ # Used only for CE-based modes (ce, weighted_ce, ce_dice). Ignored for focal_ce / focal_ce_dice.
60
+ label_smoothing: 0.05
61
+ # Inverse log-frequency class weights; ratio clamped to [1/cap, cap] vs median. Typical cap 5–25;
62
+ # higher = stronger upweight for rare classes (watch for instability).
63
+ class_weight_cap: 15.0
64
+ focal_gamma: 2.0
65
+
66
+ augmentation:
67
+ strong: true
68
+
69
+ ema:
70
+ enabled: true
71
+ decay: 0.999
72
+
73
+ inference:
74
+ tile_size: 512
75
+ overlap: 0.25
76
+ tta_flip: true
77
+ batch_size: 1
desert_segmentation/data/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from desert_segmentation.data.dataset import SegmentationDataset
2
+ from desert_segmentation.data.mask_encoding import RawMaskCodec
3
+
4
+ __all__ = ["SegmentationDataset", "RawMaskCodec"]
desert_segmentation/data/dataset.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image / mask dataset with optional rare-class biased cropping."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import random
7
+ from pathlib import Path
8
+ from typing import Callable, List, Optional, Sequence, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+ from torch.utils.data import Dataset
14
+
15
+ from desert_segmentation.data.mask_encoding import RawMaskCodec
16
+
17
+
18
+ def _list_images(dir_path: Path) -> List[str]:
19
+ exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
20
+ return sorted(
21
+ f for f in os.listdir(dir_path) if Path(f).suffix.lower() in exts
22
+ )
23
+
24
+
25
+ class SegmentationDataset(Dataset):
26
+ def __init__(
27
+ self,
28
+ images_dir: Path,
29
+ masks_dir: Path,
30
+ codec: RawMaskCodec,
31
+ transform: Optional[Callable] = None,
32
+ mode: str = "train",
33
+ crop_size: int = 512,
34
+ rare_class_crop_prob: float = 0.35,
35
+ ignore_index: int = 255,
36
+ seed: int = 42,
37
+ ) -> None:
38
+ self.images_dir = Path(images_dir)
39
+ self.masks_dir = Path(masks_dir)
40
+ self.codec = codec
41
+ self.transform = transform
42
+ self.mode = mode
43
+ self.crop_size = crop_size
44
+ self.rare_class_crop_prob = rare_class_crop_prob if mode == "train" else 0.0
45
+ self.ignore_index = ignore_index
46
+ self._rng = random.Random(seed)
47
+
48
+ names = _list_images(self.images_dir)
49
+ self._pairs: List[Tuple[str, str]] = []
50
+ for n in names:
51
+ mp = self.masks_dir / n
52
+ if not mp.is_file():
53
+ raise FileNotFoundError(f"Missing mask for {n}: {mp}")
54
+ self._pairs.append((str(self.images_dir / n), str(mp)))
55
+
56
+ if not self._pairs:
57
+ raise RuntimeError(f"No images in {self.images_dir}")
58
+
59
+ def __len__(self) -> int:
60
+ return len(self._pairs)
61
+
62
+ @property
63
+ def image_names(self) -> List[str]:
64
+ """Basenames aligned with dataset indices (for weighted sampling)."""
65
+ return [Path(p[0]).name for p in self._pairs]
66
+
67
+ def _load_pair(self, ip: str, mp: str) -> Tuple[np.ndarray, np.ndarray]:
68
+ image = np.array(Image.open(ip).convert("RGB"))
69
+ raw_mask = np.array(Image.open(mp))
70
+ if raw_mask.ndim == 2:
71
+ enc, _ = self.codec.encode_mask(raw_mask.astype(np.uint16))
72
+ else:
73
+ raise ValueError(f"Expected single-channel mask, got shape {raw_mask.shape}")
74
+ return image, enc
75
+
76
+ def _random_crop_bias_rare(
77
+ self, image: np.ndarray, mask: np.ndarray
78
+ ) -> Tuple[np.ndarray, np.ndarray]:
79
+ h, w = image.shape[:2]
80
+ ch, cw = self.crop_size, self.crop_size
81
+ if h < ch or w < cw:
82
+ pad_h = max(0, ch - h)
83
+ pad_w = max(0, cw - w)
84
+ image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode="constant")
85
+ mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=self.ignore_index)
86
+ h, w = image.shape[:2]
87
+
88
+ if self.mode == "train" and self._rng.random() < self.rare_class_crop_prob:
89
+ hist, _ = np.histogram(mask.flatten(), bins=self.codec.num_classes, range=(0, self.codec.num_classes))
90
+ rare = int(np.argmin(hist))
91
+ ys, xs = np.where(mask == rare)
92
+ if len(xs) > 0:
93
+ idx = self._rng.randrange(len(xs))
94
+ cx, cy = int(xs[idx]), int(ys[idx])
95
+ else:
96
+ cx, cy = w // 2, h // 2
97
+ else:
98
+ cx, cy = self._rng.randrange(w), self._rng.randrange(h)
99
+
100
+ x0 = np.clip(cx - cw // 2, 0, w - cw)
101
+ y0 = np.clip(cy - ch // 2, 0, h - ch)
102
+ return image[y0 : y0 + ch, x0 : x0 + cw], mask[y0 : y0 + ch, x0 : x0 + cw]
103
+
104
+ def __getitem__(self, idx: int) -> dict:
105
+ ip, mp = self._pairs[idx]
106
+ image, mask = self._load_pair(ip, mp)
107
+
108
+ if self.mode == "train":
109
+ image, mask = self._random_crop_bias_rare(image, mask)
110
+
111
+ if self.transform is not None:
112
+ t = self.transform(image=image, mask=mask)
113
+ image = t["image"]
114
+ mask = t["mask"]
115
+
116
+ if isinstance(mask, torch.Tensor):
117
+ mask_t = mask
118
+ else:
119
+ mask_t = torch.from_numpy(np.asarray(mask))
120
+ if mask_t.dtype in (torch.float32, torch.float16):
121
+ mask_t = (mask_t * 255.0).round().clamp(0, 255).long()
122
+ else:
123
+ mask_t = mask_t.long()
124
+
125
+ return {
126
+ "image": image,
127
+ "mask": mask_t,
128
+ "path": ip,
129
+ }
desert_segmentation/data/mask_encoding.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decode 16-bit raw mask values to contiguous class indices and back."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Sequence, Tuple
7
+
8
+ import numpy as np
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class RawMaskCodec:
13
+ """Maps dataset-specific raw label IDs (e.g. uint16 PNG values) to 0..num_classes-1."""
14
+
15
+ raw_ids: Tuple[int, ...]
16
+ class_names: Tuple[str, ...]
17
+
18
+ def __post_init__(self) -> None:
19
+ if len(self.raw_ids) != len(self.class_names):
20
+ raise ValueError("raw_ids and class_names must have the same length")
21
+ if len(set(self.raw_ids)) != len(self.raw_ids):
22
+ raise ValueError("raw_ids must be unique")
23
+
24
+ @property
25
+ def num_classes(self) -> int:
26
+ return len(self.raw_ids)
27
+
28
+ @property
29
+ def raw_to_index(self) -> Dict[int, int]:
30
+ return {r: i for i, r in enumerate(self.raw_ids)}
31
+
32
+ @property
33
+ def index_to_raw(self) -> Dict[int, int]:
34
+ return {i: r for i, r in enumerate(self.raw_ids)}
35
+
36
+ def _build_lut(self) -> np.ndarray:
37
+ max_id = max(self.raw_ids)
38
+ lut = np.full(max_id + 1, 255, dtype=np.uint8)
39
+ for i, rid in enumerate(self.raw_ids):
40
+ lut[rid] = i
41
+ return lut
42
+
43
+ def encode_mask(self, raw: np.ndarray) -> Tuple[np.ndarray, float]:
44
+ """Map raw uint16 labels to uint8 class indices 0..C-1. Returns (encoded, unknown_fraction)."""
45
+ if raw.ndim != 2:
46
+ raise ValueError(f"Expected HxW mask, got shape {raw.shape}")
47
+ lut = self._build_lut()
48
+ if int(raw.max()) >= lut.size:
49
+ raise ValueError(f"Mask value {int(raw.max())} exceeds LUT; extend raw_ids in config.")
50
+ out = lut[raw.astype(np.int64, copy=False)]
51
+ unknown_frac = float((out == 255).mean())
52
+ if unknown_frac > 0:
53
+ bad = out == 255
54
+ raise ValueError(
55
+ f"Unknown mask pixels: {unknown_frac:.6f} of image. "
56
+ f"Unique unknown raw values: {np.unique(raw[bad])[:16]}"
57
+ )
58
+ return out.astype(np.uint8), unknown_frac
59
+
60
+ def decode_to_raw(self, class_indices: np.ndarray) -> np.ndarray:
61
+ """Map class indices back to raw dataset IDs (for visualization/export)."""
62
+ arr = np.asarray(class_indices)
63
+ raw = np.zeros_like(arr, dtype=np.uint16)
64
+ for i, rid in enumerate(self.raw_ids):
65
+ raw[arr == i] = rid
66
+ return raw
67
+
68
+
69
+ def build_codec_from_config(raw_id_list: Sequence[int], names: Sequence[str]) -> RawMaskCodec:
70
+ pairs = sorted(zip(raw_id_list, names), key=lambda x: x[0])
71
+ r, n = zip(*pairs)
72
+ return RawMaskCodec(raw_ids=tuple(int(x) for x in r), class_names=tuple(str(x) for x in n))
73
+
74
+
75
+ def default_desert_codec() -> RawMaskCodec:
76
+ """Default codec for this workspace: 10 classes with fixed raw IDs (see dataset scan)."""
77
+ raw_ids = (100, 200, 300, 500, 550, 600, 700, 800, 7100, 10000)
78
+ names = (
79
+ "id_100",
80
+ "id_200",
81
+ "id_300",
82
+ "id_500",
83
+ "id_550",
84
+ "id_600",
85
+ "id_700",
86
+ "id_800",
87
+ "id_7100",
88
+ "id_10000",
89
+ )
90
+ return RawMaskCodec(raw_ids=raw_ids, class_names=names)
desert_segmentation/data/transforms.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Albumentations pipelines for images and class masks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Tuple
6
+
7
+ import albumentations as A
8
+ from albumentations.pytorch import ToTensorV2
9
+
10
+
11
+ def _base_normalize() -> A.Normalize:
12
+ return A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
13
+
14
+
15
+ def build_train_transforms(
16
+ crop_size: int,
17
+ strong: bool = True,
18
+ ignore_index: int = 255,
19
+ ) -> A.Compose:
20
+ """Spatial crops are applied in `SegmentationDataset` (with rare-class bias)."""
21
+ del crop_size
22
+ geometric: list[Any] = [
23
+ A.HorizontalFlip(p=0.5),
24
+ A.ShiftScaleRotate(
25
+ shift_limit=0.02,
26
+ scale_limit=0.12,
27
+ rotate_limit=10,
28
+ border_mode=0,
29
+ mask_value=ignore_index,
30
+ p=0.55,
31
+ ),
32
+ ]
33
+ color: list[Any] = [
34
+ A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.55),
35
+ A.HueSaturationValue(hue_shift_limit=14, sat_shift_limit=22, val_shift_limit=14, p=0.4),
36
+ A.GaussianBlur(blur_limit=(3, 5), p=0.22),
37
+ A.GaussNoise(var_limit=(8.0, 48.0), p=0.25),
38
+ A.ImageCompression(quality_lower=70, quality_upper=100, p=0.25),
39
+ A.RGBShift(r_shift_limit=18, g_shift_limit=18, b_shift_limit=18, p=0.28),
40
+ ]
41
+ if strong:
42
+ color.extend(
43
+ [
44
+ A.RandomSunFlare(
45
+ flare_roi=(0.45, 0.0, 1.0, 0.42),
46
+ angle_lower=0.4,
47
+ p=0.12,
48
+ ),
49
+ A.RandomShadow(
50
+ shadow_roi=(0, 0.42, 1, 1),
51
+ num_shadows_lower=1,
52
+ num_shadows_upper=2,
53
+ p=0.16,
54
+ ),
55
+ ]
56
+ )
57
+ return A.Compose(
58
+ geometric + color + [_base_normalize(), ToTensorV2()],
59
+ additional_targets={"mask": "mask"},
60
+ )
61
+
62
+
63
+ def build_val_transforms(
64
+ crop_size: int,
65
+ ignore_index: int = 255,
66
+ ) -> A.Compose:
67
+ return A.Compose(
68
+ [
69
+ A.LongestMaxSize(max_size=crop_size),
70
+ A.PadIfNeeded(
71
+ min_height=crop_size,
72
+ min_width=crop_size,
73
+ border_mode=0,
74
+ value=0,
75
+ mask_value=ignore_index,
76
+ ),
77
+ _base_normalize(),
78
+ ToTensorV2(),
79
+ ],
80
+ additional_targets={"mask": "mask"},
81
+ )
82
+
83
+
84
+ def apply_transform(
85
+ transform: A.Compose,
86
+ image,
87
+ mask,
88
+ ) -> Tuple[Any, Any]:
89
+ out = transform(image=image, mask=mask)
90
+ return out["image"], out["mask"]
desert_segmentation/demo/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from desert_segmentation.demo.inference_ui import (
2
+ build_legend_rows,
3
+ dominant_classes_markdown,
4
+ legend_table_html,
5
+ side_by_side_strip,
6
+ validate_rgb_array,
7
+ )
8
+
9
+ __all__ = [
10
+ "build_legend_rows",
11
+ "dominant_classes_markdown",
12
+ "legend_table_html",
13
+ "side_by_side_strip",
14
+ "validate_rgb_array",
15
+ ]
desert_segmentation/demo/inference_ui.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for Gradio / web demo: legend, validation, composites."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import html
6
+ from typing import Any, Dict, List, Sequence, Tuple
7
+
8
+ import numpy as np
9
+
10
+
11
+ from desert_segmentation.utils.viz import palette
12
+
13
+
14
+ def validate_rgb_array(
15
+ arr: np.ndarray,
16
+ max_side: int = 4096,
17
+ max_megapixels: float = 16.0,
18
+ ) -> None:
19
+ """Raises ValueError with a user-facing message if invalid or too large."""
20
+ if arr is None:
21
+ raise ValueError("No image provided.")
22
+ if not isinstance(arr, np.ndarray):
23
+ arr = np.asarray(arr)
24
+ if arr.ndim != 3 or arr.shape[2] != 3:
25
+ raise ValueError(f"Expected RGB image HxWx3, got shape {getattr(arr, 'shape', None)}")
26
+ h, w = arr.shape[0], arr.shape[1]
27
+ if h < 1 or w < 1:
28
+ raise ValueError("Image is empty.")
29
+ if max(h, w) > max_side:
30
+ raise ValueError(f"Image too large: max side is {max_side}px (got {h}x{w}).")
31
+ mp = (h * w) / 1_000_000.0
32
+ if mp > max_megapixels:
33
+ raise ValueError(f"Image too large: max {max_megapixels} megapixels (got {mp:.1f} MP).")
34
+
35
+
36
+ def build_legend_rows(class_names: Sequence[str], num_classes: int, seed: int = 42) -> Tuple[List[Dict[str, Any]], np.ndarray]:
37
+ """Returns list of {index, name, hex, r, g, b} and color table (same seed as viz.palette)."""
38
+ colors = palette(num_classes, seed=seed)
39
+ rows: List[Dict[str, Any]] = []
40
+ for i, name in enumerate(class_names):
41
+ r, g, b = (int(colors[i, 0]), int(colors[i, 1]), int(colors[i, 2]))
42
+ rows.append(
43
+ {
44
+ "index": i,
45
+ "name": str(name),
46
+ "hex": f"#{r:02x}{g:02x}{b:02x}",
47
+ "r": r,
48
+ "g": g,
49
+ "b": b,
50
+ }
51
+ )
52
+ return rows, colors
53
+
54
+
55
+ def legend_table_html(rows: Sequence[Dict[str, Any]]) -> str:
56
+ """Small HTML table with color swatches for Gradio gr.HTML."""
57
+ parts = [
58
+ "<table style='border-collapse:collapse;font-size:14px'>",
59
+ "<thead><tr><th>Swatch</th><th>#</th><th>Name</th><th>Hex</th></tr></thead><tbody>",
60
+ ]
61
+ for row in rows:
62
+ sw = f"background-color:{row['hex']};width:32px;height:22px;border:1px solid #888"
63
+ safe_name = html.escape(str(row["name"]))
64
+ parts.append(
65
+ f"<tr><td><div style='{sw}'></div></td>"
66
+ f"<td>{row['index']}</td><td>{safe_name}</td><td><code>{row['hex']}</code></td></tr>"
67
+ )
68
+ parts.append("</tbody></table>")
69
+ return "".join(parts)
70
+
71
+
72
+ def dominant_classes_markdown(pred: np.ndarray, class_names: Sequence[str], top_k: int = 3) -> str:
73
+ flat = pred.reshape(-1).astype(np.int64, copy=False)
74
+ n = len(class_names)
75
+ counts = np.bincount(flat, minlength=n)
76
+ total = int(counts.sum())
77
+ if total == 0:
78
+ return "_No pixels._"
79
+ order = np.argsort(-counts)
80
+ lines: List[str] = []
81
+ for i in order[:top_k]:
82
+ c = int(counts[i])
83
+ if c == 0:
84
+ continue
85
+ pct = 100.0 * c / total
86
+ name = class_names[i] if i < len(class_names) else str(i)
87
+ lines.append(f"- **{name}** (class {i}): **{pct:.1f}%**")
88
+ return "\n".join(lines) if lines else "_No dominant classes._"
89
+
90
+
91
+ def side_by_side_strip(rgb: np.ndarray, mask_rgb: np.ndarray, overlay_rgb: np.ndarray, gap: int = 8) -> np.ndarray:
92
+ """Horizontal strip: RGB | colored mask | overlay."""
93
+ h, w = rgb.shape[:2]
94
+ gap_arr = np.zeros((h, gap, 3), dtype=np.uint8)
95
+ return np.concatenate([rgb, gap_arr, mask_rgb, gap_arr, overlay_rgb], axis=1)
desert_segmentation/infer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from desert_segmentation.infer.predict import predict_image, predict_folder
2
+
3
+ __all__ = ["predict_image", "predict_folder"]
desert_segmentation/infer/predict.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sliding-window inference with optional horizontal-flip TTA and ONNX export helper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import time
8
+ from pathlib import Path
9
+ from typing import List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+
17
+ from desert_segmentation.data.mask_encoding import RawMaskCodec, build_codec_from_config
18
+ from desert_segmentation.models.factory import create_model
19
+ from desert_segmentation.utils.viz import blend_overlay, colorize_mask, palette, save_triplet
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def _gaussian_2d(h: int, w: int) -> np.ndarray:
25
+ yy, xx = np.ogrid[:h, :w]
26
+ cy, cx = (h - 1) / 2.0, (w - 1) / 2.0
27
+ sig = min(h, w) / 3.0
28
+ g = np.exp(-(((yy - cy) ** 2 + (xx - cx) ** 2) / (2.0 * sig**2)))
29
+ return g.astype(np.float32)
30
+
31
+
32
+ def _preprocess(
33
+ rgb: np.ndarray,
34
+ mean: Tuple[float, float, float],
35
+ std: Tuple[float, float, float],
36
+ ) -> torch.Tensor:
37
+ # Keep float32 end-to-end: np.array(mean) defaults to float64 and would upcast x → conv2d dtype mismatch.
38
+ x = rgb.astype(np.float32, copy=False) / 255.0
39
+ m = np.asarray(mean, dtype=np.float32).reshape(1, 1, 3)
40
+ s = np.asarray(std, dtype=np.float32).reshape(1, 1, 3)
41
+ x = (x - m) / s
42
+ t = torch.from_numpy(np.ascontiguousarray(x)).permute(2, 0, 1).unsqueeze(0)
43
+ return t.float()
44
+
45
+
46
+ @torch.no_grad()
47
+ def _forward_logits(
48
+ model: nn.Module,
49
+ x: torch.Tensor,
50
+ device: torch.device,
51
+ tta_flip: bool,
52
+ ) -> torch.Tensor:
53
+ logits = model(x)
54
+ if not tta_flip:
55
+ return logits
56
+ xf = torch.flip(x, dims=[3])
57
+ lf = model(xf)
58
+ lf = torch.flip(lf, dims=[3])
59
+ return (logits + lf) * 0.5
60
+
61
+
62
+ def _tile_starts(length: int, tile: int, stride: int) -> List[int]:
63
+ if length <= tile:
64
+ return [0]
65
+ last_pos = length - tile
66
+ starts = list(range(0, last_pos + 1, stride))
67
+ if not starts:
68
+ return [0]
69
+ if starts[-1] != last_pos:
70
+ starts.append(last_pos)
71
+ return sorted(set(starts))
72
+
73
+
74
+ @torch.no_grad()
75
+ def predict_image(
76
+ model: nn.Module,
77
+ image_np: np.ndarray,
78
+ device: torch.device,
79
+ tile_size: int,
80
+ overlap: float,
81
+ tta_flip: bool,
82
+ mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
83
+ std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
84
+ ) -> np.ndarray:
85
+ """Returns HxW int class map."""
86
+ h, w = image_np.shape[:2]
87
+ if h <= tile_size and w <= tile_size:
88
+ t = _preprocess(image_np, mean, std).to(device)
89
+ logits = _forward_logits(model, t, device, tta_flip)
90
+ return logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.int64)
91
+
92
+ stride = max(1, int(tile_size * (1.0 - overlap)))
93
+ g = _gaussian_2d(tile_size, tile_size)
94
+
95
+ n_ty = len(_tile_starts(h, tile_size, stride))
96
+ n_tx = len(_tile_starts(w, tile_size, stride))
97
+ H_pad = (n_ty - 1) * stride + tile_size
98
+ W_pad = (n_tx - 1) * stride + tile_size
99
+ pad_h = max(0, H_pad - h)
100
+ pad_w = max(0, W_pad - w)
101
+ img_p = np.pad(image_np, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
102
+ H, W = img_p.shape[:2]
103
+
104
+ t0 = _preprocess(img_p[0:tile_size, 0:tile_size], mean, std).to(device)
105
+ logits0 = _forward_logits(model, t0, device, tta_flip)
106
+ num_classes = int(logits0.shape[1])
107
+ acc = np.zeros((num_classes, H, W), dtype=np.float32)
108
+ weight = np.zeros((H, W), dtype=np.float32)
109
+
110
+ for y in _tile_starts(H, tile_size, stride):
111
+ for x in _tile_starts(W, tile_size, stride):
112
+ tile = img_p[y : y + tile_size, x : x + tile_size]
113
+ t = _preprocess(tile, mean, std).to(device)
114
+ logits = _forward_logits(model, t, device, tta_flip)
115
+ probs = torch.softmax(logits, dim=1)
116
+ ls = probs.squeeze(0).cpu().numpy()
117
+ acc[:, y : y + tile_size, x : x + tile_size] += ls * g
118
+ weight[y : y + tile_size, x : x + tile_size] += g
119
+
120
+ pred = np.argmax(acc, axis=0).astype(np.int64)
121
+ return pred[:h, :w]
122
+
123
+
124
+ def _load_model_for_inference(
125
+ checkpoint_path: Path,
126
+ device: torch.device,
127
+ ) -> Tuple[nn.Module, dict, RawMaskCodec]:
128
+ try:
129
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
130
+ except TypeError:
131
+ ckpt = torch.load(checkpoint_path, map_location=device)
132
+ cfg = ckpt["config"]
133
+ raw_ids = cfg["data"]["raw_ids"]
134
+ names = ckpt.get("class_names") or tuple(cfg["data"].get("class_names") or ())
135
+ if not names:
136
+ names = tuple(str(x) for x in raw_ids)
137
+ codec = build_codec_from_config(raw_ids, names)
138
+ model = create_model(cfg["model"], num_classes=codec.num_classes).to(device)
139
+ if ckpt.get("model") is not None:
140
+ model.load_state_dict(ckpt["model"])
141
+ if ckpt.get("ema") is not None:
142
+ for n, p in model.named_parameters():
143
+ if n in ckpt["ema"]:
144
+ p.data.copy_(ckpt["ema"][n].to(device))
145
+ model.eval()
146
+ return model, cfg, codec
147
+
148
+
149
+ @torch.no_grad()
150
+ def predict_folder(
151
+ checkpoint_path: Path,
152
+ image_dir: Path,
153
+ out_dir: Path,
154
+ device: Optional[torch.device] = None,
155
+ limit: Optional[int] = None,
156
+ ) -> None:
157
+ device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
158
+ model, cfg, codec = _load_model_for_inference(checkpoint_path, device)
159
+ icfg = cfg.get("inference") or {}
160
+ tile_size = int(icfg.get("tile_size", 512))
161
+ overlap = float(icfg.get("overlap", 0.25))
162
+ tta = bool(icfg.get("tta_flip", True))
163
+
164
+ out_dir.mkdir(parents=True, exist_ok=True)
165
+ colors = palette(codec.num_classes)
166
+ names = sorted(f for f in os.listdir(image_dir) if f.lower().endswith((".png", ".jpg", ".jpeg")))
167
+ if limit is not None:
168
+ names = names[:limit]
169
+ times: List[float] = []
170
+ for name in tqdm(names, desc="infer"):
171
+ ip = image_dir / name
172
+ rgb = np.array(Image.open(ip).convert("RGB"))
173
+ t0 = time.perf_counter()
174
+ pred = predict_image(model, rgb, device, tile_size, overlap, tta)
175
+ times.append(time.perf_counter() - t0)
176
+ overlay = blend_overlay(rgb, colorize_mask(pred, colors))
177
+ Image.fromarray(overlay).save(out_dir / f"pred_{name}")
178
+ save_triplet(out_dir / f"triplet_{name}", rgb, None, pred, colors)
179
+
180
+ if times:
181
+ mean_ms = float(np.mean(times) * 1000.0)
182
+ logger.info("mean inference time: %.2f ms (device=%s)", mean_ms, device)
183
+ with (out_dir / "latency.txt").open("w", encoding="utf-8") as f:
184
+ f.write(f"mean_ms_per_image={mean_ms:.4f}\n")
185
+ f.write(f"device={device}\n")
186
+
187
+
188
+ def export_onnx(
189
+ checkpoint_path: Path,
190
+ out_onnx: Path,
191
+ height: int = 512,
192
+ width: int = 512,
193
+ opset: int = 17,
194
+ ) -> None:
195
+ device = torch.device("cpu")
196
+ model, _, _ = _load_model_for_inference(checkpoint_path, device)
197
+ model.eval()
198
+ dummy = torch.randn(1, 3, height, width, device=device)
199
+ torch.onnx.export(
200
+ model,
201
+ dummy,
202
+ str(out_onnx),
203
+ input_names=["input"],
204
+ output_names=["logits"],
205
+ opset_version=opset,
206
+ dynamic_axes={
207
+ "input": {0: "batch", 2: "height", 3: "width"},
208
+ "logits": {0: "batch", 2: "h", 3: "w"},
209
+ },
210
+ )
desert_segmentation/losses/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from desert_segmentation.losses.combined import build_loss
2
+
3
+ __all__ = ["build_loss"]
desert_segmentation/losses/combined.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Segmentation losses: CE, weighted CE, focal, Dice, and combinations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def _ce(
13
+ logits: torch.Tensor,
14
+ target: torch.Tensor,
15
+ weight: Optional[torch.Tensor],
16
+ ignore_index: int,
17
+ label_smoothing: float,
18
+ ) -> torch.Tensor:
19
+ return F.cross_entropy(
20
+ logits,
21
+ target,
22
+ weight=weight,
23
+ ignore_index=ignore_index,
24
+ label_smoothing=label_smoothing,
25
+ )
26
+
27
+
28
+ def _focal_ce(
29
+ logits: torch.Tensor,
30
+ target: torch.Tensor,
31
+ gamma: float,
32
+ weight: Optional[torch.Tensor],
33
+ ignore_index: int,
34
+ ) -> torch.Tensor:
35
+ log_probs = F.log_softmax(logits, dim=1)
36
+ probs = log_probs.exp()
37
+ tgt = target.clone()
38
+ valid = tgt != ignore_index
39
+ tgt_clamped = tgt.clone()
40
+ tgt_clamped[~valid] = 0
41
+ log_pt = log_probs.gather(1, tgt_clamped.unsqueeze(1)).squeeze(1)
42
+ pt = probs.gather(1, tgt_clamped.unsqueeze(1)).squeeze(1)
43
+ focal = (1 - pt) ** gamma * (-log_pt)
44
+ if weight is not None:
45
+ focal = focal * weight[tgt_clamped]
46
+ focal = focal * valid.float()
47
+ return focal.sum() / (valid.float().sum().clamp_min(1.0))
48
+
49
+
50
+ def _multiclass_dice(
51
+ logits: torch.Tensor,
52
+ target: torch.Tensor,
53
+ ignore_index: int,
54
+ eps: float = 1e-6,
55
+ ) -> torch.Tensor:
56
+ probs = F.softmax(logits, dim=1)
57
+ n, c, _, _ = probs.shape
58
+ tgt = target
59
+ valid = tgt != ignore_index
60
+ dice_losses = []
61
+ for k in range(c):
62
+ pk = probs[:, k]
63
+ tk = (tgt == k).float()
64
+ m = valid.float()
65
+ pk, tk = pk * m, tk * m
66
+ inter = (pk * tk).sum(dim=(1, 2))
67
+ denom = pk.sum(dim=(1, 2)) + tk.sum(dim=(1, 2)) + eps
68
+ dice = 1.0 - (2.0 * inter + eps) / denom
69
+ dice_losses.append(dice.mean())
70
+ return torch.stack(dice_losses).mean()
71
+
72
+
73
+ class CombinedSegLoss(nn.Module):
74
+ def __init__(
75
+ self,
76
+ mode: str,
77
+ num_classes: int,
78
+ ignore_index: int = 255,
79
+ class_weights: Optional[torch.Tensor] = None,
80
+ dice_weight: float = 0.5,
81
+ label_smoothing: float = 0.05,
82
+ focal_gamma: float = 2.0,
83
+ ) -> None:
84
+ super().__init__()
85
+ self.mode = mode
86
+ self.num_classes = num_classes
87
+ self.ignore_index = ignore_index
88
+ self.register_buffer("class_weights", class_weights if class_weights is not None else torch.ones(num_classes))
89
+ self.dice_weight = dice_weight
90
+ self.label_smoothing = label_smoothing
91
+ self.focal_gamma = focal_gamma
92
+
93
+ def forward(self, logits: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, dict]:
94
+ w = self.class_weights
95
+ if self.mode == "ce":
96
+ loss = _ce(logits, target, None, self.ignore_index, self.label_smoothing)
97
+ elif self.mode == "weighted_ce":
98
+ loss = _ce(logits, target, w, self.ignore_index, self.label_smoothing)
99
+ elif self.mode == "focal_ce":
100
+ loss = _focal_ce(logits, target, self.focal_gamma, w, self.ignore_index)
101
+ elif self.mode == "ce_dice":
102
+ ce = _ce(logits, target, w, self.ignore_index, self.label_smoothing)
103
+ dice = _multiclass_dice(logits, target, self.ignore_index)
104
+ loss = ce + self.dice_weight * dice
105
+ elif self.mode == "focal_ce_dice":
106
+ focal = _focal_ce(logits, target, self.focal_gamma, w, self.ignore_index)
107
+ dice = _multiclass_dice(logits, target, self.ignore_index)
108
+ loss = focal + self.dice_weight * dice
109
+ else:
110
+ raise ValueError(f"Unknown loss mode {self.mode}")
111
+ return loss, {"loss": float(loss.detach().cpu())}
112
+
113
+
114
+ def build_loss(
115
+ loss_cfg: dict,
116
+ num_classes: int,
117
+ class_weights: Optional[torch.Tensor],
118
+ ignore_index: int,
119
+ ) -> CombinedSegLoss:
120
+ mode = loss_cfg.get("name", "ce_dice")
121
+ return CombinedSegLoss(
122
+ mode=mode,
123
+ num_classes=num_classes,
124
+ ignore_index=ignore_index,
125
+ class_weights=class_weights,
126
+ dice_weight=float(loss_cfg.get("dice_weight", 0.5)),
127
+ label_smoothing=float(loss_cfg.get("label_smoothing", 0.0)),
128
+ focal_gamma=float(loss_cfg.get("focal_gamma", 2.0)),
129
+ )
130
+
131
+
132
+ def compute_class_weights_from_freq(
133
+ freq: torch.Tensor,
134
+ cap: float = 15.0,
135
+ eps: float = 1e-6,
136
+ ) -> torch.Tensor:
137
+ """Inverse log frequency with mean normalization and per-class cap on max/min ratio."""
138
+ w = 1.0 / torch.log(freq + eps)
139
+ w = w / w.mean()
140
+ ratio = w / w.median()
141
+ ratio = torch.clamp(ratio, max=cap)
142
+ w = ratio * w.median()
143
+ return w
desert_segmentation/metrics/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from desert_segmentation.metrics.iou import (
2
+ IoUMetrics,
3
+ compute_confusion,
4
+ confusion_to_accuracy_metrics,
5
+ gt_pixel_counts,
6
+ valid_class_miou_from_confusion,
7
+ )
8
+
9
+ __all__ = [
10
+ "IoUMetrics",
11
+ "compute_confusion",
12
+ "confusion_to_accuracy_metrics",
13
+ "gt_pixel_counts",
14
+ "valid_class_miou_from_confusion",
15
+ ]
desert_segmentation/metrics/iou.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-class IoU, mIoU, frequency-weighted IoU, confusion matrix."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def compute_confusion(
12
+ logits: torch.Tensor,
13
+ target: torch.Tensor,
14
+ num_classes: int,
15
+ ignore_index: int = 255,
16
+ ) -> torch.Tensor:
17
+ """Accumulate confusion matrix (pred rows, target columns) — shape CxC."""
18
+ pred = logits.argmax(dim=1).view(-1)
19
+ tgt = target.view(-1)
20
+ valid = tgt != ignore_index
21
+ pred = pred[valid]
22
+ tgt = tgt[valid]
23
+ if pred.numel() == 0:
24
+ return torch.zeros(num_classes, num_classes, dtype=torch.int64, device=logits.device)
25
+ idx = tgt * num_classes + pred
26
+ cm = torch.bincount(idx, minlength=num_classes * num_classes).reshape(num_classes, num_classes)
27
+ return cm
28
+
29
+
30
+ def confusion_to_accuracy_metrics(
31
+ cm: Union[np.ndarray, torch.Tensor],
32
+ eps: float = 1e-12,
33
+ ) -> Dict[str, float | np.ndarray]:
34
+ """Pixel accuracies from confusion ``cm[gt_i, pred_j]`` (same layout as ``IoUMetrics``).
35
+
36
+ - **global_pixel_accuracy:** ``trace(cm) / sum(cm)`` — fraction of pixels correct.
37
+ - **mean_class_accuracy:** mean of per-class **recall** ``cm[k,k] / sum_j cm[k,j]`` over
38
+ classes with at least one ground-truth pixel (ignores empty rows).
39
+
40
+ Returns ``per_class_recall`` aligned with class index for optional reporting.
41
+ """
42
+ if isinstance(cm, torch.Tensor):
43
+ cm = cm.detach().cpu().numpy()
44
+ cm = np.asarray(cm, dtype=np.float64)
45
+ total = cm.sum()
46
+ if total <= eps:
47
+ z = np.zeros(cm.shape[0], dtype=np.float64)
48
+ return {
49
+ "global_pixel_accuracy": 0.0,
50
+ "mean_class_accuracy": 0.0,
51
+ "per_class_recall": z,
52
+ }
53
+ trace = np.trace(cm)
54
+ global_acc = float(trace / total)
55
+ row_sums = cm.sum(axis=1)
56
+ diag = np.diag(cm)
57
+ with np.errstate(divide="ignore", invalid="ignore"):
58
+ per_class_recall = np.where(row_sums > eps, diag / np.maximum(row_sums, eps), np.nan)
59
+ present = row_sums > eps
60
+ mean_class_acc = (
61
+ float(np.nanmean(per_class_recall[present])) if np.any(present) else 0.0
62
+ )
63
+ return {
64
+ "global_pixel_accuracy": global_acc,
65
+ "mean_class_accuracy": mean_class_acc,
66
+ "per_class_recall": per_class_recall,
67
+ }
68
+
69
+
70
+ def gt_pixel_counts(cm: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
71
+ """Ground-truth pixel counts per class: ``sum_j cm[gt_k, pred_j]`` (row sums)."""
72
+ if isinstance(cm, torch.Tensor):
73
+ cm = cm.detach().cpu().numpy()
74
+ cm = np.asarray(cm, dtype=np.float64)
75
+ return np.sum(cm, axis=1).astype(np.int64)
76
+
77
+
78
+ def valid_class_miou_from_confusion(
79
+ cm: Union[np.ndarray, torch.Tensor],
80
+ eps: float = 1e-6,
81
+ ) -> float:
82
+ """Mean IoU over classes that have at least one ground-truth pixel on the val set.
83
+
84
+ Unlike full mIoU (mean over all classes, often many zeros when a class is absent from
85
+ val GT), this only averages **finite** per-class IoU values for rows with ``GT > 0``.
86
+ Returns ``0.0`` if no class has any GT pixels.
87
+ """
88
+ if isinstance(cm, torch.Tensor):
89
+ cm = cm.detach().cpu().numpy()
90
+ cm = np.asarray(cm, dtype=np.float64)
91
+ diag = np.diag(cm)
92
+ rows = cm.sum(axis=1)
93
+ cols = cm.sum(axis=0)
94
+ union = rows + cols - diag + eps
95
+ with np.errstate(divide="ignore", invalid="ignore"):
96
+ iou = diag / union
97
+ present = rows > 0
98
+ if not np.any(present):
99
+ return 0.0
100
+ finite = present & np.isfinite(iou)
101
+ if not np.any(finite):
102
+ return 0.0
103
+ return float(np.mean(iou[finite]))
104
+
105
+
106
+ def confusion_to_iou(cm: torch.Tensor) -> Tuple[torch.Tensor, float, float]:
107
+ """Returns per-class IoU, mean IoU, frequency-weighted IoU."""
108
+ diag = torch.diag(cm).float()
109
+ rows = cm.sum(dim=1).float()
110
+ cols = cm.sum(dim=0).float()
111
+ union = rows + cols - diag + 1e-6
112
+ iou = diag / union
113
+ miou = iou[torch.isfinite(iou)].mean().item()
114
+ freq = cols / (cols.sum() + 1e-6)
115
+ fw_iou = (iou * freq).sum().item()
116
+ return iou, miou, fw_iou
117
+
118
+
119
+ class IoUMetrics:
120
+ def __init__(self, num_classes: int, ignore_index: int = 255, device: Optional[torch.device] = None):
121
+ self.num_classes = num_classes
122
+ self.ignore_index = ignore_index
123
+ self.device = device or torch.device("cpu")
124
+ self.reset()
125
+
126
+ def reset(self) -> None:
127
+ self._cm = torch.zeros(self.num_classes, self.num_classes, dtype=torch.int64, device=self.device)
128
+
129
+ @torch.no_grad()
130
+ def update(self, logits: torch.Tensor, target: torch.Tensor) -> None:
131
+ logits = logits.to(self.device)
132
+ target = target.to(self.device)
133
+ self._cm += compute_confusion(logits, target, self.num_classes, self.ignore_index).to(self.device)
134
+
135
+ def compute(self) -> Dict[str, float | np.ndarray]:
136
+ cm = self._cm.cpu()
137
+ iou, miou, fw_iou = confusion_to_iou(cm)
138
+ return {
139
+ "per_class_iou": iou.numpy(),
140
+ "miou": miou,
141
+ "fw_iou": fw_iou,
142
+ "confusion": cm.numpy(),
143
+ }
desert_segmentation/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from desert_segmentation.models.factory import create_model
2
+
3
+ __all__ = ["create_model"]
desert_segmentation/models/factory.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build segmentation models via segmentation_models_pytorch."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict
6
+
7
+ import segmentation_models_pytorch as smp
8
+ import torch.nn as nn
9
+
10
+
11
+ def create_model(model_cfg: Dict[str, Any], num_classes: int) -> nn.Module:
12
+ arch = (model_cfg.get("architecture") or "deeplabv3plus").lower()
13
+ encoder_name = model_cfg.get("encoder_name", "resnet50")
14
+ encoder_weights = model_cfg.get("encoder_weights", "imagenet")
15
+
16
+ if arch == "deeplabv3plus":
17
+ return smp.DeepLabV3Plus(
18
+ encoder_name=encoder_name,
19
+ encoder_weights=encoder_weights,
20
+ in_channels=3,
21
+ classes=num_classes,
22
+ )
23
+ if arch == "unet":
24
+ return smp.Unet(
25
+ encoder_name=encoder_name,
26
+ encoder_weights=encoder_weights,
27
+ in_channels=3,
28
+ classes=num_classes,
29
+ )
30
+ if arch == "fpn":
31
+ return smp.FPN(
32
+ encoder_name=encoder_name,
33
+ encoder_weights=encoder_weights,
34
+ in_channels=3,
35
+ classes=num_classes,
36
+ )
37
+ raise ValueError(f"Unknown architecture: {arch}")
desert_segmentation/train/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from desert_segmentation.train.evaluate import evaluate
2
+ from desert_segmentation.train.trainer import train
3
+
4
+ __all__ = ["evaluate", "train"]
desert_segmentation/train/evaluate.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Validation loop and metric aggregation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from desert_segmentation.metrics.iou import IoUMetrics
12
+
13
+
14
+ @torch.no_grad()
15
+ def evaluate(
16
+ model: torch.nn.Module,
17
+ loader: DataLoader,
18
+ device: torch.device,
19
+ num_classes: int,
20
+ ignore_index: int = 255,
21
+ desc: str = "val",
22
+ ) -> dict:
23
+ model.eval()
24
+ metrics = IoUMetrics(num_classes=num_classes, ignore_index=ignore_index, device=device)
25
+ for batch in tqdm(loader, desc=desc, leave=False):
26
+ images = batch["image"].to(device, non_blocking=True)
27
+ masks = batch["mask"].to(device, non_blocking=True)
28
+ logits = model(images)
29
+ metrics.update(logits, masks)
30
+ return metrics.compute()
desert_segmentation/train/trainer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training loop with AMP, cosine+warmup, EMA, best-mIoU checkpointing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import json
7
+ import logging
8
+ import math
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.cuda.amp import GradScaler, autocast
15
+ from torch.optim import AdamW
16
+ from torch.optim.lr_scheduler import LambdaLR
17
+ from torch.utils.data import DataLoader
18
+ from tqdm import tqdm
19
+
20
+ from desert_segmentation.train.evaluate import evaluate
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class ModelEMA:
26
+ """Exponential moving average of model parameters."""
27
+
28
+ def __init__(self, model: nn.Module, decay: float = 0.999) -> None:
29
+ self.decay = decay
30
+ self.shadow: Dict[str, torch.Tensor] = {}
31
+ self._collect(model)
32
+
33
+ @torch.no_grad()
34
+ def _collect(self, model: nn.Module) -> None:
35
+ for n, p in model.named_parameters():
36
+ if p.requires_grad:
37
+ self.shadow[n] = p.detach().clone()
38
+
39
+ @torch.no_grad()
40
+ def update(self, model: nn.Module) -> None:
41
+ for n, p in model.named_parameters():
42
+ if not p.requires_grad:
43
+ continue
44
+ self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)
45
+
46
+ @torch.no_grad()
47
+ def copy_to(self, model: nn.Module) -> None:
48
+ for n, p in model.named_parameters():
49
+ if n in self.shadow:
50
+ p.data.copy_(self.shadow[n])
51
+
52
+ def _warmup_cosine_lambda(
53
+ total_steps: int,
54
+ warmup_steps: int,
55
+ min_ratio: float = 0.01,
56
+ ) -> Any:
57
+ def lr_lambda(step: int) -> float:
58
+ if step < warmup_steps:
59
+ return float(step + 1) / float(max(1, warmup_steps))
60
+ progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
61
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * progress))
62
+
63
+ return lr_lambda
64
+
65
+
66
+ def train(
67
+ model: nn.Module,
68
+ train_loader: DataLoader,
69
+ val_loader: DataLoader,
70
+ criterion: nn.Module,
71
+ device: torch.device,
72
+ cfg: Dict[str, Any],
73
+ num_classes: int,
74
+ ignore_index: int,
75
+ checkpoint_dir: Path,
76
+ class_names: tuple[str, ...],
77
+ max_train_batches: Optional[int] = None,
78
+ ) -> Dict[str, Any]:
79
+ tcfg = cfg["train"]
80
+ epochs = int(tcfg["epochs"])
81
+ lr = float(tcfg["lr"])
82
+ wd = float(tcfg["weight_decay"])
83
+ amp_enabled = bool(tcfg.get("amp", True)) and torch.cuda.is_available()
84
+ clip = float(tcfg.get("gradient_clip", 0.0))
85
+ warmup_ratio = float(tcfg.get("warmup_ratio", 0.08))
86
+ patience = int(tcfg.get("early_stop_patience", 20))
87
+ log_interval = int(tcfg.get("log_interval", 20))
88
+
89
+ ema_cfg = cfg.get("ema") or {}
90
+ use_ema = bool(ema_cfg.get("enabled", False))
91
+ ema_decay = float(ema_cfg.get("decay", 0.999))
92
+ ema: Optional[ModelEMA] = ModelEMA(model, decay=ema_decay) if use_ema else None
93
+
94
+ opt = AdamW(model.parameters(), lr=lr, weight_decay=wd)
95
+ steps_per_epoch = max(1, len(train_loader))
96
+ total_steps = steps_per_epoch * epochs
97
+ warmup_steps = max(1, int(total_steps * warmup_ratio))
98
+ sched = LambdaLR(opt, _warmup_cosine_lambda(total_steps, warmup_steps))
99
+ scaler: Optional[GradScaler] = GradScaler() if amp_enabled else None
100
+
101
+ best_miou = -1.0
102
+ bad_epochs = 0
103
+ history: list = []
104
+
105
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
106
+ best_path = checkpoint_dir / "best.pt"
107
+ last_path = checkpoint_dir / "last.pt"
108
+
109
+ global_step = 0
110
+ for epoch in range(1, epochs + 1):
111
+ model.train()
112
+ running = 0.0
113
+ n_log = 0
114
+ pbar = tqdm(train_loader, desc=f"train {epoch}/{epochs}")
115
+ for batch_idx, batch in enumerate(pbar):
116
+ images = batch["image"].to(device, non_blocking=True)
117
+ masks = batch["mask"].to(device, non_blocking=True)
118
+ opt.zero_grad(set_to_none=True)
119
+ with autocast(enabled=amp_enabled):
120
+ logits = model(images)
121
+ loss, _ = criterion(logits, masks)
122
+ if scaler is not None:
123
+ scaler.scale(loss).backward()
124
+ if clip > 0:
125
+ scaler.unscale_(opt)
126
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
127
+ scaler.step(opt)
128
+ scaler.update()
129
+ else:
130
+ loss.backward()
131
+ if clip > 0:
132
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
133
+ opt.step()
134
+ sched.step()
135
+ global_step += 1
136
+ if ema is not None:
137
+ ema.update(model)
138
+ running += float(loss.detach().cpu())
139
+ n_log += 1
140
+ if global_step % log_interval == 0:
141
+ pbar.set_postfix(loss=f"{running / max(n_log, 1):.4f}")
142
+ running = 0.0
143
+ n_log = 0
144
+ if max_train_batches is not None and (batch_idx + 1) >= max_train_batches:
145
+ break
146
+
147
+ backup = copy.deepcopy(model.state_dict())
148
+ if ema is not None:
149
+ ema.copy_to(model)
150
+ val_metrics = evaluate(
151
+ model,
152
+ val_loader,
153
+ device,
154
+ num_classes=num_classes,
155
+ ignore_index=ignore_index,
156
+ desc=f"val {epoch}",
157
+ )
158
+ model.load_state_dict(backup)
159
+
160
+ miou = float(val_metrics["miou"])
161
+ row = {"epoch": epoch, "miou": miou, "fw_iou": float(val_metrics["fw_iou"])}
162
+ history.append(row)
163
+ logger.info(
164
+ "epoch %s | val mIoU=%.4f fwIoU=%.4f",
165
+ epoch,
166
+ miou,
167
+ row["fw_iou"],
168
+ )
169
+
170
+ torch.save(
171
+ {
172
+ "epoch": epoch,
173
+ "model": model.state_dict(),
174
+ "ema": ema.shadow if ema is not None else None,
175
+ "optimizer": opt.state_dict(),
176
+ "config": cfg,
177
+ "class_names": class_names,
178
+ },
179
+ last_path,
180
+ )
181
+
182
+ if miou > best_miou:
183
+ best_miou = miou
184
+ bad_epochs = 0
185
+ save_payload = {
186
+ "epoch": epoch,
187
+ "model": model.state_dict(),
188
+ "ema": ema.shadow if ema is not None else None,
189
+ "miou": miou,
190
+ "per_class_iou": val_metrics["per_class_iou"].tolist(),
191
+ "config": cfg,
192
+ "class_names": class_names,
193
+ }
194
+ torch.save(save_payload, best_path)
195
+ logger.info("saved new best checkpoint mIoU=%.4f -> %s", miou, best_path)
196
+ else:
197
+ bad_epochs += 1
198
+ if bad_epochs >= patience:
199
+ logger.info("early stopping at epoch %s (no improvement %s epochs)", epoch, patience)
200
+ break
201
+
202
+ with (checkpoint_dir / "history.json").open("w", encoding="utf-8") as f:
203
+ json.dump(history, f, indent=2)
204
+
205
+ return {"best_miou": best_miou, "best_path": str(best_path), "history": history}
desert_segmentation/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from desert_segmentation.utils.seed import set_seed
2
+
3
+ __all__ = ["set_seed"]
desert_segmentation/utils/config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load YAML config and resolve paths relative to workspace root."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Dict
8
+
9
+ import yaml
10
+
11
+
12
+ def load_config(path: Path | str, root: Path | None = None) -> Dict[str, Any]:
13
+ path = Path(path)
14
+ with path.open("r", encoding="utf-8") as f:
15
+ cfg = yaml.safe_load(f)
16
+ if root is None:
17
+ root = Path(cfg.get("root", ".")).resolve()
18
+ else:
19
+ root = Path(root).resolve()
20
+ cfg["root"] = str(root)
21
+ return cfg
22
+
23
+
24
+ def resolve_path(root: Path, *parts: str) -> Path:
25
+ return (root / Path(*parts)).resolve()
26
+
27
+
28
+ def get_paths(cfg: Dict[str, Any]) -> Dict[str, Path]:
29
+ root = Path(cfg["root"])
30
+ d = cfg["data"]
31
+ return {
32
+ "train_images": resolve_path(root, d["train_images"]),
33
+ "train_masks": resolve_path(root, d["train_masks"]),
34
+ "val_images": resolve_path(root, d["val_images"]),
35
+ "val_masks": resolve_path(root, d["val_masks"]),
36
+ "test_images": resolve_path(root, d["test_images"]),
37
+ }
desert_segmentation/utils/freq.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Estimate class pixel frequencies from mask files (fast path for loss weighting)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import List, Sequence
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+
13
+ from desert_segmentation.data.mask_encoding import RawMaskCodec
14
+
15
+
16
+ def list_masks(dir_path: Path) -> List[str]:
17
+ return sorted(f for f in os.listdir(dir_path) if f.lower().endswith(".png"))
18
+
19
+
20
+ @torch.no_grad()
21
+ def estimate_pixel_frequencies(
22
+ masks_dir: Path,
23
+ codec: RawMaskCodec,
24
+ max_files: int | None = 800,
25
+ ) -> torch.Tensor:
26
+ paths = list_masks(masks_dir)
27
+ if max_files is not None:
28
+ paths = paths[:max_files]
29
+ counts = np.zeros(codec.num_classes, dtype=np.int64)
30
+ for name in paths:
31
+ raw = np.array(Image.open(masks_dir / name))
32
+ enc, _ = codec.encode_mask(raw.astype(np.uint16))
33
+ for c in range(codec.num_classes):
34
+ counts[c] += int((enc == c).sum())
35
+ freq = counts.astype(np.float64) / max(counts.sum(), 1)
36
+ return torch.tensor(freq, dtype=torch.float32)
37
+
38
+
39
+ def per_image_sampling_weights(
40
+ masks_dir: Path,
41
+ image_basenames: Sequence[str],
42
+ codec: RawMaskCodec,
43
+ freq: torch.Tensor,
44
+ eps: float = 1e-6,
45
+ ) -> torch.DoubleTensor:
46
+ """Weights for ``WeightedRandomSampler``: upweight images containing rare classes.
47
+
48
+ For each mask, ``w_i = sum_{c : n_{ic}>0} 1 / (freq[c] + eps)``, then weights are
49
+ scaled to mean 1.0. ``image_basenames`` must match the order of
50
+ ``SegmentationDataset`` indices (same filenames as train pairs).
51
+ """
52
+ masks_dir = Path(masks_dir)
53
+ f = freq.detach().cpu().numpy().astype(np.float64)
54
+ raw_weights = np.zeros(len(image_basenames), dtype=np.float64)
55
+ for i, name in enumerate(image_basenames):
56
+ raw = np.array(Image.open(masks_dir / name))
57
+ enc, _ = codec.encode_mask(raw.astype(np.uint16))
58
+ present = np.zeros(codec.num_classes, dtype=bool)
59
+ for c in range(codec.num_classes):
60
+ present[c] = bool((enc == c).any())
61
+ raw_weights[i] = sum(1.0 / (f[c] + eps) for c in range(codec.num_classes) if present[c])
62
+ m = raw_weights.mean()
63
+ if m <= 0:
64
+ return torch.ones(len(image_basenames), dtype=torch.double)
65
+ scaled = raw_weights / m
66
+ return torch.tensor(scaled, dtype=torch.double)
desert_segmentation/utils/logging_utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+
4
+
5
+ def setup_logging(level: int = logging.INFO) -> None:
6
+ logging.basicConfig(
7
+ level=level,
8
+ format="%(asctime)s [%(levelname)s] %(message)s",
9
+ datefmt="%H:%M:%S",
10
+ stream=sys.stdout,
11
+ )
desert_segmentation/utils/seed.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def set_seed(seed: int) -> None:
9
+ random.seed(seed)
10
+ np.random.seed(seed)
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed_all(seed)
13
+ os.environ["PYTHONHASHSEED"] = str(seed)
desert_segmentation/utils/viz.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Color overlays and side-by-side panels for segmentation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import List, Sequence, Tuple
7
+
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+
11
+
12
+ def palette(num_classes: int, seed: int = 42) -> np.ndarray:
13
+ rng = np.random.default_rng(seed)
14
+ colors = rng.integers(32, 256, size=(num_classes, 3), dtype=np.uint8)
15
+ colors[0] = np.array([128, 128, 128], dtype=np.uint8)
16
+ return colors
17
+
18
+
19
+ def colorize_mask(mask: np.ndarray, colors: np.ndarray) -> np.ndarray:
20
+ """mask HxW int 0..C-1 -> RGB uint8"""
21
+ m = mask.clip(0, len(colors) - 1)
22
+ return colors[m]
23
+
24
+
25
+ def blend_overlay(
26
+ image_rgb: np.ndarray,
27
+ colored_mask: np.ndarray,
28
+ alpha: float = 0.55,
29
+ ) -> np.ndarray:
30
+ return (image_rgb.astype(np.float32) * (1 - alpha) + colored_mask.astype(np.float32) * alpha).clip(
31
+ 0, 255
32
+ ).astype(np.uint8)
33
+
34
+
35
+ def save_triplet(
36
+ out_path: Path,
37
+ rgb: np.ndarray,
38
+ gt: np.ndarray | None,
39
+ pred: np.ndarray,
40
+ class_colors: np.ndarray,
41
+ titles: Tuple[str, str, str] = ("RGB", "GT", "Pred"),
42
+ ) -> None:
43
+ h, w = rgb.shape[:2]
44
+ panels: List[np.ndarray] = [rgb]
45
+ if gt is not None:
46
+ panels.append(blend_overlay(rgb, colorize_mask(gt, class_colors)))
47
+ else:
48
+ panels.append(np.zeros_like(rgb))
49
+ panels.append(blend_overlay(rgb, colorize_mask(pred, class_colors)))
50
+
51
+ # Optional text strip (simple border)
52
+ gap = 8
53
+ total_w = w * len(panels) + gap * (len(panels) - 1)
54
+ canvas = np.zeros((h, total_w, 3), dtype=np.uint8)
55
+ x = 0
56
+ for p in panels:
57
+ canvas[:, x : x + w] = p
58
+ x += w + gap
59
+ Image.fromarray(canvas).save(out_path)
60
+
eval_summary.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint": "D:\\codewizard 2.0\\checkpoints\\best.pt",
3
+ "val_dir": "D:\\codewizard 2.0\\training\\val\\Color_Images",
4
+ "num_val_samples": 317,
5
+ "miou": 0.07851162552833557,
6
+ "miou_all_classes": 0.07851162552833557,
7
+ "miou_valid_gt_classes": 0.07851162270064849,
8
+ "fw_iou": 0.3974744379520416,
9
+ "global_pixel_accuracy": 0.448105526939844,
10
+ "mean_class_accuracy": 0.15349954460756882,
11
+ "per_class_iou": {
12
+ "id_100": 0.0,
13
+ "id_200": 0.0,
14
+ "id_300": 0.25856709480285645,
15
+ "id_500": 0.0,
16
+ "id_550": 0.0,
17
+ "id_600": 0.0,
18
+ "id_700": 0.0,
19
+ "id_800": 0.0,
20
+ "id_7100": 0.0,
21
+ "id_10000": 0.5265491604804993
22
+ },
23
+ "per_class_recall": {
24
+ "id_100": 0.0,
25
+ "id_200": 0.0,
26
+ "id_300": 0.7182908230723474,
27
+ "id_500": 0.0,
28
+ "id_550": 0.0,
29
+ "id_600": 0.0,
30
+ "id_700": 0.0,
31
+ "id_800": 0.0,
32
+ "id_7100": 0.0,
33
+ "id_10000": 0.8167046230033408
34
+ },
35
+ "val_gt_pixel_counts": {
36
+ "id_100": 1902003,
37
+ "id_200": 2808908,
38
+ "id_300": 9019195,
39
+ "id_500": 512976,
40
+ "id_550": 1976309,
41
+ "id_600": 1138100,
42
+ "id_700": 30968,
43
+ "id_800": 566002,
44
+ "id_7100": 11074438,
45
+ "id_10000": 17714653
46
+ }
47
+ }
requirements-demo.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Optional: interactive Gradio demo (install alongside requirements.txt)
2
+ # pip install -r requirements.txt -r requirements-demo.txt
3
+ gradio>=4.44.0,<6
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ numpy>=1.24.0
4
+ Pillow>=10.0.0
5
+ PyYAML>=6.0
6
+ # Pin avoids optional native build deps (e.g. stringzilla) on some Windows/Python setups
7
+ albumentations>=1.3.1,<1.5
8
+ segmentation-models-pytorch>=0.3.3
9
+ tqdm>=4.66.0
10
+ pytest>=7.4.0
scripts/demo_gradio.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Gradio demo: upload RGB image, get colored mask, overlay, legend, and timing."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+ import sys
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Tuple
13
+
14
+ ROOT = Path(__file__).resolve().parents[1]
15
+ if str(ROOT) not in sys.path:
16
+ sys.path.insert(0, str(ROOT))
17
+
18
+ import gradio as gr
19
+ import numpy as np
20
+ import torch
21
+ from PIL import Image
22
+
23
+ from desert_segmentation.demo.inference_ui import (
24
+ build_legend_rows,
25
+ dominant_classes_markdown,
26
+ legend_table_html,
27
+ side_by_side_strip,
28
+ validate_rgb_array,
29
+ )
30
+ from desert_segmentation.infer.predict import _load_model_for_inference, predict_image
31
+ from desert_segmentation.utils.viz import blend_overlay, colorize_mask
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ _STATE: Dict[str, Any] = {}
36
+
37
+
38
+ def _to_uint8_rgb(arr: Any) -> np.ndarray:
39
+ if arr is None:
40
+ raise gr.Error("Please upload an image.")
41
+ if isinstance(arr, Image.Image):
42
+ arr = np.array(arr.convert("RGB"))
43
+ a = np.asarray(arr)
44
+ if a.ndim == 2:
45
+ raise gr.Error("Expected a color RGB image, got grayscale.")
46
+ if a.ndim == 3 and a.shape[2] == 4:
47
+ a = a[:, :, :3]
48
+ if a.ndim != 3 or a.shape[2] != 3:
49
+ raise gr.Error(f"Expected HxWx3 RGB image, got shape {a.shape}.")
50
+ if np.issubdtype(a.dtype, np.floating) and float(a.max()) <= 1.0 + 1e-6:
51
+ a = (np.clip(a, 0.0, 1.0) * 255.0).round().astype(np.uint8)
52
+ elif a.dtype != np.uint8:
53
+ a = np.clip(a, 0, 255).astype(np.uint8)
54
+ return np.ascontiguousarray(a)
55
+
56
+
57
+ def _init_state(checkpoint: Path, device: torch.device) -> None:
58
+ global _STATE
59
+ if _STATE:
60
+ return
61
+ logger.info("Loading checkpoint: %s", checkpoint)
62
+ model, cfg, codec = _load_model_for_inference(checkpoint, device)
63
+ icfg = cfg.get("inference") or {}
64
+ legend_rows, colors = build_legend_rows(codec.class_names, codec.num_classes, seed=42)
65
+ _STATE.update(
66
+ {
67
+ "model": model,
68
+ "cfg": cfg,
69
+ "codec": codec,
70
+ "device": device,
71
+ "icfg": icfg,
72
+ "legend_rows": legend_rows,
73
+ "colors": colors,
74
+ "legend_html_static": legend_table_html(legend_rows),
75
+ },
76
+ )
77
+ logger.info(
78
+ "Model ready | classes=%s | device=%s | default tile=%s overlap=%s tta=%s",
79
+ codec.num_classes,
80
+ device,
81
+ icfg.get("tile_size", 512),
82
+ icfg.get("overlap", 0.25),
83
+ icfg.get("tta_flip", True),
84
+ )
85
+
86
+
87
+ def _run(
88
+ image_input: Any,
89
+ use_tta: bool,
90
+ overlap: float,
91
+ tile_size: float,
92
+ max_side: int,
93
+ max_megapixels: float,
94
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str, str]:
95
+ rgb = _to_uint8_rgb(image_input)
96
+ try:
97
+ validate_rgb_array(rgb, max_side=max_side, max_megapixels=max_megapixels)
98
+ except ValueError as e:
99
+ raise gr.Error(str(e)) from e
100
+
101
+ st = _STATE
102
+ model = st["model"]
103
+ device = st["device"]
104
+ icfg = st["icfg"]
105
+ codec = st["codec"]
106
+ colors = st["colors"]
107
+
108
+ tile = int(round(float(tile_size))) if tile_size is not None else int(icfg.get("tile_size", 512))
109
+ tile = max(256, min(tile, 2048))
110
+ ov = float(overlap)
111
+ ov = max(0.0, min(ov, 0.5))
112
+
113
+ t0 = time.perf_counter()
114
+ pred = predict_image(model, rgb, device, tile, ov, bool(use_tta))
115
+ ms = (time.perf_counter() - t0) * 1000.0
116
+
117
+ colored = colorize_mask(pred, colors)
118
+ overlay = blend_overlay(rgb, colored)
119
+ strip = side_by_side_strip(rgb, colored, overlay)
120
+
121
+ dev_str = str(device)
122
+ if device.type == "cpu":
123
+ dev_str += " (CPU mode — slower than GPU)"
124
+
125
+ stats = (
126
+ f"**Inference:** {ms:.1f} ms \n"
127
+ f"**Device:** {dev_str} \n"
128
+ f"**Tile size:** {tile} | **Overlap:** {ov:.2f} | **TTA:** {use_tta}"
129
+ )
130
+ dominant = "### Dominant classes in this image\n" + dominant_classes_markdown(pred, codec.class_names, top_k=3)
131
+
132
+ return rgb, colored, overlay, strip, stats, dominant
133
+
134
+
135
+ def main() -> None:
136
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
137
+
138
+ parser = argparse.ArgumentParser(description="Gradio demo for desert semantic segmentation")
139
+ parser.add_argument("--root", type=str, default=os.environ.get("ROOT"), help="Workspace root (default: repo root or env ROOT)")
140
+ parser.add_argument(
141
+ "--checkpoint",
142
+ type=str,
143
+ default=os.environ.get("CHECKPOINT_PATH"),
144
+ help="Path to best.pt (default: env CHECKPOINT_PATH or <root>/checkpoints/best.pt)",
145
+ )
146
+ parser.add_argument("--host", type=str, default="127.0.0.1")
147
+ parser.add_argument("--port", type=int, default=7860)
148
+ parser.add_argument("--share", action="store_true", help="Create a temporary public Gradio link")
149
+ parser.add_argument("--max-side", type=int, default=4096)
150
+ parser.add_argument("--max-megapixels", type=float, default=16.0)
151
+ args = parser.parse_args()
152
+
153
+ root = Path(args.root or ROOT).resolve()
154
+ ckpt_arg = args.checkpoint or str(root / "checkpoints" / "best.pt")
155
+ ckpt = Path(ckpt_arg)
156
+ if not ckpt.is_absolute():
157
+ ckpt = (root / ckpt).resolve()
158
+ if not ckpt.is_file():
159
+ raise SystemExit(f"Checkpoint not found: {ckpt}")
160
+
161
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
162
+ _init_state(ckpt, device)
163
+
164
+ icfg = _STATE["icfg"]
165
+ def_tta = bool(icfg.get("tta_flip", True))
166
+ def_ov = float(icfg.get("overlap", 0.25))
167
+ def_tile = float(icfg.get("tile_size", 512))
168
+
169
+ intro = """## Desert semantic segmentation demo
170
+
171
+ This is **semantic segmentation**: each pixel is assigned one of several **classes** (terrain, vegetation, sky, etc.).
172
+ It is **not** bounding-box object detection.
173
+
174
+ **How to read the outputs:**
175
+ - **Colored mask:** each color is one class (see legend).
176
+ - **Overlay:** prediction blended on your photo.
177
+ - **Strip:** original | mask | side-by-side for screenshots.
178
+
179
+ _Confidence heatmaps for full-resolution sliding windows are not in this demo (v1); see README._
180
+ """
181
+
182
+ cpu_note = ""
183
+ if device.type == "cpu":
184
+ cpu_note = "\n\n> Running on **CPU** — expect slower inference. Use a CUDA GPU for best speed.\n"
185
+
186
+ with gr.Blocks(title="Desert segmentation", theme=gr.themes.Soft()) as demo:
187
+ gr.Markdown(intro + cpu_note)
188
+ inp = gr.Image(type="numpy", label="Upload RGB image", sources=["upload"])
189
+ with gr.Accordion("Advanced", open=False):
190
+ use_tta = gr.Checkbox(label="TTA (horizontal flip average)", value=def_tta)
191
+ overlap = gr.Slider(0.0, 0.5, value=def_ov, step=0.05, label="Tile overlap")
192
+ tile_sz = gr.Slider(256, 2048, value=int(def_tile), step=64, label="Tile size (pixels)")
193
+ run_btn = gr.Button("Run segmentation", variant="primary")
194
+
195
+ with gr.Row():
196
+ out_orig = gr.Image(label="Input", type="numpy")
197
+ out_mask = gr.Image(label="Colored class mask", type="numpy")
198
+ out_overlay = gr.Image(label="Overlay", type="numpy")
199
+ out_strip = gr.Image(label="RGB | mask | overlay", type="numpy")
200
+ stats_md = gr.Markdown("")
201
+ dominant_md = gr.Markdown("")
202
+ gr.Markdown("### Class legend (fixed palette)")
203
+ gr.HTML(_STATE["legend_html_static"])
204
+
205
+ def _fn(img, tta, ov, ts):
206
+ return _run(img, tta, ov, ts, args.max_side, args.max_megapixels)
207
+
208
+ run_btn.click(
209
+ fn=_fn,
210
+ inputs=[inp, use_tta, overlap, tile_sz],
211
+ outputs=[out_orig, out_mask, out_overlay, out_strip, stats_md, dominant_md],
212
+ )
213
+
214
+ logger.info("Launching Gradio on http://%s:%s", args.host, args.port)
215
+ demo.launch(server_name=args.host, server_port=args.port, share=args.share)
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
scripts/eval.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run validation, print metrics, save confusion matrix and overlays."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import logging
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ ROOT = Path(__file__).resolve().parents[1]
14
+ if str(ROOT) not in sys.path:
15
+ sys.path.insert(0, str(ROOT))
16
+
17
+ import numpy as np
18
+ import torch
19
+ from PIL import Image
20
+ from torch.utils.data import DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from desert_segmentation.data.dataset import SegmentationDataset
24
+ from desert_segmentation.data.mask_encoding import build_codec_from_config
25
+ from desert_segmentation.data.transforms import build_val_transforms
26
+ from desert_segmentation.models.factory import create_model
27
+ from desert_segmentation.train.evaluate import evaluate
28
+ from desert_segmentation.utils.config import get_paths, load_config
29
+ from desert_segmentation.utils.logging_utils import setup_logging
30
+ from desert_segmentation.utils.seed import set_seed
31
+ from desert_segmentation.utils.viz import palette, save_triplet
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def main() -> None:
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--config", type=str, default=str(ROOT / "desert_segmentation" / "configs" / "default.yaml"))
39
+ parser.add_argument("--checkpoint", type=str, required=True)
40
+ parser.add_argument("--root", type=str, default=None)
41
+ parser.add_argument("--out_dir", type=str, default="eval_outputs")
42
+ parser.add_argument("--max_viz", type=int, default=24)
43
+ args = parser.parse_args()
44
+
45
+ root = Path(args.root or ROOT).resolve()
46
+ cfg = load_config(args.config, root=root)
47
+ setup_logging()
48
+ set_seed(int(cfg["train"]["seed"]))
49
+
50
+ paths = get_paths(cfg)
51
+ raw_ids = cfg["data"]["raw_ids"]
52
+ names = tuple(cfg["data"]["class_names"])
53
+ codec = build_codec_from_config(raw_ids, names)
54
+ ignore_index = int(cfg["data"].get("ignore_index", 255))
55
+ crop_size = int(cfg["data"]["crop_size"])
56
+
57
+ val_tf = build_val_transforms(crop_size=crop_size, ignore_index=ignore_index)
58
+ val_ds = SegmentationDataset(
59
+ paths["val_images"],
60
+ paths["val_masks"],
61
+ codec=codec,
62
+ transform=val_tf,
63
+ mode="val",
64
+ crop_size=crop_size,
65
+ rare_class_crop_prob=0.0,
66
+ ignore_index=ignore_index,
67
+ seed=int(cfg["train"]["seed"]),
68
+ )
69
+
70
+ nw = 0 if os.name == "nt" else int(cfg["data"].get("num_workers", 4))
71
+ val_loader = DataLoader(
72
+ val_ds,
73
+ batch_size=int(cfg["train"].get("val_batch_size", cfg["train"]["batch_size"])),
74
+ shuffle=False,
75
+ num_workers=nw,
76
+ pin_memory=torch.cuda.is_available(),
77
+ )
78
+
79
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ try:
81
+ ckpt = torch.load(Path(args.checkpoint), map_location=device, weights_only=False)
82
+ except TypeError:
83
+ ckpt = torch.load(Path(args.checkpoint), map_location=device)
84
+ cfg_ck = ckpt["config"]
85
+ model = create_model(cfg_ck["model"], num_classes=codec.num_classes).to(device)
86
+ if ckpt.get("ema") is not None:
87
+ for n, p in model.named_parameters():
88
+ if n in ckpt["ema"]:
89
+ p.data.copy_(ckpt["ema"][n].to(device))
90
+ else:
91
+ model.load_state_dict(ckpt["model"])
92
+ model.eval()
93
+
94
+ metrics = evaluate(model, val_loader, device, num_classes=codec.num_classes, ignore_index=ignore_index)
95
+ logger.info("mIoU=%.4f fwIoU=%.4f", metrics["miou"], metrics["fw_iou"])
96
+ per = metrics["per_class_iou"]
97
+ for i, name in enumerate(codec.class_names):
98
+ logger.info(" %s IoU=%.4f", name, float(per[i]))
99
+
100
+ out_dir = Path(args.out_dir)
101
+ if not out_dir.is_absolute():
102
+ out_dir = root / out_dir
103
+ out_dir.mkdir(parents=True, exist_ok=True)
104
+ with (out_dir / "metrics.json").open("w", encoding="utf-8") as f:
105
+ json.dump(
106
+ {
107
+ "miou": float(metrics["miou"]),
108
+ "fw_iou": float(metrics["fw_iou"]),
109
+ "per_class_iou": {codec.class_names[i]: float(per[i]) for i in range(len(codec.class_names))},
110
+ },
111
+ f,
112
+ indent=2,
113
+ )
114
+ np.save(out_dir / "confusion.npy", metrics["confusion"])
115
+
116
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
117
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
118
+ colors = palette(codec.num_classes)
119
+ n = 0
120
+ with torch.no_grad():
121
+ for batch in tqdm(val_loader, desc="viz"):
122
+ images = batch["image"].to(device)
123
+ masks = batch["mask"].to(device)
124
+ logits = model(images)
125
+ pred = logits.argmax(dim=1).cpu().numpy()
126
+ gt = masks.cpu().numpy()
127
+ for b in range(images.shape[0]):
128
+ if n >= args.max_viz:
129
+ break
130
+ t = images[b].cpu().permute(1, 2, 0).numpy()
131
+ rgb = (t * std + mean) * 255.0
132
+ rgb = np.clip(rgb, 0, 255).astype(np.uint8)
133
+ save_triplet(
134
+ out_dir / f"val_{n:04d}.png",
135
+ rgb,
136
+ gt[b],
137
+ pred[b],
138
+ colors,
139
+ )
140
+ n += 1
141
+ if n >= args.max_viz:
142
+ break
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
scripts/eval_summary.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Print segmentation metrics: mIoU (all classes + valid-GT-only), fwIoU, accuracies, GT counts.
3
+
4
+ Runs a full validation pass by default (same setup as ``scripts/eval.py``). With
5
+ ``--from-checkpoint-only``, only prints metrics stored inside the checkpoint file
6
+ (mIoU and per-class IoU when present); full metrics require a validation forward pass."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import json
12
+ import math
13
+ import os
14
+ import sys
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List
17
+
18
+ ROOT = Path(__file__).resolve().parents[1]
19
+ if str(ROOT) not in sys.path:
20
+ sys.path.insert(0, str(ROOT))
21
+
22
+ import torch
23
+ from torch.utils.data import DataLoader
24
+
25
+ from desert_segmentation.data.dataset import SegmentationDataset
26
+ from desert_segmentation.data.mask_encoding import build_codec_from_config
27
+ from desert_segmentation.data.transforms import build_val_transforms
28
+ from desert_segmentation.metrics.iou import (
29
+ confusion_to_accuracy_metrics,
30
+ gt_pixel_counts,
31
+ valid_class_miou_from_confusion,
32
+ )
33
+ from desert_segmentation.models.factory import create_model
34
+ from desert_segmentation.train.evaluate import evaluate
35
+ from desert_segmentation.utils.config import get_paths, load_config
36
+ from desert_segmentation.utils.logging_utils import setup_logging
37
+ from desert_segmentation.utils.seed import set_seed
38
+
39
+
40
+ def _load_checkpoint(path: Path, device: torch.device) -> Dict[str, Any]:
41
+ try:
42
+ return torch.load(path, map_location=device, weights_only=False)
43
+ except TypeError:
44
+ return torch.load(path, map_location=device)
45
+
46
+
47
+ def _print_table(rows: List[List[str]]) -> None:
48
+ widths = [max(len(rows[i][c]) for i in range(len(rows))) for c in range(len(rows[0]))]
49
+ for row in rows:
50
+ line = " ".join(row[c].ljust(widths[c]) for c in range(len(row)))
51
+ print(line)
52
+
53
+
54
+ def run_from_checkpoint_only(ckpt_path: Path) -> int:
55
+ ckpt = _load_checkpoint(ckpt_path, torch.device("cpu"))
56
+ print(f"Checkpoint: {ckpt_path.resolve()}")
57
+ print()
58
+ if "miou" in ckpt:
59
+ print(f" mIoU (stored): {float(ckpt['miou']):.6f}")
60
+ else:
61
+ print(" mIoU: (not stored in this file)")
62
+ names = ckpt.get("class_names")
63
+ per = ckpt.get("per_class_iou")
64
+ if per is not None and names is not None:
65
+ print(" Per-class IoU (stored):")
66
+ for i, name in enumerate(names):
67
+ print(f" [{i}] {name}: {float(per[i]):.6f}")
68
+ elif per is not None:
69
+ print(" Per-class IoU (stored):")
70
+ for i, v in enumerate(per):
71
+ print(f" [{i}]: {float(v):.6f}")
72
+ else:
73
+ print(" Per-class IoU: (not stored in this file)")
74
+ print()
75
+ print(
76
+ "Note: fwIoU, global pixel accuracy, and mean class accuracy are not saved in "
77
+ "checkpoints. Run without --from-checkpoint-only to compute them on the val set."
78
+ )
79
+ return 0
80
+
81
+
82
+ def run_full_eval(args: argparse.Namespace) -> int:
83
+ root = Path(args.root or ROOT).resolve()
84
+ cfg = load_config(args.config, root=root)
85
+ setup_logging()
86
+ set_seed(int(cfg["train"]["seed"]))
87
+
88
+ paths = get_paths(cfg)
89
+ raw_ids = cfg["data"]["raw_ids"]
90
+ names = tuple(cfg["data"]["class_names"])
91
+ codec = build_codec_from_config(raw_ids, names)
92
+ ignore_index = int(cfg["data"].get("ignore_index", 255))
93
+ crop_size = int(cfg["data"]["crop_size"])
94
+
95
+ val_tf = build_val_transforms(crop_size=crop_size, ignore_index=ignore_index)
96
+ val_ds = SegmentationDataset(
97
+ paths["val_images"],
98
+ paths["val_masks"],
99
+ codec=codec,
100
+ transform=val_tf,
101
+ mode="val",
102
+ crop_size=crop_size,
103
+ rare_class_crop_prob=0.0,
104
+ ignore_index=ignore_index,
105
+ seed=int(cfg["train"]["seed"]),
106
+ )
107
+
108
+ nw = 0 if os.name == "nt" else int(cfg["data"].get("num_workers", 4))
109
+ val_loader = DataLoader(
110
+ val_ds,
111
+ batch_size=int(cfg["train"].get("val_batch_size", cfg["train"]["batch_size"])),
112
+ shuffle=False,
113
+ num_workers=nw,
114
+ pin_memory=torch.cuda.is_available(),
115
+ )
116
+
117
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
118
+ ckpt_path = Path(args.checkpoint)
119
+ ckpt = _load_checkpoint(ckpt_path, device)
120
+ cfg_ck = ckpt["config"]
121
+ model = create_model(cfg_ck["model"], num_classes=codec.num_classes).to(device)
122
+ if ckpt.get("ema") is not None:
123
+ for n, p in model.named_parameters():
124
+ if n in ckpt["ema"]:
125
+ p.data.copy_(ckpt["ema"][n].to(device))
126
+ else:
127
+ model.load_state_dict(ckpt["model"])
128
+ model.eval()
129
+
130
+ metrics = evaluate(
131
+ model,
132
+ val_loader,
133
+ device,
134
+ num_classes=codec.num_classes,
135
+ ignore_index=ignore_index,
136
+ desc="eval_summary",
137
+ )
138
+ cm = metrics["confusion"]
139
+ acc = confusion_to_accuracy_metrics(cm)
140
+ miou_valid = float(valid_class_miou_from_confusion(cm))
141
+ gt_counts = gt_pixel_counts(cm)
142
+
143
+ miou = float(metrics["miou"])
144
+ fw_iou = float(metrics["fw_iou"])
145
+ gpa = float(acc["global_pixel_accuracy"])
146
+ mca = float(acc["mean_class_accuracy"])
147
+ per_iou = metrics["per_class_iou"]
148
+ per_rec = acc["per_class_recall"]
149
+
150
+ def _rec_str(i: int) -> str:
151
+ v = float(per_rec[i])
152
+ if math.isnan(v):
153
+ return "n/a"
154
+ return f"{v:.6f}"
155
+
156
+ print()
157
+ print(f"Checkpoint: {ckpt_path.resolve()}")
158
+ print(f"Val images: {paths['val_images']}")
159
+ print(f"Val samples: {len(val_ds)}")
160
+ print()
161
+ print(" mIoU (all classes): {:.6f}".format(miou))
162
+ print(" mIoU (classes w/ GT): {:.6f}".format(miou_valid))
163
+ print(" Frequency-weighted IoU: {:.6f}".format(fw_iou))
164
+ print(" Global pixel accuracy: {:.6f}".format(gpa))
165
+ print(" Mean class accuracy: {:.6f}".format(mca))
166
+ print(" (mean of per-class recall over classes with GT pixels)")
167
+ print()
168
+ table: List[List[str]] = [["cls", "name", "IoU", "recall"]]
169
+ for i, name in enumerate(codec.class_names):
170
+ table.append(
171
+ [
172
+ str(i),
173
+ name,
174
+ f"{float(per_iou[i]):.6f}",
175
+ _rec_str(i),
176
+ ]
177
+ )
178
+ _print_table(table)
179
+ print()
180
+ print(" Val GT pixels per class (full val set):")
181
+ for i, name in enumerate(codec.class_names):
182
+ print(f" [{i}] {name}: {int(gt_counts[i])}")
183
+ print()
184
+
185
+ payload = {
186
+ "checkpoint": str(ckpt_path.resolve()),
187
+ "val_dir": str(paths["val_images"]),
188
+ "num_val_samples": len(val_ds),
189
+ "miou": miou,
190
+ "miou_all_classes": miou,
191
+ "miou_valid_gt_classes": miou_valid,
192
+ "fw_iou": fw_iou,
193
+ "global_pixel_accuracy": gpa,
194
+ "mean_class_accuracy": mca,
195
+ "per_class_iou": {codec.class_names[i]: float(per_iou[i]) for i in range(len(codec.class_names))},
196
+ "per_class_recall": {
197
+ codec.class_names[i]: (None if math.isnan(float(per_rec[i])) else float(per_rec[i]))
198
+ for i in range(len(codec.class_names))
199
+ },
200
+ "val_gt_pixel_counts": {codec.class_names[i]: int(gt_counts[i]) for i in range(len(codec.class_names))},
201
+ }
202
+
203
+ if args.json_out:
204
+ out = Path(args.json_out)
205
+ if not out.is_absolute():
206
+ out = root / out
207
+ out.parent.mkdir(parents=True, exist_ok=True)
208
+ with out.open("w", encoding="utf-8") as f:
209
+ json.dump(payload, f, indent=2)
210
+ print(f"Wrote {out}")
211
+
212
+ return 0
213
+
214
+
215
+ def main() -> None:
216
+ parser = argparse.ArgumentParser(description="Segmentation metric summary (val set).")
217
+ parser.add_argument(
218
+ "--checkpoint",
219
+ type=str,
220
+ default=None,
221
+ help="Path to .pt checkpoint (default: <root>/checkpoints/best.pt)",
222
+ )
223
+ parser.add_argument(
224
+ "--config",
225
+ type=str,
226
+ default=str(ROOT / "desert_segmentation" / "configs" / "default.yaml"),
227
+ )
228
+ parser.add_argument("--root", type=str, default=None, help="Workspace root (defaults to repo root)")
229
+ parser.add_argument(
230
+ "--from-checkpoint-only",
231
+ action="store_true",
232
+ help="Only print mIoU/per-class IoU stored in the file (no forward pass).",
233
+ )
234
+ parser.add_argument(
235
+ "--json-out",
236
+ type=str,
237
+ default=None,
238
+ help="Optional path to write full metrics JSON (relative to --root unless absolute).",
239
+ )
240
+ args = parser.parse_args()
241
+ root = Path(args.root or ROOT).resolve()
242
+ ck_path = Path(args.checkpoint) if args.checkpoint else root / "checkpoints" / "best.pt"
243
+
244
+ if args.from_checkpoint_only:
245
+ if not ck_path.is_file():
246
+ print(f"Error: checkpoint not found: {ck_path}", file=sys.stderr)
247
+ sys.exit(1)
248
+ sys.exit(run_from_checkpoint_only(ck_path))
249
+
250
+ args.checkpoint = str(ck_path)
251
+ args.root = str(root)
252
+ if not ck_path.is_file():
253
+ print(f"Error: checkpoint not found: {ck_path}", file=sys.stderr)
254
+ sys.exit(1)
255
+ sys.exit(run_full_eval(args))
256
+
257
+
258
+ if __name__ == "__main__":
259
+ main()
scripts/infer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run inference on testing/Color_Images; optional ONNX export."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import logging
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ ROOT = Path(__file__).resolve().parents[1]
12
+ if str(ROOT) not in sys.path:
13
+ sys.path.insert(0, str(ROOT))
14
+
15
+ from desert_segmentation.infer.predict import export_onnx, predict_folder
16
+ from desert_segmentation.utils.config import get_paths, load_config
17
+ from desert_segmentation.utils.logging_utils import setup_logging
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def main() -> None:
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--config", type=str, default=str(ROOT / "desert_segmentation" / "configs" / "default.yaml"))
25
+ parser.add_argument("--checkpoint", type=str, required=True)
26
+ parser.add_argument("--root", type=str, default=None)
27
+ parser.add_argument("--out_dir", type=str, default="infer_outputs")
28
+ parser.add_argument("--limit", type=int, default=None)
29
+ parser.add_argument("--onnx", type=str, default=None, help="If set, export ONNX to this path and exit")
30
+ args = parser.parse_args()
31
+
32
+ root = Path(args.root or ROOT).resolve()
33
+ cfg = load_config(args.config, root=root)
34
+ setup_logging()
35
+
36
+ if args.onnx:
37
+ export_onnx(Path(args.checkpoint), Path(args.onnx))
38
+ logger.info("exported ONNX to %s", args.onnx)
39
+ return
40
+
41
+ paths = get_paths(cfg)
42
+ out_dir = Path(args.out_dir)
43
+ if not out_dir.is_absolute():
44
+ out_dir = root / out_dir
45
+ predict_folder(Path(args.checkpoint), paths["test_images"], out_dir, limit=args.limit)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
scripts/train.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Train semantic segmentation model from YAML config."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ ROOT = Path(__file__).resolve().parents[1]
13
+ if str(ROOT) not in sys.path:
14
+ sys.path.insert(0, str(ROOT))
15
+
16
+ import torch
17
+ from torch.utils.data import DataLoader, WeightedRandomSampler
18
+
19
+ from desert_segmentation.data.dataset import SegmentationDataset
20
+ from desert_segmentation.data.mask_encoding import build_codec_from_config
21
+ from desert_segmentation.data.transforms import build_train_transforms, build_val_transforms
22
+ from desert_segmentation.losses.combined import build_loss, compute_class_weights_from_freq
23
+ from desert_segmentation.models.factory import create_model
24
+ from desert_segmentation.train.trainer import train
25
+ from desert_segmentation.utils.config import get_paths, load_config
26
+ from desert_segmentation.utils.freq import estimate_pixel_frequencies, per_image_sampling_weights
27
+ from desert_segmentation.utils.logging_utils import setup_logging
28
+ from desert_segmentation.utils.seed import set_seed
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def main() -> None:
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "--config",
37
+ type=str,
38
+ default=str(ROOT / "desert_segmentation" / "configs" / "default.yaml"),
39
+ )
40
+ parser.add_argument("--root", type=str, default=None, help="Workspace root (defaults to repo root)")
41
+ parser.add_argument("--epochs", type=int, default=None, help="Override epochs (smoke tests)")
42
+ parser.add_argument("--max_train_batches", type=int, default=None, help="Stop each epoch after N batches (smoke tests)")
43
+ args = parser.parse_args()
44
+
45
+ root = Path(args.root or ROOT).resolve()
46
+ cfg = load_config(args.config, root=root)
47
+ if args.epochs is not None:
48
+ cfg["train"]["epochs"] = int(args.epochs)
49
+ setup_logging()
50
+ set_seed(int(cfg["train"]["seed"]))
51
+
52
+ paths = get_paths(cfg)
53
+ raw_ids = cfg["data"]["raw_ids"]
54
+ names = tuple(cfg["data"]["class_names"])
55
+ codec = build_codec_from_config(raw_ids, names)
56
+ ignore_index = int(cfg["data"].get("ignore_index", 255))
57
+ crop_size = int(cfg["data"]["crop_size"])
58
+
59
+ train_tf = build_train_transforms(
60
+ crop_size=crop_size,
61
+ strong=bool(cfg.get("augmentation", {}).get("strong", True)),
62
+ ignore_index=ignore_index,
63
+ )
64
+ val_tf = build_val_transforms(crop_size=crop_size, ignore_index=ignore_index)
65
+
66
+ train_ds = SegmentationDataset(
67
+ paths["train_images"],
68
+ paths["train_masks"],
69
+ codec=codec,
70
+ transform=train_tf,
71
+ mode="train",
72
+ crop_size=crop_size,
73
+ rare_class_crop_prob=float(cfg["data"].get("rare_class_crop_prob", 0.35)),
74
+ ignore_index=ignore_index,
75
+ seed=int(cfg["train"]["seed"]),
76
+ )
77
+ val_ds = SegmentationDataset(
78
+ paths["val_images"],
79
+ paths["val_masks"],
80
+ codec=codec,
81
+ transform=val_tf,
82
+ mode="val",
83
+ crop_size=crop_size,
84
+ rare_class_crop_prob=0.0,
85
+ ignore_index=ignore_index,
86
+ seed=int(cfg["train"]["seed"]),
87
+ )
88
+
89
+ nw = int(cfg["data"].get("num_workers", 4))
90
+ if os.name == "nt":
91
+ nw = 0
92
+
93
+ val_loader = DataLoader(
94
+ val_ds,
95
+ batch_size=int(cfg["train"].get("val_batch_size", cfg["train"]["batch_size"])),
96
+ shuffle=False,
97
+ num_workers=nw,
98
+ pin_memory=torch.cuda.is_available(),
99
+ )
100
+
101
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+ model = create_model(cfg["model"], num_classes=codec.num_classes).to(device)
103
+
104
+ freq = estimate_pixel_frequencies(paths["train_masks"], codec, max_files=None)
105
+ cap = float(cfg.get("loss", {}).get("class_weight_cap", 15.0))
106
+ class_w = compute_class_weights_from_freq(freq, cap=cap).to(device)
107
+ logger.info("class pixel frequencies (train masks): %s", freq.tolist())
108
+
109
+ use_weighted_sampler = bool(cfg.get("data", {}).get("weighted_sampler", False))
110
+ sampler: WeightedRandomSampler | None = None
111
+ if use_weighted_sampler:
112
+ eps = float(cfg.get("data", {}).get("weighted_sampler_eps", 1e-6))
113
+ logger.info("computing per-image sampling weights (scanning train masks)...")
114
+ sample_w = per_image_sampling_weights(
115
+ paths["train_masks"],
116
+ train_ds.image_names,
117
+ codec,
118
+ freq,
119
+ eps=eps,
120
+ )
121
+ sampler = WeightedRandomSampler(
122
+ sample_w,
123
+ num_samples=len(train_ds),
124
+ replacement=True,
125
+ generator=torch.Generator().manual_seed(int(cfg["train"]["seed"])),
126
+ )
127
+
128
+ train_loader = DataLoader(
129
+ train_ds,
130
+ batch_size=int(cfg["train"]["batch_size"]),
131
+ shuffle=sampler is None,
132
+ sampler=sampler,
133
+ num_workers=nw,
134
+ pin_memory=torch.cuda.is_available(),
135
+ drop_last=True,
136
+ )
137
+
138
+ criterion = build_loss(
139
+ cfg.get("loss", {}),
140
+ num_classes=codec.num_classes,
141
+ class_weights=class_w,
142
+ ignore_index=ignore_index,
143
+ ).to(device)
144
+
145
+ ckpt_dir = Path(cfg["train"]["checkpoint_dir"])
146
+ if not ckpt_dir.is_absolute():
147
+ ckpt_dir = root / ckpt_dir
148
+
149
+ out = train(
150
+ model,
151
+ train_loader,
152
+ val_loader,
153
+ criterion,
154
+ device,
155
+ cfg,
156
+ num_classes=codec.num_classes,
157
+ ignore_index=ignore_index,
158
+ checkpoint_dir=ckpt_dir,
159
+ class_names=codec.class_names,
160
+ max_train_batches=args.max_train_batches,
161
+ )
162
+ logger.info("finished best_mIoU=%s path=%s", out["best_miou"], out["best_path"])
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()
tests/test_confusion_metrics.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for accuracy metrics derived from confusion matrices."""
2
+
3
+ import numpy as np
4
+
5
+ from desert_segmentation.metrics.iou import (
6
+ confusion_to_accuracy_metrics,
7
+ valid_class_miou_from_confusion,
8
+ )
9
+
10
+
11
+ def test_perfect_confusion():
12
+ cm = np.eye(3, dtype=np.int64) * 100
13
+ out = confusion_to_accuracy_metrics(cm)
14
+ assert out["global_pixel_accuracy"] == 1.0
15
+ assert out["mean_class_accuracy"] == 1.0
16
+ assert np.allclose(out["per_class_recall"], [1.0, 1.0, 1.0])
17
+
18
+
19
+ def test_two_class_mixed():
20
+ # GT: 100 class0, 100 class1; half wrong each
21
+ cm = np.array([[50, 50], [50, 50]], dtype=np.int64)
22
+ out = confusion_to_accuracy_metrics(cm)
23
+ assert out["global_pixel_accuracy"] == 0.5
24
+ assert abs(out["mean_class_accuracy"] - 0.5) < 1e-6
25
+
26
+
27
+ def test_one_class_absent_in_gt():
28
+ # Only class 0 appears in val GT; 80 correct, 20 predicted as class 1.
29
+ cm = np.array([[80, 20], [0, 0]], dtype=np.int64)
30
+ out = confusion_to_accuracy_metrics(cm)
31
+ assert abs(out["global_pixel_accuracy"] - 0.8) < 1e-9
32
+ assert abs(out["mean_class_accuracy"] - 0.8) < 1e-9
33
+ assert np.isnan(out["per_class_recall"][1])
34
+
35
+
36
+ def test_valid_class_miou_only_classes_with_gt():
37
+ # Two classes in GT; full mIoU averages zeros for empty rows if any — here both rows have GT.
38
+ cm = np.array([[90, 10], [10, 90]], dtype=np.int64)
39
+ # IoU class0: 90/(90+10+10)=90/110, class1: 90/110
40
+ v = valid_class_miou_from_confusion(cm)
41
+ iou0 = 90.0 / (90 + 10 + 10)
42
+ assert abs(v - iou0) < 1e-6
43
+
44
+
45
+ def test_valid_class_miou_ignores_empty_gt_rows():
46
+ # Class 1 has no GT pixels; valid-class mIoU averages only class 0.
47
+ cm = np.array([[80, 20], [0, 0]], dtype=np.int64)
48
+ v = valid_class_miou_from_confusion(cm)
49
+ # IoU class 0: TP=80, union = rows[0]+cols[0]-TP = 100+80-80 = 100
50
+ assert abs(v - 0.8) < 1e-9
tests/test_mask_encoding.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from desert_segmentation.data.mask_encoding import RawMaskCodec, default_desert_codec
5
+
6
+
7
+ def test_roundtrip_known_ids():
8
+ codec = default_desert_codec()
9
+ h, w = 32, 48
10
+ raw = np.full((h, w), 100, dtype=np.uint16)
11
+ raw[:, :10] = 10000
12
+ raw[10:20, :] = 7100
13
+ enc, unk = codec.encode_mask(raw)
14
+ assert unk == 0.0
15
+ assert enc.shape == (h, w)
16
+ back = codec.decode_to_raw(enc)
17
+ assert np.array_equal(back, raw)
18
+
19
+
20
+ def test_unknown_pixel_raises():
21
+ codec = RawMaskCodec(raw_ids=(1, 2), class_names=("a", "b"))
22
+ raw = np.array([[1, 2], [99, 1]], dtype=np.uint16)
23
+ with pytest.raises(ValueError):
24
+ codec.encode_mask(raw)
25
+
26
+
27
+ def test_lut_all_ids():
28
+ codec = default_desert_codec()
29
+ for rid in codec.raw_ids:
30
+ raw = np.full((4, 4), rid, dtype=np.uint16)
31
+ enc, unk = codec.encode_mask(raw)
32
+ assert unk == 0.0
33
+ assert np.unique(enc).size == 1