Upload burn_scar_batch_inference_script.py

#3
by rbavery - opened
Files changed (4) hide show
  1. README.md +36 -0
  2. burn_scar_batch_inference_script.py +219 -0
  3. custom.py +191 -0
  4. requirements.txt +47 -0
README.md CHANGED
@@ -33,6 +33,42 @@ Code for Finetuning is available through [github](https://github.com/NASA-IMPACT
33
  Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/firescars_config.py
34
  )
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  ### Results
38
 
 
33
  Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/firescars_config.py
34
  )
35
 
36
+ To run inference, first install dependencies
37
+
38
+ ```
39
+ mamba create -n prithvi-burn-scar python=3.10 pycocotools ncurses
40
+ mamba activate prithvi-burn-scar
41
+ pip install --upgrade pip && \
42
+ pip install -r requirements.txt && \
43
+ mim install mmcv-full==1.5.0
44
+ ```
45
+
46
+ #### Instructions for downloading from [HuggingFace datasets](https://huggingface.co/datasets)
47
+
48
+ 1. Create account on https://huggingface.co/join
49
+ 2. Install `git` following https://git-scm.com/downloads
50
+ 3. Install git-lfs with `sudo apt install git-lfs` and `git lfs install`
51
+ 4. Run the following command to download the HLS datasets. You may need to
52
+ enter your HuggingFace username/password to do the `git clone`.
53
+
54
+ ```
55
+ mkdir -p data
56
+ cd data/
57
+ git clone https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars burn_scars
58
+ tar -xzvf burn_scars/hls_burn_scars.tar.gz -C ./
59
+ ```
60
+
61
+
62
+ With the datasets and the environment, you can now run the inference script.
63
+
64
+ ```
65
+ python burn_scar_batch_inference_script.py \
66
+ -config burn_scars_Prithvi_100M.py \
67
+ -ckpt burn_scars_Prithvi_100M.pth \
68
+ -input data/burn_scars/validation \
69
+ -output data/burn_scars/inference_output \
70
+ -input_type tif
71
+ ```
72
 
73
  ### Results
74
 
burn_scar_batch_inference_script.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from mmcv import Config
3
+ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,wrap_fp16_model)
4
+ from mmseg.models import build_segmentor
5
+
6
+ import matplotlib.pyplot as plt
7
+ import mmcv
8
+ import torch
9
+ from mmcv.parallel import collate, scatter
10
+ from mmcv.runner import load_checkpoint
11
+
12
+ from mmseg.datasets.pipelines import Compose
13
+ from mmseg.models import build_segmentor
14
+
15
+ from mmseg.datasets import build_dataloader, build_dataset, load_flood_test_data
16
+ import rasterio
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ from torchvision import transforms
21
+ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
22
+
23
+ from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor
24
+ from . import custom # custom preprocessing for hls
25
+ import pdb
26
+
27
+ import numpy as np
28
+ import glob
29
+ import os
30
+
31
+ import time
32
+
33
+ def parse_args():
34
+
35
+ parser = argparse.ArgumentParser(description="Inference on burn scar fine-tuned model")
36
+ parser.add_argument('-config', help='path to model configuration file')
37
+ parser.add_argument('-ckpt', help='path to model checkpoint')
38
+ parser.add_argument('-input', help='path to input images folder for inference')
39
+ parser.add_argument('-output', help='directory path to save output images')
40
+ parser.add_argument('-input_type', help='file type of input images',default="tif")
41
+
42
+ args = parser.parse_args()
43
+
44
+ return args
45
+
46
+ def open_tiff(fname):
47
+
48
+ with rasterio.open(fname, "r") as src:
49
+
50
+ data = src.read()
51
+
52
+ return data
53
+
54
+ def write_tiff(img_wrt, filename, metadata):
55
+
56
+ """
57
+ It writes a raster image to file.
58
+
59
+ :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
60
+ :param filename: file path to the output file
61
+ :param metadata: metadata to use to write the raster to disk
62
+ :return:
63
+ """
64
+
65
+ with rasterio.open(filename, "w", **metadata) as dest:
66
+
67
+ if len(img_wrt.shape) == 2:
68
+
69
+ img_wrt = img_wrt[None]
70
+
71
+ for i in range(img_wrt.shape[0]):
72
+ dest.write(img_wrt[i, :, :], i + 1)
73
+
74
+
75
+ def get_meta(fname):
76
+
77
+ with rasterio.open(fname, "r") as src:
78
+
79
+ meta = src.meta
80
+
81
+ return meta
82
+
83
+ def preprocess_image(data, means, stds, nodata=-9999):
84
+
85
+ data=np.where(data == nodata, 0, data)
86
+ data = data.astype(np.float32)
87
+
88
+ if len(data)==2:
89
+ (x, y) = data
90
+ else:
91
+ x=data
92
+ y=np.full((x.shape[-2], x.shape[-1]), -1)
93
+
94
+ im, label = x.copy(), y.copy()
95
+ label = label.astype(np.float64)
96
+
97
+ im1 = im[0] # red
98
+ im2 = im[1] # green
99
+ im3 = im[2] # blue
100
+ im4 = im[3] # NIR narrow
101
+ im5 = im[4] # swir 1
102
+ im6 = im[5] # swir 2
103
+
104
+ dim = x.shape[-1]
105
+ label = label.squeeze()
106
+ norm = transforms.Normalize(means, stds)
107
+ ims = [torch.stack((transforms.ToTensor()(im1).squeeze(),
108
+ transforms.ToTensor()(im2).squeeze(),
109
+ transforms.ToTensor()(im3).squeeze(),
110
+ transforms.ToTensor()(im4).squeeze(),
111
+ transforms.ToTensor()(im5).squeeze(),
112
+ transforms.ToTensor()(im6).squeeze()))]
113
+ ims = [norm(im) for im in ims]
114
+ ims = torch.stack(ims)
115
+
116
+ label = transforms.ToTensor()(label).squeeze()
117
+
118
+ _img_metas = {
119
+ 'ori_shape': (dim, dim),
120
+ 'img_shape': (dim, dim),
121
+ 'pad_shape': (dim, dim),
122
+ 'scale_factor': [1., 1., 1., 1.],
123
+ 'flip': False, # needs flip direction specified
124
+ }
125
+
126
+ img_metas = [_img_metas] * 1
127
+ return {"img": ims,
128
+ "img_metas": img_metas,
129
+ "gt_semantic_seg": label}
130
+
131
+
132
+ def load_model(config, ckpt):
133
+
134
+ print('Loading configuration...')
135
+ cfg = Config.fromfile(config)
136
+ print('Building model...')
137
+ model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
138
+ print('Loading checkpoint...')
139
+ checkpoint = load_checkpoint(model,ckpt, map_location='cpu')
140
+ print('Evaluating model...')
141
+ model = MMDataParallel(model, device_ids=[0])
142
+ model.eval()
143
+
144
+ return model
145
+
146
+
147
+ def inference_on_file(model, target_image, output_image, means, stds):
148
+
149
+ try:
150
+ st = time.time()
151
+ data_orig = open_tiff(target_image)
152
+ meta = get_meta(target_image)
153
+ nodata = meta['nodata'] if meta['nodata'] is not None else -9999
154
+
155
+ data = preprocess_image(data_orig, means, stds, nodata)
156
+
157
+ small_fixed_size_arrs = custom.split_and_pad(data['img'][:,:,None,:,:], (1, 6, 1, 224, 224))
158
+ single_chip_batch = [torch.vstack([torch.tensor(t) for t in small_fixed_size_arrs])]
159
+ print('Running inference...')
160
+ with torch.no_grad():
161
+ result = model(single_chip_batch, data['img_metas'], return_loss=False, rescale=False)
162
+ print("Result: Unique Values: ",np.unique(result))
163
+
164
+ print("Output has shape: " + str(result[0].shape))
165
+ #### TO DO: Post process (e.g. morphological operations)
166
+
167
+ result = custom.merge_and_unpad(result, (data_orig.shape[-2],data_orig.shape[-1]), (224, 224))
168
+
169
+ print("Result: Unique Values: ",np.unique(result))
170
+
171
+ ##### Save file to disk
172
+ meta["count"] = 1
173
+ meta["dtype"] = "int16"
174
+ meta["compress"] = "lzw"
175
+ meta["nodata"] = -1
176
+ meta["nodata"] = nodata
177
+ print('Saving output...')
178
+ # pdb.set_trace()
179
+ result = np.where(data_orig[0] == nodata, nodata, result)
180
+
181
+ write_tiff(result, output_image, meta)
182
+ et = time.time()
183
+ print(f'Inference completed in {str(np.round(et - st, 1))} seconds. Output available at: ' + output_image)
184
+
185
+ except:
186
+ print(f'Error on image {target_image} \nContinue to next input')
187
+
188
+ def main():
189
+
190
+ args = parse_args()
191
+
192
+ model = load_model(args.config, args.ckpt)
193
+ image_pattern = "*merged"
194
+ target_images = glob.glob(os.path.join(args.input, image_pattern + "." + args.input_type))
195
+
196
+ print('Identified images to predict on: ' + str(len(target_images)))
197
+
198
+ if not os.path.isdir(args.output):
199
+ os.mkdir(args.output)
200
+
201
+ means, stds = custom.calculate_band_statistics(args.input, image_pattern, bands=[0, 1, 2, 3, 4, 5])
202
+
203
+ for i, target_image in enumerate(target_images):
204
+
205
+ print(f'Working on Image {i}')
206
+ output_image = os.path.join(args.output,target_image.split("/")[-1].split(f"_{image_pattern[1:]}.")[0]+'_pred.'+args.input_type)
207
+
208
+ inference_on_file(model, target_image, output_image, means, stds)
209
+
210
+ print("Running metric eval")
211
+
212
+ gt_dir = "/home/workdir/hls-foundation/data/burn_scars/validation"
213
+ pred_dir = args.output
214
+ avg_dice_score = custom.compute_metrics(gt_dir, pred_dir)
215
+ print("Average Dice score:", avg_dice_score)
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
custom.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+
3
+ import numpy as np
4
+ import glob
5
+ import rasterio
6
+ from torchvision import transforms
7
+ import torch
8
+ import re
9
+ from torchmetrics import Dice
10
+ import os
11
+
12
+ def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]):
13
+ """
14
+ Calculate the mean and standard deviation of each band in a folder of GeoTIFF files.
15
+
16
+ Args:
17
+ image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training.
18
+ image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats.
19
+ bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5].
20
+
21
+ Raises:
22
+ Exception: If no images are found in the given directory.
23
+
24
+ Returns:
25
+ tuple: Two lists containing the means and standard deviations of each band.
26
+ """
27
+ # Initialize lists to store the means and standard deviations
28
+ all_means = []
29
+ all_stds = []
30
+
31
+ # Use glob to get a list of all .tif images in the directory
32
+ all_images = glob.glob(f"{image_directory}/{image_pattern}.tif")
33
+
34
+ # Make sure there are images to process
35
+ if not all_images:
36
+ raise Exception("No images found")
37
+
38
+ # Get the number of bands
39
+ num_bands = len(bands)
40
+
41
+ # Initialize arrays to hold sums and sum of squares for each band
42
+ band_sums = np.zeros(num_bands)
43
+ band_sq_sums = np.zeros(num_bands)
44
+ pixel_counts = np.zeros(num_bands)
45
+
46
+ # Iterate over each image
47
+ for image_file in all_images:
48
+ with rasterio.open(image_file) as src:
49
+ # For each band, calculate the sum, square sum, and pixel count
50
+ for band in bands:
51
+ data = src.read(band + 1) # rasterio band index starts from 1
52
+ band_sums[band] += np.nansum(data)
53
+ band_sq_sums[band] += np.nansum(data**2)
54
+ pixel_counts[band] += np.count_nonzero(~np.isnan(data))
55
+
56
+ # Calculate means and standard deviations for each band
57
+ for i in bands:
58
+ mean = band_sums[i] / pixel_counts[i]
59
+ std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2))
60
+ all_means.append(mean)
61
+ all_stds.append(std)
62
+
63
+ return all_means, all_stds
64
+
65
+
66
+ def split_and_pad(array, target_shape):
67
+ """
68
+ Splits the input array into smaller arrays of the target shape, padding if necessary.
69
+
70
+ Args:
71
+ array (numpy.ndarray): The input array. Must be shape (batch, band, time, height, width)
72
+ target_shape (tuple): The target shape of the smaller arrays. Must be of shape
73
+ (batch, band, time, height, width)
74
+
75
+ Raises:
76
+ ValueError: If target shape is larger than the array shape.
77
+
78
+ Returns:
79
+ list[numpy.ndarray]: A list of the smaller arrays.
80
+ """
81
+ # Check if the target shape is smaller or equal to the array shape
82
+ if target_shape[-2:] > array.shape[-2:]:
83
+ raise ValueError('Target shape must be smaller or equal to the array shape.')
84
+
85
+ # Calculate how much padding is needed
86
+ pad_h = (target_shape[-2] - array.shape[-2] % target_shape[-2]) % target_shape[-2]
87
+ pad_w = (target_shape[-1] - array.shape[-1] % target_shape[-1]) % target_shape[-1]
88
+
89
+ # Apply padding to the array
90
+ padded_array = np.pad(array, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)))
91
+
92
+ # Split the array into smaller arrays of the target shape
93
+ result = []
94
+ for i in range(0, padded_array.shape[-2], target_shape[-2]):
95
+ for j in range(0, padded_array.shape[-1], target_shape[-1]):
96
+ result.append(padded_array[..., i:i+target_shape[-2], j:j+target_shape[-1]])
97
+
98
+ return result
99
+
100
+ def merge_and_unpad(np_array_list, original_shape, target_shape):
101
+ """
102
+ Assembles smaller numpy arrays back into the original larger numpy array, removing padding if necessary.
103
+
104
+ Args:
105
+ np_array_list (list[numpy.ndarray]): The list of smaller numpy arrays derived from split_and_pad.
106
+ original_shape (tuple): The original shape of the larger numpy array. Must be shape (Height, Width).
107
+ target_shape (tuple): The target shape of the smaller numpy arrays. Must be shape (Height, Width).
108
+
109
+ Returns:
110
+ numpy.ndarray: The original larger numpy array.
111
+ """
112
+ # Calculate how much padding was added
113
+ pad_h = (target_shape[0] - original_shape[0] % target_shape[0]) % target_shape[0]
114
+ pad_w = (target_shape[1] - original_shape[1] % target_shape[1]) % target_shape[1]
115
+
116
+ # Calculate the shape of the padded larger array
117
+ padded_shape = (original_shape[0] + pad_h, original_shape[1] + pad_w)
118
+
119
+ # Calculate the number of smaller arrays in each dimension
120
+ num_arrays_h = padded_shape[0] // target_shape[0]
121
+ num_arrays_w = padded_shape[1] // target_shape[1]
122
+
123
+ # Reshape the list of smaller arrays back into the shape of the padded larger array
124
+ merged_array = np.stack(np_array_list).reshape(num_arrays_h, num_arrays_w, *target_shape)
125
+
126
+ # Rearrange the array dimensions
127
+ merged_array = merged_array.transpose(0, 2, 1, 3).reshape(*padded_shape)
128
+
129
+ # Remove the padding
130
+ unpadded_array = merged_array[:original_shape[0], :original_shape[1]]
131
+
132
+ return unpadded_array
133
+
134
+ def compute_metrics(gt_dir, pred_dir):
135
+ """
136
+ Compute the Dice similarity coefficient between the predicted and ground truth images.
137
+
138
+ Args:
139
+ gt_dir (str): Directory where the ground truth images are stored.
140
+ pred_dir (str): Directory where the predicted images are stored.
141
+
142
+ Returns:
143
+ Tensor: Dice similarity coefficient score.
144
+ """
145
+ dice_metric = Dice()
146
+
147
+ # find all .tif files in the prediction directory
148
+ pred_files = glob.glob(os.path.join(pred_dir, "*.tif"))
149
+
150
+ # iterate over each prediction file
151
+ for pred_file in pred_files:
152
+ # extract the unique_id from the file name
153
+ unique_id = re.search('HLS\..*\.v1\.4', os.path.basename(pred_file))
154
+
155
+ if unique_id is not None:
156
+ unique_id = unique_id.group()
157
+
158
+ # create the unique pattern for the gt directory
159
+ gt_file_pattern = os.path.join(gt_dir, f"*{unique_id}*mask.tif")
160
+
161
+ # glob the file pattern
162
+ gt_files = glob.glob(gt_file_pattern)
163
+
164
+ # if we found a matching gt file
165
+ if len(gt_files) == 1:
166
+ gt_file = gt_files[0]
167
+
168
+ # read the .tif files
169
+ with rasterio.open(gt_file) as src:
170
+ gt_img = src.read(1) # ground truth image
171
+
172
+ with rasterio.open(pred_file) as src:
173
+ pred_img = src.read(1) # predicted image
174
+
175
+ # make sure the images are binary (values are 0 or 1)
176
+ gt_img = (gt_img > 0).astype(np.uint8)
177
+ pred_img = (pred_img > 0).astype(np.uint8)
178
+
179
+ # convert numpy arrays to PyTorch tensors
180
+ gt_img_tensor = torch.from_numpy(gt_img).long().flatten()
181
+ pred_img_tensor = torch.from_numpy(pred_img).long().flatten()
182
+
183
+ # update dice_metric
184
+ dice_metric.update(pred_img_tensor, gt_img_tensor)
185
+
186
+ else:
187
+ print(f"No matching ground truth file for prediction file {pred_file}.")
188
+
189
+ # compute the dice score
190
+ dice_score = dice_metric.compute()
191
+ return dice_score
requirements.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ boxsdk==3.6.2
2
+ cityscapesscripts==2.2.1
3
+ codecov
4
+ detail==0.2.2
5
+ docutils==0.16.0
6
+ einops==0.6.0
7
+ flake8
8
+ interrogate
9
+ jupyterlab==4.0.1
10
+ matplotlib==3.5.1
11
+ mmcls>=0.20.1
12
+ mmdet==2.22.0
13
+ model_archiver==1.0.3
14
+ myst-parser
15
+ -e git+https://github.com/gaotongxiao/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
16
+ natsort==8.3.1
17
+ numpy==1.21.6
18
+ onnx==1.13.1
19
+ onnxruntime==1.14.1
20
+ onnx2torch
21
+ opencv-python==4.7.0.72
22
+ openmim
23
+ packaging==21.3
24
+ pandas==1.3.5
25
+ pavi==0.0.1
26
+ Pillow==9.4.0
27
+ pip-tools
28
+ prettytable==3.6.0
29
+ pytest==7.1.3
30
+ rasterio==1.3.4
31
+ requests==2.28.2
32
+ scikit-learn
33
+ scipy==1.7.3
34
+ scikit-image
35
+ seaborn==0.12.2
36
+ sphinx==4.0.2
37
+ sphinx_copybutton
38
+ sphinx_markdown_tables
39
+ tensorrt==8.5.3.1
40
+ timm==0.4.12
41
+ torch==1.9.0+cu111
42
+ -f https://download.pytorch.org/whl/torch_stable.html
43
+ torchvision==0.10.0
44
+ torchmetrics
45
+ ts==0.5.1
46
+ xdoctest>=0.10.0
47
+ yapf