rbavery commited on
Commit
996303b
1 Parent(s): c599a73

Upload burn_scar_batch_inference_script.py

Browse files
Files changed (1) hide show
  1. burn_scar_batch_inference_script.py +219 -0
burn_scar_batch_inference_script.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from mmcv import Config
3
+ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,wrap_fp16_model)
4
+ from mmseg.models import build_segmentor
5
+
6
+ import matplotlib.pyplot as plt
7
+ import mmcv
8
+ import torch
9
+ from mmcv.parallel import collate, scatter
10
+ from mmcv.runner import load_checkpoint
11
+
12
+ from mmseg.datasets.pipelines import Compose
13
+ from mmseg.models import build_segmentor
14
+
15
+ from mmseg.datasets import build_dataloader, build_dataset, load_flood_test_data
16
+ import rasterio
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ from torchvision import transforms
21
+ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
22
+
23
+ from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor
24
+ from mmseg.utils import custom # custom preprocessing for hls
25
+ import pdb
26
+
27
+ import numpy as np
28
+ import glob
29
+ import os
30
+
31
+ import time
32
+
33
+ def parse_args():
34
+
35
+ parser = argparse.ArgumentParser(description="Inference on burn scar fine-tuned model")
36
+ parser.add_argument('-config', help='path to model configuration file')
37
+ parser.add_argument('-ckpt', help='path to model checkpoint')
38
+ parser.add_argument('-input', help='path to input images folder for inference')
39
+ parser.add_argument('-output', help='directory path to save output images')
40
+ parser.add_argument('-input_type', help='file type of input images',default="tif")
41
+
42
+ args = parser.parse_args()
43
+
44
+ return args
45
+
46
+ def open_tiff(fname):
47
+
48
+ with rasterio.open(fname, "r") as src:
49
+
50
+ data = src.read()
51
+
52
+ return data
53
+
54
+ def write_tiff(img_wrt, filename, metadata):
55
+
56
+ """
57
+ It writes a raster image to file.
58
+
59
+ :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
60
+ :param filename: file path to the output file
61
+ :param metadata: metadata to use to write the raster to disk
62
+ :return:
63
+ """
64
+
65
+ with rasterio.open(filename, "w", **metadata) as dest:
66
+
67
+ if len(img_wrt.shape) == 2:
68
+
69
+ img_wrt = img_wrt[None]
70
+
71
+ for i in range(img_wrt.shape[0]):
72
+ dest.write(img_wrt[i, :, :], i + 1)
73
+
74
+
75
+ def get_meta(fname):
76
+
77
+ with rasterio.open(fname, "r") as src:
78
+
79
+ meta = src.meta
80
+
81
+ return meta
82
+
83
+ def preprocess_image(data, means, stds, nodata=-9999):
84
+
85
+ data=np.where(data == nodata, 0, data)
86
+ data = data.astype(np.float32)
87
+
88
+ if len(data)==2:
89
+ (x, y) = data
90
+ else:
91
+ x=data
92
+ y=np.full((x.shape[-2], x.shape[-1]), -1)
93
+
94
+ im, label = x.copy(), y.copy()
95
+ label = label.astype(np.float64)
96
+
97
+ im1 = im[0] # red
98
+ im2 = im[1] # green
99
+ im3 = im[2] # blue
100
+ im4 = im[3] # NIR narrow
101
+ im5 = im[4] # swir 1
102
+ im6 = im[5] # swir 2
103
+
104
+ dim = x.shape[-1]
105
+ label = label.squeeze()
106
+ norm = transforms.Normalize(means, stds)
107
+ ims = [torch.stack((transforms.ToTensor()(im1).squeeze(),
108
+ transforms.ToTensor()(im2).squeeze(),
109
+ transforms.ToTensor()(im3).squeeze(),
110
+ transforms.ToTensor()(im4).squeeze(),
111
+ transforms.ToTensor()(im5).squeeze(),
112
+ transforms.ToTensor()(im6).squeeze()))]
113
+ ims = [norm(im) for im in ims]
114
+ ims = torch.stack(ims)
115
+
116
+ label = transforms.ToTensor()(label).squeeze()
117
+
118
+ _img_metas = {
119
+ 'ori_shape': (dim, dim),
120
+ 'img_shape': (dim, dim),
121
+ 'pad_shape': (dim, dim),
122
+ 'scale_factor': [1., 1., 1., 1.],
123
+ 'flip': False, # needs flip direction specified
124
+ }
125
+
126
+ img_metas = [_img_metas] * 1
127
+ return {"img": ims,
128
+ "img_metas": img_metas,
129
+ "gt_semantic_seg": label}
130
+
131
+
132
+ def load_model(config, ckpt):
133
+
134
+ print('Loading configuration...')
135
+ cfg = Config.fromfile(config)
136
+ print('Building model...')
137
+ model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
138
+ print('Loading checkpoint...')
139
+ checkpoint = load_checkpoint(model,ckpt, map_location='cpu')
140
+ print('Evaluating model...')
141
+ model = MMDataParallel(model, device_ids=[0])
142
+ model.eval()
143
+
144
+ return model
145
+
146
+
147
+ def inference_on_file(model, target_image, output_image, means, stds):
148
+
149
+ try:
150
+ st = time.time()
151
+ data_orig = open_tiff(target_image)
152
+ meta = get_meta(target_image)
153
+ nodata = meta['nodata'] if meta['nodata'] is not None else -9999
154
+
155
+ data = preprocess_image(data_orig, means, stds, nodata)
156
+
157
+ small_fixed_size_arrs = custom.split_and_pad(data['img'][:,:,None,:,:], (1, 6, 1, 224, 224))
158
+ single_chip_batch = [torch.vstack([torch.tensor(t) for t in small_fixed_size_arrs])]
159
+ print('Running inference...')
160
+ with torch.no_grad():
161
+ result = model(single_chip_batch, data['img_metas'], return_loss=False, rescale=False)
162
+ print("Result: Unique Values: ",np.unique(result))
163
+
164
+ print("Output has shape: " + str(result[0].shape))
165
+ #### TO DO: Post process (e.g. morphological operations)
166
+
167
+ result = custom.merge_and_unpad(result, (data_orig.shape[-2],data_orig.shape[-1]), (224, 224))
168
+
169
+ print("Result: Unique Values: ",np.unique(result))
170
+
171
+ ##### Save file to disk
172
+ meta["count"] = 1
173
+ meta["dtype"] = "int16"
174
+ meta["compress"] = "lzw"
175
+ meta["nodata"] = -1
176
+ meta["nodata"] = nodata
177
+ print('Saving output...')
178
+ # pdb.set_trace()
179
+ result = np.where(data_orig[0] == nodata, nodata, result)
180
+
181
+ write_tiff(result, output_image, meta)
182
+ et = time.time()
183
+ print(f'Inference completed in {str(np.round(et - st, 1))} seconds. Output available at: ' + output_image)
184
+
185
+ except:
186
+ print(f'Error on image {target_image} \nContinue to next input')
187
+
188
+ def main():
189
+
190
+ args = parse_args()
191
+
192
+ model = load_model(args.config, args.ckpt)
193
+ image_pattern = "*merged"
194
+ target_images = glob.glob(os.path.join(args.input, image_pattern + "." + args.input_type))
195
+
196
+ print('Identified images to predict on: ' + str(len(target_images)))
197
+
198
+ if not os.path.isdir(args.output):
199
+ os.mkdir(args.output)
200
+
201
+ means, stds = custom.calculate_band_statistics(args.input, image_pattern, bands=[0, 1, 2, 3, 4, 5])
202
+
203
+ for i, target_image in enumerate(target_images):
204
+
205
+ print(f'Working on Image {i}')
206
+ output_image = os.path.join(args.output,target_image.split("/")[-1].split(f"_{image_pattern[1:]}.")[0]+'_pred.'+args.input_type)
207
+
208
+ inference_on_file(model, target_image, output_image, means, stds)
209
+
210
+ print("Running metric eval")
211
+
212
+ gt_dir = "/home/workdir/hls-foundation/data/burn_scars/validation"
213
+ pred_dir = args.output
214
+ avg_dice_score = custom.compute_metrics(gt_dir, pred_dir)
215
+ print("Average Dice score:", avg_dice_score)
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()