update inf script to use correct import, isntructions for dependencies and data
Browse files- README.md +31 -0
- burn_scar_batch_inference_script.py +1 -1
- custom.py +191 -0
README.md
CHANGED
@@ -33,6 +33,37 @@ Code for Finetuning is available through [github](https://github.com/NASA-IMPACT
|
|
33 |
Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/firescars_config.py
|
34 |
)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
### Results
|
38 |
|
|
|
33 |
Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/firescars_config.py
|
34 |
)
|
35 |
|
36 |
+
To run inference, first install dependencies
|
37 |
+
|
38 |
+
```
|
39 |
+
mamba create -n prithvi-burn-scar python=3.10 torchvision numpy matplotlib rasterio torchmetrics openmim
|
40 |
+
mamba activate prithvi-burn-scar
|
41 |
+
mim install mmcv-full==1.5
|
42 |
+
```
|
43 |
+
|
44 |
+
#### Instructions for downloading from [HuggingFace datasets](https://huggingface.co/datasets)
|
45 |
+
|
46 |
+
1. Create account on https://huggingface.co/join
|
47 |
+
2. Install `git` following https://git-scm.com/downloads
|
48 |
+
3. Install git-lfs with `sudo apt install git-lfs` and `git lfs install`
|
49 |
+
4. Run the following command to download the HLS datasets. You may need to
|
50 |
+
enter your HuggingFace username/password to do the `git clone`.
|
51 |
+
|
52 |
+
```
|
53 |
+
mkdir data
|
54 |
+
cd data/
|
55 |
+
git clone https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars burn_scars
|
56 |
+
tar -xzvf burn_scars/hls_burn_scars.tar.gz -C data/
|
57 |
+
ls -lh data/
|
58 |
+
```
|
59 |
+
|
60 |
+
|
61 |
+
With the datasets and the environment, you can now run the inference script.
|
62 |
+
|
63 |
+
```
|
64 |
+
|
65 |
+
|
66 |
+
```
|
67 |
|
68 |
### Results
|
69 |
|
burn_scar_batch_inference_script.py
CHANGED
@@ -21,7 +21,7 @@ from torchvision import transforms
|
|
21 |
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
22 |
|
23 |
from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor
|
24 |
-
from
|
25 |
import pdb
|
26 |
|
27 |
import numpy as np
|
|
|
21 |
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
22 |
|
23 |
from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor
|
24 |
+
from . import custom # custom preprocessing for hls
|
25 |
import pdb
|
26 |
|
27 |
import numpy as np
|
custom.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils.py
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import glob
|
5 |
+
import rasterio
|
6 |
+
from torchvision import transforms
|
7 |
+
import torch
|
8 |
+
import re
|
9 |
+
from torchmetrics import Dice
|
10 |
+
import os
|
11 |
+
|
12 |
+
def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]):
|
13 |
+
"""
|
14 |
+
Calculate the mean and standard deviation of each band in a folder of GeoTIFF files.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training.
|
18 |
+
image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats.
|
19 |
+
bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5].
|
20 |
+
|
21 |
+
Raises:
|
22 |
+
Exception: If no images are found in the given directory.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
tuple: Two lists containing the means and standard deviations of each band.
|
26 |
+
"""
|
27 |
+
# Initialize lists to store the means and standard deviations
|
28 |
+
all_means = []
|
29 |
+
all_stds = []
|
30 |
+
|
31 |
+
# Use glob to get a list of all .tif images in the directory
|
32 |
+
all_images = glob.glob(f"{image_directory}/{image_pattern}.tif")
|
33 |
+
|
34 |
+
# Make sure there are images to process
|
35 |
+
if not all_images:
|
36 |
+
raise Exception("No images found")
|
37 |
+
|
38 |
+
# Get the number of bands
|
39 |
+
num_bands = len(bands)
|
40 |
+
|
41 |
+
# Initialize arrays to hold sums and sum of squares for each band
|
42 |
+
band_sums = np.zeros(num_bands)
|
43 |
+
band_sq_sums = np.zeros(num_bands)
|
44 |
+
pixel_counts = np.zeros(num_bands)
|
45 |
+
|
46 |
+
# Iterate over each image
|
47 |
+
for image_file in all_images:
|
48 |
+
with rasterio.open(image_file) as src:
|
49 |
+
# For each band, calculate the sum, square sum, and pixel count
|
50 |
+
for band in bands:
|
51 |
+
data = src.read(band + 1) # rasterio band index starts from 1
|
52 |
+
band_sums[band] += np.nansum(data)
|
53 |
+
band_sq_sums[band] += np.nansum(data**2)
|
54 |
+
pixel_counts[band] += np.count_nonzero(~np.isnan(data))
|
55 |
+
|
56 |
+
# Calculate means and standard deviations for each band
|
57 |
+
for i in bands:
|
58 |
+
mean = band_sums[i] / pixel_counts[i]
|
59 |
+
std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2))
|
60 |
+
all_means.append(mean)
|
61 |
+
all_stds.append(std)
|
62 |
+
|
63 |
+
return all_means, all_stds
|
64 |
+
|
65 |
+
|
66 |
+
def split_and_pad(array, target_shape):
|
67 |
+
"""
|
68 |
+
Splits the input array into smaller arrays of the target shape, padding if necessary.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
array (numpy.ndarray): The input array. Must be shape (batch, band, time, height, width)
|
72 |
+
target_shape (tuple): The target shape of the smaller arrays. Must be of shape
|
73 |
+
(batch, band, time, height, width)
|
74 |
+
|
75 |
+
Raises:
|
76 |
+
ValueError: If target shape is larger than the array shape.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
list[numpy.ndarray]: A list of the smaller arrays.
|
80 |
+
"""
|
81 |
+
# Check if the target shape is smaller or equal to the array shape
|
82 |
+
if target_shape[-2:] > array.shape[-2:]:
|
83 |
+
raise ValueError('Target shape must be smaller or equal to the array shape.')
|
84 |
+
|
85 |
+
# Calculate how much padding is needed
|
86 |
+
pad_h = (target_shape[-2] - array.shape[-2] % target_shape[-2]) % target_shape[-2]
|
87 |
+
pad_w = (target_shape[-1] - array.shape[-1] % target_shape[-1]) % target_shape[-1]
|
88 |
+
|
89 |
+
# Apply padding to the array
|
90 |
+
padded_array = np.pad(array, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)))
|
91 |
+
|
92 |
+
# Split the array into smaller arrays of the target shape
|
93 |
+
result = []
|
94 |
+
for i in range(0, padded_array.shape[-2], target_shape[-2]):
|
95 |
+
for j in range(0, padded_array.shape[-1], target_shape[-1]):
|
96 |
+
result.append(padded_array[..., i:i+target_shape[-2], j:j+target_shape[-1]])
|
97 |
+
|
98 |
+
return result
|
99 |
+
|
100 |
+
def merge_and_unpad(np_array_list, original_shape, target_shape):
|
101 |
+
"""
|
102 |
+
Assembles smaller numpy arrays back into the original larger numpy array, removing padding if necessary.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
np_array_list (list[numpy.ndarray]): The list of smaller numpy arrays derived from split_and_pad.
|
106 |
+
original_shape (tuple): The original shape of the larger numpy array. Must be shape (Height, Width).
|
107 |
+
target_shape (tuple): The target shape of the smaller numpy arrays. Must be shape (Height, Width).
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
numpy.ndarray: The original larger numpy array.
|
111 |
+
"""
|
112 |
+
# Calculate how much padding was added
|
113 |
+
pad_h = (target_shape[0] - original_shape[0] % target_shape[0]) % target_shape[0]
|
114 |
+
pad_w = (target_shape[1] - original_shape[1] % target_shape[1]) % target_shape[1]
|
115 |
+
|
116 |
+
# Calculate the shape of the padded larger array
|
117 |
+
padded_shape = (original_shape[0] + pad_h, original_shape[1] + pad_w)
|
118 |
+
|
119 |
+
# Calculate the number of smaller arrays in each dimension
|
120 |
+
num_arrays_h = padded_shape[0] // target_shape[0]
|
121 |
+
num_arrays_w = padded_shape[1] // target_shape[1]
|
122 |
+
|
123 |
+
# Reshape the list of smaller arrays back into the shape of the padded larger array
|
124 |
+
merged_array = np.stack(np_array_list).reshape(num_arrays_h, num_arrays_w, *target_shape)
|
125 |
+
|
126 |
+
# Rearrange the array dimensions
|
127 |
+
merged_array = merged_array.transpose(0, 2, 1, 3).reshape(*padded_shape)
|
128 |
+
|
129 |
+
# Remove the padding
|
130 |
+
unpadded_array = merged_array[:original_shape[0], :original_shape[1]]
|
131 |
+
|
132 |
+
return unpadded_array
|
133 |
+
|
134 |
+
def compute_metrics(gt_dir, pred_dir):
|
135 |
+
"""
|
136 |
+
Compute the Dice similarity coefficient between the predicted and ground truth images.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
gt_dir (str): Directory where the ground truth images are stored.
|
140 |
+
pred_dir (str): Directory where the predicted images are stored.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Tensor: Dice similarity coefficient score.
|
144 |
+
"""
|
145 |
+
dice_metric = Dice()
|
146 |
+
|
147 |
+
# find all .tif files in the prediction directory
|
148 |
+
pred_files = glob.glob(os.path.join(pred_dir, "*.tif"))
|
149 |
+
|
150 |
+
# iterate over each prediction file
|
151 |
+
for pred_file in pred_files:
|
152 |
+
# extract the unique_id from the file name
|
153 |
+
unique_id = re.search('HLS\..*\.v1\.4', os.path.basename(pred_file))
|
154 |
+
|
155 |
+
if unique_id is not None:
|
156 |
+
unique_id = unique_id.group()
|
157 |
+
|
158 |
+
# create the unique pattern for the gt directory
|
159 |
+
gt_file_pattern = os.path.join(gt_dir, f"*{unique_id}*mask.tif")
|
160 |
+
|
161 |
+
# glob the file pattern
|
162 |
+
gt_files = glob.glob(gt_file_pattern)
|
163 |
+
|
164 |
+
# if we found a matching gt file
|
165 |
+
if len(gt_files) == 1:
|
166 |
+
gt_file = gt_files[0]
|
167 |
+
|
168 |
+
# read the .tif files
|
169 |
+
with rasterio.open(gt_file) as src:
|
170 |
+
gt_img = src.read(1) # ground truth image
|
171 |
+
|
172 |
+
with rasterio.open(pred_file) as src:
|
173 |
+
pred_img = src.read(1) # predicted image
|
174 |
+
|
175 |
+
# make sure the images are binary (values are 0 or 1)
|
176 |
+
gt_img = (gt_img > 0).astype(np.uint8)
|
177 |
+
pred_img = (pred_img > 0).astype(np.uint8)
|
178 |
+
|
179 |
+
# convert numpy arrays to PyTorch tensors
|
180 |
+
gt_img_tensor = torch.from_numpy(gt_img).long().flatten()
|
181 |
+
pred_img_tensor = torch.from_numpy(pred_img).long().flatten()
|
182 |
+
|
183 |
+
# update dice_metric
|
184 |
+
dice_metric.update(pred_img_tensor, gt_img_tensor)
|
185 |
+
|
186 |
+
else:
|
187 |
+
print(f"No matching ground truth file for prediction file {pred_file}.")
|
188 |
+
|
189 |
+
# compute the dice score
|
190 |
+
dice_score = dice_metric.compute()
|
191 |
+
return dice_score
|