Spaces:
Configuration error
Configuration error
XuJunHao-TJ
commited on
Commit
•
3e7a77e
1
Parent(s):
6c23ccc
Upload 4 files
Browse files- LICENSE.txt +21 -0
- README.md +21 -9
- example.ipynb +776 -0
- gitignore.txt +106 -0
LICENSE.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Jonathan Ventura
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# canopy
|
2 |
+
Automatic tree species classification from remote sensing data
|
3 |
+
|
4 |
+
Code from our paper:
|
5 |
+
|
6 |
+
[Fricker, G. A., Ventura, J. D., Wolf, J. A., North, M. P., Davis, F. W., & Franklin, J. (2019). A Convolutional Neural Network Classifier Identifies Tree Species in Mixed-Conifer Forest from Hyperspectral Imagery. Remote Sensing, 11(19), 2326.](https://www.mdpi.com/2072-4292/11/19/2326)
|
7 |
+
|
8 |
---
|
9 |
+
To create a conda environment:
|
10 |
+
|
11 |
+
conda create -n canopy tensorflow=1.10.0 pip ;
|
12 |
+
conda activate canopy ;
|
13 |
+
pip3 install -r docker/requirements.txt ;
|
14 |
+
|
|
|
|
|
15 |
---
|
16 |
|
17 |
+
To run the example experiment:
|
18 |
+
|
19 |
+
mkdir hyperspectral ;
|
20 |
+
python -m experiment.download ;
|
21 |
+
python -m experiment.preprocess --out hyperspectral;
|
22 |
+
python -m experiment.train --out hyperspectral;
|
23 |
+
python -m experiment.test --out hyperspectral;
|
24 |
+
python -m experiment.analyze --out hyperspectral;
|
example.ipynb
ADDED
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"### Tree species classification example\n",
|
8 |
+
"This notebook gives an example of using a convolutional neural network to classify tree species in the Sierra Nevada forest."
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"metadata": {},
|
14 |
+
"source": [
|
15 |
+
"First we download the NEON data and label files from our dataset stored on Zenodo."
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "code",
|
20 |
+
"execution_count": 1,
|
21 |
+
"metadata": {},
|
22 |
+
"outputs": [],
|
23 |
+
"source": [
|
24 |
+
"import os\n",
|
25 |
+
"import sys\n",
|
26 |
+
"import tqdm\n",
|
27 |
+
"import argparse\n",
|
28 |
+
"\n",
|
29 |
+
"from wget import download\n",
|
30 |
+
"\n",
|
31 |
+
"from experiment.paths import *\n",
|
32 |
+
"\n",
|
33 |
+
"# make output directory if necessary\n",
|
34 |
+
"if not os.path.exists('data'):\n",
|
35 |
+
" os.makedirs('data')\n",
|
36 |
+
"\n",
|
37 |
+
"files = [ 'Labels_Trimmed_Selective.CPG',\n",
|
38 |
+
" 'Labels_Trimmed_Selective.dbf',\n",
|
39 |
+
" 'Labels_Trimmed_Selective.prj',\n",
|
40 |
+
" 'Labels_Trimmed_Selective.sbn',\n",
|
41 |
+
" 'Labels_Trimmed_Selective.sbx',\n",
|
42 |
+
" 'Labels_Trimmed_Selective.shp',\n",
|
43 |
+
" 'Labels_Trimmed_Selective.shp.xml',\n",
|
44 |
+
" 'Labels_Trimmed_Selective.shx',\n",
|
45 |
+
" 'NEON_D17_TEAK_DP1_20170627_181333_reflectance.tif',\n",
|
46 |
+
" 'NEON_D17_TEAK_DP1_20170627_181333_reflectance.tif.aux.xml',\n",
|
47 |
+
" 'NEON_D17_TEAK_DP1_20170627_181333_reflectance.tif.enp',\n",
|
48 |
+
" 'NEON_D17_TEAK_DP1_20170627_181333_reflectance.tif.ovr',\n",
|
49 |
+
" 'D17_CHM_all.tfw',\n",
|
50 |
+
" 'D17_CHM_all.tif',\n",
|
51 |
+
" 'D17_CHM_all.tif.aux.xml',\n",
|
52 |
+
" 'D17_CHM_all.tif.ovr',\n",
|
53 |
+
" ]\n",
|
54 |
+
"\n",
|
55 |
+
"for f in files:\n",
|
56 |
+
" if not os.path.exists('data/%s'%f):\n",
|
57 |
+
" print('downloading %s'%f)\n",
|
58 |
+
" download('https://zenodo.org/record/3468720/files/%s?download=1'%f,'data/%s'%f)\n",
|
59 |
+
" print('')"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "markdown",
|
64 |
+
"metadata": {},
|
65 |
+
"source": [
|
66 |
+
"Next we loads and co-register our data sources, including the hyperspectral image, the canopy height model, and the tree labels. Then we build a dataset of patches and their corresponding labels and store it in a HDF5 file for easy use in Keras."
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 2,
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [
|
74 |
+
{
|
75 |
+
"name": "stderr",
|
76 |
+
"output_type": "stream",
|
77 |
+
"text": [
|
78 |
+
"100%|██████████| 15668/15668 [05:17<00:00, 49.38it/s]\n",
|
79 |
+
"100%|██████████| 1909/1909 [00:39<00:00, 48.41it/s]\n"
|
80 |
+
]
|
81 |
+
}
|
82 |
+
],
|
83 |
+
"source": [
|
84 |
+
"import numpy as np\n",
|
85 |
+
"import tqdm\n",
|
86 |
+
"from experiment.paths import *\n",
|
87 |
+
"import os\n",
|
88 |
+
"\n",
|
89 |
+
"from canopy.vector_utils import *\n",
|
90 |
+
"from canopy.extract import *\n",
|
91 |
+
"import h5py as h5\n",
|
92 |
+
"\n",
|
93 |
+
"from sklearn.model_selection import train_test_split\n",
|
94 |
+
"from sklearn.cluster import KMeans\n",
|
95 |
+
"\n",
|
96 |
+
"# Load the metadata from the image.\n",
|
97 |
+
"with rasterio.open(image_uri) as src:\n",
|
98 |
+
" image_meta = src.meta.copy()\n",
|
99 |
+
"\n",
|
100 |
+
"os.makedirs('example',exist_ok=True)\n",
|
101 |
+
"\n",
|
102 |
+
"seed = 0\n",
|
103 |
+
"\n",
|
104 |
+
"# Load the shapefile and transform it to the hypersectral image's CRS.\n",
|
105 |
+
"polygons, labels = load_and_transform_shapefile(labels_shp_uri,'SP',image_meta['crs'])\n",
|
106 |
+
"\n",
|
107 |
+
"# Cluster polygons for use in stratified sampling\n",
|
108 |
+
"centroids = np.stack([np.mean(np.array(poly['coordinates'][0]),axis=0) for poly in polygons])\n",
|
109 |
+
"cluster_ids = KMeans(10).fit_predict(centroids)\n",
|
110 |
+
"rasterize_shapefile(polygons, cluster_ids, image_meta, 'example/clusters.tiff')\n",
|
111 |
+
"stratify = cluster_ids\n",
|
112 |
+
"\n",
|
113 |
+
"# alternative: stratify by species label\n",
|
114 |
+
"# stratify = labels\n",
|
115 |
+
"\n",
|
116 |
+
"# Split up polygons into train, val, test here\n",
|
117 |
+
"train_inds, test_inds = train_test_split(range(len(polygons)),test_size=0.1,random_state=seed,stratify=stratify)\n",
|
118 |
+
"\n",
|
119 |
+
"# Save ids of train,val,test polygons\n",
|
120 |
+
"with open('example/' + train_ids_uri,'w') as f:\n",
|
121 |
+
" f.writelines([\"%d\\n\"%ind for ind in train_inds])\n",
|
122 |
+
"with open('example/' + test_ids_uri,'w') as f:\n",
|
123 |
+
" f.writelines([\"%d\\n\"%ind for ind in test_inds])\n",
|
124 |
+
"\n",
|
125 |
+
"# Separate out polygons\n",
|
126 |
+
"train_polygons = [polygons[ind] for ind in train_inds]\n",
|
127 |
+
"train_labels = [labels[ind] for ind in train_inds]\n",
|
128 |
+
"test_polygons = [polygons[ind] for ind in test_inds]\n",
|
129 |
+
"test_labels = [labels[ind] for ind in test_inds]\n",
|
130 |
+
"\n",
|
131 |
+
"# Rasterize the shapefile to a TIFF. Using LZW compression, the resulting file is pretty small.\n",
|
132 |
+
"train_labels_raster = rasterize_shapefile(train_polygons, train_labels, image_meta, 'example/' + train_labels_uri)\n",
|
133 |
+
"test_labels_raster = rasterize_shapefile(test_polygons, test_labels, image_meta, 'example/' + test_labels_uri)\n",
|
134 |
+
"\n",
|
135 |
+
"# Extract patches and labels\n",
|
136 |
+
"patch_radius = 7\n",
|
137 |
+
"height_threshold = 5\n",
|
138 |
+
"train_image_patches, train_patch_labels = extract_patches(image_uri,patch_radius,chm_uri,height_threshold,'example/' + train_labels_uri)\n",
|
139 |
+
"test_image_patches, test_patch_labels = extract_patches(image_uri,patch_radius,chm_uri,height_threshold,'example/' + test_labels_uri)\n"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "markdown",
|
144 |
+
"metadata": {},
|
145 |
+
"source": [
|
146 |
+
"Now we set up and train the convolutional neural network model."
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "code",
|
151 |
+
"execution_count": 6,
|
152 |
+
"metadata": {},
|
153 |
+
"outputs": [
|
154 |
+
{
|
155 |
+
"name": "stdout",
|
156 |
+
"output_type": "stream",
|
157 |
+
"text": [
|
158 |
+
"class weights: [ 0.74829501 2.29405615 1.21758085 0.48317187 0.7970631 24.93668831\n",
|
159 |
+
" 2.45540281 0.61169959]\n",
|
160 |
+
"(15, 15, 32) int16\n",
|
161 |
+
"() uint8\n",
|
162 |
+
"_________________________________________________________________\n",
|
163 |
+
"Layer (type) Output Shape Param # \n",
|
164 |
+
"=================================================================\n",
|
165 |
+
"input_7 (InputLayer) (None, 15, 15, 32) 0 \n",
|
166 |
+
"_________________________________________________________________\n",
|
167 |
+
"conv2d_16 (Conv2D) (None, 13, 13, 32) 9248 \n",
|
168 |
+
"_________________________________________________________________\n",
|
169 |
+
"conv2d_17 (Conv2D) (None, 11, 11, 64) 18496 \n",
|
170 |
+
"_________________________________________________________________\n",
|
171 |
+
"conv2d_18 (Conv2D) (None, 9, 9, 128) 73856 \n",
|
172 |
+
"_________________________________________________________________\n",
|
173 |
+
"conv2d_19 (Conv2D) (None, 7, 7, 128) 147584 \n",
|
174 |
+
"_________________________________________________________________\n",
|
175 |
+
"conv2d_20 (Conv2D) (None, 5, 5, 128) 147584 \n",
|
176 |
+
"_________________________________________________________________\n",
|
177 |
+
"conv2d_21 (Conv2D) (None, 3, 3, 128) 147584 \n",
|
178 |
+
"_________________________________________________________________\n",
|
179 |
+
"conv2d_22 (Conv2D) (None, 1, 1, 128) 147584 \n",
|
180 |
+
"_________________________________________________________________\n",
|
181 |
+
"conv2d_23 (Conv2D) (None, 1, 1, 8) 1032 \n",
|
182 |
+
"_________________________________________________________________\n",
|
183 |
+
"flatten_2 (Flatten) (None, 8) 0 \n",
|
184 |
+
"=================================================================\n",
|
185 |
+
"Total params: 692,968\n",
|
186 |
+
"Trainable params: 692,968\n",
|
187 |
+
"Non-trainable params: 0\n",
|
188 |
+
"_________________________________________________________________\n",
|
189 |
+
"None\n"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"name": "stderr",
|
194 |
+
"output_type": "stream",
|
195 |
+
"text": [
|
196 |
+
"augmenting images: 100%|██████████| 122888/122888 [00:01<00:00, 73026.99it/s]\n"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"name": "stdout",
|
201 |
+
"output_type": "stream",
|
202 |
+
"text": [
|
203 |
+
"Train on 110599 samples, validate on 12289 samples\n",
|
204 |
+
"Epoch 1/20\n",
|
205 |
+
"110599/110599 [==============================] - 41s 366us/step - loss: 1.7037 - acc: 0.5477 - val_loss: 1.1374 - val_acc: 0.8593\n",
|
206 |
+
"\n",
|
207 |
+
"Epoch 00001: val_acc improved from -inf to 0.85931, saving model to example/weights.hdf5\n",
|
208 |
+
"Epoch 2/20\n",
|
209 |
+
"110599/110599 [==============================] - 44s 399us/step - loss: 0.9412 - acc: 0.9016 - val_loss: 0.9102 - val_acc: 0.9249\n",
|
210 |
+
"\n",
|
211 |
+
"Epoch 00002: val_acc improved from 0.85931 to 0.92489, saving model to example/weights.hdf5\n",
|
212 |
+
"Epoch 3/20\n",
|
213 |
+
"110599/110599 [==============================] - 44s 397us/step - loss: 0.8229 - acc: 0.9410 - val_loss: 0.8023 - val_acc: 0.9571\n",
|
214 |
+
"\n",
|
215 |
+
"Epoch 00003: val_acc improved from 0.92489 to 0.95712, saving model to example/weights.hdf5\n",
|
216 |
+
"Epoch 4/20\n",
|
217 |
+
"110599/110599 [==============================] - 44s 395us/step - loss: 0.7664 - acc: 0.9590 - val_loss: 0.7577 - val_acc: 0.9675\n",
|
218 |
+
"\n",
|
219 |
+
"Epoch 00004: val_acc improved from 0.95712 to 0.96745, saving model to example/weights.hdf5\n",
|
220 |
+
"Epoch 5/20\n",
|
221 |
+
"110599/110599 [==============================] - 44s 397us/step - loss: 0.7245 - acc: 0.9712 - val_loss: 0.7225 - val_acc: 0.9788\n",
|
222 |
+
"\n",
|
223 |
+
"Epoch 00005: val_acc improved from 0.96745 to 0.97876, saving model to example/weights.hdf5\n",
|
224 |
+
"Epoch 6/20\n",
|
225 |
+
"110599/110599 [==============================] - 44s 400us/step - loss: 0.6950 - acc: 0.9795 - val_loss: 0.6946 - val_acc: 0.9841\n",
|
226 |
+
"\n",
|
227 |
+
"Epoch 00006: val_acc improved from 0.97876 to 0.98413, saving model to example/weights.hdf5\n",
|
228 |
+
"Epoch 7/20\n",
|
229 |
+
"110599/110599 [==============================] - 45s 404us/step - loss: 0.6772 - acc: 0.9846 - val_loss: 0.6740 - val_acc: 0.9900\n",
|
230 |
+
"\n",
|
231 |
+
"Epoch 00007: val_acc improved from 0.98413 to 0.98999, saving model to example/weights.hdf5\n",
|
232 |
+
"Epoch 8/20\n",
|
233 |
+
"110599/110599 [==============================] - 45s 404us/step - loss: 0.6574 - acc: 0.9896 - val_loss: 0.6548 - val_acc: 0.9941\n",
|
234 |
+
"\n",
|
235 |
+
"Epoch 00008: val_acc improved from 0.98999 to 0.99406, saving model to example/weights.hdf5\n",
|
236 |
+
"Epoch 9/20\n",
|
237 |
+
"110599/110599 [==============================] - 45s 409us/step - loss: 0.6461 - acc: 0.9918 - val_loss: 0.6478 - val_acc: 0.9924\n",
|
238 |
+
"\n",
|
239 |
+
"Epoch 00009: val_acc did not improve from 0.99406\n",
|
240 |
+
"Epoch 10/20\n",
|
241 |
+
"110599/110599 [==============================] - 46s 415us/step - loss: 0.6434 - acc: 0.9918 - val_loss: 0.6347 - val_acc: 0.9934\n",
|
242 |
+
"\n",
|
243 |
+
"Epoch 00010: val_acc did not improve from 0.99406\n",
|
244 |
+
"Epoch 11/20\n",
|
245 |
+
"110599/110599 [==============================] - 43s 389us/step - loss: 0.6213 - acc: 0.9961 - val_loss: 0.6205 - val_acc: 0.9970\n",
|
246 |
+
"\n",
|
247 |
+
"Epoch 00011: val_acc improved from 0.99406 to 0.99699, saving model to example/weights.hdf5\n",
|
248 |
+
"Epoch 12/20\n",
|
249 |
+
"110599/110599 [==============================] - 43s 391us/step - loss: 0.6118 - acc: 0.9969 - val_loss: 0.6206 - val_acc: 0.9928\n",
|
250 |
+
"\n",
|
251 |
+
"Epoch 00012: val_acc did not improve from 0.99699\n",
|
252 |
+
"Epoch 13/20\n",
|
253 |
+
"110599/110599 [==============================] - 44s 394us/step - loss: 0.6026 - acc: 0.9976 - val_loss: 0.6020 - val_acc: 0.9972\n",
|
254 |
+
"\n",
|
255 |
+
"Epoch 00013: val_acc improved from 0.99699 to 0.99723, saving model to example/weights.hdf5\n",
|
256 |
+
"Epoch 14/20\n",
|
257 |
+
"110599/110599 [==============================] - 42s 384us/step - loss: 0.5938 - acc: 0.9982 - val_loss: 0.5922 - val_acc: 0.9980\n",
|
258 |
+
"\n",
|
259 |
+
"Epoch 00014: val_acc improved from 0.99723 to 0.99805, saving model to example/weights.hdf5\n",
|
260 |
+
"Epoch 15/20\n",
|
261 |
+
"110599/110599 [==============================] - 45s 407us/step - loss: 0.5853 - acc: 0.9986 - val_loss: 0.5822 - val_acc: 0.9988\n",
|
262 |
+
"\n",
|
263 |
+
"Epoch 00015: val_acc improved from 0.99805 to 0.99878, saving model to example/weights.hdf5\n",
|
264 |
+
"Epoch 16/20\n",
|
265 |
+
"110599/110599 [==============================] - 45s 410us/step - loss: 0.5771 - acc: 0.9989 - val_loss: 0.5746 - val_acc: 0.9989\n",
|
266 |
+
"\n",
|
267 |
+
"Epoch 00016: val_acc improved from 0.99878 to 0.99886, saving model to example/weights.hdf5\n",
|
268 |
+
"Epoch 17/20\n",
|
269 |
+
"110599/110599 [==============================] - 44s 399us/step - loss: 0.5692 - acc: 0.9990 - val_loss: 0.5686 - val_acc: 0.9985\n",
|
270 |
+
"\n",
|
271 |
+
"Epoch 00017: val_acc did not improve from 0.99886\n",
|
272 |
+
"Epoch 18/20\n",
|
273 |
+
"110599/110599 [==============================] - 44s 402us/step - loss: 0.5615 - acc: 0.9993 - val_loss: 0.5595 - val_acc: 0.9987\n",
|
274 |
+
"\n",
|
275 |
+
"Epoch 00018: val_acc did not improve from 0.99886\n",
|
276 |
+
"Epoch 19/20\n",
|
277 |
+
"110599/110599 [==============================] - 46s 411us/step - loss: 0.5538 - acc: 0.9993 - val_loss: 0.5559 - val_acc: 0.9981\n",
|
278 |
+
"\n",
|
279 |
+
"Epoch 00019: val_acc did not improve from 0.99886\n",
|
280 |
+
"Epoch 20/20\n",
|
281 |
+
"110599/110599 [==============================] - 43s 389us/step - loss: 0.5464 - acc: 0.9995 - val_loss: 0.5446 - val_acc: 0.9991\n",
|
282 |
+
"\n",
|
283 |
+
"Epoch 00020: val_acc improved from 0.99886 to 0.99910, saving model to example/weights.hdf5\n"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"data": {
|
288 |
+
"text/plain": [
|
289 |
+
"<tensorflow.python.keras.callbacks.History at 0x7f3f1c6d8ba8>"
|
290 |
+
]
|
291 |
+
},
|
292 |
+
"execution_count": 6,
|
293 |
+
"metadata": {},
|
294 |
+
"output_type": "execute_result"
|
295 |
+
}
|
296 |
+
],
|
297 |
+
"source": [
|
298 |
+
"import numpy as np\n",
|
299 |
+
"import h5py as h5\n",
|
300 |
+
"from tqdm import tqdm, trange\n",
|
301 |
+
"import os\n",
|
302 |
+
"import sys\n",
|
303 |
+
"\n",
|
304 |
+
"import tensorflow as tf\n",
|
305 |
+
"from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
|
306 |
+
"from tensorflow.keras.optimizers import SGD, Adam\n",
|
307 |
+
"\n",
|
308 |
+
"from sklearn.decomposition import PCA\n",
|
309 |
+
"from joblib import dump, load\n",
|
310 |
+
"from sklearn.utils.class_weight import compute_class_weight\n",
|
311 |
+
"from sklearn.model_selection import train_test_split\n",
|
312 |
+
"\n",
|
313 |
+
"from canopy.model import PatchClassifier\n",
|
314 |
+
"from experiment.paths import *\n",
|
315 |
+
"\n",
|
316 |
+
"from tensorflow.keras import backend as K\n",
|
317 |
+
"import tensorflow as tf\n",
|
318 |
+
"config = tf.ConfigProto()\n",
|
319 |
+
"config.gpu_options.allow_growth = True\n",
|
320 |
+
"sess = tf.Session(config=config)\n",
|
321 |
+
"K.set_session(sess)\n",
|
322 |
+
"\n",
|
323 |
+
"np.random.seed(0)\n",
|
324 |
+
"tf.set_random_seed(0)\n",
|
325 |
+
"\n",
|
326 |
+
"out = 'example'\n",
|
327 |
+
"lr = 0.0001\n",
|
328 |
+
"epochs = 20\n",
|
329 |
+
"\n",
|
330 |
+
"x_all = train_image_patches\n",
|
331 |
+
"y_all = train_patch_labels\n",
|
332 |
+
"\n",
|
333 |
+
"class_weights = compute_class_weight('balanced',range(8),y_all)\n",
|
334 |
+
"print('class weights: ',class_weights)\n",
|
335 |
+
"class_weight_dict = {}\n",
|
336 |
+
"for i in range(8):\n",
|
337 |
+
" class_weight_dict[i] = class_weights[i]\n",
|
338 |
+
"\n",
|
339 |
+
"def estimate_pca():\n",
|
340 |
+
" x_samples = x_all[:,7,7]\n",
|
341 |
+
" pca = PCA(32,whiten=True)\n",
|
342 |
+
" pca.fit(x_samples)\n",
|
343 |
+
" return pca\n",
|
344 |
+
"\n",
|
345 |
+
"\"\"\"Normalize training data\"\"\"\n",
|
346 |
+
"pca = estimate_pca()\n",
|
347 |
+
"dump(pca,out + '/pca.joblib')\n",
|
348 |
+
"\n",
|
349 |
+
"x_shape = x_all.shape[1:]\n",
|
350 |
+
"x_dtype = x_all.dtype\n",
|
351 |
+
"y_shape = y_all.shape[1:]\n",
|
352 |
+
"y_dtype = y_all.dtype\n",
|
353 |
+
"x_shape = x_shape[:-1] + (pca.n_components_,)\n",
|
354 |
+
"\n",
|
355 |
+
"print(x_shape, x_dtype)\n",
|
356 |
+
"print(y_shape, y_dtype)\n",
|
357 |
+
"\n",
|
358 |
+
"classifier = PatchClassifier(num_classes=8)\n",
|
359 |
+
"model = classifier.get_patch_model(x_shape)\n",
|
360 |
+
"\n",
|
361 |
+
"print(model.summary())\n",
|
362 |
+
"\n",
|
363 |
+
"model.compile(optimizer=SGD(lr,momentum=0.9), loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
|
364 |
+
"\n",
|
365 |
+
"def apply_pca(x):\n",
|
366 |
+
" N,H,W,C = x.shape\n",
|
367 |
+
" x = np.reshape(x,(-1,C))\n",
|
368 |
+
" x = pca.transform(x)\n",
|
369 |
+
" x = np.reshape(x,(-1,H,W,x.shape[-1]))\n",
|
370 |
+
" return x\n",
|
371 |
+
"\n",
|
372 |
+
"checkpoint = ModelCheckpoint(filepath=out + '/' + weights_uri, monitor='val_acc', verbose=True, save_best_only=True, save_weights_only=True)\n",
|
373 |
+
"reducelr = ReduceLROnPlateau(monitor='val_acc', factor=0.5, patience=10, verbose=1, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0)\n",
|
374 |
+
"\n",
|
375 |
+
"x_all = apply_pca(x_all)\n",
|
376 |
+
"\n",
|
377 |
+
"def augment_images(x,y):\n",
|
378 |
+
" x_aug = []\n",
|
379 |
+
" y_aug = []\n",
|
380 |
+
" with tqdm(total=len(x)*8,desc='augmenting images') as pbar:\n",
|
381 |
+
" for rot in range(4):\n",
|
382 |
+
" for flip in range(2):\n",
|
383 |
+
" for patch,label in zip(x,y):\n",
|
384 |
+
" patch = np.rot90(patch,rot)\n",
|
385 |
+
" if flip:\n",
|
386 |
+
" patch = np.flip(patch,axis=0)\n",
|
387 |
+
" patch = np.flip(patch,axis=1)\n",
|
388 |
+
" x_aug.append(patch)\n",
|
389 |
+
" y_aug.append(label)\n",
|
390 |
+
" pbar.update(1)\n",
|
391 |
+
" return np.stack(x_aug,axis=0), np.stack(y_aug,axis=0)\n",
|
392 |
+
"\n",
|
393 |
+
"x_all, y_all = augment_images(x_all,y_all)\n",
|
394 |
+
"\n",
|
395 |
+
"train_inds, val_inds = train_test_split(range(len(x_all)),test_size=0.1,random_state=0)\n",
|
396 |
+
"x_train = np.stack([x_all[i] for i in train_inds],axis=0)\n",
|
397 |
+
"y_train = np.stack([y_all[i] for i in train_inds],axis=0)\n",
|
398 |
+
"x_val = np.stack([x_all[i] for i in val_inds],axis=0)\n",
|
399 |
+
"y_val = np.stack([y_all[i] for i in val_inds],axis=0)\n",
|
400 |
+
"\n",
|
401 |
+
"batch_size = 32\n",
|
402 |
+
"\n",
|
403 |
+
"model.fit( x_train, y_train,\n",
|
404 |
+
" epochs=epochs,\n",
|
405 |
+
" batch_size=batch_size,\n",
|
406 |
+
" validation_data=(x_val,y_val),\n",
|
407 |
+
" verbose=1,\n",
|
408 |
+
" callbacks=[checkpoint,reducelr],\n",
|
409 |
+
" class_weight=class_weight_dict)\n",
|
410 |
+
"\n",
|
411 |
+
"\n"
|
412 |
+
]
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"cell_type": "markdown",
|
416 |
+
"metadata": {},
|
417 |
+
"source": [
|
418 |
+
"Now we run the trained model on the full image in tiles."
|
419 |
+
]
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "code",
|
423 |
+
"execution_count": 7,
|
424 |
+
"metadata": {},
|
425 |
+
"outputs": [
|
426 |
+
{
|
427 |
+
"name": "stderr",
|
428 |
+
"output_type": "stream",
|
429 |
+
"text": [
|
430 |
+
"\r\n",
|
431 |
+
" 0%| | 0/774 [00:00<?, ?it/s]"
|
432 |
+
]
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"name": "stdout",
|
436 |
+
"output_type": "stream",
|
437 |
+
"text": [
|
438 |
+
"Metadata for image\n",
|
439 |
+
"nodata:\n",
|
440 |
+
"None\n",
|
441 |
+
"\n",
|
442 |
+
"transform:\n",
|
443 |
+
"| 1.00, 0.00, 319344.00|\n",
|
444 |
+
"| 0.00,-1.00, 4101691.00|\n",
|
445 |
+
"| 0.00, 0.00, 1.00|\n",
|
446 |
+
"\n",
|
447 |
+
"width:\n",
|
448 |
+
"1028\n",
|
449 |
+
"\n",
|
450 |
+
"count:\n",
|
451 |
+
"426\n",
|
452 |
+
"\n",
|
453 |
+
"height:\n",
|
454 |
+
"10948\n",
|
455 |
+
"\n",
|
456 |
+
"dtype:\n",
|
457 |
+
"int16\n",
|
458 |
+
"\n",
|
459 |
+
"crs:\n",
|
460 |
+
"+init=epsg:32611\n",
|
461 |
+
"\n",
|
462 |
+
"driver:\n",
|
463 |
+
"GTiff\n",
|
464 |
+
"\n"
|
465 |
+
]
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"name": "stderr",
|
469 |
+
"output_type": "stream",
|
470 |
+
"text": [
|
471 |
+
" 89%|████████▉ | 688/774 [02:58<00:20, 4.23it/s]\n"
|
472 |
+
]
|
473 |
+
}
|
474 |
+
],
|
475 |
+
"source": [
|
476 |
+
"import numpy as np\n",
|
477 |
+
"import cv2\n",
|
478 |
+
"from math import floor, ceil\n",
|
479 |
+
"import tqdm\n",
|
480 |
+
"from joblib import dump, load\n",
|
481 |
+
"\n",
|
482 |
+
"import rasterio\n",
|
483 |
+
"from rasterio.windows import Window\n",
|
484 |
+
"from rasterio.enums import Resampling\n",
|
485 |
+
"from rasterio.vrt import WarpedVRT\n",
|
486 |
+
"\n",
|
487 |
+
"from canopy.model import PatchClassifier\n",
|
488 |
+
"from experiment.paths import *\n",
|
489 |
+
"\n",
|
490 |
+
"from tensorflow.keras import backend as K\n",
|
491 |
+
"import tensorflow as tf\n",
|
492 |
+
"config = tf.ConfigProto()\n",
|
493 |
+
"config.gpu_options.allow_growth = True\n",
|
494 |
+
"sess = tf.Session(config=config)\n",
|
495 |
+
"K.set_session(sess)\n",
|
496 |
+
"\n",
|
497 |
+
"pca = load(out + '/pca.joblib')\n",
|
498 |
+
"\n",
|
499 |
+
"# \"no data value\" for labels\n",
|
500 |
+
"label_ndv = 255\n",
|
501 |
+
"\n",
|
502 |
+
"# radius of square patch (side of patch = 2*radius+1)\n",
|
503 |
+
"patch_radius = 7\n",
|
504 |
+
"\n",
|
505 |
+
"# height threshold for CHM -- pixels at or below this height will be discarded\n",
|
506 |
+
"height_threshold = 5\n",
|
507 |
+
"\n",
|
508 |
+
"# tile size for processing\n",
|
509 |
+
"tile_size = 128\n",
|
510 |
+
"\n",
|
511 |
+
"# tile size with padding\n",
|
512 |
+
"padded_tile_size = tile_size + 2*patch_radius\n",
|
513 |
+
"\n",
|
514 |
+
"# open the hyperspectral or RGB image\n",
|
515 |
+
"image = rasterio.open(image_uri)\n",
|
516 |
+
"image_meta = image.meta.copy()\n",
|
517 |
+
"image_ndv = image.meta['nodata']\n",
|
518 |
+
"image_width = image.meta['width']\n",
|
519 |
+
"image_height = image.meta['height']\n",
|
520 |
+
"image_channels = image.meta['count']\n",
|
521 |
+
"\n",
|
522 |
+
"# load model\n",
|
523 |
+
"input_shape = (padded_tile_size,padded_tile_size,pca.n_components_)\n",
|
524 |
+
"tree_classifier = PatchClassifier(num_classes=8)\n",
|
525 |
+
"training_model = tree_classifier.get_patch_model(input_shape)\n",
|
526 |
+
"training_model.load_weights(out + '/' + weights_uri)\n",
|
527 |
+
"model = tree_classifier.get_convolutional_model(input_shape)\n",
|
528 |
+
"\n",
|
529 |
+
"# calculate number of tiles\n",
|
530 |
+
"num_tiles_y = ceil(image_height / float(tile_size))\n",
|
531 |
+
"num_tiles_x = ceil(image_width / float(tile_size))\n",
|
532 |
+
"\n",
|
533 |
+
"print('Metadata for image')\n",
|
534 |
+
"for key in image_meta.keys():\n",
|
535 |
+
" print('%s:'%key)\n",
|
536 |
+
" print(image_meta[key])\n",
|
537 |
+
" print()\n",
|
538 |
+
"\n",
|
539 |
+
"# create predicted label raster\n",
|
540 |
+
"predict_meta = image_meta.copy()\n",
|
541 |
+
"predict_meta['dtype'] = 'uint8'\n",
|
542 |
+
"predict_meta['nodata'] = label_ndv\n",
|
543 |
+
"predict_meta['count'] = 1\n",
|
544 |
+
"predict = rasterio.open(out + '/' + predict_uri, 'w', compress='lzw', **predict_meta)\n",
|
545 |
+
"\n",
|
546 |
+
"# open the CHM\n",
|
547 |
+
"chm = rasterio.open(chm_uri)\n",
|
548 |
+
"chm_vrt = WarpedVRT(chm, crs=image.meta['crs'], transform=image.meta['transform'], width=image.meta['width'], height=image.meta['height'],\n",
|
549 |
+
" resampling=Resampling.bilinear)\n",
|
550 |
+
"\n",
|
551 |
+
"# dilation kernel\n",
|
552 |
+
"kernel = np.ones((patch_radius*2+1,patch_radius*2+1),dtype=np.uint8)\n",
|
553 |
+
"\n",
|
554 |
+
"def apply_pca(x):\n",
|
555 |
+
" N,H,W,C = x.shape\n",
|
556 |
+
" x = np.reshape(x,(-1,C))\n",
|
557 |
+
" x = pca.transform(x)\n",
|
558 |
+
" x = np.reshape(x,(-1,H,W,x.shape[-1]))\n",
|
559 |
+
" return x\n",
|
560 |
+
"\n",
|
561 |
+
"# go through all tiles of input image\n",
|
562 |
+
"# run convolutional model on tile\n",
|
563 |
+
"# write labels to output label raster\n",
|
564 |
+
"with tqdm.tqdm(total=num_tiles_y*num_tiles_x) as pbar:\n",
|
565 |
+
" for y in range(patch_radius,image_height-patch_radius,tile_size):\n",
|
566 |
+
" for x in range(patch_radius,image_width-patch_radius,tile_size):\n",
|
567 |
+
" pbar.update(1)\n",
|
568 |
+
"\n",
|
569 |
+
" window = Window(x-patch_radius,y-patch_radius,padded_tile_size,padded_tile_size)\n",
|
570 |
+
"\n",
|
571 |
+
" # get tile from chm\n",
|
572 |
+
" chm_tile = chm_vrt.read(1,window=window)\n",
|
573 |
+
" if chm_tile.shape[0] != padded_tile_size or chm_tile.shape[1] != padded_tile_size:\n",
|
574 |
+
" pad = ((0,padded_tile_size-chm_tile.shape[0]),(0,padded_tile_size-chm_tile.shape[1]))\n",
|
575 |
+
" chm_tile = np.pad(chm_tile,pad,mode='constant',constant_values=0)\n",
|
576 |
+
" \n",
|
577 |
+
" chm_tile = np.expand_dims(chm_tile,axis=0)\n",
|
578 |
+
" chm_bad = chm_tile <= height_threshold\n",
|
579 |
+
"\n",
|
580 |
+
" # get tile from image\n",
|
581 |
+
" image_tile = image.read(window=window)\n",
|
582 |
+
" image_pad_y = padded_tile_size-image_tile.shape[1]\n",
|
583 |
+
" image_pad_x = padded_tile_size-image_tile.shape[2]\n",
|
584 |
+
" output_window = Window(x,y,tile_size-image_pad_x,tile_size-image_pad_y)\n",
|
585 |
+
" if image_tile.shape[1] != padded_tile_size or image_tile.shape[2] != padded_tile_size:\n",
|
586 |
+
" pad = ((0,0),(0,image_pad_y),(0,image_pad_x))\n",
|
587 |
+
" image_tile = np.pad(image_tile,pad,mode='constant',constant_values=-1)\n",
|
588 |
+
"\n",
|
589 |
+
" # re-order image tile to have height,width,channels\n",
|
590 |
+
" image_tile = np.transpose(image_tile,axes=[1,2,0])\n",
|
591 |
+
"\n",
|
592 |
+
" # add batch axis\n",
|
593 |
+
" image_tile = np.expand_dims(image_tile,axis=0)\n",
|
594 |
+
" image_bad = np.any(image_tile < 0,axis=-1)\n",
|
595 |
+
"\n",
|
596 |
+
" image_tile = image_tile.astype('float32')\n",
|
597 |
+
" image_tile = apply_pca(image_tile)\n",
|
598 |
+
" \n",
|
599 |
+
" # run tile through network\n",
|
600 |
+
" predict_tile = np.argmax(model.predict(image_tile),axis=-1).astype('uint8')\n",
|
601 |
+
"\n",
|
602 |
+
" # dilate mask\n",
|
603 |
+
" image_bad = cv2.dilate(image_bad.astype('uint8'),kernel).astype('bool')\n",
|
604 |
+
"\n",
|
605 |
+
" # set bad pixels to NDV\n",
|
606 |
+
" predict_tile[chm_bad[:,patch_radius:-patch_radius,patch_radius:-patch_radius]] = label_ndv\n",
|
607 |
+
" predict_tile[image_bad[:,patch_radius:-patch_radius,patch_radius:-patch_radius]] = label_ndv\n",
|
608 |
+
"\n",
|
609 |
+
" # undo padding\n",
|
610 |
+
" if image_pad_y > 0:\n",
|
611 |
+
" predict_tile = predict_tile[:,:-image_pad_y,:]\n",
|
612 |
+
" if image_pad_x > 0:\n",
|
613 |
+
" predict_tile = predict_tile[:,:,:-image_pad_x]\n",
|
614 |
+
"\n",
|
615 |
+
" # write to file\n",
|
616 |
+
" predict.write(predict_tile,window=output_window)\n",
|
617 |
+
"\n",
|
618 |
+
"image.close()\n",
|
619 |
+
"chm.close()\n",
|
620 |
+
"predict.close()"
|
621 |
+
]
|
622 |
+
},
|
623 |
+
{
|
624 |
+
"cell_type": "markdown",
|
625 |
+
"metadata": {},
|
626 |
+
"source": [
|
627 |
+
"Finally we run an analysis of the classification performance on the test set."
|
628 |
+
]
|
629 |
+
},
|
630 |
+
{
|
631 |
+
"cell_type": "code",
|
632 |
+
"execution_count": 8,
|
633 |
+
"metadata": {},
|
634 |
+
"outputs": [
|
635 |
+
{
|
636 |
+
"name": "stdout",
|
637 |
+
"output_type": "stream",
|
638 |
+
"text": [
|
639 |
+
"classification report:\n",
|
640 |
+
" precision recall f1-score support\n",
|
641 |
+
"\n",
|
642 |
+
" 0 0.62 0.89 0.73 9\n",
|
643 |
+
" 1 0.00 0.00 0.00 1\n",
|
644 |
+
" 2 0.82 1.00 0.90 9\n",
|
645 |
+
" 3 1.00 0.88 0.93 16\n",
|
646 |
+
" 4 0.88 1.00 0.93 7\n",
|
647 |
+
" 5 0.00 0.00 0.00 2\n",
|
648 |
+
" 6 0.56 0.71 0.63 7\n",
|
649 |
+
" 7 1.00 0.67 0.80 21\n",
|
650 |
+
"\n",
|
651 |
+
"avg / total 0.83 0.79 0.80 72\n",
|
652 |
+
"\n",
|
653 |
+
"confusion matrix:\n",
|
654 |
+
"[[ 8 1 0 0 0 0 0 0]\n",
|
655 |
+
" [ 1 0 0 0 0 0 0 0]\n",
|
656 |
+
" [ 0 0 9 0 0 0 0 0]\n",
|
657 |
+
" [ 1 0 0 14 1 0 0 0]\n",
|
658 |
+
" [ 0 0 0 0 7 0 0 0]\n",
|
659 |
+
" [ 0 0 0 0 0 0 2 0]\n",
|
660 |
+
" [ 0 0 0 0 0 2 5 0]\n",
|
661 |
+
" [ 3 0 2 0 0 0 2 14]]\n"
|
662 |
+
]
|
663 |
+
}
|
664 |
+
],
|
665 |
+
"source": [
|
666 |
+
"import numpy as np\n",
|
667 |
+
"\n",
|
668 |
+
"import rasterio\n",
|
669 |
+
"from rasterio.windows import Window\n",
|
670 |
+
"from rasterio.enums import Resampling\n",
|
671 |
+
"from rasterio.vrt import WarpedVRT\n",
|
672 |
+
"from rasterio.mask import mask\n",
|
673 |
+
"\n",
|
674 |
+
"from shapely.geometry import Polygon\n",
|
675 |
+
"from shapely.geometry import Point\n",
|
676 |
+
"from shapely.geometry import mapping\n",
|
677 |
+
"\n",
|
678 |
+
"import tqdm\n",
|
679 |
+
"\n",
|
680 |
+
"from math import floor, ceil\n",
|
681 |
+
"\n",
|
682 |
+
"from experiment.paths import *\n",
|
683 |
+
"\n",
|
684 |
+
"from canopy.vector_utils import *\n",
|
685 |
+
"from canopy.extract import *\n",
|
686 |
+
"\n",
|
687 |
+
"import sklearn.metrics\n",
|
688 |
+
"from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score\n",
|
689 |
+
"\n",
|
690 |
+
"train_inds = np.loadtxt(out + '/' + train_ids_uri,dtype='int32')\n",
|
691 |
+
"test_inds = np.loadtxt(out + '/' + test_ids_uri,dtype='int32')\n",
|
692 |
+
"\n",
|
693 |
+
"# Load the metadata from the image.\n",
|
694 |
+
"with rasterio.open(image_uri) as src:\n",
|
695 |
+
" image_meta = src.meta.copy()\n",
|
696 |
+
"\n",
|
697 |
+
"# Load the shapefile and transform it to the hypersectral image's CRS.\n",
|
698 |
+
"polygons, labels = load_and_transform_shapefile(labels_shp_uri,'SP',image_meta['crs'])\n",
|
699 |
+
"\n",
|
700 |
+
"train_labels = [labels[ind] for ind in train_inds]\n",
|
701 |
+
"test_labels = [labels[ind] for ind in test_inds]\n",
|
702 |
+
"\n",
|
703 |
+
"# open predicted label raster\n",
|
704 |
+
"predict = rasterio.open(out + '/' + predict_uri)\n",
|
705 |
+
"predict_raster = predict.read(1)\n",
|
706 |
+
"ndv = predict.meta['nodata']\n",
|
707 |
+
"\n",
|
708 |
+
"def get_predictions(inds):\n",
|
709 |
+
" preds = []\n",
|
710 |
+
" for ind in inds:\n",
|
711 |
+
" poly = [mapping(Polygon(polygons[ind]['coordinates'][0]))]\n",
|
712 |
+
" out_image, out_transform = mask(predict, poly, crop=False)\n",
|
713 |
+
" out_image = out_image[0]\n",
|
714 |
+
" \n",
|
715 |
+
" label = labels[ind]\n",
|
716 |
+
"\n",
|
717 |
+
" rows, cols = np.where(out_image != ndv)\n",
|
718 |
+
" predict_labels = []\n",
|
719 |
+
" for row, col in zip(rows,cols):\n",
|
720 |
+
" predict_labels.append(predict_raster[row,col])\n",
|
721 |
+
" predict_labels = np.array(predict_labels)\n",
|
722 |
+
" \n",
|
723 |
+
" hist = [np.count_nonzero(predict_labels==i) for i in range(8)]\n",
|
724 |
+
" majority_label = np.argmax(hist)\n",
|
725 |
+
" preds.append(majority_label)\n",
|
726 |
+
" return preds\n",
|
727 |
+
"\n",
|
728 |
+
"def calculate_confusion_matrix(labels,preds):\n",
|
729 |
+
" mat = np.zeros((8,8),dtype='int32')\n",
|
730 |
+
" for label,pred in zip(labels,preds):\n",
|
731 |
+
" mat[label,pred] += 1\n",
|
732 |
+
" return mat\n",
|
733 |
+
"\n",
|
734 |
+
"def calculate_fscore(labels,preds):\n",
|
735 |
+
" return sklearn.metrics.f1_score(labels,preds,average='micro')\n",
|
736 |
+
"\n",
|
737 |
+
"test_preds = get_predictions(test_inds)\n",
|
738 |
+
" \n",
|
739 |
+
"report = classification_report(test_labels, test_preds)\n",
|
740 |
+
"mat = confusion_matrix(test_labels,test_preds)\n",
|
741 |
+
"print('classification report:')\n",
|
742 |
+
"print(report)\n",
|
743 |
+
"print('confusion matrix:')\n",
|
744 |
+
"print(mat)"
|
745 |
+
]
|
746 |
+
},
|
747 |
+
{
|
748 |
+
"cell_type": "code",
|
749 |
+
"execution_count": null,
|
750 |
+
"metadata": {},
|
751 |
+
"outputs": [],
|
752 |
+
"source": []
|
753 |
+
}
|
754 |
+
],
|
755 |
+
"metadata": {
|
756 |
+
"kernelspec": {
|
757 |
+
"display_name": "Python 3",
|
758 |
+
"language": "python",
|
759 |
+
"name": "python3"
|
760 |
+
},
|
761 |
+
"language_info": {
|
762 |
+
"codemirror_mode": {
|
763 |
+
"name": "ipython",
|
764 |
+
"version": 3
|
765 |
+
},
|
766 |
+
"file_extension": ".py",
|
767 |
+
"mimetype": "text/x-python",
|
768 |
+
"name": "python",
|
769 |
+
"nbconvert_exporter": "python",
|
770 |
+
"pygments_lexer": "ipython3",
|
771 |
+
"version": "3.7.16"
|
772 |
+
}
|
773 |
+
},
|
774 |
+
"nbformat": 4,
|
775 |
+
"nbformat_minor": 2
|
776 |
+
}
|
gitignore.txt
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/_build/
|
68 |
+
|
69 |
+
# PyBuilder
|
70 |
+
target/
|
71 |
+
|
72 |
+
# Jupyter Notebook
|
73 |
+
.ipynb_checkpoints
|
74 |
+
|
75 |
+
# pyenv
|
76 |
+
.python-version
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# SageMath parsed files
|
82 |
+
*.sage.py
|
83 |
+
|
84 |
+
# Environments
|
85 |
+
.env
|
86 |
+
.venv
|
87 |
+
env/
|
88 |
+
venv/
|
89 |
+
ENV/
|
90 |
+
env.bak/
|
91 |
+
venv.bak/
|
92 |
+
|
93 |
+
# Spyder project settings
|
94 |
+
.spyderproject
|
95 |
+
.spyproject
|
96 |
+
|
97 |
+
# Rope project settings
|
98 |
+
.ropeproject
|
99 |
+
|
100 |
+
# mkdocs documentation
|
101 |
+
/site
|
102 |
+
|
103 |
+
# mypy
|
104 |
+
.mypy_cache/
|
105 |
+
|
106 |
+
.DS_Store
|