webVishnu commited on
Commit
e84b383
·
verified ·
1 Parent(s): 9a6b7b7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +13 -636
README.md CHANGED
@@ -1,641 +1,18 @@
1
- # Desert Semantic Segmentation
2
-
3
- End-to-end **semantic segmentation** for **off-road / desert** scenes: every pixel is classified into one of several terrain / object categories. The pipeline is built for **synthetic RGB + mask** data, **PyTorch**, **[segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch)** (SMP), **Albumentations**, and hackathon-style iteration (strong baselines, IoU-driven checkpoints, optional EMA / TTA / ONNX).
4
-
5
- ---
6
-
7
- ## Table of contents
8
-
9
- 1. [What this project does](#1-what-this-project-does)
10
- 2. [Problem statement and goals](#2-problem-statement-and-goals)
11
- 3. [Dataset layout and assumptions](#3-dataset-layout-and-assumptions)
12
- 4. [Label format (critical)](#4-label-format-critical)
13
- 5. [Repository structure](#5-repository-structure)
14
- 6. [Configuration (`default.yaml`)](#6-configuration-defaultyaml)
15
- 7. [High-level architecture](#7-high-level-architecture)
16
- 8. [Data pipeline (detailed)](#8-data-pipeline-detailed)
17
- 9. [Model](#9-model)
18
- 10. [Loss functions](#10-loss-functions)
19
- 11. [Metrics](#11-metrics)
20
- 12. [Training loop](#12-training-loop)
21
- 13. [Validation and evaluation scripts](#13-validation-and-evaluation-scripts)
22
- 14. [Inference (testing folder, sliding window, TTA, ONNX)](#14-inference-testing-folder-sliding-window-tta-onnx)
23
- 15. [Checkpoints and artifacts](#15-checkpoints-and-artifacts)
24
- 16. [How to run (commands)](#16-how-to-run-commands)
25
- 17. [Interactive demo (Gradio)](#17-interactive-demo-gradio)
26
- 18. [Tests](#18-tests)
27
- 19. [Dependencies and environment notes](#19-dependencies-and-environment-notes)
28
- 20. [Design decisions and limitations](#20-design-decisions-and-limitations)
29
- 21. [Extending the project](#21-extending-the-project)
30
- 22. [Flowcharts](#22-flowcharts)
31
-
32
- ---
33
-
34
- ## 1. What this project does
35
-
36
- - **Input:** RGB color images (`Color_Images`).
37
- - **Supervision:** Per-pixel class masks (`Segmentation`) aligned by **filename** with the RGB image.
38
- - **Output:** A trained neural network that predicts a **class index per pixel** on validation, held-out **testing** images (no labels in repo), or any folder of images you point inference at.
39
- - **Primary quality metric:** **mean Intersection-over-Union (mIoU)** on the validation set, plus **per-class IoU**, **frequency-weighted IoU (fwIoU)**, and a **confusion matrix**.
40
-
41
- ---
42
-
43
- ## 2. Problem statement and goals
44
-
45
-
46
- | Goal | How we address it |
47
- | -------------------------------------------- | -------------------------------------------------------------------------------------------- |
48
- | Accurate pixel-wise classification | DeepLabV3+ with ImageNet-pretrained encoder; CE + Dice loss; class-frequency weights |
49
- | Robustness (synthetic → harder real domains) | Strong photometric + mild “desert-like” augmentations (sun flare, shadow, blur, noise, JPEG) |
50
- | Class imbalance | Inverse log-frequency weights with a **cap**; rare-class-biased random crops |
51
- | Stable training | AdamW, cosine decay with **warmup**, gradient clipping, optional **EMA** |
52
- | Fast iteration | YAML-driven config; SMP for one-line model construction; scripts for train / eval / infer |
53
- | Deployment story | Optional **ONNX** export; inference timing written to `latency.txt` |
54
-
55
-
56
- **Note:** The original hackathon plan also mentioned **SegFormer-B2** as a balanced option. This codebase’s **default** is **DeepLabV3+ + ResNet-50**. UNet and FPN are supported in code; SegFormer is **not** implemented as a separate architecture in `models/factory.py` (you can experiment with **MiT** encoders under DeepLabV3+ if SMP supports your chosen encoder name).
57
-
58
- ---
59
-
60
- ## 3. Dataset layout and assumptions
61
-
62
- All paths in config are **relative to the workspace root** (`--root` on the CLI, or the repo root by default).
63
-
64
- ```text
65
- <root>/
66
- training/
67
- train/
68
- Color_Images/ # RGB training inputs
69
- Segmentation/ # Training masks (same filenames as Color_Images)
70
- val/
71
- Color_Images/ # RGB validation inputs
72
- Segmentation/ # Validation masks
73
- testing/
74
- Color_Images/ # Unlabeled images for final inference / demo
75
- ```
76
-
77
- **Pairing rule:** For each split, every file in `Color_Images` must have a mask with the **same basename** in `Segmentation`. The dataset constructor raises if a mask is missing.
78
-
79
- **Typical image size in this workspace:** RGB and masks are often **960×540** (masks are single-channel uint16 PNGs). Training uses **512×512** crops; validation pads to a **512×512** canvas for batching.
80
-
81
- ---
82
-
83
- ## 4. Label format (critical)
84
-
85
- ### 4.1 What the masks are
86
-
87
- - Masks are read as **2D arrays** (single channel).
88
- - In this dataset they behave as `**I;16` (16-bit unsigned)** semantic IDs: pixel values are **not** 0, 1, 2, …
89
- They are **dataset-specific raw IDs**, e.g. `100, 200, 300, 500, 550, 600, 700, 800, 7100, 10000`.
90
-
91
- ### 4.2 Mapping raw IDs → training indices
92
-
93
- The class `RawMaskCodec` in `desert_segmentation/data/mask_encoding.py`:
94
-
95
- 1. Builds a **lookup table (LUT)** from `max(raw_ids)` down to 0.
96
- 2. Maps each legal raw ID to a contiguous index `**0 … num_classes-1`** (uint8 for Albumentations compatibility).
97
- 3. **Raises** if any pixel is not in the configured `raw_ids` list (unknown pixels would map to sentinel `255` in the LUT and trigger an error).
98
-
99
- **Why this matters:** Using the wrong mapping (or treating masks as 8-bit class indices) silently destroys learning.
100
-
101
- ### 4.3 Ignore index (255)
102
-
103
- - **Training:** `ShiftScaleRotate` can introduce border pixels on the mask; those are filled with `**ignore_index` (255)**. Cross-entropy and Dice **ignore** those pixels.
104
- - **Validation:** `PadIfNeeded` pads the mask with **255** so square tensors align; metrics and loss skip those pixels.
105
-
106
- ### 4.4 Class names
107
-
108
- `class_names` in YAML are **display labels** (e.g. `id_100`, …). Replace them with semantic names (e.g. `sky`, `sand`) when you have official ontology from the dataset provider.
109
-
110
  ---
111
-
112
- ## 5. Repository structure
113
-
114
- ```text
115
- codewizard 2.0/
116
- README.md # This file
117
- requirements.txt # Python dependencies
118
- requirements-demo.txt # Optional: Gradio demo
119
- desert_segmentation/ # Importable package
120
- __init__.py
121
- configs/
122
- default.yaml # Single source of truth for paths & hyperparameters
123
- data/
124
- dataset.py # SegmentationDataset (pairing, crop, rare bias)
125
- transforms.py # Albumentations train/val pipelines
126
- mask_encoding.py # RawMaskCodec + build_codec_from_config
127
- models/
128
- factory.py # SMP: DeepLabV3+, UNet, FPN
129
- losses/
130
- combined.py # CE, weighted CE, focal, CE+Dice + weight helper
131
- metrics/
132
- iou.py # Confusion matrix, IoU, mIoU, fwIoU
133
- train/
134
- trainer.py # Main training loop (AMP, EMA, scheduler, checkpoints)
135
- evaluate.py # Batched validation metric pass
136
- infer/
137
- predict.py # Sliding window, TTA, folder inference, ONNX export
138
- utils/
139
- config.py # YAML load + path resolution
140
- seed.py # Reproducibility
141
- logging_utils.py # Logging setup
142
- freq.py # Scan mask folders for class frequencies
143
- viz.py # Colorization + overlay + triplet PNG export
144
- demo/
145
- inference_ui.py # Gradio helpers: legend HTML, validation, composites
146
- scripts/
147
- train.py # CLI: train from config
148
- eval.py # CLI: val metrics + confusion + visualization PNGs
149
- eval_summary.py # CLI: mIoU (all + valid-GT), fwIoU, accuracies, GT counts, per-class table (+ JSON)
150
- infer.py # CLI: run on testing/ or export ONNX
151
- demo_gradio.py # CLI: browser upload demo (Gradio)
152
- tests/
153
- test_mask_encoding.py # Unit tests for codec / unknown pixels
154
- ```
155
-
156
- **Scripts** add the repo root to `sys.path` so you can run them without installing the package as a wheel.
157
-
158
- ---
159
-
160
- ## 6. Configuration (`default.yaml`)
161
-
162
- Key sections (see `desert_segmentation/configs/default.yaml` for the full file):
163
-
164
-
165
- | Section | Purpose |
166
- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
167
- | `root` | Base path for resolving relative data paths (overridden by `--root` in scripts) |
168
- | `data.`* | Relative dirs for train/val images and masks, test images, `raw_ids`, `class_names`, `crop_size`, `rare_class_crop_prob`, `weighted_sampler`, `weighted_sampler_eps`, `ignore_index` |
169
- | `model.*` | `architecture` (`deeplabv3plus` | `unet` | `fpn`), `encoder_name`, `encoder_weights` |
170
- | `train.*` | `batch_size`, `epochs`, `lr`, `weight_decay`, `warmup_ratio`, `amp`, `gradient_clip`, `seed`, `checkpoint_dir`, `log_interval`, `early_stop_patience` |
171
- | `loss.*` | `name` (`ce` | `weighted_ce` | `ce_dice` | `focal_ce` | `focal_ce_dice`), `dice_weight`, `label_smoothing` (CE modes only), `class_weight_cap`, `focal_gamma` |
172
- | `augmentation.strong` | Enables extra sun flare + shadow blocks in training |
173
- | `ema.*` | Optional exponential moving average of weights for evaluation |
174
- | `inference.*` | `tile_size`, `overlap` (for sliding window), `tta_flip`, `batch_size` (reserved for future batching) |
175
-
176
-
177
- ---
178
-
179
- ## 7. High-level architecture
180
-
181
- ```mermaid
182
- flowchart TB
183
- subgraph inputs [Inputs]
184
- RGB[RGB images]
185
- GT[Ground truth masks]
186
- end
187
- subgraph prep [Preprocessing]
188
- Codec[RawMaskCodec LUT]
189
- Crop[Train: random 512 crop with rare bias]
190
- ValPad[Val: resize longest side then pad to 512]
191
- Aug[Albumentations geom plus color]
192
- end
193
- subgraph model [Model SMP]
194
- DL[DeepLabV3Plus default]
195
- end
196
- subgraph train [Training]
197
- Loss[CE plus Dice with class weights]
198
- Opt[AdamW plus cosine warmup LR]
199
- AMP[AMP if CUDA]
200
- EMA[EMA optional]
201
- CKPT[Best mIoU checkpoint]
202
- end
203
- subgraph out [Outputs]
204
- Metrics[mIoU per class IoU fwIoU confusion]
205
- Viz[Overlays triplets]
206
- ONNX[Optional ONNX]
207
- end
208
- RGB --> Codec
209
- GT --> Codec
210
- Codec --> Crop
211
- Codec --> ValPad
212
- Crop --> Aug
213
- Aug --> DL
214
- ValPad --> DL
215
- DL --> Loss
216
- Loss --> Opt
217
- Opt --> AMP
218
- Opt --> EMA
219
- DL --> Metrics
220
- Metrics --> CKPT
221
- Metrics --> Viz
222
- DL --> ONNX
223
- ```
224
-
225
-
226
-
227
  ---
228
 
229
- ## 8. Data pipeline (detailed)
230
-
231
- ### 8.1 `SegmentationDataset` (`data/dataset.py`)
232
-
233
- 1. **List images** in `images_dir` with extensions: `.png`, `.jpg`, `.jpeg`, `.bmp`, `.tif`, `.tiff`.
234
- 2. **Verify** each image has a mask with the same filename in `masks_dir`.
235
- 3. **Load RGB** with Pillow → `HxWx3` uint8.
236
- 4. **Load mask** as numpy 2D → cast to `uint16` → `**codec.encode_mask`** → `HxW` uint8 with values `0 … C-1` (or padded 255 later in transforms).
237
-
238
- **Train mode (`mode="train"`):**
239
-
240
- - `**_random_crop_bias_rare`:** Extract a `**crop_size × crop_size`** patch.
241
- - With probability `rare_class_crop_prob` (default **0.35**), pick the **rarest class** in that image (by histogram) and center the crop on a random pixel of that class (if any exist).
242
- - Otherwise pick a uniformly random center.
243
- - If the image is smaller than the crop, **zero-pad** the image and **255-pad** the mask (ignore regions).
244
-
245
- **Val mode (`mode="val"`):**
246
-
247
- - No random crop in the dataset; the **full** image goes to Albumentations.
248
-
249
- ### 8.2 Transforms (`data/transforms.py`)
250
-
251
- **Train (`build_train_transforms`):**
252
-
253
- - **Geometric:** `HorizontalFlip`, `ShiftScaleRotate` (shift, scale, ±10° rotation) with `mask_value=ignore_index` on borders.
254
- - **Photometric:** brightness/contrast, hue/sat/value, Gaussian blur, Gaussian noise, JPEG compression simulation, RGB shift.
255
- - **If `augmentation.strong`:** `RandomSunFlare`, `RandomShadow` (desert-relevant appearance stress).
256
- - **Normalize:** ImageNet mean/std.
257
- - `**ToTensorV2`:** Image → `float` tensor `CHW`; mask handled so downstream converts to `long` in `__getitem__`.
258
-
259
- **Val (`build_val_transforms`):**
260
-
261
- - `LongestMaxSize(crop_size)` then `PadIfNeeded(crop_size, crop_size)` with **mask pad = 255** (ignored in loss/metrics).
262
-
263
- ### 8.3 Class frequency estimation (`utils/freq.py`)
264
-
265
- Before training, `scripts/train.py` calls `**estimate_pixel_frequencies`** over **all** training mask files (configurable `max_files` in code; train script uses full corpus). This yields a normalized frequency vector per class → used to build **class weights**.
266
-
267
- ---
268
-
269
- ## 9. Model
270
-
271
- **Factory:** `desert_segmentation/models/factory.py`
272
-
273
-
274
- | `architecture` | SMP class | Notes |
275
- | ------------------------- | ------------------- | ------------------------------------------------- |
276
- | `deeplabv3plus` (default) | `smp.DeepLabV3Plus` | Mainline; strong decoder + atrous spatial pyramid |
277
- | `unet` | `smp.Unet` | Classic encoder–decoder skips |
278
- | `fpn` | `smp.FPN` | Feature pyramid neck |
279
-
280
-
281
- **Default encoder:** `resnet50` with `encoder_weights: imagenet`.
282
-
283
- **Forward:** Input batch `N×3×H×W` → logits `N×C×H×W` where `C = num_classes`.
284
-
285
- ---
286
-
287
- ## 10. Loss functions
288
-
289
- **File:** `desert_segmentation/losses/combined.py`
290
-
291
- **Modes (`loss.name`):**
292
-
293
-
294
- | Mode | Description |
295
- | ------------------- | ------------------------------------------------------------------------------------ |
296
- | `ce` | Plain cross-entropy, unweighted |
297
- | `weighted_ce` | Cross-entropy with per-class `weight` tensor |
298
- | `ce_dice` (default) | `CE(weighted) + dice_weight * multiclass_Dice_loss` |
299
- | `focal_ce` | Focal modulated CE; optional class weights on pixels |
300
- | `focal_ce_dice` | `focal_ce` + `dice_weight * multiclass_Dice_loss` (same class weights in focal term) |
301
-
302
-
303
- **Shared options:**
304
-
305
- - `**ignore_index`:** Pixels with label 255 are masked out of CE / focal / dice.
306
- - `**label_smoothing`:** Applied to **CE-based** modes (`ce`, `weighted_ce`, `ce_dice`) only; not used in `focal_ce` / `focal_ce_dice`.
307
-
308
- **Class weights (`compute_class_weights_from_freq`):**
309
-
310
- 1. Start from per-class pixel frequency `freq` on the training masks.
311
- 2. `w ∝ 1 / log(freq + ε)`, normalize by mean.
312
- 3. Clamp the ratio `w / median(w)` to `**class_weight_cap`** (default **15**) so rare classes do not explode the loss.
313
-
314
- ---
315
-
316
- ## 11. Metrics
317
-
318
- **File:** `desert_segmentation/metrics/iou.py`
319
-
320
- 1. **Confusion matrix** `C×C` (implementation uses `idx = tgt * C + pred` then `bincount`; rows correspond to **ground-truth class**, columns to **predicted class**).
321
- 2. **Per-class IoU:**
322
- \text{IoU}_k = \frac{TP_k}{TP_k + FP_k + FN_k}
323
- with `TP_k = CM[k,k]`, row/col sums for FP/FN.
324
- 3. **mIoU:** Mean of per-class IoU over finite entries.
325
- 4. **fwIoU (frequency-weighted IoU):** \sum_k \text{IoU}_k \cdot p_k where p_k is the empirical frequency of class k in the ground-truth pixels (column marginals).
326
-
327
- **Note:** The docstring in `compute_confusion` mentions “pred rows, target columns”; the actual indexing follows `**tgt` (row) × `C` + `pred` (col)`** after reshape.
328
-
329
- ---
330
-
331
- ## 12. Training loop
332
-
333
- **File:** `desert_segmentation/train/trainer.py`
334
-
335
- **Optimizer:** AdamW on all parameters.
336
-
337
- **Learning rate:** `LambdaLR` with:
338
-
339
- - **Linear warmup** for `warmup_ratio` of total optimizer steps (default **8%**).
340
- - **Cosine** decay from 1.0 down to `min_ratio` **0.01** (implemented in `_warmup_cosine_lambda`).
341
-
342
- **AMP (mixed precision):**
343
-
344
- - Enabled only if `train.amp` is true **and** `torch.cuda.is_available()`.
345
- - Uses `torch.cuda.amp.autocast` + `GradScaler` when on CUDA.
346
- - On **CPU**, AMP is off; training uses standard FP32 backward (no scaler).
347
-
348
- **Gradient clipping:** Global norm clip when `gradient_clip > 0` (default **1.0**).
349
-
350
- **EMA (optional):**
351
-
352
- - If `ema.enabled`, after each optimizer step the code maintains a **shadow weight** copy per trainable parameter: exponential decay **0.999** by default.
353
- - **Each epoch:** Training weights are **deep-copied**; **EMA weights are copied into the model** for validation only; then the training snapshot is **restored** so optimization continues from the non-EMA weights.
354
-
355
- **Checkpointing:**
356
 
357
- - Every epoch: `checkpoints/last.pt` (model, optional EMA dict, optimizer, full config, class names).
358
- - **Best validation mIoU:** `checkpoints/best.pt` (adds `miou`, `per_class_iou`).
359
-
360
- **Early stopping:** If validation mIoU does not improve for `early_stop_patience` epochs (default **12**), training stops.
361
-
362
- **Optional smoke flags (`scripts/train.py`):**
363
-
364
- - `--epochs N` — override epoch count.
365
- - `--max_train_batches K` — stop each training epoch after `K` batches (debug only; scheduler still advances per batch).
366
-
367
- **Logging:** `checkpoints/history.json` lists per-epoch `miou` and `fw_iou`.
368
-
369
- ---
370
-
371
- ## 13. Validation and evaluation scripts
372
-
373
- **Core loop:** `desert_segmentation/train/evaluate.py` runs the model in `eval()` mode, accumulates confusion via `IoUMetrics`, returns a dict.
374
-
375
- **CLI:** `scripts/eval.py`
376
-
377
- 1. Loads config + builds validation dataset (same codec and val transforms as training).
378
- 2. Loads checkpoint from `--checkpoint`.
379
- 3. **Weight loading priority:** If `ema` dict exists in checkpoint, **EMA tensors are copied into parameters** for evaluation; else `state_dict` from `model` key.
380
- 4. Runs full val loader → logs **mIoU**, **fwIoU**, per-class IoU.
381
- 5. Writes:
382
- - `eval_outputs/metrics.json` (or `--out_dir`)
383
- - `confusion.npy`
384
- - Up to `--max_viz` side-by-side **RGB | GT | Pred** PNGs (`save_triplet` in `utils/viz.py`), with ImageNet denormalization for RGB panels.
385
-
386
- ---
387
-
388
- ## 14. Inference (testing folder, sliding window, TTA, ONNX)
389
-
390
- **CLI:** `scripts/infer.py`
391
-
392
- ### 14.1 Folder inference
393
-
394
- - Reads `testing/Color_Images` (or whatever `data.test_images` points to).
395
- - Loads checkpoint with the same **EMA-first** rule as eval.
396
- - For each image:
397
- - If **both** height and width ≤ `tile_size` (512): single forward pass.
398
- - Else: **sliding window** with stride `tile_size * (1 - overlap)` (default overlap **0.25** → stride **384**).
399
- - Pads the image with **reflect** padding so tile grid covers corners; crops back to original size.
400
- - Accumulates **per-class logits** weighted by a **2D Gaussian** (`sigma ∝ tile/3`) so tile borders blend smoothly; final prediction is `**argmax` over classes** per pixel.
401
-
402
- ### 14.2 Test-time augmentation (TTA)
403
-
404
- If `inference.tta_flip` is true: logits = **0.5 × (logits(x) + unflip(logits(flip(x))))** horizontally.
405
-
406
- ### 14.3 Outputs
407
-
408
- Under `--out_dir` (default `infer_outputs/`):
409
-
410
- - `pred_<filename>` — color overlay (prediction tinted on RGB).
411
- - `triplet_<filename>` — **RGB | blank or GT | Pred** strip (test set has no GT, so middle panel is zeros in current `save_triplet` usage).
412
- - `latency.txt` — mean milliseconds per image and device string.
413
-
414
- ### 14.4 ONNX
415
-
416
- `python scripts/infer.py --checkpoint ... --onnx model.onnx` calls `export_onnx`: builds model on **CPU**, dummy input `1×3×512×512`, `torch.onnx.export` with dynamic axes for batch and spatial size.
417
-
418
- ---
419
-
420
- ## 15. Checkpoints and artifacts
421
-
422
-
423
- | Artifact | Contents |
424
- | -------------------------- | --------------------------------------------------------------------------- |
425
- | `checkpoints/best.pt` | `model`, `ema` (optional), `miou`, `per_class_iou`, `config`, `class_names` |
426
- | `checkpoints/last.pt` | Latest epoch snapshot + optimizer |
427
- | `checkpoints/history.json` | List of `{epoch, miou, fw_iou}` |
428
- | `eval_outputs/`* | `metrics.json`, `confusion.npy`, visualization PNGs |
429
- | `infer_outputs/*` | Overlays, triplets, `latency.txt` |
430
-
431
-
432
- ---
433
-
434
- ## 16. How to run (commands)
435
-
436
- From the repository root (adjust paths if yours differ).
437
-
438
- ### 16.1 Install
439
-
440
- ```powershell
441
- python -m pip install -r requirements.txt
442
- ```
443
-
444
- ### 16.2 Train
445
-
446
- ```powershell
447
- $env:PYTHONPATH="."
448
- python scripts\train.py --root "d:\codewizard 2.0"
449
- ```
450
-
451
- Optional:
452
-
453
- ```powershell
454
- python scripts\train.py --root "d:\codewizard 2.0" --config desert_segmentation\configs\default.yaml --epochs 5 --max_train_batches 50
455
- ```
456
-
457
- **Imbalanced classes (optional YAML):** set `loss.name` to `focal_ce_dice` for focal + Dice; tune `class_weight_cap`, `rare_class_crop_prob`, and/or `data.weighted_sampler: true` to oversample train images that contain rare classes (scans all train masks once at startup—can take a minute on large sets).
458
-
459
- ### 16.3 Evaluate (validation)
460
-
461
- ```powershell
462
- python scripts\eval.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --out_dir eval_outputs
463
- ```
464
-
465
- **Metric summary (no PNGs):** prints **mIoU (all classes)** and **mIoU (classes with val GT)** (the latter ignores absent classes so it is easier to interpret on sparse val labels), **fwIoU**, **global / mean class accuracy**, **val GT pixel counts per class**, and a **per-class IoU / recall** table. Same validation forward pass as `eval.py`. Optional: `--json-out eval_summary.json` (includes `miou_valid_gt_classes`, `val_gt_pixel_counts`).
466
-
467
- ```powershell
468
- python scripts\eval_summary.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --json-out eval_summary.json
469
- ```
470
-
471
- To print only **mIoU** and **per-class IoU** stored inside the checkpoint (no GPU eval): `python scripts\eval_summary.py --from-checkpoint-only --checkpoint checkpoints\best.pt`
472
-
473
- ### 16.4 Infer on `testing/Color_Images`
474
-
475
- ```powershell
476
- python scripts\infer.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --out_dir infer_outputs --limit 20
477
- ```
478
-
479
- ### 16.5 Export ONNX
480
-
481
- ```powershell
482
- python scripts\infer.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --onnx model.onnx
483
- ```
484
-
485
- ---
486
-
487
- ## 17. Interactive demo (Gradio)
488
-
489
- Upload an RGB image in the browser and get a **colored class mask**, **overlay**, a **side-by-side strip** (RGB | mask | overlay), a **fixed legend** (colors match `palette()` in training), **inference time**, and **dominant classes** (pixel histogram). Uses the same path as CLI inference: `[_load_model_for_inference](d:\codewizard 2.0\desert_segmentation\infer\predict.py)` and `[predict_image](d:\codewizard 2.0\desert_segmentation\infer\predict.py)` (EMA weights preferred when present in the checkpoint).
490
-
491
- **Install** (base + demo extras):
492
-
493
- ```powershell
494
- python -m pip install -r requirements.txt -r requirements-demo.txt
495
- ```
496
-
497
- **Run** (from repo root; model loads **once** at startup — look for a log line `Model ready`):
498
-
499
- ```powershell
500
- $env:PYTHONPATH="."
501
- python scripts\demo_gradio.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt
502
- ```
503
-
504
- **CLI flags:** `--host` (default `127.0.0.1`), `--port` (default `7860`), `--share` (temporary public Gradio link), `--max-side`, `--max-megapixels` (reject huge uploads before inference).
505
-
506
- **Environment variables** (optional defaults if flags omitted):
507
-
508
-
509
- | Variable | Purpose |
510
- | ----------------- | ------------------------------------------------------- |
511
- | `ROOT` | Workspace root (same as `--root`) |
512
- | `CHECKPOINT_PATH` | Path to `best.pt` (relative paths resolve under `ROOT`) |
513
-
514
-
515
- **Advanced panel:** TTA on/off, tile overlap slider, tile size slider (256–2048, step 64). Overrides are passed into `predict_image` only; the checkpoint file is not modified.
516
-
517
- **v1 limitations:** No per-pixel **confidence heatmap** for full sliding-window runs (only `argmax` is returned from `predict_image`). See plan follow-up to add logits fusion if needed.
518
-
519
- **Windows:** Use backslashes or quoted paths as above; first launch may be slow while dependencies initialize.
520
-
521
- **Follow-ups (not in v1):** full-resolution **confidence** heatmap (needs logits path in `predict.py`); **ZIP** batch upload; **two-checkpoint** comparison UI; client-side **ONNX** inference.
522
-
523
- ---
524
-
525
- ## 18. Tests
526
-
527
- ```powershell
528
- python -m pytest tests\test_mask_encoding.py -q
529
- ```
530
-
531
- Covers:
532
-
533
- - Round-trip **raw mask ↔ class indices** for known IDs.
534
- - **Unknown raw pixel** raises `ValueError`.
535
- - LUT correctness for each configured raw id.
536
-
537
- ---
538
-
539
- ## 19. Dependencies and environment notes
540
-
541
- `**requirements.txt`:**
542
-
543
- - `torch`, `torchvision`, `numpy`, `Pillow`, `PyYAML`
544
- - `albumentations` pinned to `<1.5` to reduce optional native build issues on some Windows setups
545
- - `segmentation-models-pytorch` (SMP)
546
- - `tqdm`, `pytest`
547
- - Optional demo: `requirements-demo.txt` adds **Gradio**
548
-
549
- **Windows:** `scripts/train.py` and `scripts/eval.py` set `num_workers=0` for `DataLoader` on NT to avoid multiprocessing friction.
550
-
551
- **SMP pretrained weights:** First run may download encoder weights (e.g. ResNet-50 ImageNet) via SMP / Hugging Face hubs depending on SMP version.
552
-
553
- ---
554
-
555
- ## 20. Design decisions and limitations
556
-
557
-
558
- | Topic | Decision / limitation |
559
- | ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
560
- | Mask modes | **16-bit raw IDs** supported via LUT; **P-mode palette** and **RGB color masks** are *not* auto-detected in this codebase—extend `mask_encoding.py` if your dataset uses them |
561
- | SegFormer | **Not** a separate `architecture` enum; plan mentioned SegFormer-B2 as an alternative—would require additional factory code or using a supported SMP encoder |
562
- | Val resolution | Images are **letterboxed** to 512×512 for batching; mIoU is on padded regions with ignore—fine for hackathon; for publication-grade eval consider sliding-window val too |
563
- | Inference fusion | Overlapping tiles add **Gaussian-weighted logits** per class into an accumulator; the final label is `**argmax` over the accumulated logits** (feathered overlap fusion). A per-pixel `weight` tensor is also accumulated in code for possible future normalization extensions |
564
- | Poly LR / sync BN | **Not** implemented (cosine+warmup only) |
565
- | Ensemble | **Not** implemented (single model + optional EMA) |
566
-
567
-
568
- ---
569
-
570
- ## 21. Extending the project
571
-
572
- 1. **New classes / raw IDs:** Edit `data.raw_ids` and `data.class_names` in YAML; rerun frequency scan is automatic in `train.py`.
573
- 2. **UNet / FPN:** Set `model.architecture` to `unet` or `fpn`; pick a valid `encoder_name` for SMP.
574
- 3. **Larger encoder:** e.g. `encoder_name: resnet101` for DeepLabV3+.
575
- 4. **Loss ablation:** Set `loss.name` to `ce`, `weighted_ce`, `focal_ce`, or `focal_ce_dice`; tune `dice_weight`, `label_smoothing`, `class_weight_cap`.
576
- 5. **Stronger aug:** Add Albumentations ops in `transforms.py` (keep `additional_targets={"mask":"mask"}` for paired geometry).
577
-
578
- ---
579
-
580
- ## 22. Flowcharts
581
-
582
- ### 22.1 Training epoch (simplified)
583
-
584
- ```mermaid
585
- flowchart TD
586
- start[Start epoch]
587
- trainLoop[For each batch]
588
- fwd[Forward logits]
589
- lossStep[Compute loss CE plus Dice]
590
- backward[Backward plus clip]
591
- stepOpt[Optimizer step plus scheduler step]
592
- emaUp[Update EMA if enabled]
593
- endTrain[End train batches]
594
- snap[Snapshot model weights]
595
- applyEMA[Copy EMA into model if enabled]
596
- valRun[Run validation mIoU]
597
- restore[Restore snapshot weights]
598
- better{New best mIoU?}
599
- saveBest[Save best.pt]
600
- early{Patience exceeded?}
601
- stop[Stop training]
602
- start --> trainLoop
603
- trainLoop --> fwd --> lossStep --> backward --> stepOpt --> emaUp
604
- emaUp --> trainLoop
605
- trainLoop --> endTrain
606
- endTrain --> snap --> applyEMA --> valRun --> restore --> better
607
- better -->|yes| saveBest --> early
608
- better -->|no| early
609
- early -->|yes| stop
610
- early -->|no| start
611
- ```
612
-
613
-
614
-
615
- ### 22.2 Inference on large images
616
-
617
- ```mermaid
618
- flowchart LR
619
- img[Input RGB HxW]
620
- pad[Reflect pad to tile grid]
621
- tiles[For each tile]
622
- fwdT[Forward logits optional TTA]
623
- g[Multiply by Gaussian feather]
624
- acc[Accumulate class logits maps]
625
- argmax[Argmax over classes]
626
- cropBack[Crop to original HxW]
627
- img --> pad --> tiles --> fwdT --> g --> acc --> argmax --> cropBack
628
- ```
629
-
630
-
631
-
632
- ---
633
-
634
- ## Acknowledgments
635
-
636
- - **segmentation_models_pytorch** (Pavel Iakubovskii and contributors) for modular segmentation architectures.
637
- - **Albumentations** for fast, paired image–mask augmentations.
638
-
639
- ---
640
 
641
- *Generated to document the implementation in this repository as of the README authoring date. For the original hackathon planning narrative, see your separate plan document (not stored in this repo’s `README`).*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Desert semantic segmentation
3
+ emoji: 🏜️
4
+ colorFrom: yellow
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: "5.12.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Desert RGB → per-pixel class mask, overlay, metrics.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
+ # Desert semantic segmentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ Interactive demo for desert / off-road **semantic segmentation**. Upload an RGB image and click **Run segmentation**.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ **Weights:** set Space variable `HF_HUB_CHECKPOINT_REPO` to `webVishnu/desert-seg-best` (and optional `HF_HUB_CHECKPOINT_FILENAME` if the file is not named `best.pt`).