|
import argparse |
|
from mmcv import Config |
|
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,wrap_fp16_model) |
|
from mmseg.models import build_segmentor |
|
|
|
import matplotlib.pyplot as plt |
|
import mmcv |
|
import torch |
|
from mmcv.parallel import collate, scatter |
|
from mmcv.runner import load_checkpoint |
|
|
|
from mmseg.datasets.pipelines import Compose |
|
from mmseg.models import build_segmentor |
|
|
|
from mmseg.datasets import build_dataloader, build_dataset, load_flood_test_data |
|
import rasterio |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from torchvision import transforms |
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
|
|
|
from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor |
|
from . import custom |
|
import pdb |
|
|
|
import numpy as np |
|
import glob |
|
import os |
|
|
|
import time |
|
|
|
def parse_args(): |
|
|
|
parser = argparse.ArgumentParser(description="Inference on burn scar fine-tuned model") |
|
parser.add_argument('-config', help='path to model configuration file') |
|
parser.add_argument('-ckpt', help='path to model checkpoint') |
|
parser.add_argument('-input', help='path to input images folder for inference') |
|
parser.add_argument('-output', help='directory path to save output images') |
|
parser.add_argument('-input_type', help='file type of input images',default="tif") |
|
|
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
def open_tiff(fname): |
|
|
|
with rasterio.open(fname, "r") as src: |
|
|
|
data = src.read() |
|
|
|
return data |
|
|
|
def write_tiff(img_wrt, filename, metadata): |
|
|
|
""" |
|
It writes a raster image to file. |
|
|
|
:param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands) |
|
:param filename: file path to the output file |
|
:param metadata: metadata to use to write the raster to disk |
|
:return: |
|
""" |
|
|
|
with rasterio.open(filename, "w", **metadata) as dest: |
|
|
|
if len(img_wrt.shape) == 2: |
|
|
|
img_wrt = img_wrt[None] |
|
|
|
for i in range(img_wrt.shape[0]): |
|
dest.write(img_wrt[i, :, :], i + 1) |
|
|
|
|
|
def get_meta(fname): |
|
|
|
with rasterio.open(fname, "r") as src: |
|
|
|
meta = src.meta |
|
|
|
return meta |
|
|
|
def preprocess_image(data, means, stds, nodata=-9999): |
|
|
|
data=np.where(data == nodata, 0, data) |
|
data = data.astype(np.float32) |
|
|
|
if len(data)==2: |
|
(x, y) = data |
|
else: |
|
x=data |
|
y=np.full((x.shape[-2], x.shape[-1]), -1) |
|
|
|
im, label = x.copy(), y.copy() |
|
label = label.astype(np.float64) |
|
|
|
im1 = im[0] |
|
im2 = im[1] |
|
im3 = im[2] |
|
im4 = im[3] |
|
im5 = im[4] |
|
im6 = im[5] |
|
|
|
dim = x.shape[-1] |
|
label = label.squeeze() |
|
norm = transforms.Normalize(means, stds) |
|
ims = [torch.stack((transforms.ToTensor()(im1).squeeze(), |
|
transforms.ToTensor()(im2).squeeze(), |
|
transforms.ToTensor()(im3).squeeze(), |
|
transforms.ToTensor()(im4).squeeze(), |
|
transforms.ToTensor()(im5).squeeze(), |
|
transforms.ToTensor()(im6).squeeze()))] |
|
ims = [norm(im) for im in ims] |
|
ims = torch.stack(ims) |
|
|
|
label = transforms.ToTensor()(label).squeeze() |
|
|
|
_img_metas = { |
|
'ori_shape': (dim, dim), |
|
'img_shape': (dim, dim), |
|
'pad_shape': (dim, dim), |
|
'scale_factor': [1., 1., 1., 1.], |
|
'flip': False, |
|
} |
|
|
|
img_metas = [_img_metas] * 1 |
|
return {"img": ims, |
|
"img_metas": img_metas, |
|
"gt_semantic_seg": label} |
|
|
|
|
|
def load_model(config, ckpt): |
|
|
|
print('Loading configuration...') |
|
cfg = Config.fromfile(config) |
|
print('Building model...') |
|
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) |
|
print('Loading checkpoint...') |
|
checkpoint = load_checkpoint(model,ckpt, map_location='cpu') |
|
print('Evaluating model...') |
|
model = MMDataParallel(model, device_ids=[0]) |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
def inference_on_file(model, target_image, output_image, means, stds): |
|
|
|
try: |
|
st = time.time() |
|
data_orig = open_tiff(target_image) |
|
meta = get_meta(target_image) |
|
nodata = meta['nodata'] if meta['nodata'] is not None else -9999 |
|
|
|
data = preprocess_image(data_orig, means, stds, nodata) |
|
|
|
small_fixed_size_arrs = custom.split_and_pad(data['img'][:,:,None,:,:], (1, 6, 1, 224, 224)) |
|
single_chip_batch = [torch.vstack([torch.tensor(t) for t in small_fixed_size_arrs])] |
|
print('Running inference...') |
|
with torch.no_grad(): |
|
result = model(single_chip_batch, data['img_metas'], return_loss=False, rescale=False) |
|
print("Result: Unique Values: ",np.unique(result)) |
|
|
|
print("Output has shape: " + str(result[0].shape)) |
|
|
|
|
|
result = custom.merge_and_unpad(result, (data_orig.shape[-2],data_orig.shape[-1]), (224, 224)) |
|
|
|
print("Result: Unique Values: ",np.unique(result)) |
|
|
|
|
|
meta["count"] = 1 |
|
meta["dtype"] = "int16" |
|
meta["compress"] = "lzw" |
|
meta["nodata"] = -1 |
|
meta["nodata"] = nodata |
|
print('Saving output...') |
|
|
|
result = np.where(data_orig[0] == nodata, nodata, result) |
|
|
|
write_tiff(result, output_image, meta) |
|
et = time.time() |
|
print(f'Inference completed in {str(np.round(et - st, 1))} seconds. Output available at: ' + output_image) |
|
|
|
except: |
|
print(f'Error on image {target_image} \nContinue to next input') |
|
|
|
def main(): |
|
|
|
args = parse_args() |
|
|
|
model = load_model(args.config, args.ckpt) |
|
image_pattern = "*merged" |
|
target_images = glob.glob(os.path.join(args.input, image_pattern + "." + args.input_type)) |
|
|
|
print('Identified images to predict on: ' + str(len(target_images))) |
|
|
|
if not os.path.isdir(args.output): |
|
os.mkdir(args.output) |
|
|
|
means, stds = custom.calculate_band_statistics(args.input, image_pattern, bands=[0, 1, 2, 3, 4, 5]) |
|
|
|
for i, target_image in enumerate(target_images): |
|
|
|
print(f'Working on Image {i}') |
|
output_image = os.path.join(args.output,target_image.split("/")[-1].split(f"_{image_pattern[1:]}.")[0]+'_pred.'+args.input_type) |
|
|
|
inference_on_file(model, target_image, output_image, means, stds) |
|
|
|
print("Running metric eval") |
|
|
|
gt_dir = "/home/workdir/hls-foundation/data/burn_scars/validation" |
|
pred_dir = args.output |
|
avg_dice_score = custom.compute_metrics(gt_dir, pred_dir) |
|
print("Average Dice score:", avg_dice_score) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|