Spaces:
Runtime error
Runtime error
Commit ·
77445cb
0
Parent(s):
inital
Browse files- .gitignore +12 -0
- .gradio/certificate.pem +31 -0
- README.md +621 -0
- desert_segmentation/__init__.py +3 -0
- desert_segmentation/configs/default.yaml +77 -0
- desert_segmentation/data/__init__.py +4 -0
- desert_segmentation/data/dataset.py +129 -0
- desert_segmentation/data/mask_encoding.py +90 -0
- desert_segmentation/data/transforms.py +90 -0
- desert_segmentation/demo/__init__.py +15 -0
- desert_segmentation/demo/inference_ui.py +95 -0
- desert_segmentation/infer/__init__.py +3 -0
- desert_segmentation/infer/predict.py +210 -0
- desert_segmentation/losses/__init__.py +3 -0
- desert_segmentation/losses/combined.py +143 -0
- desert_segmentation/metrics/__init__.py +15 -0
- desert_segmentation/metrics/iou.py +143 -0
- desert_segmentation/models/__init__.py +3 -0
- desert_segmentation/models/factory.py +37 -0
- desert_segmentation/train/__init__.py +4 -0
- desert_segmentation/train/evaluate.py +30 -0
- desert_segmentation/train/trainer.py +205 -0
- desert_segmentation/utils/__init__.py +3 -0
- desert_segmentation/utils/config.py +37 -0
- desert_segmentation/utils/freq.py +66 -0
- desert_segmentation/utils/logging_utils.py +11 -0
- desert_segmentation/utils/seed.py +13 -0
- desert_segmentation/utils/viz.py +60 -0
- eval_summary.json +47 -0
- requirements-demo.txt +3 -0
- requirements.txt +10 -0
- scripts/demo_gradio.py +219 -0
- scripts/eval.py +146 -0
- scripts/eval_summary.py +259 -0
- scripts/infer.py +49 -0
- scripts/train.py +166 -0
- tests/test_confusion_metrics.py +50 -0
- tests/test_mask_encoding.py +33 -0
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
testing/
|
| 2 |
+
training/
|
| 3 |
+
checkpoints/
|
| 4 |
+
eval_outputs/
|
| 5 |
+
infer_outputs/
|
| 6 |
+
logs/
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
*.pyw
|
| 12 |
+
*.pyz
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
README.md
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Desert Semantic Segmentation
|
| 2 |
+
|
| 3 |
+
End-to-end **semantic segmentation** for **off-road / desert** scenes: every pixel is classified into one of several terrain / object categories. The pipeline is built for **synthetic RGB + mask** data, **PyTorch**, **[segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch)** (SMP), **Albumentations**, and hackathon-style iteration (strong baselines, IoU-driven checkpoints, optional EMA / TTA / ONNX).
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Table of contents
|
| 8 |
+
|
| 9 |
+
1. [What this project does](#1-what-this-project-does)
|
| 10 |
+
2. [Problem statement and goals](#2-problem-statement-and-goals)
|
| 11 |
+
3. [Dataset layout and assumptions](#3-dataset-layout-and-assumptions)
|
| 12 |
+
4. [Label format (critical)](#4-label-format-critical)
|
| 13 |
+
5. [Repository structure](#5-repository-structure)
|
| 14 |
+
6. [Configuration (`default.yaml`)](#6-configuration-defaultyaml)
|
| 15 |
+
7. [High-level architecture](#7-high-level-architecture)
|
| 16 |
+
8. [Data pipeline (detailed)](#8-data-pipeline-detailed)
|
| 17 |
+
9. [Model](#9-model)
|
| 18 |
+
10. [Loss functions](#10-loss-functions)
|
| 19 |
+
11. [Metrics](#11-metrics)
|
| 20 |
+
12. [Training loop](#12-training-loop)
|
| 21 |
+
13. [Validation and evaluation scripts](#13-validation-and-evaluation-scripts)
|
| 22 |
+
14. [Inference (testing folder, sliding window, TTA, ONNX)](#14-inference-testing-folder-sliding-window-tta-onnx)
|
| 23 |
+
15. [Checkpoints and artifacts](#15-checkpoints-and-artifacts)
|
| 24 |
+
16. [How to run (commands)](#16-how-to-run-commands)
|
| 25 |
+
17. [Interactive demo (Gradio)](#17-interactive-demo-gradio)
|
| 26 |
+
18. [Tests](#18-tests)
|
| 27 |
+
19. [Dependencies and environment notes](#19-dependencies-and-environment-notes)
|
| 28 |
+
20. [Design decisions and limitations](#20-design-decisions-and-limitations)
|
| 29 |
+
21. [Extending the project](#21-extending-the-project)
|
| 30 |
+
22. [Flowcharts](#22-flowcharts)
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
## 1. What this project does
|
| 35 |
+
|
| 36 |
+
- **Input:** RGB color images (`Color_Images`).
|
| 37 |
+
- **Supervision:** Per-pixel class masks (`Segmentation`) aligned by **filename** with the RGB image.
|
| 38 |
+
- **Output:** A trained neural network that predicts a **class index per pixel** on validation, held-out **testing** images (no labels in repo), or any folder of images you point inference at.
|
| 39 |
+
- **Primary quality metric:** **mean Intersection-over-Union (mIoU)** on the validation set, plus **per-class IoU**, **frequency-weighted IoU (fwIoU)**, and a **confusion matrix**.
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## 2. Problem statement and goals
|
| 44 |
+
|
| 45 |
+
| Goal | How we address it |
|
| 46 |
+
|------|-------------------|
|
| 47 |
+
| Accurate pixel-wise classification | DeepLabV3+ with ImageNet-pretrained encoder; CE + Dice loss; class-frequency weights |
|
| 48 |
+
| Robustness (synthetic → harder real domains) | Strong photometric + mild “desert-like” augmentations (sun flare, shadow, blur, noise, JPEG) |
|
| 49 |
+
| Class imbalance | Inverse log-frequency weights with a **cap**; rare-class-biased random crops |
|
| 50 |
+
| Stable training | AdamW, cosine decay with **warmup**, gradient clipping, optional **EMA** |
|
| 51 |
+
| Fast iteration | YAML-driven config; SMP for one-line model construction; scripts for train / eval / infer |
|
| 52 |
+
| Deployment story | Optional **ONNX** export; inference timing written to `latency.txt` |
|
| 53 |
+
|
| 54 |
+
**Note:** The original hackathon plan also mentioned **SegFormer-B2** as a balanced option. This codebase’s **default** is **DeepLabV3+ + ResNet-50**. UNet and FPN are supported in code; SegFormer is **not** implemented as a separate architecture in `models/factory.py` (you can experiment with **MiT** encoders under DeepLabV3+ if SMP supports your chosen encoder name).
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## 3. Dataset layout and assumptions
|
| 59 |
+
|
| 60 |
+
All paths in config are **relative to the workspace root** (`--root` on the CLI, or the repo root by default).
|
| 61 |
+
|
| 62 |
+
```text
|
| 63 |
+
<root>/
|
| 64 |
+
training/
|
| 65 |
+
train/
|
| 66 |
+
Color_Images/ # RGB training inputs
|
| 67 |
+
Segmentation/ # Training masks (same filenames as Color_Images)
|
| 68 |
+
val/
|
| 69 |
+
Color_Images/ # RGB validation inputs
|
| 70 |
+
Segmentation/ # Validation masks
|
| 71 |
+
testing/
|
| 72 |
+
Color_Images/ # Unlabeled images for final inference / demo
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
**Pairing rule:** For each split, every file in `Color_Images` must have a mask with the **same basename** in `Segmentation`. The dataset constructor raises if a mask is missing.
|
| 76 |
+
|
| 77 |
+
**Typical image size in this workspace:** RGB and masks are often **960×540** (masks are single-channel uint16 PNGs). Training uses **512×512** crops; validation pads to a **512×512** canvas for batching.
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## 4. Label format (critical)
|
| 82 |
+
|
| 83 |
+
### 4.1 What the masks are
|
| 84 |
+
|
| 85 |
+
- Masks are read as **2D arrays** (single channel).
|
| 86 |
+
- In this dataset they behave as **`I;16` (16-bit unsigned)** semantic IDs: pixel values are **not** 0, 1, 2, …
|
| 87 |
+
They are **dataset-specific raw IDs**, e.g. `100, 200, 300, 500, 550, 600, 700, 800, 7100, 10000`.
|
| 88 |
+
|
| 89 |
+
### 4.2 Mapping raw IDs → training indices
|
| 90 |
+
|
| 91 |
+
The class `RawMaskCodec` in `desert_segmentation/data/mask_encoding.py`:
|
| 92 |
+
|
| 93 |
+
1. Builds a **lookup table (LUT)** from `max(raw_ids)` down to 0.
|
| 94 |
+
2. Maps each legal raw ID to a contiguous index **`0 … num_classes-1`** (uint8 for Albumentations compatibility).
|
| 95 |
+
3. **Raises** if any pixel is not in the configured `raw_ids` list (unknown pixels would map to sentinel `255` in the LUT and trigger an error).
|
| 96 |
+
|
| 97 |
+
**Why this matters:** Using the wrong mapping (or treating masks as 8-bit class indices) silently destroys learning.
|
| 98 |
+
|
| 99 |
+
### 4.3 Ignore index (255)
|
| 100 |
+
|
| 101 |
+
- **Training:** `ShiftScaleRotate` can introduce border pixels on the mask; those are filled with **`ignore_index` (255)**. Cross-entropy and Dice **ignore** those pixels.
|
| 102 |
+
- **Validation:** `PadIfNeeded` pads the mask with **255** so square tensors align; metrics and loss skip those pixels.
|
| 103 |
+
|
| 104 |
+
### 4.4 Class names
|
| 105 |
+
|
| 106 |
+
`class_names` in YAML are **display labels** (e.g. `id_100`, …). Replace them with semantic names (e.g. `sky`, `sand`) when you have official ontology from the dataset provider.
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## 5. Repository structure
|
| 111 |
+
|
| 112 |
+
```text
|
| 113 |
+
codewizard 2.0/
|
| 114 |
+
README.md # This file
|
| 115 |
+
requirements.txt # Python dependencies
|
| 116 |
+
requirements-demo.txt # Optional: Gradio demo
|
| 117 |
+
desert_segmentation/ # Importable package
|
| 118 |
+
__init__.py
|
| 119 |
+
configs/
|
| 120 |
+
default.yaml # Single source of truth for paths & hyperparameters
|
| 121 |
+
data/
|
| 122 |
+
dataset.py # SegmentationDataset (pairing, crop, rare bias)
|
| 123 |
+
transforms.py # Albumentations train/val pipelines
|
| 124 |
+
mask_encoding.py # RawMaskCodec + build_codec_from_config
|
| 125 |
+
models/
|
| 126 |
+
factory.py # SMP: DeepLabV3+, UNet, FPN
|
| 127 |
+
losses/
|
| 128 |
+
combined.py # CE, weighted CE, focal, CE+Dice + weight helper
|
| 129 |
+
metrics/
|
| 130 |
+
iou.py # Confusion matrix, IoU, mIoU, fwIoU
|
| 131 |
+
train/
|
| 132 |
+
trainer.py # Main training loop (AMP, EMA, scheduler, checkpoints)
|
| 133 |
+
evaluate.py # Batched validation metric pass
|
| 134 |
+
infer/
|
| 135 |
+
predict.py # Sliding window, TTA, folder inference, ONNX export
|
| 136 |
+
utils/
|
| 137 |
+
config.py # YAML load + path resolution
|
| 138 |
+
seed.py # Reproducibility
|
| 139 |
+
logging_utils.py # Logging setup
|
| 140 |
+
freq.py # Scan mask folders for class frequencies
|
| 141 |
+
viz.py # Colorization + overlay + triplet PNG export
|
| 142 |
+
demo/
|
| 143 |
+
inference_ui.py # Gradio helpers: legend HTML, validation, composites
|
| 144 |
+
scripts/
|
| 145 |
+
train.py # CLI: train from config
|
| 146 |
+
eval.py # CLI: val metrics + confusion + visualization PNGs
|
| 147 |
+
eval_summary.py # CLI: mIoU (all + valid-GT), fwIoU, accuracies, GT counts, per-class table (+ JSON)
|
| 148 |
+
infer.py # CLI: run on testing/ or export ONNX
|
| 149 |
+
demo_gradio.py # CLI: browser upload demo (Gradio)
|
| 150 |
+
tests/
|
| 151 |
+
test_mask_encoding.py # Unit tests for codec / unknown pixels
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
**Scripts** add the repo root to `sys.path` so you can run them without installing the package as a wheel.
|
| 155 |
+
|
| 156 |
+
---
|
| 157 |
+
|
| 158 |
+
## 6. Configuration (`default.yaml`)
|
| 159 |
+
|
| 160 |
+
Key sections (see `desert_segmentation/configs/default.yaml` for the full file):
|
| 161 |
+
|
| 162 |
+
| Section | Purpose |
|
| 163 |
+
|---------|---------|
|
| 164 |
+
| `root` | Base path for resolving relative data paths (overridden by `--root` in scripts) |
|
| 165 |
+
| `data.*` | Relative dirs for train/val images and masks, test images, `raw_ids`, `class_names`, `crop_size`, `rare_class_crop_prob`, `weighted_sampler`, `weighted_sampler_eps`, `ignore_index` |
|
| 166 |
+
| `model.*` | `architecture` (`deeplabv3plus` \| `unet` \| `fpn`), `encoder_name`, `encoder_weights` |
|
| 167 |
+
| `train.*` | `batch_size`, `epochs`, `lr`, `weight_decay`, `warmup_ratio`, `amp`, `gradient_clip`, `seed`, `checkpoint_dir`, `log_interval`, `early_stop_patience` |
|
| 168 |
+
| `loss.*` | `name` (`ce` \| `weighted_ce` \| `ce_dice` \| `focal_ce` \| `focal_ce_dice`), `dice_weight`, `label_smoothing` (CE modes only), `class_weight_cap`, `focal_gamma` |
|
| 169 |
+
| `augmentation.strong` | Enables extra sun flare + shadow blocks in training |
|
| 170 |
+
| `ema.*` | Optional exponential moving average of weights for evaluation |
|
| 171 |
+
| `inference.*` | `tile_size`, `overlap` (for sliding window), `tta_flip`, `batch_size` (reserved for future batching) |
|
| 172 |
+
|
| 173 |
+
---
|
| 174 |
+
|
| 175 |
+
## 7. High-level architecture
|
| 176 |
+
|
| 177 |
+
```mermaid
|
| 178 |
+
flowchart TB
|
| 179 |
+
subgraph inputs [Inputs]
|
| 180 |
+
RGB[RGB images]
|
| 181 |
+
GT[Ground truth masks]
|
| 182 |
+
end
|
| 183 |
+
subgraph prep [Preprocessing]
|
| 184 |
+
Codec[RawMaskCodec LUT]
|
| 185 |
+
Crop[Train: random 512 crop with rare bias]
|
| 186 |
+
ValPad[Val: resize longest side then pad to 512]
|
| 187 |
+
Aug[Albumentations geom plus color]
|
| 188 |
+
end
|
| 189 |
+
subgraph model [Model SMP]
|
| 190 |
+
DL[DeepLabV3Plus default]
|
| 191 |
+
end
|
| 192 |
+
subgraph train [Training]
|
| 193 |
+
Loss[CE plus Dice with class weights]
|
| 194 |
+
Opt[AdamW plus cosine warmup LR]
|
| 195 |
+
AMP[AMP if CUDA]
|
| 196 |
+
EMA[EMA optional]
|
| 197 |
+
CKPT[Best mIoU checkpoint]
|
| 198 |
+
end
|
| 199 |
+
subgraph out [Outputs]
|
| 200 |
+
Metrics[mIoU per class IoU fwIoU confusion]
|
| 201 |
+
Viz[Overlays triplets]
|
| 202 |
+
ONNX[Optional ONNX]
|
| 203 |
+
end
|
| 204 |
+
RGB --> Codec
|
| 205 |
+
GT --> Codec
|
| 206 |
+
Codec --> Crop
|
| 207 |
+
Codec --> ValPad
|
| 208 |
+
Crop --> Aug
|
| 209 |
+
Aug --> DL
|
| 210 |
+
ValPad --> DL
|
| 211 |
+
DL --> Loss
|
| 212 |
+
Loss --> Opt
|
| 213 |
+
Opt --> AMP
|
| 214 |
+
Opt --> EMA
|
| 215 |
+
DL --> Metrics
|
| 216 |
+
Metrics --> CKPT
|
| 217 |
+
Metrics --> Viz
|
| 218 |
+
DL --> ONNX
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
## 8. Data pipeline (detailed)
|
| 224 |
+
|
| 225 |
+
### 8.1 `SegmentationDataset` (`data/dataset.py`)
|
| 226 |
+
|
| 227 |
+
1. **List images** in `images_dir` with extensions: `.png`, `.jpg`, `.jpeg`, `.bmp`, `.tif`, `.tiff`.
|
| 228 |
+
2. **Verify** each image has a mask with the same filename in `masks_dir`.
|
| 229 |
+
3. **Load RGB** with Pillow → `HxWx3` uint8.
|
| 230 |
+
4. **Load mask** as numpy 2D → cast to `uint16` → **`codec.encode_mask`** → `HxW` uint8 with values `0 … C-1` (or padded 255 later in transforms).
|
| 231 |
+
|
| 232 |
+
**Train mode (`mode="train"`):**
|
| 233 |
+
|
| 234 |
+
- **`_random_crop_bias_rare`:** Extract a **`crop_size × crop_size`** patch.
|
| 235 |
+
- With probability `rare_class_crop_prob` (default **0.35**), pick the **rarest class** in that image (by histogram) and center the crop on a random pixel of that class (if any exist).
|
| 236 |
+
- Otherwise pick a uniformly random center.
|
| 237 |
+
- If the image is smaller than the crop, **zero-pad** the image and **255-pad** the mask (ignore regions).
|
| 238 |
+
|
| 239 |
+
**Val mode (`mode="val"`):**
|
| 240 |
+
|
| 241 |
+
- No random crop in the dataset; the **full** image goes to Albumentations.
|
| 242 |
+
|
| 243 |
+
### 8.2 Transforms (`data/transforms.py`)
|
| 244 |
+
|
| 245 |
+
**Train (`build_train_transforms`):**
|
| 246 |
+
|
| 247 |
+
- **Geometric:** `HorizontalFlip`, `ShiftScaleRotate` (shift, scale, ±10° rotation) with `mask_value=ignore_index` on borders.
|
| 248 |
+
- **Photometric:** brightness/contrast, hue/sat/value, Gaussian blur, Gaussian noise, JPEG compression simulation, RGB shift.
|
| 249 |
+
- **If `augmentation.strong`:** `RandomSunFlare`, `RandomShadow` (desert-relevant appearance stress).
|
| 250 |
+
- **Normalize:** ImageNet mean/std.
|
| 251 |
+
- **`ToTensorV2`:** Image → `float` tensor `CHW`; mask handled so downstream converts to `long` in `__getitem__`.
|
| 252 |
+
|
| 253 |
+
**Val (`build_val_transforms`):**
|
| 254 |
+
|
| 255 |
+
- `LongestMaxSize(crop_size)` then `PadIfNeeded(crop_size, crop_size)` with **mask pad = 255** (ignored in loss/metrics).
|
| 256 |
+
|
| 257 |
+
### 8.3 Class frequency estimation (`utils/freq.py`)
|
| 258 |
+
|
| 259 |
+
Before training, `scripts/train.py` calls **`estimate_pixel_frequencies`** over **all** training mask files (configurable `max_files` in code; train script uses full corpus). This yields a normalized frequency vector per class → used to build **class weights**.
|
| 260 |
+
|
| 261 |
+
---
|
| 262 |
+
|
| 263 |
+
## 9. Model
|
| 264 |
+
|
| 265 |
+
**Factory:** `desert_segmentation/models/factory.py`
|
| 266 |
+
|
| 267 |
+
| `architecture` | SMP class | Notes |
|
| 268 |
+
|----------------|-----------|--------|
|
| 269 |
+
| `deeplabv3plus` (default) | `smp.DeepLabV3Plus` | Mainline; strong decoder + atrous spatial pyramid |
|
| 270 |
+
| `unet` | `smp.Unet` | Classic encoder–decoder skips |
|
| 271 |
+
| `fpn` | `smp.FPN` | Feature pyramid neck |
|
| 272 |
+
|
| 273 |
+
**Default encoder:** `resnet50` with `encoder_weights: imagenet`.
|
| 274 |
+
|
| 275 |
+
**Forward:** Input batch `N×3×H×W` → logits `N×C×H×W` where `C = num_classes`.
|
| 276 |
+
|
| 277 |
+
---
|
| 278 |
+
|
| 279 |
+
## 10. Loss functions
|
| 280 |
+
|
| 281 |
+
**File:** `desert_segmentation/losses/combined.py`
|
| 282 |
+
|
| 283 |
+
**Modes (`loss.name`):**
|
| 284 |
+
|
| 285 |
+
| Mode | Description |
|
| 286 |
+
|------|-------------|
|
| 287 |
+
| `ce` | Plain cross-entropy, unweighted |
|
| 288 |
+
| `weighted_ce` | Cross-entropy with per-class `weight` tensor |
|
| 289 |
+
| `ce_dice` (default) | `CE(weighted) + dice_weight * multiclass_Dice_loss` |
|
| 290 |
+
| `focal_ce` | Focal modulated CE; optional class weights on pixels |
|
| 291 |
+
| `focal_ce_dice` | `focal_ce` + `dice_weight * multiclass_Dice_loss` (same class weights in focal term) |
|
| 292 |
+
|
| 293 |
+
**Shared options:**
|
| 294 |
+
|
| 295 |
+
- **`ignore_index`:** Pixels with label 255 are masked out of CE / focal / dice.
|
| 296 |
+
- **`label_smoothing`:** Applied to **CE-based** modes (`ce`, `weighted_ce`, `ce_dice`) only; not used in `focal_ce` / `focal_ce_dice`.
|
| 297 |
+
|
| 298 |
+
**Class weights (`compute_class_weights_from_freq`):**
|
| 299 |
+
|
| 300 |
+
1. Start from per-class pixel frequency `freq` on the training masks.
|
| 301 |
+
2. `w ∝ 1 / log(freq + ε)`, normalize by mean.
|
| 302 |
+
3. Clamp the ratio `w / median(w)` to **`class_weight_cap`** (default **15**) so rare classes do not explode the loss.
|
| 303 |
+
|
| 304 |
+
---
|
| 305 |
+
|
| 306 |
+
## 11. Metrics
|
| 307 |
+
|
| 308 |
+
**File:** `desert_segmentation/metrics/iou.py`
|
| 309 |
+
|
| 310 |
+
1. **Confusion matrix** `C×C` (implementation uses `idx = tgt * C + pred` then `bincount`; rows correspond to **ground-truth class**, columns to **predicted class**).
|
| 311 |
+
2. **Per-class IoU:**
|
| 312 |
+
\(\text{IoU}_k = \frac{TP_k}{TP_k + FP_k + FN_k}\)
|
| 313 |
+
with `TP_k = CM[k,k]`, row/col sums for FP/FN.
|
| 314 |
+
3. **mIoU:** Mean of per-class IoU over finite entries.
|
| 315 |
+
4. **fwIoU (frequency-weighted IoU):** \(\sum_k \text{IoU}_k \cdot p_k\) where \(p_k\) is the empirical frequency of class \(k\) in the ground-truth pixels (column marginals).
|
| 316 |
+
|
| 317 |
+
**Note:** The docstring in `compute_confusion` mentions “pred rows, target columns”; the actual indexing follows **`tgt` (row) × `C` + `pred` (col)`** after reshape.
|
| 318 |
+
|
| 319 |
+
---
|
| 320 |
+
|
| 321 |
+
## 12. Training loop
|
| 322 |
+
|
| 323 |
+
**File:** `desert_segmentation/train/trainer.py`
|
| 324 |
+
|
| 325 |
+
**Optimizer:** AdamW on all parameters.
|
| 326 |
+
|
| 327 |
+
**Learning rate:** `LambdaLR` with:
|
| 328 |
+
|
| 329 |
+
- **Linear warmup** for `warmup_ratio` of total optimizer steps (default **8%**).
|
| 330 |
+
- **Cosine** decay from 1.0 down to `min_ratio` **0.01** (implemented in `_warmup_cosine_lambda`).
|
| 331 |
+
|
| 332 |
+
**AMP (mixed precision):**
|
| 333 |
+
|
| 334 |
+
- Enabled only if `train.amp` is true **and** `torch.cuda.is_available()`.
|
| 335 |
+
- Uses `torch.cuda.amp.autocast` + `GradScaler` when on CUDA.
|
| 336 |
+
- On **CPU**, AMP is off; training uses standard FP32 backward (no scaler).
|
| 337 |
+
|
| 338 |
+
**Gradient clipping:** Global norm clip when `gradient_clip > 0` (default **1.0**).
|
| 339 |
+
|
| 340 |
+
**EMA (optional):**
|
| 341 |
+
|
| 342 |
+
- If `ema.enabled`, after each optimizer step the code maintains a **shadow weight** copy per trainable parameter: exponential decay **0.999** by default.
|
| 343 |
+
- **Each epoch:** Training weights are **deep-copied**; **EMA weights are copied into the model** for validation only; then the training snapshot is **restored** so optimization continues from the non-EMA weights.
|
| 344 |
+
|
| 345 |
+
**Checkpointing:**
|
| 346 |
+
|
| 347 |
+
- Every epoch: `checkpoints/last.pt` (model, optional EMA dict, optimizer, full config, class names).
|
| 348 |
+
- **Best validation mIoU:** `checkpoints/best.pt` (adds `miou`, `per_class_iou`).
|
| 349 |
+
|
| 350 |
+
**Early stopping:** If validation mIoU does not improve for `early_stop_patience` epochs (default **12**), training stops.
|
| 351 |
+
|
| 352 |
+
**Optional smoke flags (`scripts/train.py`):**
|
| 353 |
+
|
| 354 |
+
- `--epochs N` — override epoch count.
|
| 355 |
+
- `--max_train_batches K` — stop each training epoch after `K` batches (debug only; scheduler still advances per batch).
|
| 356 |
+
|
| 357 |
+
**Logging:** `checkpoints/history.json` lists per-epoch `miou` and `fw_iou`.
|
| 358 |
+
|
| 359 |
+
---
|
| 360 |
+
|
| 361 |
+
## 13. Validation and evaluation scripts
|
| 362 |
+
|
| 363 |
+
**Core loop:** `desert_segmentation/train/evaluate.py` runs the model in `eval()` mode, accumulates confusion via `IoUMetrics`, returns a dict.
|
| 364 |
+
|
| 365 |
+
**CLI:** `scripts/eval.py`
|
| 366 |
+
|
| 367 |
+
1. Loads config + builds validation dataset (same codec and val transforms as training).
|
| 368 |
+
2. Loads checkpoint from `--checkpoint`.
|
| 369 |
+
3. **Weight loading priority:** If `ema` dict exists in checkpoint, **EMA tensors are copied into parameters** for evaluation; else `state_dict` from `model` key.
|
| 370 |
+
4. Runs full val loader → logs **mIoU**, **fwIoU**, per-class IoU.
|
| 371 |
+
5. Writes:
|
| 372 |
+
- `eval_outputs/metrics.json` (or `--out_dir`)
|
| 373 |
+
- `confusion.npy`
|
| 374 |
+
- Up to `--max_viz` side-by-side **RGB | GT | Pred** PNGs (`save_triplet` in `utils/viz.py`), with ImageNet denormalization for RGB panels.
|
| 375 |
+
|
| 376 |
+
---
|
| 377 |
+
|
| 378 |
+
## 14. Inference (testing folder, sliding window, TTA, ONNX)
|
| 379 |
+
|
| 380 |
+
**CLI:** `scripts/infer.py`
|
| 381 |
+
|
| 382 |
+
### 14.1 Folder inference
|
| 383 |
+
|
| 384 |
+
- Reads `testing/Color_Images` (or whatever `data.test_images` points to).
|
| 385 |
+
- Loads checkpoint with the same **EMA-first** rule as eval.
|
| 386 |
+
- For each image:
|
| 387 |
+
- If **both** height and width ≤ `tile_size` (512): single forward pass.
|
| 388 |
+
- Else: **sliding window** with stride `tile_size * (1 - overlap)` (default overlap **0.25** → stride **384**).
|
| 389 |
+
- Pads the image with **reflect** padding so tile grid covers corners; crops back to original size.
|
| 390 |
+
- Accumulates **per-class logits** weighted by a **2D Gaussian** (`sigma ∝ tile/3`) so tile borders blend smoothly; final prediction is **`argmax` over classes** per pixel.
|
| 391 |
+
|
| 392 |
+
### 14.2 Test-time augmentation (TTA)
|
| 393 |
+
|
| 394 |
+
If `inference.tta_flip` is true: logits = **0.5 × (logits(x) + unflip(logits(flip(x))))** horizontally.
|
| 395 |
+
|
| 396 |
+
### 14.3 Outputs
|
| 397 |
+
|
| 398 |
+
Under `--out_dir` (default `infer_outputs/`):
|
| 399 |
+
|
| 400 |
+
- `pred_<filename>` — color overlay (prediction tinted on RGB).
|
| 401 |
+
- `triplet_<filename>` — **RGB | blank or GT | Pred** strip (test set has no GT, so middle panel is zeros in current `save_triplet` usage).
|
| 402 |
+
- `latency.txt` — mean milliseconds per image and device string.
|
| 403 |
+
|
| 404 |
+
### 14.4 ONNX
|
| 405 |
+
|
| 406 |
+
`python scripts/infer.py --checkpoint ... --onnx model.onnx` calls `export_onnx`: builds model on **CPU**, dummy input `1×3×512×512`, `torch.onnx.export` with dynamic axes for batch and spatial size.
|
| 407 |
+
|
| 408 |
+
---
|
| 409 |
+
|
| 410 |
+
## 15. Checkpoints and artifacts
|
| 411 |
+
|
| 412 |
+
| Artifact | Contents |
|
| 413 |
+
|----------|----------|
|
| 414 |
+
| `checkpoints/best.pt` | `model`, `ema` (optional), `miou`, `per_class_iou`, `config`, `class_names` |
|
| 415 |
+
| `checkpoints/last.pt` | Latest epoch snapshot + optimizer |
|
| 416 |
+
| `checkpoints/history.json` | List of `{epoch, miou, fw_iou}` |
|
| 417 |
+
| `eval_outputs/*` | `metrics.json`, `confusion.npy`, visualization PNGs |
|
| 418 |
+
| `infer_outputs/*` | Overlays, triplets, `latency.txt` |
|
| 419 |
+
|
| 420 |
+
---
|
| 421 |
+
|
| 422 |
+
## 16. How to run (commands)
|
| 423 |
+
|
| 424 |
+
From the repository root (adjust paths if yours differ).
|
| 425 |
+
|
| 426 |
+
### 16.1 Install
|
| 427 |
+
|
| 428 |
+
```powershell
|
| 429 |
+
python -m pip install -r requirements.txt
|
| 430 |
+
```
|
| 431 |
+
|
| 432 |
+
### 16.2 Train
|
| 433 |
+
|
| 434 |
+
```powershell
|
| 435 |
+
$env:PYTHONPATH="."
|
| 436 |
+
python scripts\train.py --root "d:\codewizard 2.0"
|
| 437 |
+
```
|
| 438 |
+
|
| 439 |
+
Optional:
|
| 440 |
+
|
| 441 |
+
```powershell
|
| 442 |
+
python scripts\train.py --root "d:\codewizard 2.0" --config desert_segmentation\configs\default.yaml --epochs 5 --max_train_batches 50
|
| 443 |
+
```
|
| 444 |
+
|
| 445 |
+
**Imbalanced classes (optional YAML):** set `loss.name` to `focal_ce_dice` for focal + Dice; tune `class_weight_cap`, `rare_class_crop_prob`, and/or `data.weighted_sampler: true` to oversample train images that contain rare classes (scans all train masks once at startup—can take a minute on large sets).
|
| 446 |
+
|
| 447 |
+
### 16.3 Evaluate (validation)
|
| 448 |
+
|
| 449 |
+
```powershell
|
| 450 |
+
python scripts\eval.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --out_dir eval_outputs
|
| 451 |
+
```
|
| 452 |
+
|
| 453 |
+
**Metric summary (no PNGs):** prints **mIoU (all classes)** and **mIoU (classes with val GT)** (the latter ignores absent classes so it is easier to interpret on sparse val labels), **fwIoU**, **global / mean class accuracy**, **val GT pixel counts per class**, and a **per-class IoU / recall** table. Same validation forward pass as `eval.py`. Optional: `--json-out eval_summary.json` (includes `miou_valid_gt_classes`, `val_gt_pixel_counts`).
|
| 454 |
+
|
| 455 |
+
```powershell
|
| 456 |
+
python scripts\eval_summary.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --json-out eval_summary.json
|
| 457 |
+
```
|
| 458 |
+
|
| 459 |
+
To print only **mIoU** and **per-class IoU** stored inside the checkpoint (no GPU eval): `python scripts\eval_summary.py --from-checkpoint-only --checkpoint checkpoints\best.pt`
|
| 460 |
+
|
| 461 |
+
### 16.4 Infer on `testing/Color_Images`
|
| 462 |
+
|
| 463 |
+
```powershell
|
| 464 |
+
python scripts\infer.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --out_dir infer_outputs --limit 20
|
| 465 |
+
```
|
| 466 |
+
|
| 467 |
+
### 16.5 Export ONNX
|
| 468 |
+
|
| 469 |
+
```powershell
|
| 470 |
+
python scripts\infer.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt --onnx model.onnx
|
| 471 |
+
```
|
| 472 |
+
|
| 473 |
+
---
|
| 474 |
+
|
| 475 |
+
## 17. Interactive demo (Gradio)
|
| 476 |
+
|
| 477 |
+
Upload an RGB image in the browser and get a **colored class mask**, **overlay**, a **side-by-side strip** (RGB | mask | overlay), a **fixed legend** (colors match `palette()` in training), **inference time**, and **dominant classes** (pixel histogram). Uses the same path as CLI inference: [`_load_model_for_inference`](d:\codewizard 2.0\desert_segmentation\infer\predict.py) and [`predict_image`](d:\codewizard 2.0\desert_segmentation\infer\predict.py) (EMA weights preferred when present in the checkpoint).
|
| 478 |
+
|
| 479 |
+
**Install** (base + demo extras):
|
| 480 |
+
|
| 481 |
+
```powershell
|
| 482 |
+
python -m pip install -r requirements.txt -r requirements-demo.txt
|
| 483 |
+
```
|
| 484 |
+
|
| 485 |
+
**Run** (from repo root; model loads **once** at startup — look for a log line `Model ready`):
|
| 486 |
+
|
| 487 |
+
```powershell
|
| 488 |
+
$env:PYTHONPATH="."
|
| 489 |
+
python scripts\demo_gradio.py --root "d:\codewizard 2.0" --checkpoint checkpoints\best.pt
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
**CLI flags:** `--host` (default `127.0.0.1`), `--port` (default `7860`), `--share` (temporary public Gradio link), `--max-side`, `--max-megapixels` (reject huge uploads before inference).
|
| 493 |
+
|
| 494 |
+
**Environment variables** (optional defaults if flags omitted):
|
| 495 |
+
|
| 496 |
+
| Variable | Purpose |
|
| 497 |
+
|----------|---------|
|
| 498 |
+
| `ROOT` | Workspace root (same as `--root`) |
|
| 499 |
+
| `CHECKPOINT_PATH` | Path to `best.pt` (relative paths resolve under `ROOT`) |
|
| 500 |
+
|
| 501 |
+
**Advanced panel:** TTA on/off, tile overlap slider, tile size slider (256–2048, step 64). Overrides are passed into `predict_image` only; the checkpoint file is not modified.
|
| 502 |
+
|
| 503 |
+
**v1 limitations:** No per-pixel **confidence heatmap** for full sliding-window runs (only `argmax` is returned from `predict_image`). See plan follow-up to add logits fusion if needed.
|
| 504 |
+
|
| 505 |
+
**Windows:** Use backslashes or quoted paths as above; first launch may be slow while dependencies initialize.
|
| 506 |
+
|
| 507 |
+
**Follow-ups (not in v1):** full-resolution **confidence** heatmap (needs logits path in `predict.py`); **ZIP** batch upload; **two-checkpoint** comparison UI; client-side **ONNX** inference.
|
| 508 |
+
|
| 509 |
+
---
|
| 510 |
+
|
| 511 |
+
## 18. Tests
|
| 512 |
+
|
| 513 |
+
```powershell
|
| 514 |
+
python -m pytest tests\test_mask_encoding.py -q
|
| 515 |
+
```
|
| 516 |
+
|
| 517 |
+
Covers:
|
| 518 |
+
|
| 519 |
+
- Round-trip **raw mask ↔ class indices** for known IDs.
|
| 520 |
+
- **Unknown raw pixel** raises `ValueError`.
|
| 521 |
+
- LUT correctness for each configured raw id.
|
| 522 |
+
|
| 523 |
+
---
|
| 524 |
+
|
| 525 |
+
## 19. Dependencies and environment notes
|
| 526 |
+
|
| 527 |
+
**`requirements.txt`:**
|
| 528 |
+
|
| 529 |
+
- `torch`, `torchvision`, `numpy`, `Pillow`, `PyYAML`
|
| 530 |
+
- `albumentations` pinned to `<1.5` to reduce optional native build issues on some Windows setups
|
| 531 |
+
- `segmentation-models-pytorch` (SMP)
|
| 532 |
+
- `tqdm`, `pytest`
|
| 533 |
+
- Optional demo: `requirements-demo.txt` adds **Gradio**
|
| 534 |
+
|
| 535 |
+
**Windows:** `scripts/train.py` and `scripts/eval.py` set `num_workers=0` for `DataLoader` on NT to avoid multiprocessing friction.
|
| 536 |
+
|
| 537 |
+
**SMP pretrained weights:** First run may download encoder weights (e.g. ResNet-50 ImageNet) via SMP / Hugging Face hubs depending on SMP version.
|
| 538 |
+
|
| 539 |
+
---
|
| 540 |
+
|
| 541 |
+
## 20. Design decisions and limitations
|
| 542 |
+
|
| 543 |
+
| Topic | Decision / limitation |
|
| 544 |
+
|-------|------------------------|
|
| 545 |
+
| Mask modes | **16-bit raw IDs** supported via LUT; **P-mode palette** and **RGB color masks** are *not* auto-detected in this codebase—extend `mask_encoding.py` if your dataset uses them |
|
| 546 |
+
| SegFormer | **Not** a separate `architecture` enum; plan mentioned SegFormer-B2 as an alternative—would require additional factory code or using a supported SMP encoder |
|
| 547 |
+
| Val resolution | Images are **letterboxed** to 512×512 for batching; mIoU is on padded regions with ignore—fine for hackathon; for publication-grade eval consider sliding-window val too |
|
| 548 |
+
| Inference fusion | Overlapping tiles add **Gaussian-weighted logits** per class into an accumulator; the final label is **`argmax` over the accumulated logits** (feathered overlap fusion). A per-pixel `weight` tensor is also accumulated in code for possible future normalization extensions |
|
| 549 |
+
| Poly LR / sync BN | **Not** implemented (cosine+warmup only) |
|
| 550 |
+
| Ensemble | **Not** implemented (single model + optional EMA) |
|
| 551 |
+
|
| 552 |
+
---
|
| 553 |
+
|
| 554 |
+
## 21. Extending the project
|
| 555 |
+
|
| 556 |
+
1. **New classes / raw IDs:** Edit `data.raw_ids` and `data.class_names` in YAML; rerun frequency scan is automatic in `train.py`.
|
| 557 |
+
2. **UNet / FPN:** Set `model.architecture` to `unet` or `fpn`; pick a valid `encoder_name` for SMP.
|
| 558 |
+
3. **Larger encoder:** e.g. `encoder_name: resnet101` for DeepLabV3+.
|
| 559 |
+
4. **Loss ablation:** Set `loss.name` to `ce`, `weighted_ce`, `focal_ce`, or `focal_ce_dice`; tune `dice_weight`, `label_smoothing`, `class_weight_cap`.
|
| 560 |
+
5. **Stronger aug:** Add Albumentations ops in `transforms.py` (keep `additional_targets={"mask":"mask"}` for paired geometry).
|
| 561 |
+
|
| 562 |
+
---
|
| 563 |
+
|
| 564 |
+
## 22. Flowcharts
|
| 565 |
+
|
| 566 |
+
### 22.1 Training epoch (simplified)
|
| 567 |
+
|
| 568 |
+
```mermaid
|
| 569 |
+
flowchart TD
|
| 570 |
+
start[Start epoch]
|
| 571 |
+
trainLoop[For each batch]
|
| 572 |
+
fwd[Forward logits]
|
| 573 |
+
lossStep[Compute loss CE plus Dice]
|
| 574 |
+
backward[Backward plus clip]
|
| 575 |
+
stepOpt[Optimizer step plus scheduler step]
|
| 576 |
+
emaUp[Update EMA if enabled]
|
| 577 |
+
endTrain[End train batches]
|
| 578 |
+
snap[Snapshot model weights]
|
| 579 |
+
applyEMA[Copy EMA into model if enabled]
|
| 580 |
+
valRun[Run validation mIoU]
|
| 581 |
+
restore[Restore snapshot weights]
|
| 582 |
+
better{New best mIoU?}
|
| 583 |
+
saveBest[Save best.pt]
|
| 584 |
+
early{Patience exceeded?}
|
| 585 |
+
stop[Stop training]
|
| 586 |
+
start --> trainLoop
|
| 587 |
+
trainLoop --> fwd --> lossStep --> backward --> stepOpt --> emaUp
|
| 588 |
+
emaUp --> trainLoop
|
| 589 |
+
trainLoop --> endTrain
|
| 590 |
+
endTrain --> snap --> applyEMA --> valRun --> restore --> better
|
| 591 |
+
better -->|yes| saveBest --> early
|
| 592 |
+
better -->|no| early
|
| 593 |
+
early -->|yes| stop
|
| 594 |
+
early -->|no| start
|
| 595 |
+
```
|
| 596 |
+
|
| 597 |
+
### 22.2 Inference on large images
|
| 598 |
+
|
| 599 |
+
```mermaid
|
| 600 |
+
flowchart LR
|
| 601 |
+
img[Input RGB HxW]
|
| 602 |
+
pad[Reflect pad to tile grid]
|
| 603 |
+
tiles[For each tile]
|
| 604 |
+
fwdT[Forward logits optional TTA]
|
| 605 |
+
g[Multiply by Gaussian feather]
|
| 606 |
+
acc[Accumulate class logits maps]
|
| 607 |
+
argmax[Argmax over classes]
|
| 608 |
+
cropBack[Crop to original HxW]
|
| 609 |
+
img --> pad --> tiles --> fwdT --> g --> acc --> argmax --> cropBack
|
| 610 |
+
```
|
| 611 |
+
|
| 612 |
+
---
|
| 613 |
+
|
| 614 |
+
## Acknowledgments
|
| 615 |
+
|
| 616 |
+
- **segmentation_models_pytorch** (Pavel Iakubovskii and contributors) for modular segmentation architectures.
|
| 617 |
+
- **Albumentations** for fast, paired image–mask augmentations.
|
| 618 |
+
|
| 619 |
+
---
|
| 620 |
+
|
| 621 |
+
*Generated to document the implementation in this repository as of the README authoring date. For the original hackathon planning narrative, see your separate plan document (not stored in this repo’s `README`).*
|
desert_segmentation/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Desert semantic segmentation training and inference package."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
desert_segmentation/configs/default.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Paths are relative to `root` unless absolute.
|
| 2 |
+
root: "."
|
| 3 |
+
|
| 4 |
+
data:
|
| 5 |
+
train_images: "training/train/Color_Images"
|
| 6 |
+
train_masks: "training/train/Segmentation"
|
| 7 |
+
val_images: "training/val/Color_Images"
|
| 8 |
+
val_masks: "training/val/Segmentation"
|
| 9 |
+
test_images: "testing/Color_Images"
|
| 10 |
+
|
| 11 |
+
# Raw uint16 label IDs (must match PNG values) and display names
|
| 12 |
+
raw_ids: [100, 200, 300, 500, 550, 600, 700, 800, 7100, 10000]
|
| 13 |
+
class_names:
|
| 14 |
+
- "id_100"
|
| 15 |
+
- "id_200"
|
| 16 |
+
- "id_300"
|
| 17 |
+
- "id_500"
|
| 18 |
+
- "id_550"
|
| 19 |
+
- "id_600"
|
| 20 |
+
- "id_700"
|
| 21 |
+
- "id_800"
|
| 22 |
+
- "id_7100"
|
| 23 |
+
- "id_10000"
|
| 24 |
+
|
| 25 |
+
crop_size: 512
|
| 26 |
+
num_workers: 4
|
| 27 |
+
# Prefer crops containing underrepresented classes (probability 0–1). Higher = more
|
| 28 |
+
# training crops centered on rare-class pixels (see SegmentationDataset).
|
| 29 |
+
rare_class_crop_prob: 0.35
|
| 30 |
+
# Oversample images that contain rare classes (scans train masks at startup).
|
| 31 |
+
weighted_sampler: false
|
| 32 |
+
weighted_sampler_eps: 1.0e-6
|
| 33 |
+
ignore_index: 255
|
| 34 |
+
|
| 35 |
+
model:
|
| 36 |
+
architecture: "deeplabv3plus"
|
| 37 |
+
encoder_name: "resnet50"
|
| 38 |
+
encoder_weights: "imagenet"
|
| 39 |
+
# Alternative: mit_b2 with deeplabv3plus if supported by SMP
|
| 40 |
+
# encoder_name: "mit_b2"
|
| 41 |
+
|
| 42 |
+
train:
|
| 43 |
+
batch_size: 4
|
| 44 |
+
epochs: 40
|
| 45 |
+
lr: 0.0003
|
| 46 |
+
weight_decay: 0.0005
|
| 47 |
+
warmup_ratio: 0.08
|
| 48 |
+
amp: true
|
| 49 |
+
gradient_clip: 1.0
|
| 50 |
+
seed: 42
|
| 51 |
+
checkpoint_dir: "checkpoints"
|
| 52 |
+
log_interval: 20
|
| 53 |
+
early_stop_patience: 12
|
| 54 |
+
|
| 55 |
+
loss:
|
| 56 |
+
# ce | weighted_ce | ce_dice | focal_ce | focal_ce_dice
|
| 57 |
+
name: "focal_ce_dice"
|
| 58 |
+
dice_weight: 0.5
|
| 59 |
+
# Used only for CE-based modes (ce, weighted_ce, ce_dice). Ignored for focal_ce / focal_ce_dice.
|
| 60 |
+
label_smoothing: 0.05
|
| 61 |
+
# Inverse log-frequency class weights; ratio clamped to [1/cap, cap] vs median. Typical cap 5–25;
|
| 62 |
+
# higher = stronger upweight for rare classes (watch for instability).
|
| 63 |
+
class_weight_cap: 15.0
|
| 64 |
+
focal_gamma: 2.0
|
| 65 |
+
|
| 66 |
+
augmentation:
|
| 67 |
+
strong: true
|
| 68 |
+
|
| 69 |
+
ema:
|
| 70 |
+
enabled: true
|
| 71 |
+
decay: 0.999
|
| 72 |
+
|
| 73 |
+
inference:
|
| 74 |
+
tile_size: 512
|
| 75 |
+
overlap: 0.25
|
| 76 |
+
tta_flip: true
|
| 77 |
+
batch_size: 1
|
desert_segmentation/data/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from desert_segmentation.data.dataset import SegmentationDataset
|
| 2 |
+
from desert_segmentation.data.mask_encoding import RawMaskCodec
|
| 3 |
+
|
| 4 |
+
__all__ = ["SegmentationDataset", "RawMaskCodec"]
|
desert_segmentation/data/dataset.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image / mask dataset with optional rare-class biased cropping."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Callable, List, Optional, Sequence, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
|
| 15 |
+
from desert_segmentation.data.mask_encoding import RawMaskCodec
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _list_images(dir_path: Path) -> List[str]:
|
| 19 |
+
exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
|
| 20 |
+
return sorted(
|
| 21 |
+
f for f in os.listdir(dir_path) if Path(f).suffix.lower() in exts
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SegmentationDataset(Dataset):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
images_dir: Path,
|
| 29 |
+
masks_dir: Path,
|
| 30 |
+
codec: RawMaskCodec,
|
| 31 |
+
transform: Optional[Callable] = None,
|
| 32 |
+
mode: str = "train",
|
| 33 |
+
crop_size: int = 512,
|
| 34 |
+
rare_class_crop_prob: float = 0.35,
|
| 35 |
+
ignore_index: int = 255,
|
| 36 |
+
seed: int = 42,
|
| 37 |
+
) -> None:
|
| 38 |
+
self.images_dir = Path(images_dir)
|
| 39 |
+
self.masks_dir = Path(masks_dir)
|
| 40 |
+
self.codec = codec
|
| 41 |
+
self.transform = transform
|
| 42 |
+
self.mode = mode
|
| 43 |
+
self.crop_size = crop_size
|
| 44 |
+
self.rare_class_crop_prob = rare_class_crop_prob if mode == "train" else 0.0
|
| 45 |
+
self.ignore_index = ignore_index
|
| 46 |
+
self._rng = random.Random(seed)
|
| 47 |
+
|
| 48 |
+
names = _list_images(self.images_dir)
|
| 49 |
+
self._pairs: List[Tuple[str, str]] = []
|
| 50 |
+
for n in names:
|
| 51 |
+
mp = self.masks_dir / n
|
| 52 |
+
if not mp.is_file():
|
| 53 |
+
raise FileNotFoundError(f"Missing mask for {n}: {mp}")
|
| 54 |
+
self._pairs.append((str(self.images_dir / n), str(mp)))
|
| 55 |
+
|
| 56 |
+
if not self._pairs:
|
| 57 |
+
raise RuntimeError(f"No images in {self.images_dir}")
|
| 58 |
+
|
| 59 |
+
def __len__(self) -> int:
|
| 60 |
+
return len(self._pairs)
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def image_names(self) -> List[str]:
|
| 64 |
+
"""Basenames aligned with dataset indices (for weighted sampling)."""
|
| 65 |
+
return [Path(p[0]).name for p in self._pairs]
|
| 66 |
+
|
| 67 |
+
def _load_pair(self, ip: str, mp: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 68 |
+
image = np.array(Image.open(ip).convert("RGB"))
|
| 69 |
+
raw_mask = np.array(Image.open(mp))
|
| 70 |
+
if raw_mask.ndim == 2:
|
| 71 |
+
enc, _ = self.codec.encode_mask(raw_mask.astype(np.uint16))
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Expected single-channel mask, got shape {raw_mask.shape}")
|
| 74 |
+
return image, enc
|
| 75 |
+
|
| 76 |
+
def _random_crop_bias_rare(
|
| 77 |
+
self, image: np.ndarray, mask: np.ndarray
|
| 78 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 79 |
+
h, w = image.shape[:2]
|
| 80 |
+
ch, cw = self.crop_size, self.crop_size
|
| 81 |
+
if h < ch or w < cw:
|
| 82 |
+
pad_h = max(0, ch - h)
|
| 83 |
+
pad_w = max(0, cw - w)
|
| 84 |
+
image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode="constant")
|
| 85 |
+
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=self.ignore_index)
|
| 86 |
+
h, w = image.shape[:2]
|
| 87 |
+
|
| 88 |
+
if self.mode == "train" and self._rng.random() < self.rare_class_crop_prob:
|
| 89 |
+
hist, _ = np.histogram(mask.flatten(), bins=self.codec.num_classes, range=(0, self.codec.num_classes))
|
| 90 |
+
rare = int(np.argmin(hist))
|
| 91 |
+
ys, xs = np.where(mask == rare)
|
| 92 |
+
if len(xs) > 0:
|
| 93 |
+
idx = self._rng.randrange(len(xs))
|
| 94 |
+
cx, cy = int(xs[idx]), int(ys[idx])
|
| 95 |
+
else:
|
| 96 |
+
cx, cy = w // 2, h // 2
|
| 97 |
+
else:
|
| 98 |
+
cx, cy = self._rng.randrange(w), self._rng.randrange(h)
|
| 99 |
+
|
| 100 |
+
x0 = np.clip(cx - cw // 2, 0, w - cw)
|
| 101 |
+
y0 = np.clip(cy - ch // 2, 0, h - ch)
|
| 102 |
+
return image[y0 : y0 + ch, x0 : x0 + cw], mask[y0 : y0 + ch, x0 : x0 + cw]
|
| 103 |
+
|
| 104 |
+
def __getitem__(self, idx: int) -> dict:
|
| 105 |
+
ip, mp = self._pairs[idx]
|
| 106 |
+
image, mask = self._load_pair(ip, mp)
|
| 107 |
+
|
| 108 |
+
if self.mode == "train":
|
| 109 |
+
image, mask = self._random_crop_bias_rare(image, mask)
|
| 110 |
+
|
| 111 |
+
if self.transform is not None:
|
| 112 |
+
t = self.transform(image=image, mask=mask)
|
| 113 |
+
image = t["image"]
|
| 114 |
+
mask = t["mask"]
|
| 115 |
+
|
| 116 |
+
if isinstance(mask, torch.Tensor):
|
| 117 |
+
mask_t = mask
|
| 118 |
+
else:
|
| 119 |
+
mask_t = torch.from_numpy(np.asarray(mask))
|
| 120 |
+
if mask_t.dtype in (torch.float32, torch.float16):
|
| 121 |
+
mask_t = (mask_t * 255.0).round().clamp(0, 255).long()
|
| 122 |
+
else:
|
| 123 |
+
mask_t = mask_t.long()
|
| 124 |
+
|
| 125 |
+
return {
|
| 126 |
+
"image": image,
|
| 127 |
+
"mask": mask_t,
|
| 128 |
+
"path": ip,
|
| 129 |
+
}
|
desert_segmentation/data/mask_encoding.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decode 16-bit raw mask values to contiguous class indices and back."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Dict, List, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class RawMaskCodec:
|
| 13 |
+
"""Maps dataset-specific raw label IDs (e.g. uint16 PNG values) to 0..num_classes-1."""
|
| 14 |
+
|
| 15 |
+
raw_ids: Tuple[int, ...]
|
| 16 |
+
class_names: Tuple[str, ...]
|
| 17 |
+
|
| 18 |
+
def __post_init__(self) -> None:
|
| 19 |
+
if len(self.raw_ids) != len(self.class_names):
|
| 20 |
+
raise ValueError("raw_ids and class_names must have the same length")
|
| 21 |
+
if len(set(self.raw_ids)) != len(self.raw_ids):
|
| 22 |
+
raise ValueError("raw_ids must be unique")
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def num_classes(self) -> int:
|
| 26 |
+
return len(self.raw_ids)
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def raw_to_index(self) -> Dict[int, int]:
|
| 30 |
+
return {r: i for i, r in enumerate(self.raw_ids)}
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def index_to_raw(self) -> Dict[int, int]:
|
| 34 |
+
return {i: r for i, r in enumerate(self.raw_ids)}
|
| 35 |
+
|
| 36 |
+
def _build_lut(self) -> np.ndarray:
|
| 37 |
+
max_id = max(self.raw_ids)
|
| 38 |
+
lut = np.full(max_id + 1, 255, dtype=np.uint8)
|
| 39 |
+
for i, rid in enumerate(self.raw_ids):
|
| 40 |
+
lut[rid] = i
|
| 41 |
+
return lut
|
| 42 |
+
|
| 43 |
+
def encode_mask(self, raw: np.ndarray) -> Tuple[np.ndarray, float]:
|
| 44 |
+
"""Map raw uint16 labels to uint8 class indices 0..C-1. Returns (encoded, unknown_fraction)."""
|
| 45 |
+
if raw.ndim != 2:
|
| 46 |
+
raise ValueError(f"Expected HxW mask, got shape {raw.shape}")
|
| 47 |
+
lut = self._build_lut()
|
| 48 |
+
if int(raw.max()) >= lut.size:
|
| 49 |
+
raise ValueError(f"Mask value {int(raw.max())} exceeds LUT; extend raw_ids in config.")
|
| 50 |
+
out = lut[raw.astype(np.int64, copy=False)]
|
| 51 |
+
unknown_frac = float((out == 255).mean())
|
| 52 |
+
if unknown_frac > 0:
|
| 53 |
+
bad = out == 255
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f"Unknown mask pixels: {unknown_frac:.6f} of image. "
|
| 56 |
+
f"Unique unknown raw values: {np.unique(raw[bad])[:16]}"
|
| 57 |
+
)
|
| 58 |
+
return out.astype(np.uint8), unknown_frac
|
| 59 |
+
|
| 60 |
+
def decode_to_raw(self, class_indices: np.ndarray) -> np.ndarray:
|
| 61 |
+
"""Map class indices back to raw dataset IDs (for visualization/export)."""
|
| 62 |
+
arr = np.asarray(class_indices)
|
| 63 |
+
raw = np.zeros_like(arr, dtype=np.uint16)
|
| 64 |
+
for i, rid in enumerate(self.raw_ids):
|
| 65 |
+
raw[arr == i] = rid
|
| 66 |
+
return raw
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def build_codec_from_config(raw_id_list: Sequence[int], names: Sequence[str]) -> RawMaskCodec:
|
| 70 |
+
pairs = sorted(zip(raw_id_list, names), key=lambda x: x[0])
|
| 71 |
+
r, n = zip(*pairs)
|
| 72 |
+
return RawMaskCodec(raw_ids=tuple(int(x) for x in r), class_names=tuple(str(x) for x in n))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def default_desert_codec() -> RawMaskCodec:
|
| 76 |
+
"""Default codec for this workspace: 10 classes with fixed raw IDs (see dataset scan)."""
|
| 77 |
+
raw_ids = (100, 200, 300, 500, 550, 600, 700, 800, 7100, 10000)
|
| 78 |
+
names = (
|
| 79 |
+
"id_100",
|
| 80 |
+
"id_200",
|
| 81 |
+
"id_300",
|
| 82 |
+
"id_500",
|
| 83 |
+
"id_550",
|
| 84 |
+
"id_600",
|
| 85 |
+
"id_700",
|
| 86 |
+
"id_800",
|
| 87 |
+
"id_7100",
|
| 88 |
+
"id_10000",
|
| 89 |
+
)
|
| 90 |
+
return RawMaskCodec(raw_ids=raw_ids, class_names=names)
|
desert_segmentation/data/transforms.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Albumentations pipelines for images and class masks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
+
|
| 7 |
+
import albumentations as A
|
| 8 |
+
from albumentations.pytorch import ToTensorV2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _base_normalize() -> A.Normalize:
|
| 12 |
+
return A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def build_train_transforms(
|
| 16 |
+
crop_size: int,
|
| 17 |
+
strong: bool = True,
|
| 18 |
+
ignore_index: int = 255,
|
| 19 |
+
) -> A.Compose:
|
| 20 |
+
"""Spatial crops are applied in `SegmentationDataset` (with rare-class bias)."""
|
| 21 |
+
del crop_size
|
| 22 |
+
geometric: list[Any] = [
|
| 23 |
+
A.HorizontalFlip(p=0.5),
|
| 24 |
+
A.ShiftScaleRotate(
|
| 25 |
+
shift_limit=0.02,
|
| 26 |
+
scale_limit=0.12,
|
| 27 |
+
rotate_limit=10,
|
| 28 |
+
border_mode=0,
|
| 29 |
+
mask_value=ignore_index,
|
| 30 |
+
p=0.55,
|
| 31 |
+
),
|
| 32 |
+
]
|
| 33 |
+
color: list[Any] = [
|
| 34 |
+
A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.55),
|
| 35 |
+
A.HueSaturationValue(hue_shift_limit=14, sat_shift_limit=22, val_shift_limit=14, p=0.4),
|
| 36 |
+
A.GaussianBlur(blur_limit=(3, 5), p=0.22),
|
| 37 |
+
A.GaussNoise(var_limit=(8.0, 48.0), p=0.25),
|
| 38 |
+
A.ImageCompression(quality_lower=70, quality_upper=100, p=0.25),
|
| 39 |
+
A.RGBShift(r_shift_limit=18, g_shift_limit=18, b_shift_limit=18, p=0.28),
|
| 40 |
+
]
|
| 41 |
+
if strong:
|
| 42 |
+
color.extend(
|
| 43 |
+
[
|
| 44 |
+
A.RandomSunFlare(
|
| 45 |
+
flare_roi=(0.45, 0.0, 1.0, 0.42),
|
| 46 |
+
angle_lower=0.4,
|
| 47 |
+
p=0.12,
|
| 48 |
+
),
|
| 49 |
+
A.RandomShadow(
|
| 50 |
+
shadow_roi=(0, 0.42, 1, 1),
|
| 51 |
+
num_shadows_lower=1,
|
| 52 |
+
num_shadows_upper=2,
|
| 53 |
+
p=0.16,
|
| 54 |
+
),
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
return A.Compose(
|
| 58 |
+
geometric + color + [_base_normalize(), ToTensorV2()],
|
| 59 |
+
additional_targets={"mask": "mask"},
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def build_val_transforms(
|
| 64 |
+
crop_size: int,
|
| 65 |
+
ignore_index: int = 255,
|
| 66 |
+
) -> A.Compose:
|
| 67 |
+
return A.Compose(
|
| 68 |
+
[
|
| 69 |
+
A.LongestMaxSize(max_size=crop_size),
|
| 70 |
+
A.PadIfNeeded(
|
| 71 |
+
min_height=crop_size,
|
| 72 |
+
min_width=crop_size,
|
| 73 |
+
border_mode=0,
|
| 74 |
+
value=0,
|
| 75 |
+
mask_value=ignore_index,
|
| 76 |
+
),
|
| 77 |
+
_base_normalize(),
|
| 78 |
+
ToTensorV2(),
|
| 79 |
+
],
|
| 80 |
+
additional_targets={"mask": "mask"},
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def apply_transform(
|
| 85 |
+
transform: A.Compose,
|
| 86 |
+
image,
|
| 87 |
+
mask,
|
| 88 |
+
) -> Tuple[Any, Any]:
|
| 89 |
+
out = transform(image=image, mask=mask)
|
| 90 |
+
return out["image"], out["mask"]
|
desert_segmentation/demo/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from desert_segmentation.demo.inference_ui import (
|
| 2 |
+
build_legend_rows,
|
| 3 |
+
dominant_classes_markdown,
|
| 4 |
+
legend_table_html,
|
| 5 |
+
side_by_side_strip,
|
| 6 |
+
validate_rgb_array,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"build_legend_rows",
|
| 11 |
+
"dominant_classes_markdown",
|
| 12 |
+
"legend_table_html",
|
| 13 |
+
"side_by_side_strip",
|
| 14 |
+
"validate_rgb_array",
|
| 15 |
+
]
|
desert_segmentation/demo/inference_ui.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for Gradio / web demo: legend, validation, composites."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import html
|
| 6 |
+
from typing import Any, Dict, List, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from desert_segmentation.utils.viz import palette
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def validate_rgb_array(
|
| 15 |
+
arr: np.ndarray,
|
| 16 |
+
max_side: int = 4096,
|
| 17 |
+
max_megapixels: float = 16.0,
|
| 18 |
+
) -> None:
|
| 19 |
+
"""Raises ValueError with a user-facing message if invalid or too large."""
|
| 20 |
+
if arr is None:
|
| 21 |
+
raise ValueError("No image provided.")
|
| 22 |
+
if not isinstance(arr, np.ndarray):
|
| 23 |
+
arr = np.asarray(arr)
|
| 24 |
+
if arr.ndim != 3 or arr.shape[2] != 3:
|
| 25 |
+
raise ValueError(f"Expected RGB image HxWx3, got shape {getattr(arr, 'shape', None)}")
|
| 26 |
+
h, w = arr.shape[0], arr.shape[1]
|
| 27 |
+
if h < 1 or w < 1:
|
| 28 |
+
raise ValueError("Image is empty.")
|
| 29 |
+
if max(h, w) > max_side:
|
| 30 |
+
raise ValueError(f"Image too large: max side is {max_side}px (got {h}x{w}).")
|
| 31 |
+
mp = (h * w) / 1_000_000.0
|
| 32 |
+
if mp > max_megapixels:
|
| 33 |
+
raise ValueError(f"Image too large: max {max_megapixels} megapixels (got {mp:.1f} MP).")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_legend_rows(class_names: Sequence[str], num_classes: int, seed: int = 42) -> Tuple[List[Dict[str, Any]], np.ndarray]:
|
| 37 |
+
"""Returns list of {index, name, hex, r, g, b} and color table (same seed as viz.palette)."""
|
| 38 |
+
colors = palette(num_classes, seed=seed)
|
| 39 |
+
rows: List[Dict[str, Any]] = []
|
| 40 |
+
for i, name in enumerate(class_names):
|
| 41 |
+
r, g, b = (int(colors[i, 0]), int(colors[i, 1]), int(colors[i, 2]))
|
| 42 |
+
rows.append(
|
| 43 |
+
{
|
| 44 |
+
"index": i,
|
| 45 |
+
"name": str(name),
|
| 46 |
+
"hex": f"#{r:02x}{g:02x}{b:02x}",
|
| 47 |
+
"r": r,
|
| 48 |
+
"g": g,
|
| 49 |
+
"b": b,
|
| 50 |
+
}
|
| 51 |
+
)
|
| 52 |
+
return rows, colors
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def legend_table_html(rows: Sequence[Dict[str, Any]]) -> str:
|
| 56 |
+
"""Small HTML table with color swatches for Gradio gr.HTML."""
|
| 57 |
+
parts = [
|
| 58 |
+
"<table style='border-collapse:collapse;font-size:14px'>",
|
| 59 |
+
"<thead><tr><th>Swatch</th><th>#</th><th>Name</th><th>Hex</th></tr></thead><tbody>",
|
| 60 |
+
]
|
| 61 |
+
for row in rows:
|
| 62 |
+
sw = f"background-color:{row['hex']};width:32px;height:22px;border:1px solid #888"
|
| 63 |
+
safe_name = html.escape(str(row["name"]))
|
| 64 |
+
parts.append(
|
| 65 |
+
f"<tr><td><div style='{sw}'></div></td>"
|
| 66 |
+
f"<td>{row['index']}</td><td>{safe_name}</td><td><code>{row['hex']}</code></td></tr>"
|
| 67 |
+
)
|
| 68 |
+
parts.append("</tbody></table>")
|
| 69 |
+
return "".join(parts)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def dominant_classes_markdown(pred: np.ndarray, class_names: Sequence[str], top_k: int = 3) -> str:
|
| 73 |
+
flat = pred.reshape(-1).astype(np.int64, copy=False)
|
| 74 |
+
n = len(class_names)
|
| 75 |
+
counts = np.bincount(flat, minlength=n)
|
| 76 |
+
total = int(counts.sum())
|
| 77 |
+
if total == 0:
|
| 78 |
+
return "_No pixels._"
|
| 79 |
+
order = np.argsort(-counts)
|
| 80 |
+
lines: List[str] = []
|
| 81 |
+
for i in order[:top_k]:
|
| 82 |
+
c = int(counts[i])
|
| 83 |
+
if c == 0:
|
| 84 |
+
continue
|
| 85 |
+
pct = 100.0 * c / total
|
| 86 |
+
name = class_names[i] if i < len(class_names) else str(i)
|
| 87 |
+
lines.append(f"- **{name}** (class {i}): **{pct:.1f}%**")
|
| 88 |
+
return "\n".join(lines) if lines else "_No dominant classes._"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def side_by_side_strip(rgb: np.ndarray, mask_rgb: np.ndarray, overlay_rgb: np.ndarray, gap: int = 8) -> np.ndarray:
|
| 92 |
+
"""Horizontal strip: RGB | colored mask | overlay."""
|
| 93 |
+
h, w = rgb.shape[:2]
|
| 94 |
+
gap_arr = np.zeros((h, gap, 3), dtype=np.uint8)
|
| 95 |
+
return np.concatenate([rgb, gap_arr, mask_rgb, gap_arr, overlay_rgb], axis=1)
|
desert_segmentation/infer/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from desert_segmentation.infer.predict import predict_image, predict_folder
|
| 2 |
+
|
| 3 |
+
__all__ = ["predict_image", "predict_folder"]
|
desert_segmentation/infer/predict.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sliding-window inference with optional horizontal-flip TTA and ONNX export helper."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from desert_segmentation.data.mask_encoding import RawMaskCodec, build_codec_from_config
|
| 18 |
+
from desert_segmentation.models.factory import create_model
|
| 19 |
+
from desert_segmentation.utils.viz import blend_overlay, colorize_mask, palette, save_triplet
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _gaussian_2d(h: int, w: int) -> np.ndarray:
|
| 25 |
+
yy, xx = np.ogrid[:h, :w]
|
| 26 |
+
cy, cx = (h - 1) / 2.0, (w - 1) / 2.0
|
| 27 |
+
sig = min(h, w) / 3.0
|
| 28 |
+
g = np.exp(-(((yy - cy) ** 2 + (xx - cx) ** 2) / (2.0 * sig**2)))
|
| 29 |
+
return g.astype(np.float32)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _preprocess(
|
| 33 |
+
rgb: np.ndarray,
|
| 34 |
+
mean: Tuple[float, float, float],
|
| 35 |
+
std: Tuple[float, float, float],
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
# Keep float32 end-to-end: np.array(mean) defaults to float64 and would upcast x → conv2d dtype mismatch.
|
| 38 |
+
x = rgb.astype(np.float32, copy=False) / 255.0
|
| 39 |
+
m = np.asarray(mean, dtype=np.float32).reshape(1, 1, 3)
|
| 40 |
+
s = np.asarray(std, dtype=np.float32).reshape(1, 1, 3)
|
| 41 |
+
x = (x - m) / s
|
| 42 |
+
t = torch.from_numpy(np.ascontiguousarray(x)).permute(2, 0, 1).unsqueeze(0)
|
| 43 |
+
return t.float()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def _forward_logits(
|
| 48 |
+
model: nn.Module,
|
| 49 |
+
x: torch.Tensor,
|
| 50 |
+
device: torch.device,
|
| 51 |
+
tta_flip: bool,
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
logits = model(x)
|
| 54 |
+
if not tta_flip:
|
| 55 |
+
return logits
|
| 56 |
+
xf = torch.flip(x, dims=[3])
|
| 57 |
+
lf = model(xf)
|
| 58 |
+
lf = torch.flip(lf, dims=[3])
|
| 59 |
+
return (logits + lf) * 0.5
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _tile_starts(length: int, tile: int, stride: int) -> List[int]:
|
| 63 |
+
if length <= tile:
|
| 64 |
+
return [0]
|
| 65 |
+
last_pos = length - tile
|
| 66 |
+
starts = list(range(0, last_pos + 1, stride))
|
| 67 |
+
if not starts:
|
| 68 |
+
return [0]
|
| 69 |
+
if starts[-1] != last_pos:
|
| 70 |
+
starts.append(last_pos)
|
| 71 |
+
return sorted(set(starts))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@torch.no_grad()
|
| 75 |
+
def predict_image(
|
| 76 |
+
model: nn.Module,
|
| 77 |
+
image_np: np.ndarray,
|
| 78 |
+
device: torch.device,
|
| 79 |
+
tile_size: int,
|
| 80 |
+
overlap: float,
|
| 81 |
+
tta_flip: bool,
|
| 82 |
+
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
|
| 83 |
+
std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
|
| 84 |
+
) -> np.ndarray:
|
| 85 |
+
"""Returns HxW int class map."""
|
| 86 |
+
h, w = image_np.shape[:2]
|
| 87 |
+
if h <= tile_size and w <= tile_size:
|
| 88 |
+
t = _preprocess(image_np, mean, std).to(device)
|
| 89 |
+
logits = _forward_logits(model, t, device, tta_flip)
|
| 90 |
+
return logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.int64)
|
| 91 |
+
|
| 92 |
+
stride = max(1, int(tile_size * (1.0 - overlap)))
|
| 93 |
+
g = _gaussian_2d(tile_size, tile_size)
|
| 94 |
+
|
| 95 |
+
n_ty = len(_tile_starts(h, tile_size, stride))
|
| 96 |
+
n_tx = len(_tile_starts(w, tile_size, stride))
|
| 97 |
+
H_pad = (n_ty - 1) * stride + tile_size
|
| 98 |
+
W_pad = (n_tx - 1) * stride + tile_size
|
| 99 |
+
pad_h = max(0, H_pad - h)
|
| 100 |
+
pad_w = max(0, W_pad - w)
|
| 101 |
+
img_p = np.pad(image_np, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
|
| 102 |
+
H, W = img_p.shape[:2]
|
| 103 |
+
|
| 104 |
+
t0 = _preprocess(img_p[0:tile_size, 0:tile_size], mean, std).to(device)
|
| 105 |
+
logits0 = _forward_logits(model, t0, device, tta_flip)
|
| 106 |
+
num_classes = int(logits0.shape[1])
|
| 107 |
+
acc = np.zeros((num_classes, H, W), dtype=np.float32)
|
| 108 |
+
weight = np.zeros((H, W), dtype=np.float32)
|
| 109 |
+
|
| 110 |
+
for y in _tile_starts(H, tile_size, stride):
|
| 111 |
+
for x in _tile_starts(W, tile_size, stride):
|
| 112 |
+
tile = img_p[y : y + tile_size, x : x + tile_size]
|
| 113 |
+
t = _preprocess(tile, mean, std).to(device)
|
| 114 |
+
logits = _forward_logits(model, t, device, tta_flip)
|
| 115 |
+
probs = torch.softmax(logits, dim=1)
|
| 116 |
+
ls = probs.squeeze(0).cpu().numpy()
|
| 117 |
+
acc[:, y : y + tile_size, x : x + tile_size] += ls * g
|
| 118 |
+
weight[y : y + tile_size, x : x + tile_size] += g
|
| 119 |
+
|
| 120 |
+
pred = np.argmax(acc, axis=0).astype(np.int64)
|
| 121 |
+
return pred[:h, :w]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _load_model_for_inference(
|
| 125 |
+
checkpoint_path: Path,
|
| 126 |
+
device: torch.device,
|
| 127 |
+
) -> Tuple[nn.Module, dict, RawMaskCodec]:
|
| 128 |
+
try:
|
| 129 |
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 130 |
+
except TypeError:
|
| 131 |
+
ckpt = torch.load(checkpoint_path, map_location=device)
|
| 132 |
+
cfg = ckpt["config"]
|
| 133 |
+
raw_ids = cfg["data"]["raw_ids"]
|
| 134 |
+
names = ckpt.get("class_names") or tuple(cfg["data"].get("class_names") or ())
|
| 135 |
+
if not names:
|
| 136 |
+
names = tuple(str(x) for x in raw_ids)
|
| 137 |
+
codec = build_codec_from_config(raw_ids, names)
|
| 138 |
+
model = create_model(cfg["model"], num_classes=codec.num_classes).to(device)
|
| 139 |
+
if ckpt.get("model") is not None:
|
| 140 |
+
model.load_state_dict(ckpt["model"])
|
| 141 |
+
if ckpt.get("ema") is not None:
|
| 142 |
+
for n, p in model.named_parameters():
|
| 143 |
+
if n in ckpt["ema"]:
|
| 144 |
+
p.data.copy_(ckpt["ema"][n].to(device))
|
| 145 |
+
model.eval()
|
| 146 |
+
return model, cfg, codec
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def predict_folder(
|
| 151 |
+
checkpoint_path: Path,
|
| 152 |
+
image_dir: Path,
|
| 153 |
+
out_dir: Path,
|
| 154 |
+
device: Optional[torch.device] = None,
|
| 155 |
+
limit: Optional[int] = None,
|
| 156 |
+
) -> None:
|
| 157 |
+
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 158 |
+
model, cfg, codec = _load_model_for_inference(checkpoint_path, device)
|
| 159 |
+
icfg = cfg.get("inference") or {}
|
| 160 |
+
tile_size = int(icfg.get("tile_size", 512))
|
| 161 |
+
overlap = float(icfg.get("overlap", 0.25))
|
| 162 |
+
tta = bool(icfg.get("tta_flip", True))
|
| 163 |
+
|
| 164 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 165 |
+
colors = palette(codec.num_classes)
|
| 166 |
+
names = sorted(f for f in os.listdir(image_dir) if f.lower().endswith((".png", ".jpg", ".jpeg")))
|
| 167 |
+
if limit is not None:
|
| 168 |
+
names = names[:limit]
|
| 169 |
+
times: List[float] = []
|
| 170 |
+
for name in tqdm(names, desc="infer"):
|
| 171 |
+
ip = image_dir / name
|
| 172 |
+
rgb = np.array(Image.open(ip).convert("RGB"))
|
| 173 |
+
t0 = time.perf_counter()
|
| 174 |
+
pred = predict_image(model, rgb, device, tile_size, overlap, tta)
|
| 175 |
+
times.append(time.perf_counter() - t0)
|
| 176 |
+
overlay = blend_overlay(rgb, colorize_mask(pred, colors))
|
| 177 |
+
Image.fromarray(overlay).save(out_dir / f"pred_{name}")
|
| 178 |
+
save_triplet(out_dir / f"triplet_{name}", rgb, None, pred, colors)
|
| 179 |
+
|
| 180 |
+
if times:
|
| 181 |
+
mean_ms = float(np.mean(times) * 1000.0)
|
| 182 |
+
logger.info("mean inference time: %.2f ms (device=%s)", mean_ms, device)
|
| 183 |
+
with (out_dir / "latency.txt").open("w", encoding="utf-8") as f:
|
| 184 |
+
f.write(f"mean_ms_per_image={mean_ms:.4f}\n")
|
| 185 |
+
f.write(f"device={device}\n")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def export_onnx(
|
| 189 |
+
checkpoint_path: Path,
|
| 190 |
+
out_onnx: Path,
|
| 191 |
+
height: int = 512,
|
| 192 |
+
width: int = 512,
|
| 193 |
+
opset: int = 17,
|
| 194 |
+
) -> None:
|
| 195 |
+
device = torch.device("cpu")
|
| 196 |
+
model, _, _ = _load_model_for_inference(checkpoint_path, device)
|
| 197 |
+
model.eval()
|
| 198 |
+
dummy = torch.randn(1, 3, height, width, device=device)
|
| 199 |
+
torch.onnx.export(
|
| 200 |
+
model,
|
| 201 |
+
dummy,
|
| 202 |
+
str(out_onnx),
|
| 203 |
+
input_names=["input"],
|
| 204 |
+
output_names=["logits"],
|
| 205 |
+
opset_version=opset,
|
| 206 |
+
dynamic_axes={
|
| 207 |
+
"input": {0: "batch", 2: "height", 3: "width"},
|
| 208 |
+
"logits": {0: "batch", 2: "h", 3: "w"},
|
| 209 |
+
},
|
| 210 |
+
)
|
desert_segmentation/losses/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from desert_segmentation.losses.combined import build_loss
|
| 2 |
+
|
| 3 |
+
__all__ = ["build_loss"]
|
desert_segmentation/losses/combined.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Segmentation losses: CE, weighted CE, focal, Dice, and combinations."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _ce(
|
| 13 |
+
logits: torch.Tensor,
|
| 14 |
+
target: torch.Tensor,
|
| 15 |
+
weight: Optional[torch.Tensor],
|
| 16 |
+
ignore_index: int,
|
| 17 |
+
label_smoothing: float,
|
| 18 |
+
) -> torch.Tensor:
|
| 19 |
+
return F.cross_entropy(
|
| 20 |
+
logits,
|
| 21 |
+
target,
|
| 22 |
+
weight=weight,
|
| 23 |
+
ignore_index=ignore_index,
|
| 24 |
+
label_smoothing=label_smoothing,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _focal_ce(
|
| 29 |
+
logits: torch.Tensor,
|
| 30 |
+
target: torch.Tensor,
|
| 31 |
+
gamma: float,
|
| 32 |
+
weight: Optional[torch.Tensor],
|
| 33 |
+
ignore_index: int,
|
| 34 |
+
) -> torch.Tensor:
|
| 35 |
+
log_probs = F.log_softmax(logits, dim=1)
|
| 36 |
+
probs = log_probs.exp()
|
| 37 |
+
tgt = target.clone()
|
| 38 |
+
valid = tgt != ignore_index
|
| 39 |
+
tgt_clamped = tgt.clone()
|
| 40 |
+
tgt_clamped[~valid] = 0
|
| 41 |
+
log_pt = log_probs.gather(1, tgt_clamped.unsqueeze(1)).squeeze(1)
|
| 42 |
+
pt = probs.gather(1, tgt_clamped.unsqueeze(1)).squeeze(1)
|
| 43 |
+
focal = (1 - pt) ** gamma * (-log_pt)
|
| 44 |
+
if weight is not None:
|
| 45 |
+
focal = focal * weight[tgt_clamped]
|
| 46 |
+
focal = focal * valid.float()
|
| 47 |
+
return focal.sum() / (valid.float().sum().clamp_min(1.0))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _multiclass_dice(
|
| 51 |
+
logits: torch.Tensor,
|
| 52 |
+
target: torch.Tensor,
|
| 53 |
+
ignore_index: int,
|
| 54 |
+
eps: float = 1e-6,
|
| 55 |
+
) -> torch.Tensor:
|
| 56 |
+
probs = F.softmax(logits, dim=1)
|
| 57 |
+
n, c, _, _ = probs.shape
|
| 58 |
+
tgt = target
|
| 59 |
+
valid = tgt != ignore_index
|
| 60 |
+
dice_losses = []
|
| 61 |
+
for k in range(c):
|
| 62 |
+
pk = probs[:, k]
|
| 63 |
+
tk = (tgt == k).float()
|
| 64 |
+
m = valid.float()
|
| 65 |
+
pk, tk = pk * m, tk * m
|
| 66 |
+
inter = (pk * tk).sum(dim=(1, 2))
|
| 67 |
+
denom = pk.sum(dim=(1, 2)) + tk.sum(dim=(1, 2)) + eps
|
| 68 |
+
dice = 1.0 - (2.0 * inter + eps) / denom
|
| 69 |
+
dice_losses.append(dice.mean())
|
| 70 |
+
return torch.stack(dice_losses).mean()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class CombinedSegLoss(nn.Module):
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
mode: str,
|
| 77 |
+
num_classes: int,
|
| 78 |
+
ignore_index: int = 255,
|
| 79 |
+
class_weights: Optional[torch.Tensor] = None,
|
| 80 |
+
dice_weight: float = 0.5,
|
| 81 |
+
label_smoothing: float = 0.05,
|
| 82 |
+
focal_gamma: float = 2.0,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.mode = mode
|
| 86 |
+
self.num_classes = num_classes
|
| 87 |
+
self.ignore_index = ignore_index
|
| 88 |
+
self.register_buffer("class_weights", class_weights if class_weights is not None else torch.ones(num_classes))
|
| 89 |
+
self.dice_weight = dice_weight
|
| 90 |
+
self.label_smoothing = label_smoothing
|
| 91 |
+
self.focal_gamma = focal_gamma
|
| 92 |
+
|
| 93 |
+
def forward(self, logits: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
| 94 |
+
w = self.class_weights
|
| 95 |
+
if self.mode == "ce":
|
| 96 |
+
loss = _ce(logits, target, None, self.ignore_index, self.label_smoothing)
|
| 97 |
+
elif self.mode == "weighted_ce":
|
| 98 |
+
loss = _ce(logits, target, w, self.ignore_index, self.label_smoothing)
|
| 99 |
+
elif self.mode == "focal_ce":
|
| 100 |
+
loss = _focal_ce(logits, target, self.focal_gamma, w, self.ignore_index)
|
| 101 |
+
elif self.mode == "ce_dice":
|
| 102 |
+
ce = _ce(logits, target, w, self.ignore_index, self.label_smoothing)
|
| 103 |
+
dice = _multiclass_dice(logits, target, self.ignore_index)
|
| 104 |
+
loss = ce + self.dice_weight * dice
|
| 105 |
+
elif self.mode == "focal_ce_dice":
|
| 106 |
+
focal = _focal_ce(logits, target, self.focal_gamma, w, self.ignore_index)
|
| 107 |
+
dice = _multiclass_dice(logits, target, self.ignore_index)
|
| 108 |
+
loss = focal + self.dice_weight * dice
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown loss mode {self.mode}")
|
| 111 |
+
return loss, {"loss": float(loss.detach().cpu())}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def build_loss(
|
| 115 |
+
loss_cfg: dict,
|
| 116 |
+
num_classes: int,
|
| 117 |
+
class_weights: Optional[torch.Tensor],
|
| 118 |
+
ignore_index: int,
|
| 119 |
+
) -> CombinedSegLoss:
|
| 120 |
+
mode = loss_cfg.get("name", "ce_dice")
|
| 121 |
+
return CombinedSegLoss(
|
| 122 |
+
mode=mode,
|
| 123 |
+
num_classes=num_classes,
|
| 124 |
+
ignore_index=ignore_index,
|
| 125 |
+
class_weights=class_weights,
|
| 126 |
+
dice_weight=float(loss_cfg.get("dice_weight", 0.5)),
|
| 127 |
+
label_smoothing=float(loss_cfg.get("label_smoothing", 0.0)),
|
| 128 |
+
focal_gamma=float(loss_cfg.get("focal_gamma", 2.0)),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def compute_class_weights_from_freq(
|
| 133 |
+
freq: torch.Tensor,
|
| 134 |
+
cap: float = 15.0,
|
| 135 |
+
eps: float = 1e-6,
|
| 136 |
+
) -> torch.Tensor:
|
| 137 |
+
"""Inverse log frequency with mean normalization and per-class cap on max/min ratio."""
|
| 138 |
+
w = 1.0 / torch.log(freq + eps)
|
| 139 |
+
w = w / w.mean()
|
| 140 |
+
ratio = w / w.median()
|
| 141 |
+
ratio = torch.clamp(ratio, max=cap)
|
| 142 |
+
w = ratio * w.median()
|
| 143 |
+
return w
|
desert_segmentation/metrics/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from desert_segmentation.metrics.iou import (
|
| 2 |
+
IoUMetrics,
|
| 3 |
+
compute_confusion,
|
| 4 |
+
confusion_to_accuracy_metrics,
|
| 5 |
+
gt_pixel_counts,
|
| 6 |
+
valid_class_miou_from_confusion,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"IoUMetrics",
|
| 11 |
+
"compute_confusion",
|
| 12 |
+
"confusion_to_accuracy_metrics",
|
| 13 |
+
"gt_pixel_counts",
|
| 14 |
+
"valid_class_miou_from_confusion",
|
| 15 |
+
]
|
desert_segmentation/metrics/iou.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Per-class IoU, mIoU, frequency-weighted IoU, confusion matrix."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compute_confusion(
|
| 12 |
+
logits: torch.Tensor,
|
| 13 |
+
target: torch.Tensor,
|
| 14 |
+
num_classes: int,
|
| 15 |
+
ignore_index: int = 255,
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
"""Accumulate confusion matrix (pred rows, target columns) — shape CxC."""
|
| 18 |
+
pred = logits.argmax(dim=1).view(-1)
|
| 19 |
+
tgt = target.view(-1)
|
| 20 |
+
valid = tgt != ignore_index
|
| 21 |
+
pred = pred[valid]
|
| 22 |
+
tgt = tgt[valid]
|
| 23 |
+
if pred.numel() == 0:
|
| 24 |
+
return torch.zeros(num_classes, num_classes, dtype=torch.int64, device=logits.device)
|
| 25 |
+
idx = tgt * num_classes + pred
|
| 26 |
+
cm = torch.bincount(idx, minlength=num_classes * num_classes).reshape(num_classes, num_classes)
|
| 27 |
+
return cm
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def confusion_to_accuracy_metrics(
|
| 31 |
+
cm: Union[np.ndarray, torch.Tensor],
|
| 32 |
+
eps: float = 1e-12,
|
| 33 |
+
) -> Dict[str, float | np.ndarray]:
|
| 34 |
+
"""Pixel accuracies from confusion ``cm[gt_i, pred_j]`` (same layout as ``IoUMetrics``).
|
| 35 |
+
|
| 36 |
+
- **global_pixel_accuracy:** ``trace(cm) / sum(cm)`` — fraction of pixels correct.
|
| 37 |
+
- **mean_class_accuracy:** mean of per-class **recall** ``cm[k,k] / sum_j cm[k,j]`` over
|
| 38 |
+
classes with at least one ground-truth pixel (ignores empty rows).
|
| 39 |
+
|
| 40 |
+
Returns ``per_class_recall`` aligned with class index for optional reporting.
|
| 41 |
+
"""
|
| 42 |
+
if isinstance(cm, torch.Tensor):
|
| 43 |
+
cm = cm.detach().cpu().numpy()
|
| 44 |
+
cm = np.asarray(cm, dtype=np.float64)
|
| 45 |
+
total = cm.sum()
|
| 46 |
+
if total <= eps:
|
| 47 |
+
z = np.zeros(cm.shape[0], dtype=np.float64)
|
| 48 |
+
return {
|
| 49 |
+
"global_pixel_accuracy": 0.0,
|
| 50 |
+
"mean_class_accuracy": 0.0,
|
| 51 |
+
"per_class_recall": z,
|
| 52 |
+
}
|
| 53 |
+
trace = np.trace(cm)
|
| 54 |
+
global_acc = float(trace / total)
|
| 55 |
+
row_sums = cm.sum(axis=1)
|
| 56 |
+
diag = np.diag(cm)
|
| 57 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
| 58 |
+
per_class_recall = np.where(row_sums > eps, diag / np.maximum(row_sums, eps), np.nan)
|
| 59 |
+
present = row_sums > eps
|
| 60 |
+
mean_class_acc = (
|
| 61 |
+
float(np.nanmean(per_class_recall[present])) if np.any(present) else 0.0
|
| 62 |
+
)
|
| 63 |
+
return {
|
| 64 |
+
"global_pixel_accuracy": global_acc,
|
| 65 |
+
"mean_class_accuracy": mean_class_acc,
|
| 66 |
+
"per_class_recall": per_class_recall,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def gt_pixel_counts(cm: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
| 71 |
+
"""Ground-truth pixel counts per class: ``sum_j cm[gt_k, pred_j]`` (row sums)."""
|
| 72 |
+
if isinstance(cm, torch.Tensor):
|
| 73 |
+
cm = cm.detach().cpu().numpy()
|
| 74 |
+
cm = np.asarray(cm, dtype=np.float64)
|
| 75 |
+
return np.sum(cm, axis=1).astype(np.int64)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def valid_class_miou_from_confusion(
|
| 79 |
+
cm: Union[np.ndarray, torch.Tensor],
|
| 80 |
+
eps: float = 1e-6,
|
| 81 |
+
) -> float:
|
| 82 |
+
"""Mean IoU over classes that have at least one ground-truth pixel on the val set.
|
| 83 |
+
|
| 84 |
+
Unlike full mIoU (mean over all classes, often many zeros when a class is absent from
|
| 85 |
+
val GT), this only averages **finite** per-class IoU values for rows with ``GT > 0``.
|
| 86 |
+
Returns ``0.0`` if no class has any GT pixels.
|
| 87 |
+
"""
|
| 88 |
+
if isinstance(cm, torch.Tensor):
|
| 89 |
+
cm = cm.detach().cpu().numpy()
|
| 90 |
+
cm = np.asarray(cm, dtype=np.float64)
|
| 91 |
+
diag = np.diag(cm)
|
| 92 |
+
rows = cm.sum(axis=1)
|
| 93 |
+
cols = cm.sum(axis=0)
|
| 94 |
+
union = rows + cols - diag + eps
|
| 95 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
| 96 |
+
iou = diag / union
|
| 97 |
+
present = rows > 0
|
| 98 |
+
if not np.any(present):
|
| 99 |
+
return 0.0
|
| 100 |
+
finite = present & np.isfinite(iou)
|
| 101 |
+
if not np.any(finite):
|
| 102 |
+
return 0.0
|
| 103 |
+
return float(np.mean(iou[finite]))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def confusion_to_iou(cm: torch.Tensor) -> Tuple[torch.Tensor, float, float]:
|
| 107 |
+
"""Returns per-class IoU, mean IoU, frequency-weighted IoU."""
|
| 108 |
+
diag = torch.diag(cm).float()
|
| 109 |
+
rows = cm.sum(dim=1).float()
|
| 110 |
+
cols = cm.sum(dim=0).float()
|
| 111 |
+
union = rows + cols - diag + 1e-6
|
| 112 |
+
iou = diag / union
|
| 113 |
+
miou = iou[torch.isfinite(iou)].mean().item()
|
| 114 |
+
freq = cols / (cols.sum() + 1e-6)
|
| 115 |
+
fw_iou = (iou * freq).sum().item()
|
| 116 |
+
return iou, miou, fw_iou
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class IoUMetrics:
|
| 120 |
+
def __init__(self, num_classes: int, ignore_index: int = 255, device: Optional[torch.device] = None):
|
| 121 |
+
self.num_classes = num_classes
|
| 122 |
+
self.ignore_index = ignore_index
|
| 123 |
+
self.device = device or torch.device("cpu")
|
| 124 |
+
self.reset()
|
| 125 |
+
|
| 126 |
+
def reset(self) -> None:
|
| 127 |
+
self._cm = torch.zeros(self.num_classes, self.num_classes, dtype=torch.int64, device=self.device)
|
| 128 |
+
|
| 129 |
+
@torch.no_grad()
|
| 130 |
+
def update(self, logits: torch.Tensor, target: torch.Tensor) -> None:
|
| 131 |
+
logits = logits.to(self.device)
|
| 132 |
+
target = target.to(self.device)
|
| 133 |
+
self._cm += compute_confusion(logits, target, self.num_classes, self.ignore_index).to(self.device)
|
| 134 |
+
|
| 135 |
+
def compute(self) -> Dict[str, float | np.ndarray]:
|
| 136 |
+
cm = self._cm.cpu()
|
| 137 |
+
iou, miou, fw_iou = confusion_to_iou(cm)
|
| 138 |
+
return {
|
| 139 |
+
"per_class_iou": iou.numpy(),
|
| 140 |
+
"miou": miou,
|
| 141 |
+
"fw_iou": fw_iou,
|
| 142 |
+
"confusion": cm.numpy(),
|
| 143 |
+
}
|
desert_segmentation/models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from desert_segmentation.models.factory import create_model
|
| 2 |
+
|
| 3 |
+
__all__ = ["create_model"]
|
desert_segmentation/models/factory.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build segmentation models via segmentation_models_pytorch."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
import segmentation_models_pytorch as smp
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_model(model_cfg: Dict[str, Any], num_classes: int) -> nn.Module:
|
| 12 |
+
arch = (model_cfg.get("architecture") or "deeplabv3plus").lower()
|
| 13 |
+
encoder_name = model_cfg.get("encoder_name", "resnet50")
|
| 14 |
+
encoder_weights = model_cfg.get("encoder_weights", "imagenet")
|
| 15 |
+
|
| 16 |
+
if arch == "deeplabv3plus":
|
| 17 |
+
return smp.DeepLabV3Plus(
|
| 18 |
+
encoder_name=encoder_name,
|
| 19 |
+
encoder_weights=encoder_weights,
|
| 20 |
+
in_channels=3,
|
| 21 |
+
classes=num_classes,
|
| 22 |
+
)
|
| 23 |
+
if arch == "unet":
|
| 24 |
+
return smp.Unet(
|
| 25 |
+
encoder_name=encoder_name,
|
| 26 |
+
encoder_weights=encoder_weights,
|
| 27 |
+
in_channels=3,
|
| 28 |
+
classes=num_classes,
|
| 29 |
+
)
|
| 30 |
+
if arch == "fpn":
|
| 31 |
+
return smp.FPN(
|
| 32 |
+
encoder_name=encoder_name,
|
| 33 |
+
encoder_weights=encoder_weights,
|
| 34 |
+
in_channels=3,
|
| 35 |
+
classes=num_classes,
|
| 36 |
+
)
|
| 37 |
+
raise ValueError(f"Unknown architecture: {arch}")
|
desert_segmentation/train/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from desert_segmentation.train.evaluate import evaluate
|
| 2 |
+
from desert_segmentation.train.trainer import train
|
| 3 |
+
|
| 4 |
+
__all__ = ["evaluate", "train"]
|
desert_segmentation/train/evaluate.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Validation loop and metric aggregation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from desert_segmentation.metrics.iou import IoUMetrics
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def evaluate(
|
| 16 |
+
model: torch.nn.Module,
|
| 17 |
+
loader: DataLoader,
|
| 18 |
+
device: torch.device,
|
| 19 |
+
num_classes: int,
|
| 20 |
+
ignore_index: int = 255,
|
| 21 |
+
desc: str = "val",
|
| 22 |
+
) -> dict:
|
| 23 |
+
model.eval()
|
| 24 |
+
metrics = IoUMetrics(num_classes=num_classes, ignore_index=ignore_index, device=device)
|
| 25 |
+
for batch in tqdm(loader, desc=desc, leave=False):
|
| 26 |
+
images = batch["image"].to(device, non_blocking=True)
|
| 27 |
+
masks = batch["mask"].to(device, non_blocking=True)
|
| 28 |
+
logits = model(images)
|
| 29 |
+
metrics.update(logits, masks)
|
| 30 |
+
return metrics.compute()
|
desert_segmentation/train/trainer.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training loop with AMP, cosine+warmup, EMA, best-mIoU checkpointing."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 15 |
+
from torch.optim import AdamW
|
| 16 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from desert_segmentation.train.evaluate import evaluate
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ModelEMA:
|
| 26 |
+
"""Exponential moving average of model parameters."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, model: nn.Module, decay: float = 0.999) -> None:
|
| 29 |
+
self.decay = decay
|
| 30 |
+
self.shadow: Dict[str, torch.Tensor] = {}
|
| 31 |
+
self._collect(model)
|
| 32 |
+
|
| 33 |
+
@torch.no_grad()
|
| 34 |
+
def _collect(self, model: nn.Module) -> None:
|
| 35 |
+
for n, p in model.named_parameters():
|
| 36 |
+
if p.requires_grad:
|
| 37 |
+
self.shadow[n] = p.detach().clone()
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def update(self, model: nn.Module) -> None:
|
| 41 |
+
for n, p in model.named_parameters():
|
| 42 |
+
if not p.requires_grad:
|
| 43 |
+
continue
|
| 44 |
+
self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def copy_to(self, model: nn.Module) -> None:
|
| 48 |
+
for n, p in model.named_parameters():
|
| 49 |
+
if n in self.shadow:
|
| 50 |
+
p.data.copy_(self.shadow[n])
|
| 51 |
+
|
| 52 |
+
def _warmup_cosine_lambda(
|
| 53 |
+
total_steps: int,
|
| 54 |
+
warmup_steps: int,
|
| 55 |
+
min_ratio: float = 0.01,
|
| 56 |
+
) -> Any:
|
| 57 |
+
def lr_lambda(step: int) -> float:
|
| 58 |
+
if step < warmup_steps:
|
| 59 |
+
return float(step + 1) / float(max(1, warmup_steps))
|
| 60 |
+
progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
|
| 61 |
+
return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * progress))
|
| 62 |
+
|
| 63 |
+
return lr_lambda
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def train(
|
| 67 |
+
model: nn.Module,
|
| 68 |
+
train_loader: DataLoader,
|
| 69 |
+
val_loader: DataLoader,
|
| 70 |
+
criterion: nn.Module,
|
| 71 |
+
device: torch.device,
|
| 72 |
+
cfg: Dict[str, Any],
|
| 73 |
+
num_classes: int,
|
| 74 |
+
ignore_index: int,
|
| 75 |
+
checkpoint_dir: Path,
|
| 76 |
+
class_names: tuple[str, ...],
|
| 77 |
+
max_train_batches: Optional[int] = None,
|
| 78 |
+
) -> Dict[str, Any]:
|
| 79 |
+
tcfg = cfg["train"]
|
| 80 |
+
epochs = int(tcfg["epochs"])
|
| 81 |
+
lr = float(tcfg["lr"])
|
| 82 |
+
wd = float(tcfg["weight_decay"])
|
| 83 |
+
amp_enabled = bool(tcfg.get("amp", True)) and torch.cuda.is_available()
|
| 84 |
+
clip = float(tcfg.get("gradient_clip", 0.0))
|
| 85 |
+
warmup_ratio = float(tcfg.get("warmup_ratio", 0.08))
|
| 86 |
+
patience = int(tcfg.get("early_stop_patience", 20))
|
| 87 |
+
log_interval = int(tcfg.get("log_interval", 20))
|
| 88 |
+
|
| 89 |
+
ema_cfg = cfg.get("ema") or {}
|
| 90 |
+
use_ema = bool(ema_cfg.get("enabled", False))
|
| 91 |
+
ema_decay = float(ema_cfg.get("decay", 0.999))
|
| 92 |
+
ema: Optional[ModelEMA] = ModelEMA(model, decay=ema_decay) if use_ema else None
|
| 93 |
+
|
| 94 |
+
opt = AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 95 |
+
steps_per_epoch = max(1, len(train_loader))
|
| 96 |
+
total_steps = steps_per_epoch * epochs
|
| 97 |
+
warmup_steps = max(1, int(total_steps * warmup_ratio))
|
| 98 |
+
sched = LambdaLR(opt, _warmup_cosine_lambda(total_steps, warmup_steps))
|
| 99 |
+
scaler: Optional[GradScaler] = GradScaler() if amp_enabled else None
|
| 100 |
+
|
| 101 |
+
best_miou = -1.0
|
| 102 |
+
bad_epochs = 0
|
| 103 |
+
history: list = []
|
| 104 |
+
|
| 105 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 106 |
+
best_path = checkpoint_dir / "best.pt"
|
| 107 |
+
last_path = checkpoint_dir / "last.pt"
|
| 108 |
+
|
| 109 |
+
global_step = 0
|
| 110 |
+
for epoch in range(1, epochs + 1):
|
| 111 |
+
model.train()
|
| 112 |
+
running = 0.0
|
| 113 |
+
n_log = 0
|
| 114 |
+
pbar = tqdm(train_loader, desc=f"train {epoch}/{epochs}")
|
| 115 |
+
for batch_idx, batch in enumerate(pbar):
|
| 116 |
+
images = batch["image"].to(device, non_blocking=True)
|
| 117 |
+
masks = batch["mask"].to(device, non_blocking=True)
|
| 118 |
+
opt.zero_grad(set_to_none=True)
|
| 119 |
+
with autocast(enabled=amp_enabled):
|
| 120 |
+
logits = model(images)
|
| 121 |
+
loss, _ = criterion(logits, masks)
|
| 122 |
+
if scaler is not None:
|
| 123 |
+
scaler.scale(loss).backward()
|
| 124 |
+
if clip > 0:
|
| 125 |
+
scaler.unscale_(opt)
|
| 126 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
| 127 |
+
scaler.step(opt)
|
| 128 |
+
scaler.update()
|
| 129 |
+
else:
|
| 130 |
+
loss.backward()
|
| 131 |
+
if clip > 0:
|
| 132 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
| 133 |
+
opt.step()
|
| 134 |
+
sched.step()
|
| 135 |
+
global_step += 1
|
| 136 |
+
if ema is not None:
|
| 137 |
+
ema.update(model)
|
| 138 |
+
running += float(loss.detach().cpu())
|
| 139 |
+
n_log += 1
|
| 140 |
+
if global_step % log_interval == 0:
|
| 141 |
+
pbar.set_postfix(loss=f"{running / max(n_log, 1):.4f}")
|
| 142 |
+
running = 0.0
|
| 143 |
+
n_log = 0
|
| 144 |
+
if max_train_batches is not None and (batch_idx + 1) >= max_train_batches:
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
backup = copy.deepcopy(model.state_dict())
|
| 148 |
+
if ema is not None:
|
| 149 |
+
ema.copy_to(model)
|
| 150 |
+
val_metrics = evaluate(
|
| 151 |
+
model,
|
| 152 |
+
val_loader,
|
| 153 |
+
device,
|
| 154 |
+
num_classes=num_classes,
|
| 155 |
+
ignore_index=ignore_index,
|
| 156 |
+
desc=f"val {epoch}",
|
| 157 |
+
)
|
| 158 |
+
model.load_state_dict(backup)
|
| 159 |
+
|
| 160 |
+
miou = float(val_metrics["miou"])
|
| 161 |
+
row = {"epoch": epoch, "miou": miou, "fw_iou": float(val_metrics["fw_iou"])}
|
| 162 |
+
history.append(row)
|
| 163 |
+
logger.info(
|
| 164 |
+
"epoch %s | val mIoU=%.4f fwIoU=%.4f",
|
| 165 |
+
epoch,
|
| 166 |
+
miou,
|
| 167 |
+
row["fw_iou"],
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
torch.save(
|
| 171 |
+
{
|
| 172 |
+
"epoch": epoch,
|
| 173 |
+
"model": model.state_dict(),
|
| 174 |
+
"ema": ema.shadow if ema is not None else None,
|
| 175 |
+
"optimizer": opt.state_dict(),
|
| 176 |
+
"config": cfg,
|
| 177 |
+
"class_names": class_names,
|
| 178 |
+
},
|
| 179 |
+
last_path,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if miou > best_miou:
|
| 183 |
+
best_miou = miou
|
| 184 |
+
bad_epochs = 0
|
| 185 |
+
save_payload = {
|
| 186 |
+
"epoch": epoch,
|
| 187 |
+
"model": model.state_dict(),
|
| 188 |
+
"ema": ema.shadow if ema is not None else None,
|
| 189 |
+
"miou": miou,
|
| 190 |
+
"per_class_iou": val_metrics["per_class_iou"].tolist(),
|
| 191 |
+
"config": cfg,
|
| 192 |
+
"class_names": class_names,
|
| 193 |
+
}
|
| 194 |
+
torch.save(save_payload, best_path)
|
| 195 |
+
logger.info("saved new best checkpoint mIoU=%.4f -> %s", miou, best_path)
|
| 196 |
+
else:
|
| 197 |
+
bad_epochs += 1
|
| 198 |
+
if bad_epochs >= patience:
|
| 199 |
+
logger.info("early stopping at epoch %s (no improvement %s epochs)", epoch, patience)
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
with (checkpoint_dir / "history.json").open("w", encoding="utf-8") as f:
|
| 203 |
+
json.dump(history, f, indent=2)
|
| 204 |
+
|
| 205 |
+
return {"best_miou": best_miou, "best_path": str(best_path), "history": history}
|
desert_segmentation/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from desert_segmentation.utils.seed import set_seed
|
| 2 |
+
|
| 3 |
+
__all__ = ["set_seed"]
|
desert_segmentation/utils/config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load YAML config and resolve paths relative to workspace root."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_config(path: Path | str, root: Path | None = None) -> Dict[str, Any]:
|
| 13 |
+
path = Path(path)
|
| 14 |
+
with path.open("r", encoding="utf-8") as f:
|
| 15 |
+
cfg = yaml.safe_load(f)
|
| 16 |
+
if root is None:
|
| 17 |
+
root = Path(cfg.get("root", ".")).resolve()
|
| 18 |
+
else:
|
| 19 |
+
root = Path(root).resolve()
|
| 20 |
+
cfg["root"] = str(root)
|
| 21 |
+
return cfg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def resolve_path(root: Path, *parts: str) -> Path:
|
| 25 |
+
return (root / Path(*parts)).resolve()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_paths(cfg: Dict[str, Any]) -> Dict[str, Path]:
|
| 29 |
+
root = Path(cfg["root"])
|
| 30 |
+
d = cfg["data"]
|
| 31 |
+
return {
|
| 32 |
+
"train_images": resolve_path(root, d["train_images"]),
|
| 33 |
+
"train_masks": resolve_path(root, d["train_masks"]),
|
| 34 |
+
"val_images": resolve_path(root, d["val_images"]),
|
| 35 |
+
"val_masks": resolve_path(root, d["val_masks"]),
|
| 36 |
+
"test_images": resolve_path(root, d["test_images"]),
|
| 37 |
+
}
|
desert_segmentation/utils/freq.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Estimate class pixel frequencies from mask files (fast path for loss weighting)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Sequence
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from desert_segmentation.data.mask_encoding import RawMaskCodec
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def list_masks(dir_path: Path) -> List[str]:
|
| 17 |
+
return sorted(f for f in os.listdir(dir_path) if f.lower().endswith(".png"))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@torch.no_grad()
|
| 21 |
+
def estimate_pixel_frequencies(
|
| 22 |
+
masks_dir: Path,
|
| 23 |
+
codec: RawMaskCodec,
|
| 24 |
+
max_files: int | None = 800,
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
+
paths = list_masks(masks_dir)
|
| 27 |
+
if max_files is not None:
|
| 28 |
+
paths = paths[:max_files]
|
| 29 |
+
counts = np.zeros(codec.num_classes, dtype=np.int64)
|
| 30 |
+
for name in paths:
|
| 31 |
+
raw = np.array(Image.open(masks_dir / name))
|
| 32 |
+
enc, _ = codec.encode_mask(raw.astype(np.uint16))
|
| 33 |
+
for c in range(codec.num_classes):
|
| 34 |
+
counts[c] += int((enc == c).sum())
|
| 35 |
+
freq = counts.astype(np.float64) / max(counts.sum(), 1)
|
| 36 |
+
return torch.tensor(freq, dtype=torch.float32)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def per_image_sampling_weights(
|
| 40 |
+
masks_dir: Path,
|
| 41 |
+
image_basenames: Sequence[str],
|
| 42 |
+
codec: RawMaskCodec,
|
| 43 |
+
freq: torch.Tensor,
|
| 44 |
+
eps: float = 1e-6,
|
| 45 |
+
) -> torch.DoubleTensor:
|
| 46 |
+
"""Weights for ``WeightedRandomSampler``: upweight images containing rare classes.
|
| 47 |
+
|
| 48 |
+
For each mask, ``w_i = sum_{c : n_{ic}>0} 1 / (freq[c] + eps)``, then weights are
|
| 49 |
+
scaled to mean 1.0. ``image_basenames`` must match the order of
|
| 50 |
+
``SegmentationDataset`` indices (same filenames as train pairs).
|
| 51 |
+
"""
|
| 52 |
+
masks_dir = Path(masks_dir)
|
| 53 |
+
f = freq.detach().cpu().numpy().astype(np.float64)
|
| 54 |
+
raw_weights = np.zeros(len(image_basenames), dtype=np.float64)
|
| 55 |
+
for i, name in enumerate(image_basenames):
|
| 56 |
+
raw = np.array(Image.open(masks_dir / name))
|
| 57 |
+
enc, _ = codec.encode_mask(raw.astype(np.uint16))
|
| 58 |
+
present = np.zeros(codec.num_classes, dtype=bool)
|
| 59 |
+
for c in range(codec.num_classes):
|
| 60 |
+
present[c] = bool((enc == c).any())
|
| 61 |
+
raw_weights[i] = sum(1.0 / (f[c] + eps) for c in range(codec.num_classes) if present[c])
|
| 62 |
+
m = raw_weights.mean()
|
| 63 |
+
if m <= 0:
|
| 64 |
+
return torch.ones(len(image_basenames), dtype=torch.double)
|
| 65 |
+
scaled = raw_weights / m
|
| 66 |
+
return torch.tensor(scaled, dtype=torch.double)
|
desert_segmentation/utils/logging_utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def setup_logging(level: int = logging.INFO) -> None:
|
| 6 |
+
logging.basicConfig(
|
| 7 |
+
level=level,
|
| 8 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 9 |
+
datefmt="%H:%M:%S",
|
| 10 |
+
stream=sys.stdout,
|
| 11 |
+
)
|
desert_segmentation/utils/seed.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_seed(seed: int) -> None:
|
| 9 |
+
random.seed(seed)
|
| 10 |
+
np.random.seed(seed)
|
| 11 |
+
torch.manual_seed(seed)
|
| 12 |
+
torch.cuda.manual_seed_all(seed)
|
| 13 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
desert_segmentation/utils/viz.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Color overlays and side-by-side panels for segmentation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def palette(num_classes: int, seed: int = 42) -> np.ndarray:
|
| 13 |
+
rng = np.random.default_rng(seed)
|
| 14 |
+
colors = rng.integers(32, 256, size=(num_classes, 3), dtype=np.uint8)
|
| 15 |
+
colors[0] = np.array([128, 128, 128], dtype=np.uint8)
|
| 16 |
+
return colors
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def colorize_mask(mask: np.ndarray, colors: np.ndarray) -> np.ndarray:
|
| 20 |
+
"""mask HxW int 0..C-1 -> RGB uint8"""
|
| 21 |
+
m = mask.clip(0, len(colors) - 1)
|
| 22 |
+
return colors[m]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def blend_overlay(
|
| 26 |
+
image_rgb: np.ndarray,
|
| 27 |
+
colored_mask: np.ndarray,
|
| 28 |
+
alpha: float = 0.55,
|
| 29 |
+
) -> np.ndarray:
|
| 30 |
+
return (image_rgb.astype(np.float32) * (1 - alpha) + colored_mask.astype(np.float32) * alpha).clip(
|
| 31 |
+
0, 255
|
| 32 |
+
).astype(np.uint8)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def save_triplet(
|
| 36 |
+
out_path: Path,
|
| 37 |
+
rgb: np.ndarray,
|
| 38 |
+
gt: np.ndarray | None,
|
| 39 |
+
pred: np.ndarray,
|
| 40 |
+
class_colors: np.ndarray,
|
| 41 |
+
titles: Tuple[str, str, str] = ("RGB", "GT", "Pred"),
|
| 42 |
+
) -> None:
|
| 43 |
+
h, w = rgb.shape[:2]
|
| 44 |
+
panels: List[np.ndarray] = [rgb]
|
| 45 |
+
if gt is not None:
|
| 46 |
+
panels.append(blend_overlay(rgb, colorize_mask(gt, class_colors)))
|
| 47 |
+
else:
|
| 48 |
+
panels.append(np.zeros_like(rgb))
|
| 49 |
+
panels.append(blend_overlay(rgb, colorize_mask(pred, class_colors)))
|
| 50 |
+
|
| 51 |
+
# Optional text strip (simple border)
|
| 52 |
+
gap = 8
|
| 53 |
+
total_w = w * len(panels) + gap * (len(panels) - 1)
|
| 54 |
+
canvas = np.zeros((h, total_w, 3), dtype=np.uint8)
|
| 55 |
+
x = 0
|
| 56 |
+
for p in panels:
|
| 57 |
+
canvas[:, x : x + w] = p
|
| 58 |
+
x += w + gap
|
| 59 |
+
Image.fromarray(canvas).save(out_path)
|
| 60 |
+
|
eval_summary.json
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint": "D:\\codewizard 2.0\\checkpoints\\best.pt",
|
| 3 |
+
"val_dir": "D:\\codewizard 2.0\\training\\val\\Color_Images",
|
| 4 |
+
"num_val_samples": 317,
|
| 5 |
+
"miou": 0.07851162552833557,
|
| 6 |
+
"miou_all_classes": 0.07851162552833557,
|
| 7 |
+
"miou_valid_gt_classes": 0.07851162270064849,
|
| 8 |
+
"fw_iou": 0.3974744379520416,
|
| 9 |
+
"global_pixel_accuracy": 0.448105526939844,
|
| 10 |
+
"mean_class_accuracy": 0.15349954460756882,
|
| 11 |
+
"per_class_iou": {
|
| 12 |
+
"id_100": 0.0,
|
| 13 |
+
"id_200": 0.0,
|
| 14 |
+
"id_300": 0.25856709480285645,
|
| 15 |
+
"id_500": 0.0,
|
| 16 |
+
"id_550": 0.0,
|
| 17 |
+
"id_600": 0.0,
|
| 18 |
+
"id_700": 0.0,
|
| 19 |
+
"id_800": 0.0,
|
| 20 |
+
"id_7100": 0.0,
|
| 21 |
+
"id_10000": 0.5265491604804993
|
| 22 |
+
},
|
| 23 |
+
"per_class_recall": {
|
| 24 |
+
"id_100": 0.0,
|
| 25 |
+
"id_200": 0.0,
|
| 26 |
+
"id_300": 0.7182908230723474,
|
| 27 |
+
"id_500": 0.0,
|
| 28 |
+
"id_550": 0.0,
|
| 29 |
+
"id_600": 0.0,
|
| 30 |
+
"id_700": 0.0,
|
| 31 |
+
"id_800": 0.0,
|
| 32 |
+
"id_7100": 0.0,
|
| 33 |
+
"id_10000": 0.8167046230033408
|
| 34 |
+
},
|
| 35 |
+
"val_gt_pixel_counts": {
|
| 36 |
+
"id_100": 1902003,
|
| 37 |
+
"id_200": 2808908,
|
| 38 |
+
"id_300": 9019195,
|
| 39 |
+
"id_500": 512976,
|
| 40 |
+
"id_550": 1976309,
|
| 41 |
+
"id_600": 1138100,
|
| 42 |
+
"id_700": 30968,
|
| 43 |
+
"id_800": 566002,
|
| 44 |
+
"id_7100": 11074438,
|
| 45 |
+
"id_10000": 17714653
|
| 46 |
+
}
|
| 47 |
+
}
|
requirements-demo.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Optional: interactive Gradio demo (install alongside requirements.txt)
|
| 2 |
+
# pip install -r requirements.txt -r requirements-demo.txt
|
| 3 |
+
gradio>=4.44.0,<6
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
Pillow>=10.0.0
|
| 5 |
+
PyYAML>=6.0
|
| 6 |
+
# Pin avoids optional native build deps (e.g. stringzilla) on some Windows/Python setups
|
| 7 |
+
albumentations>=1.3.1,<1.5
|
| 8 |
+
segmentation-models-pytorch>=0.3.3
|
| 9 |
+
tqdm>=4.66.0
|
| 10 |
+
pytest>=7.4.0
|
scripts/demo_gradio.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Gradio demo: upload RGB image, get colored mask, overlay, legend, and timing."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Dict, Tuple
|
| 13 |
+
|
| 14 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 15 |
+
if str(ROOT) not in sys.path:
|
| 16 |
+
sys.path.insert(0, str(ROOT))
|
| 17 |
+
|
| 18 |
+
import gradio as gr
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
from desert_segmentation.demo.inference_ui import (
|
| 24 |
+
build_legend_rows,
|
| 25 |
+
dominant_classes_markdown,
|
| 26 |
+
legend_table_html,
|
| 27 |
+
side_by_side_strip,
|
| 28 |
+
validate_rgb_array,
|
| 29 |
+
)
|
| 30 |
+
from desert_segmentation.infer.predict import _load_model_for_inference, predict_image
|
| 31 |
+
from desert_segmentation.utils.viz import blend_overlay, colorize_mask
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
_STATE: Dict[str, Any] = {}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _to_uint8_rgb(arr: Any) -> np.ndarray:
|
| 39 |
+
if arr is None:
|
| 40 |
+
raise gr.Error("Please upload an image.")
|
| 41 |
+
if isinstance(arr, Image.Image):
|
| 42 |
+
arr = np.array(arr.convert("RGB"))
|
| 43 |
+
a = np.asarray(arr)
|
| 44 |
+
if a.ndim == 2:
|
| 45 |
+
raise gr.Error("Expected a color RGB image, got grayscale.")
|
| 46 |
+
if a.ndim == 3 and a.shape[2] == 4:
|
| 47 |
+
a = a[:, :, :3]
|
| 48 |
+
if a.ndim != 3 or a.shape[2] != 3:
|
| 49 |
+
raise gr.Error(f"Expected HxWx3 RGB image, got shape {a.shape}.")
|
| 50 |
+
if np.issubdtype(a.dtype, np.floating) and float(a.max()) <= 1.0 + 1e-6:
|
| 51 |
+
a = (np.clip(a, 0.0, 1.0) * 255.0).round().astype(np.uint8)
|
| 52 |
+
elif a.dtype != np.uint8:
|
| 53 |
+
a = np.clip(a, 0, 255).astype(np.uint8)
|
| 54 |
+
return np.ascontiguousarray(a)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _init_state(checkpoint: Path, device: torch.device) -> None:
|
| 58 |
+
global _STATE
|
| 59 |
+
if _STATE:
|
| 60 |
+
return
|
| 61 |
+
logger.info("Loading checkpoint: %s", checkpoint)
|
| 62 |
+
model, cfg, codec = _load_model_for_inference(checkpoint, device)
|
| 63 |
+
icfg = cfg.get("inference") or {}
|
| 64 |
+
legend_rows, colors = build_legend_rows(codec.class_names, codec.num_classes, seed=42)
|
| 65 |
+
_STATE.update(
|
| 66 |
+
{
|
| 67 |
+
"model": model,
|
| 68 |
+
"cfg": cfg,
|
| 69 |
+
"codec": codec,
|
| 70 |
+
"device": device,
|
| 71 |
+
"icfg": icfg,
|
| 72 |
+
"legend_rows": legend_rows,
|
| 73 |
+
"colors": colors,
|
| 74 |
+
"legend_html_static": legend_table_html(legend_rows),
|
| 75 |
+
},
|
| 76 |
+
)
|
| 77 |
+
logger.info(
|
| 78 |
+
"Model ready | classes=%s | device=%s | default tile=%s overlap=%s tta=%s",
|
| 79 |
+
codec.num_classes,
|
| 80 |
+
device,
|
| 81 |
+
icfg.get("tile_size", 512),
|
| 82 |
+
icfg.get("overlap", 0.25),
|
| 83 |
+
icfg.get("tta_flip", True),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _run(
|
| 88 |
+
image_input: Any,
|
| 89 |
+
use_tta: bool,
|
| 90 |
+
overlap: float,
|
| 91 |
+
tile_size: float,
|
| 92 |
+
max_side: int,
|
| 93 |
+
max_megapixels: float,
|
| 94 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str, str]:
|
| 95 |
+
rgb = _to_uint8_rgb(image_input)
|
| 96 |
+
try:
|
| 97 |
+
validate_rgb_array(rgb, max_side=max_side, max_megapixels=max_megapixels)
|
| 98 |
+
except ValueError as e:
|
| 99 |
+
raise gr.Error(str(e)) from e
|
| 100 |
+
|
| 101 |
+
st = _STATE
|
| 102 |
+
model = st["model"]
|
| 103 |
+
device = st["device"]
|
| 104 |
+
icfg = st["icfg"]
|
| 105 |
+
codec = st["codec"]
|
| 106 |
+
colors = st["colors"]
|
| 107 |
+
|
| 108 |
+
tile = int(round(float(tile_size))) if tile_size is not None else int(icfg.get("tile_size", 512))
|
| 109 |
+
tile = max(256, min(tile, 2048))
|
| 110 |
+
ov = float(overlap)
|
| 111 |
+
ov = max(0.0, min(ov, 0.5))
|
| 112 |
+
|
| 113 |
+
t0 = time.perf_counter()
|
| 114 |
+
pred = predict_image(model, rgb, device, tile, ov, bool(use_tta))
|
| 115 |
+
ms = (time.perf_counter() - t0) * 1000.0
|
| 116 |
+
|
| 117 |
+
colored = colorize_mask(pred, colors)
|
| 118 |
+
overlay = blend_overlay(rgb, colored)
|
| 119 |
+
strip = side_by_side_strip(rgb, colored, overlay)
|
| 120 |
+
|
| 121 |
+
dev_str = str(device)
|
| 122 |
+
if device.type == "cpu":
|
| 123 |
+
dev_str += " (CPU mode — slower than GPU)"
|
| 124 |
+
|
| 125 |
+
stats = (
|
| 126 |
+
f"**Inference:** {ms:.1f} ms \n"
|
| 127 |
+
f"**Device:** {dev_str} \n"
|
| 128 |
+
f"**Tile size:** {tile} | **Overlap:** {ov:.2f} | **TTA:** {use_tta}"
|
| 129 |
+
)
|
| 130 |
+
dominant = "### Dominant classes in this image\n" + dominant_classes_markdown(pred, codec.class_names, top_k=3)
|
| 131 |
+
|
| 132 |
+
return rgb, colored, overlay, strip, stats, dominant
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def main() -> None:
|
| 136 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 137 |
+
|
| 138 |
+
parser = argparse.ArgumentParser(description="Gradio demo for desert semantic segmentation")
|
| 139 |
+
parser.add_argument("--root", type=str, default=os.environ.get("ROOT"), help="Workspace root (default: repo root or env ROOT)")
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--checkpoint",
|
| 142 |
+
type=str,
|
| 143 |
+
default=os.environ.get("CHECKPOINT_PATH"),
|
| 144 |
+
help="Path to best.pt (default: env CHECKPOINT_PATH or <root>/checkpoints/best.pt)",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument("--host", type=str, default="127.0.0.1")
|
| 147 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 148 |
+
parser.add_argument("--share", action="store_true", help="Create a temporary public Gradio link")
|
| 149 |
+
parser.add_argument("--max-side", type=int, default=4096)
|
| 150 |
+
parser.add_argument("--max-megapixels", type=float, default=16.0)
|
| 151 |
+
args = parser.parse_args()
|
| 152 |
+
|
| 153 |
+
root = Path(args.root or ROOT).resolve()
|
| 154 |
+
ckpt_arg = args.checkpoint or str(root / "checkpoints" / "best.pt")
|
| 155 |
+
ckpt = Path(ckpt_arg)
|
| 156 |
+
if not ckpt.is_absolute():
|
| 157 |
+
ckpt = (root / ckpt).resolve()
|
| 158 |
+
if not ckpt.is_file():
|
| 159 |
+
raise SystemExit(f"Checkpoint not found: {ckpt}")
|
| 160 |
+
|
| 161 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 162 |
+
_init_state(ckpt, device)
|
| 163 |
+
|
| 164 |
+
icfg = _STATE["icfg"]
|
| 165 |
+
def_tta = bool(icfg.get("tta_flip", True))
|
| 166 |
+
def_ov = float(icfg.get("overlap", 0.25))
|
| 167 |
+
def_tile = float(icfg.get("tile_size", 512))
|
| 168 |
+
|
| 169 |
+
intro = """## Desert semantic segmentation demo
|
| 170 |
+
|
| 171 |
+
This is **semantic segmentation**: each pixel is assigned one of several **classes** (terrain, vegetation, sky, etc.).
|
| 172 |
+
It is **not** bounding-box object detection.
|
| 173 |
+
|
| 174 |
+
**How to read the outputs:**
|
| 175 |
+
- **Colored mask:** each color is one class (see legend).
|
| 176 |
+
- **Overlay:** prediction blended on your photo.
|
| 177 |
+
- **Strip:** original | mask | side-by-side for screenshots.
|
| 178 |
+
|
| 179 |
+
_Confidence heatmaps for full-resolution sliding windows are not in this demo (v1); see README._
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
cpu_note = ""
|
| 183 |
+
if device.type == "cpu":
|
| 184 |
+
cpu_note = "\n\n> Running on **CPU** — expect slower inference. Use a CUDA GPU for best speed.\n"
|
| 185 |
+
|
| 186 |
+
with gr.Blocks(title="Desert segmentation", theme=gr.themes.Soft()) as demo:
|
| 187 |
+
gr.Markdown(intro + cpu_note)
|
| 188 |
+
inp = gr.Image(type="numpy", label="Upload RGB image", sources=["upload"])
|
| 189 |
+
with gr.Accordion("Advanced", open=False):
|
| 190 |
+
use_tta = gr.Checkbox(label="TTA (horizontal flip average)", value=def_tta)
|
| 191 |
+
overlap = gr.Slider(0.0, 0.5, value=def_ov, step=0.05, label="Tile overlap")
|
| 192 |
+
tile_sz = gr.Slider(256, 2048, value=int(def_tile), step=64, label="Tile size (pixels)")
|
| 193 |
+
run_btn = gr.Button("Run segmentation", variant="primary")
|
| 194 |
+
|
| 195 |
+
with gr.Row():
|
| 196 |
+
out_orig = gr.Image(label="Input", type="numpy")
|
| 197 |
+
out_mask = gr.Image(label="Colored class mask", type="numpy")
|
| 198 |
+
out_overlay = gr.Image(label="Overlay", type="numpy")
|
| 199 |
+
out_strip = gr.Image(label="RGB | mask | overlay", type="numpy")
|
| 200 |
+
stats_md = gr.Markdown("")
|
| 201 |
+
dominant_md = gr.Markdown("")
|
| 202 |
+
gr.Markdown("### Class legend (fixed palette)")
|
| 203 |
+
gr.HTML(_STATE["legend_html_static"])
|
| 204 |
+
|
| 205 |
+
def _fn(img, tta, ov, ts):
|
| 206 |
+
return _run(img, tta, ov, ts, args.max_side, args.max_megapixels)
|
| 207 |
+
|
| 208 |
+
run_btn.click(
|
| 209 |
+
fn=_fn,
|
| 210 |
+
inputs=[inp, use_tta, overlap, tile_sz],
|
| 211 |
+
outputs=[out_orig, out_mask, out_overlay, out_strip, stats_md, dominant_md],
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
logger.info("Launching Gradio on http://%s:%s", args.host, args.port)
|
| 215 |
+
demo.launch(server_name=args.host, server_port=args.port, share=args.share)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
main()
|
scripts/eval.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Run validation, print metrics, save confusion matrix and overlays."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
+
if str(ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(ROOT))
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from torch.utils.data import DataLoader
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from desert_segmentation.data.dataset import SegmentationDataset
|
| 24 |
+
from desert_segmentation.data.mask_encoding import build_codec_from_config
|
| 25 |
+
from desert_segmentation.data.transforms import build_val_transforms
|
| 26 |
+
from desert_segmentation.models.factory import create_model
|
| 27 |
+
from desert_segmentation.train.evaluate import evaluate
|
| 28 |
+
from desert_segmentation.utils.config import get_paths, load_config
|
| 29 |
+
from desert_segmentation.utils.logging_utils import setup_logging
|
| 30 |
+
from desert_segmentation.utils.seed import set_seed
|
| 31 |
+
from desert_segmentation.utils.viz import palette, save_triplet
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main() -> None:
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument("--config", type=str, default=str(ROOT / "desert_segmentation" / "configs" / "default.yaml"))
|
| 39 |
+
parser.add_argument("--checkpoint", type=str, required=True)
|
| 40 |
+
parser.add_argument("--root", type=str, default=None)
|
| 41 |
+
parser.add_argument("--out_dir", type=str, default="eval_outputs")
|
| 42 |
+
parser.add_argument("--max_viz", type=int, default=24)
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
root = Path(args.root or ROOT).resolve()
|
| 46 |
+
cfg = load_config(args.config, root=root)
|
| 47 |
+
setup_logging()
|
| 48 |
+
set_seed(int(cfg["train"]["seed"]))
|
| 49 |
+
|
| 50 |
+
paths = get_paths(cfg)
|
| 51 |
+
raw_ids = cfg["data"]["raw_ids"]
|
| 52 |
+
names = tuple(cfg["data"]["class_names"])
|
| 53 |
+
codec = build_codec_from_config(raw_ids, names)
|
| 54 |
+
ignore_index = int(cfg["data"].get("ignore_index", 255))
|
| 55 |
+
crop_size = int(cfg["data"]["crop_size"])
|
| 56 |
+
|
| 57 |
+
val_tf = build_val_transforms(crop_size=crop_size, ignore_index=ignore_index)
|
| 58 |
+
val_ds = SegmentationDataset(
|
| 59 |
+
paths["val_images"],
|
| 60 |
+
paths["val_masks"],
|
| 61 |
+
codec=codec,
|
| 62 |
+
transform=val_tf,
|
| 63 |
+
mode="val",
|
| 64 |
+
crop_size=crop_size,
|
| 65 |
+
rare_class_crop_prob=0.0,
|
| 66 |
+
ignore_index=ignore_index,
|
| 67 |
+
seed=int(cfg["train"]["seed"]),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
nw = 0 if os.name == "nt" else int(cfg["data"].get("num_workers", 4))
|
| 71 |
+
val_loader = DataLoader(
|
| 72 |
+
val_ds,
|
| 73 |
+
batch_size=int(cfg["train"].get("val_batch_size", cfg["train"]["batch_size"])),
|
| 74 |
+
shuffle=False,
|
| 75 |
+
num_workers=nw,
|
| 76 |
+
pin_memory=torch.cuda.is_available(),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
try:
|
| 81 |
+
ckpt = torch.load(Path(args.checkpoint), map_location=device, weights_only=False)
|
| 82 |
+
except TypeError:
|
| 83 |
+
ckpt = torch.load(Path(args.checkpoint), map_location=device)
|
| 84 |
+
cfg_ck = ckpt["config"]
|
| 85 |
+
model = create_model(cfg_ck["model"], num_classes=codec.num_classes).to(device)
|
| 86 |
+
if ckpt.get("ema") is not None:
|
| 87 |
+
for n, p in model.named_parameters():
|
| 88 |
+
if n in ckpt["ema"]:
|
| 89 |
+
p.data.copy_(ckpt["ema"][n].to(device))
|
| 90 |
+
else:
|
| 91 |
+
model.load_state_dict(ckpt["model"])
|
| 92 |
+
model.eval()
|
| 93 |
+
|
| 94 |
+
metrics = evaluate(model, val_loader, device, num_classes=codec.num_classes, ignore_index=ignore_index)
|
| 95 |
+
logger.info("mIoU=%.4f fwIoU=%.4f", metrics["miou"], metrics["fw_iou"])
|
| 96 |
+
per = metrics["per_class_iou"]
|
| 97 |
+
for i, name in enumerate(codec.class_names):
|
| 98 |
+
logger.info(" %s IoU=%.4f", name, float(per[i]))
|
| 99 |
+
|
| 100 |
+
out_dir = Path(args.out_dir)
|
| 101 |
+
if not out_dir.is_absolute():
|
| 102 |
+
out_dir = root / out_dir
|
| 103 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 104 |
+
with (out_dir / "metrics.json").open("w", encoding="utf-8") as f:
|
| 105 |
+
json.dump(
|
| 106 |
+
{
|
| 107 |
+
"miou": float(metrics["miou"]),
|
| 108 |
+
"fw_iou": float(metrics["fw_iou"]),
|
| 109 |
+
"per_class_iou": {codec.class_names[i]: float(per[i]) for i in range(len(codec.class_names))},
|
| 110 |
+
},
|
| 111 |
+
f,
|
| 112 |
+
indent=2,
|
| 113 |
+
)
|
| 114 |
+
np.save(out_dir / "confusion.npy", metrics["confusion"])
|
| 115 |
+
|
| 116 |
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
|
| 117 |
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
|
| 118 |
+
colors = palette(codec.num_classes)
|
| 119 |
+
n = 0
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
for batch in tqdm(val_loader, desc="viz"):
|
| 122 |
+
images = batch["image"].to(device)
|
| 123 |
+
masks = batch["mask"].to(device)
|
| 124 |
+
logits = model(images)
|
| 125 |
+
pred = logits.argmax(dim=1).cpu().numpy()
|
| 126 |
+
gt = masks.cpu().numpy()
|
| 127 |
+
for b in range(images.shape[0]):
|
| 128 |
+
if n >= args.max_viz:
|
| 129 |
+
break
|
| 130 |
+
t = images[b].cpu().permute(1, 2, 0).numpy()
|
| 131 |
+
rgb = (t * std + mean) * 255.0
|
| 132 |
+
rgb = np.clip(rgb, 0, 255).astype(np.uint8)
|
| 133 |
+
save_triplet(
|
| 134 |
+
out_dir / f"val_{n:04d}.png",
|
| 135 |
+
rgb,
|
| 136 |
+
gt[b],
|
| 137 |
+
pred[b],
|
| 138 |
+
colors,
|
| 139 |
+
)
|
| 140 |
+
n += 1
|
| 141 |
+
if n >= args.max_viz:
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
scripts/eval_summary.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Print segmentation metrics: mIoU (all classes + valid-GT-only), fwIoU, accuracies, GT counts.
|
| 3 |
+
|
| 4 |
+
Runs a full validation pass by default (same setup as ``scripts/eval.py``). With
|
| 5 |
+
``--from-checkpoint-only``, only prints metrics stored inside the checkpoint file
|
| 6 |
+
(mIoU and per-class IoU when present); full metrics require a validation forward pass."""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any, Dict, List
|
| 17 |
+
|
| 18 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
+
if str(ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(ROOT))
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
|
| 25 |
+
from desert_segmentation.data.dataset import SegmentationDataset
|
| 26 |
+
from desert_segmentation.data.mask_encoding import build_codec_from_config
|
| 27 |
+
from desert_segmentation.data.transforms import build_val_transforms
|
| 28 |
+
from desert_segmentation.metrics.iou import (
|
| 29 |
+
confusion_to_accuracy_metrics,
|
| 30 |
+
gt_pixel_counts,
|
| 31 |
+
valid_class_miou_from_confusion,
|
| 32 |
+
)
|
| 33 |
+
from desert_segmentation.models.factory import create_model
|
| 34 |
+
from desert_segmentation.train.evaluate import evaluate
|
| 35 |
+
from desert_segmentation.utils.config import get_paths, load_config
|
| 36 |
+
from desert_segmentation.utils.logging_utils import setup_logging
|
| 37 |
+
from desert_segmentation.utils.seed import set_seed
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _load_checkpoint(path: Path, device: torch.device) -> Dict[str, Any]:
|
| 41 |
+
try:
|
| 42 |
+
return torch.load(path, map_location=device, weights_only=False)
|
| 43 |
+
except TypeError:
|
| 44 |
+
return torch.load(path, map_location=device)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _print_table(rows: List[List[str]]) -> None:
|
| 48 |
+
widths = [max(len(rows[i][c]) for i in range(len(rows))) for c in range(len(rows[0]))]
|
| 49 |
+
for row in rows:
|
| 50 |
+
line = " ".join(row[c].ljust(widths[c]) for c in range(len(row)))
|
| 51 |
+
print(line)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def run_from_checkpoint_only(ckpt_path: Path) -> int:
|
| 55 |
+
ckpt = _load_checkpoint(ckpt_path, torch.device("cpu"))
|
| 56 |
+
print(f"Checkpoint: {ckpt_path.resolve()}")
|
| 57 |
+
print()
|
| 58 |
+
if "miou" in ckpt:
|
| 59 |
+
print(f" mIoU (stored): {float(ckpt['miou']):.6f}")
|
| 60 |
+
else:
|
| 61 |
+
print(" mIoU: (not stored in this file)")
|
| 62 |
+
names = ckpt.get("class_names")
|
| 63 |
+
per = ckpt.get("per_class_iou")
|
| 64 |
+
if per is not None and names is not None:
|
| 65 |
+
print(" Per-class IoU (stored):")
|
| 66 |
+
for i, name in enumerate(names):
|
| 67 |
+
print(f" [{i}] {name}: {float(per[i]):.6f}")
|
| 68 |
+
elif per is not None:
|
| 69 |
+
print(" Per-class IoU (stored):")
|
| 70 |
+
for i, v in enumerate(per):
|
| 71 |
+
print(f" [{i}]: {float(v):.6f}")
|
| 72 |
+
else:
|
| 73 |
+
print(" Per-class IoU: (not stored in this file)")
|
| 74 |
+
print()
|
| 75 |
+
print(
|
| 76 |
+
"Note: fwIoU, global pixel accuracy, and mean class accuracy are not saved in "
|
| 77 |
+
"checkpoints. Run without --from-checkpoint-only to compute them on the val set."
|
| 78 |
+
)
|
| 79 |
+
return 0
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def run_full_eval(args: argparse.Namespace) -> int:
|
| 83 |
+
root = Path(args.root or ROOT).resolve()
|
| 84 |
+
cfg = load_config(args.config, root=root)
|
| 85 |
+
setup_logging()
|
| 86 |
+
set_seed(int(cfg["train"]["seed"]))
|
| 87 |
+
|
| 88 |
+
paths = get_paths(cfg)
|
| 89 |
+
raw_ids = cfg["data"]["raw_ids"]
|
| 90 |
+
names = tuple(cfg["data"]["class_names"])
|
| 91 |
+
codec = build_codec_from_config(raw_ids, names)
|
| 92 |
+
ignore_index = int(cfg["data"].get("ignore_index", 255))
|
| 93 |
+
crop_size = int(cfg["data"]["crop_size"])
|
| 94 |
+
|
| 95 |
+
val_tf = build_val_transforms(crop_size=crop_size, ignore_index=ignore_index)
|
| 96 |
+
val_ds = SegmentationDataset(
|
| 97 |
+
paths["val_images"],
|
| 98 |
+
paths["val_masks"],
|
| 99 |
+
codec=codec,
|
| 100 |
+
transform=val_tf,
|
| 101 |
+
mode="val",
|
| 102 |
+
crop_size=crop_size,
|
| 103 |
+
rare_class_crop_prob=0.0,
|
| 104 |
+
ignore_index=ignore_index,
|
| 105 |
+
seed=int(cfg["train"]["seed"]),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
nw = 0 if os.name == "nt" else int(cfg["data"].get("num_workers", 4))
|
| 109 |
+
val_loader = DataLoader(
|
| 110 |
+
val_ds,
|
| 111 |
+
batch_size=int(cfg["train"].get("val_batch_size", cfg["train"]["batch_size"])),
|
| 112 |
+
shuffle=False,
|
| 113 |
+
num_workers=nw,
|
| 114 |
+
pin_memory=torch.cuda.is_available(),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 118 |
+
ckpt_path = Path(args.checkpoint)
|
| 119 |
+
ckpt = _load_checkpoint(ckpt_path, device)
|
| 120 |
+
cfg_ck = ckpt["config"]
|
| 121 |
+
model = create_model(cfg_ck["model"], num_classes=codec.num_classes).to(device)
|
| 122 |
+
if ckpt.get("ema") is not None:
|
| 123 |
+
for n, p in model.named_parameters():
|
| 124 |
+
if n in ckpt["ema"]:
|
| 125 |
+
p.data.copy_(ckpt["ema"][n].to(device))
|
| 126 |
+
else:
|
| 127 |
+
model.load_state_dict(ckpt["model"])
|
| 128 |
+
model.eval()
|
| 129 |
+
|
| 130 |
+
metrics = evaluate(
|
| 131 |
+
model,
|
| 132 |
+
val_loader,
|
| 133 |
+
device,
|
| 134 |
+
num_classes=codec.num_classes,
|
| 135 |
+
ignore_index=ignore_index,
|
| 136 |
+
desc="eval_summary",
|
| 137 |
+
)
|
| 138 |
+
cm = metrics["confusion"]
|
| 139 |
+
acc = confusion_to_accuracy_metrics(cm)
|
| 140 |
+
miou_valid = float(valid_class_miou_from_confusion(cm))
|
| 141 |
+
gt_counts = gt_pixel_counts(cm)
|
| 142 |
+
|
| 143 |
+
miou = float(metrics["miou"])
|
| 144 |
+
fw_iou = float(metrics["fw_iou"])
|
| 145 |
+
gpa = float(acc["global_pixel_accuracy"])
|
| 146 |
+
mca = float(acc["mean_class_accuracy"])
|
| 147 |
+
per_iou = metrics["per_class_iou"]
|
| 148 |
+
per_rec = acc["per_class_recall"]
|
| 149 |
+
|
| 150 |
+
def _rec_str(i: int) -> str:
|
| 151 |
+
v = float(per_rec[i])
|
| 152 |
+
if math.isnan(v):
|
| 153 |
+
return "n/a"
|
| 154 |
+
return f"{v:.6f}"
|
| 155 |
+
|
| 156 |
+
print()
|
| 157 |
+
print(f"Checkpoint: {ckpt_path.resolve()}")
|
| 158 |
+
print(f"Val images: {paths['val_images']}")
|
| 159 |
+
print(f"Val samples: {len(val_ds)}")
|
| 160 |
+
print()
|
| 161 |
+
print(" mIoU (all classes): {:.6f}".format(miou))
|
| 162 |
+
print(" mIoU (classes w/ GT): {:.6f}".format(miou_valid))
|
| 163 |
+
print(" Frequency-weighted IoU: {:.6f}".format(fw_iou))
|
| 164 |
+
print(" Global pixel accuracy: {:.6f}".format(gpa))
|
| 165 |
+
print(" Mean class accuracy: {:.6f}".format(mca))
|
| 166 |
+
print(" (mean of per-class recall over classes with GT pixels)")
|
| 167 |
+
print()
|
| 168 |
+
table: List[List[str]] = [["cls", "name", "IoU", "recall"]]
|
| 169 |
+
for i, name in enumerate(codec.class_names):
|
| 170 |
+
table.append(
|
| 171 |
+
[
|
| 172 |
+
str(i),
|
| 173 |
+
name,
|
| 174 |
+
f"{float(per_iou[i]):.6f}",
|
| 175 |
+
_rec_str(i),
|
| 176 |
+
]
|
| 177 |
+
)
|
| 178 |
+
_print_table(table)
|
| 179 |
+
print()
|
| 180 |
+
print(" Val GT pixels per class (full val set):")
|
| 181 |
+
for i, name in enumerate(codec.class_names):
|
| 182 |
+
print(f" [{i}] {name}: {int(gt_counts[i])}")
|
| 183 |
+
print()
|
| 184 |
+
|
| 185 |
+
payload = {
|
| 186 |
+
"checkpoint": str(ckpt_path.resolve()),
|
| 187 |
+
"val_dir": str(paths["val_images"]),
|
| 188 |
+
"num_val_samples": len(val_ds),
|
| 189 |
+
"miou": miou,
|
| 190 |
+
"miou_all_classes": miou,
|
| 191 |
+
"miou_valid_gt_classes": miou_valid,
|
| 192 |
+
"fw_iou": fw_iou,
|
| 193 |
+
"global_pixel_accuracy": gpa,
|
| 194 |
+
"mean_class_accuracy": mca,
|
| 195 |
+
"per_class_iou": {codec.class_names[i]: float(per_iou[i]) for i in range(len(codec.class_names))},
|
| 196 |
+
"per_class_recall": {
|
| 197 |
+
codec.class_names[i]: (None if math.isnan(float(per_rec[i])) else float(per_rec[i]))
|
| 198 |
+
for i in range(len(codec.class_names))
|
| 199 |
+
},
|
| 200 |
+
"val_gt_pixel_counts": {codec.class_names[i]: int(gt_counts[i]) for i in range(len(codec.class_names))},
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
if args.json_out:
|
| 204 |
+
out = Path(args.json_out)
|
| 205 |
+
if not out.is_absolute():
|
| 206 |
+
out = root / out
|
| 207 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 208 |
+
with out.open("w", encoding="utf-8") as f:
|
| 209 |
+
json.dump(payload, f, indent=2)
|
| 210 |
+
print(f"Wrote {out}")
|
| 211 |
+
|
| 212 |
+
return 0
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def main() -> None:
|
| 216 |
+
parser = argparse.ArgumentParser(description="Segmentation metric summary (val set).")
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--checkpoint",
|
| 219 |
+
type=str,
|
| 220 |
+
default=None,
|
| 221 |
+
help="Path to .pt checkpoint (default: <root>/checkpoints/best.pt)",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--config",
|
| 225 |
+
type=str,
|
| 226 |
+
default=str(ROOT / "desert_segmentation" / "configs" / "default.yaml"),
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument("--root", type=str, default=None, help="Workspace root (defaults to repo root)")
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--from-checkpoint-only",
|
| 231 |
+
action="store_true",
|
| 232 |
+
help="Only print mIoU/per-class IoU stored in the file (no forward pass).",
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--json-out",
|
| 236 |
+
type=str,
|
| 237 |
+
default=None,
|
| 238 |
+
help="Optional path to write full metrics JSON (relative to --root unless absolute).",
|
| 239 |
+
)
|
| 240 |
+
args = parser.parse_args()
|
| 241 |
+
root = Path(args.root or ROOT).resolve()
|
| 242 |
+
ck_path = Path(args.checkpoint) if args.checkpoint else root / "checkpoints" / "best.pt"
|
| 243 |
+
|
| 244 |
+
if args.from_checkpoint_only:
|
| 245 |
+
if not ck_path.is_file():
|
| 246 |
+
print(f"Error: checkpoint not found: {ck_path}", file=sys.stderr)
|
| 247 |
+
sys.exit(1)
|
| 248 |
+
sys.exit(run_from_checkpoint_only(ck_path))
|
| 249 |
+
|
| 250 |
+
args.checkpoint = str(ck_path)
|
| 251 |
+
args.root = str(root)
|
| 252 |
+
if not ck_path.is_file():
|
| 253 |
+
print(f"Error: checkpoint not found: {ck_path}", file=sys.stderr)
|
| 254 |
+
sys.exit(1)
|
| 255 |
+
sys.exit(run_full_eval(args))
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
main()
|
scripts/infer.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Run inference on testing/Color_Images; optional ONNX export."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(ROOT))
|
| 14 |
+
|
| 15 |
+
from desert_segmentation.infer.predict import export_onnx, predict_folder
|
| 16 |
+
from desert_segmentation.utils.config import get_paths, load_config
|
| 17 |
+
from desert_segmentation.utils.logging_utils import setup_logging
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main() -> None:
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument("--config", type=str, default=str(ROOT / "desert_segmentation" / "configs" / "default.yaml"))
|
| 25 |
+
parser.add_argument("--checkpoint", type=str, required=True)
|
| 26 |
+
parser.add_argument("--root", type=str, default=None)
|
| 27 |
+
parser.add_argument("--out_dir", type=str, default="infer_outputs")
|
| 28 |
+
parser.add_argument("--limit", type=int, default=None)
|
| 29 |
+
parser.add_argument("--onnx", type=str, default=None, help="If set, export ONNX to this path and exit")
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
root = Path(args.root or ROOT).resolve()
|
| 33 |
+
cfg = load_config(args.config, root=root)
|
| 34 |
+
setup_logging()
|
| 35 |
+
|
| 36 |
+
if args.onnx:
|
| 37 |
+
export_onnx(Path(args.checkpoint), Path(args.onnx))
|
| 38 |
+
logger.info("exported ONNX to %s", args.onnx)
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
paths = get_paths(cfg)
|
| 42 |
+
out_dir = Path(args.out_dir)
|
| 43 |
+
if not out_dir.is_absolute():
|
| 44 |
+
out_dir = root / out_dir
|
| 45 |
+
predict_folder(Path(args.checkpoint), paths["test_images"], out_dir, limit=args.limit)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
main()
|
scripts/train.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Train semantic segmentation model from YAML config."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
+
if str(ROOT) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(ROOT))
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import DataLoader, WeightedRandomSampler
|
| 18 |
+
|
| 19 |
+
from desert_segmentation.data.dataset import SegmentationDataset
|
| 20 |
+
from desert_segmentation.data.mask_encoding import build_codec_from_config
|
| 21 |
+
from desert_segmentation.data.transforms import build_train_transforms, build_val_transforms
|
| 22 |
+
from desert_segmentation.losses.combined import build_loss, compute_class_weights_from_freq
|
| 23 |
+
from desert_segmentation.models.factory import create_model
|
| 24 |
+
from desert_segmentation.train.trainer import train
|
| 25 |
+
from desert_segmentation.utils.config import get_paths, load_config
|
| 26 |
+
from desert_segmentation.utils.freq import estimate_pixel_frequencies, per_image_sampling_weights
|
| 27 |
+
from desert_segmentation.utils.logging_utils import setup_logging
|
| 28 |
+
from desert_segmentation.utils.seed import set_seed
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main() -> None:
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--config",
|
| 37 |
+
type=str,
|
| 38 |
+
default=str(ROOT / "desert_segmentation" / "configs" / "default.yaml"),
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument("--root", type=str, default=None, help="Workspace root (defaults to repo root)")
|
| 41 |
+
parser.add_argument("--epochs", type=int, default=None, help="Override epochs (smoke tests)")
|
| 42 |
+
parser.add_argument("--max_train_batches", type=int, default=None, help="Stop each epoch after N batches (smoke tests)")
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
root = Path(args.root or ROOT).resolve()
|
| 46 |
+
cfg = load_config(args.config, root=root)
|
| 47 |
+
if args.epochs is not None:
|
| 48 |
+
cfg["train"]["epochs"] = int(args.epochs)
|
| 49 |
+
setup_logging()
|
| 50 |
+
set_seed(int(cfg["train"]["seed"]))
|
| 51 |
+
|
| 52 |
+
paths = get_paths(cfg)
|
| 53 |
+
raw_ids = cfg["data"]["raw_ids"]
|
| 54 |
+
names = tuple(cfg["data"]["class_names"])
|
| 55 |
+
codec = build_codec_from_config(raw_ids, names)
|
| 56 |
+
ignore_index = int(cfg["data"].get("ignore_index", 255))
|
| 57 |
+
crop_size = int(cfg["data"]["crop_size"])
|
| 58 |
+
|
| 59 |
+
train_tf = build_train_transforms(
|
| 60 |
+
crop_size=crop_size,
|
| 61 |
+
strong=bool(cfg.get("augmentation", {}).get("strong", True)),
|
| 62 |
+
ignore_index=ignore_index,
|
| 63 |
+
)
|
| 64 |
+
val_tf = build_val_transforms(crop_size=crop_size, ignore_index=ignore_index)
|
| 65 |
+
|
| 66 |
+
train_ds = SegmentationDataset(
|
| 67 |
+
paths["train_images"],
|
| 68 |
+
paths["train_masks"],
|
| 69 |
+
codec=codec,
|
| 70 |
+
transform=train_tf,
|
| 71 |
+
mode="train",
|
| 72 |
+
crop_size=crop_size,
|
| 73 |
+
rare_class_crop_prob=float(cfg["data"].get("rare_class_crop_prob", 0.35)),
|
| 74 |
+
ignore_index=ignore_index,
|
| 75 |
+
seed=int(cfg["train"]["seed"]),
|
| 76 |
+
)
|
| 77 |
+
val_ds = SegmentationDataset(
|
| 78 |
+
paths["val_images"],
|
| 79 |
+
paths["val_masks"],
|
| 80 |
+
codec=codec,
|
| 81 |
+
transform=val_tf,
|
| 82 |
+
mode="val",
|
| 83 |
+
crop_size=crop_size,
|
| 84 |
+
rare_class_crop_prob=0.0,
|
| 85 |
+
ignore_index=ignore_index,
|
| 86 |
+
seed=int(cfg["train"]["seed"]),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
nw = int(cfg["data"].get("num_workers", 4))
|
| 90 |
+
if os.name == "nt":
|
| 91 |
+
nw = 0
|
| 92 |
+
|
| 93 |
+
val_loader = DataLoader(
|
| 94 |
+
val_ds,
|
| 95 |
+
batch_size=int(cfg["train"].get("val_batch_size", cfg["train"]["batch_size"])),
|
| 96 |
+
shuffle=False,
|
| 97 |
+
num_workers=nw,
|
| 98 |
+
pin_memory=torch.cuda.is_available(),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 102 |
+
model = create_model(cfg["model"], num_classes=codec.num_classes).to(device)
|
| 103 |
+
|
| 104 |
+
freq = estimate_pixel_frequencies(paths["train_masks"], codec, max_files=None)
|
| 105 |
+
cap = float(cfg.get("loss", {}).get("class_weight_cap", 15.0))
|
| 106 |
+
class_w = compute_class_weights_from_freq(freq, cap=cap).to(device)
|
| 107 |
+
logger.info("class pixel frequencies (train masks): %s", freq.tolist())
|
| 108 |
+
|
| 109 |
+
use_weighted_sampler = bool(cfg.get("data", {}).get("weighted_sampler", False))
|
| 110 |
+
sampler: WeightedRandomSampler | None = None
|
| 111 |
+
if use_weighted_sampler:
|
| 112 |
+
eps = float(cfg.get("data", {}).get("weighted_sampler_eps", 1e-6))
|
| 113 |
+
logger.info("computing per-image sampling weights (scanning train masks)...")
|
| 114 |
+
sample_w = per_image_sampling_weights(
|
| 115 |
+
paths["train_masks"],
|
| 116 |
+
train_ds.image_names,
|
| 117 |
+
codec,
|
| 118 |
+
freq,
|
| 119 |
+
eps=eps,
|
| 120 |
+
)
|
| 121 |
+
sampler = WeightedRandomSampler(
|
| 122 |
+
sample_w,
|
| 123 |
+
num_samples=len(train_ds),
|
| 124 |
+
replacement=True,
|
| 125 |
+
generator=torch.Generator().manual_seed(int(cfg["train"]["seed"])),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
train_loader = DataLoader(
|
| 129 |
+
train_ds,
|
| 130 |
+
batch_size=int(cfg["train"]["batch_size"]),
|
| 131 |
+
shuffle=sampler is None,
|
| 132 |
+
sampler=sampler,
|
| 133 |
+
num_workers=nw,
|
| 134 |
+
pin_memory=torch.cuda.is_available(),
|
| 135 |
+
drop_last=True,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
criterion = build_loss(
|
| 139 |
+
cfg.get("loss", {}),
|
| 140 |
+
num_classes=codec.num_classes,
|
| 141 |
+
class_weights=class_w,
|
| 142 |
+
ignore_index=ignore_index,
|
| 143 |
+
).to(device)
|
| 144 |
+
|
| 145 |
+
ckpt_dir = Path(cfg["train"]["checkpoint_dir"])
|
| 146 |
+
if not ckpt_dir.is_absolute():
|
| 147 |
+
ckpt_dir = root / ckpt_dir
|
| 148 |
+
|
| 149 |
+
out = train(
|
| 150 |
+
model,
|
| 151 |
+
train_loader,
|
| 152 |
+
val_loader,
|
| 153 |
+
criterion,
|
| 154 |
+
device,
|
| 155 |
+
cfg,
|
| 156 |
+
num_classes=codec.num_classes,
|
| 157 |
+
ignore_index=ignore_index,
|
| 158 |
+
checkpoint_dir=ckpt_dir,
|
| 159 |
+
class_names=codec.class_names,
|
| 160 |
+
max_train_batches=args.max_train_batches,
|
| 161 |
+
)
|
| 162 |
+
logger.info("finished best_mIoU=%s path=%s", out["best_miou"], out["best_path"])
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
main()
|
tests/test_confusion_metrics.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for accuracy metrics derived from confusion matrices."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from desert_segmentation.metrics.iou import (
|
| 6 |
+
confusion_to_accuracy_metrics,
|
| 7 |
+
valid_class_miou_from_confusion,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_perfect_confusion():
|
| 12 |
+
cm = np.eye(3, dtype=np.int64) * 100
|
| 13 |
+
out = confusion_to_accuracy_metrics(cm)
|
| 14 |
+
assert out["global_pixel_accuracy"] == 1.0
|
| 15 |
+
assert out["mean_class_accuracy"] == 1.0
|
| 16 |
+
assert np.allclose(out["per_class_recall"], [1.0, 1.0, 1.0])
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_two_class_mixed():
|
| 20 |
+
# GT: 100 class0, 100 class1; half wrong each
|
| 21 |
+
cm = np.array([[50, 50], [50, 50]], dtype=np.int64)
|
| 22 |
+
out = confusion_to_accuracy_metrics(cm)
|
| 23 |
+
assert out["global_pixel_accuracy"] == 0.5
|
| 24 |
+
assert abs(out["mean_class_accuracy"] - 0.5) < 1e-6
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_one_class_absent_in_gt():
|
| 28 |
+
# Only class 0 appears in val GT; 80 correct, 20 predicted as class 1.
|
| 29 |
+
cm = np.array([[80, 20], [0, 0]], dtype=np.int64)
|
| 30 |
+
out = confusion_to_accuracy_metrics(cm)
|
| 31 |
+
assert abs(out["global_pixel_accuracy"] - 0.8) < 1e-9
|
| 32 |
+
assert abs(out["mean_class_accuracy"] - 0.8) < 1e-9
|
| 33 |
+
assert np.isnan(out["per_class_recall"][1])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_valid_class_miou_only_classes_with_gt():
|
| 37 |
+
# Two classes in GT; full mIoU averages zeros for empty rows if any — here both rows have GT.
|
| 38 |
+
cm = np.array([[90, 10], [10, 90]], dtype=np.int64)
|
| 39 |
+
# IoU class0: 90/(90+10+10)=90/110, class1: 90/110
|
| 40 |
+
v = valid_class_miou_from_confusion(cm)
|
| 41 |
+
iou0 = 90.0 / (90 + 10 + 10)
|
| 42 |
+
assert abs(v - iou0) < 1e-6
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_valid_class_miou_ignores_empty_gt_rows():
|
| 46 |
+
# Class 1 has no GT pixels; valid-class mIoU averages only class 0.
|
| 47 |
+
cm = np.array([[80, 20], [0, 0]], dtype=np.int64)
|
| 48 |
+
v = valid_class_miou_from_confusion(cm)
|
| 49 |
+
# IoU class 0: TP=80, union = rows[0]+cols[0]-TP = 100+80-80 = 100
|
| 50 |
+
assert abs(v - 0.8) < 1e-9
|
tests/test_mask_encoding.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
from desert_segmentation.data.mask_encoding import RawMaskCodec, default_desert_codec
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_roundtrip_known_ids():
|
| 8 |
+
codec = default_desert_codec()
|
| 9 |
+
h, w = 32, 48
|
| 10 |
+
raw = np.full((h, w), 100, dtype=np.uint16)
|
| 11 |
+
raw[:, :10] = 10000
|
| 12 |
+
raw[10:20, :] = 7100
|
| 13 |
+
enc, unk = codec.encode_mask(raw)
|
| 14 |
+
assert unk == 0.0
|
| 15 |
+
assert enc.shape == (h, w)
|
| 16 |
+
back = codec.decode_to_raw(enc)
|
| 17 |
+
assert np.array_equal(back, raw)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_unknown_pixel_raises():
|
| 21 |
+
codec = RawMaskCodec(raw_ids=(1, 2), class_names=("a", "b"))
|
| 22 |
+
raw = np.array([[1, 2], [99, 1]], dtype=np.uint16)
|
| 23 |
+
with pytest.raises(ValueError):
|
| 24 |
+
codec.encode_mask(raw)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_lut_all_ids():
|
| 28 |
+
codec = default_desert_codec()
|
| 29 |
+
for rid in codec.raw_ids:
|
| 30 |
+
raw = np.full((4, 4), rid, dtype=np.uint16)
|
| 31 |
+
enc, unk = codec.encode_mask(raw)
|
| 32 |
+
assert unk == 0.0
|
| 33 |
+
assert np.unique(enc).size == 1
|