|
|
|
|
|
import numpy as np |
|
import glob |
|
import rasterio |
|
from torchvision import transforms |
|
import torch |
|
import re |
|
from torchmetrics import Dice |
|
import os |
|
|
|
def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]): |
|
""" |
|
Calculate the mean and standard deviation of each band in a folder of GeoTIFF files. |
|
|
|
Args: |
|
image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training. |
|
image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats. |
|
bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5]. |
|
|
|
Raises: |
|
Exception: If no images are found in the given directory. |
|
|
|
Returns: |
|
tuple: Two lists containing the means and standard deviations of each band. |
|
""" |
|
|
|
all_means = [] |
|
all_stds = [] |
|
|
|
|
|
all_images = glob.glob(f"{image_directory}/{image_pattern}.tif") |
|
|
|
|
|
if not all_images: |
|
raise Exception("No images found") |
|
|
|
|
|
num_bands = len(bands) |
|
|
|
|
|
band_sums = np.zeros(num_bands) |
|
band_sq_sums = np.zeros(num_bands) |
|
pixel_counts = np.zeros(num_bands) |
|
|
|
|
|
for image_file in all_images: |
|
with rasterio.open(image_file) as src: |
|
|
|
for band in bands: |
|
data = src.read(band + 1) |
|
band_sums[band] += np.nansum(data) |
|
band_sq_sums[band] += np.nansum(data**2) |
|
pixel_counts[band] += np.count_nonzero(~np.isnan(data)) |
|
|
|
|
|
for i in bands: |
|
mean = band_sums[i] / pixel_counts[i] |
|
std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2)) |
|
all_means.append(mean) |
|
all_stds.append(std) |
|
|
|
return all_means, all_stds |
|
|
|
|
|
def split_and_pad(array, target_shape): |
|
""" |
|
Splits the input array into smaller arrays of the target shape, padding if necessary. |
|
|
|
Args: |
|
array (numpy.ndarray): The input array. Must be shape (batch, band, time, height, width) |
|
target_shape (tuple): The target shape of the smaller arrays. Must be of shape |
|
(batch, band, time, height, width) |
|
|
|
Raises: |
|
ValueError: If target shape is larger than the array shape. |
|
|
|
Returns: |
|
list[numpy.ndarray]: A list of the smaller arrays. |
|
""" |
|
|
|
if target_shape[-2:] > array.shape[-2:]: |
|
raise ValueError('Target shape must be smaller or equal to the array shape.') |
|
|
|
|
|
pad_h = (target_shape[-2] - array.shape[-2] % target_shape[-2]) % target_shape[-2] |
|
pad_w = (target_shape[-1] - array.shape[-1] % target_shape[-1]) % target_shape[-1] |
|
|
|
|
|
padded_array = np.pad(array, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w))) |
|
|
|
|
|
result = [] |
|
for i in range(0, padded_array.shape[-2], target_shape[-2]): |
|
for j in range(0, padded_array.shape[-1], target_shape[-1]): |
|
result.append(padded_array[..., i:i+target_shape[-2], j:j+target_shape[-1]]) |
|
|
|
return result |
|
|
|
def merge_and_unpad(np_array_list, original_shape, target_shape): |
|
""" |
|
Assembles smaller numpy arrays back into the original larger numpy array, removing padding if necessary. |
|
|
|
Args: |
|
np_array_list (list[numpy.ndarray]): The list of smaller numpy arrays derived from split_and_pad. |
|
original_shape (tuple): The original shape of the larger numpy array. Must be shape (Height, Width). |
|
target_shape (tuple): The target shape of the smaller numpy arrays. Must be shape (Height, Width). |
|
|
|
Returns: |
|
numpy.ndarray: The original larger numpy array. |
|
""" |
|
|
|
pad_h = (target_shape[0] - original_shape[0] % target_shape[0]) % target_shape[0] |
|
pad_w = (target_shape[1] - original_shape[1] % target_shape[1]) % target_shape[1] |
|
|
|
|
|
padded_shape = (original_shape[0] + pad_h, original_shape[1] + pad_w) |
|
|
|
|
|
num_arrays_h = padded_shape[0] // target_shape[0] |
|
num_arrays_w = padded_shape[1] // target_shape[1] |
|
|
|
|
|
merged_array = np.stack(np_array_list).reshape(num_arrays_h, num_arrays_w, *target_shape) |
|
|
|
|
|
merged_array = merged_array.transpose(0, 2, 1, 3).reshape(*padded_shape) |
|
|
|
|
|
unpadded_array = merged_array[:original_shape[0], :original_shape[1]] |
|
|
|
return unpadded_array |
|
|
|
def compute_metrics(gt_dir, pred_dir): |
|
""" |
|
Compute the Dice similarity coefficient between the predicted and ground truth images. |
|
|
|
Args: |
|
gt_dir (str): Directory where the ground truth images are stored. |
|
pred_dir (str): Directory where the predicted images are stored. |
|
|
|
Returns: |
|
Tensor: Dice similarity coefficient score. |
|
""" |
|
dice_metric = Dice() |
|
|
|
|
|
pred_files = glob.glob(os.path.join(pred_dir, "*.tif")) |
|
|
|
|
|
for pred_file in pred_files: |
|
|
|
unique_id = re.search('HLS\..*\.v1\.4', os.path.basename(pred_file)) |
|
|
|
if unique_id is not None: |
|
unique_id = unique_id.group() |
|
|
|
|
|
gt_file_pattern = os.path.join(gt_dir, f"*{unique_id}*mask.tif") |
|
|
|
|
|
gt_files = glob.glob(gt_file_pattern) |
|
|
|
|
|
if len(gt_files) == 1: |
|
gt_file = gt_files[0] |
|
|
|
|
|
with rasterio.open(gt_file) as src: |
|
gt_img = src.read(1) |
|
|
|
with rasterio.open(pred_file) as src: |
|
pred_img = src.read(1) |
|
|
|
|
|
gt_img = (gt_img > 0).astype(np.uint8) |
|
pred_img = (pred_img > 0).astype(np.uint8) |
|
|
|
|
|
gt_img_tensor = torch.from_numpy(gt_img).long().flatten() |
|
pred_img_tensor = torch.from_numpy(pred_img).long().flatten() |
|
|
|
|
|
dice_metric.update(pred_img_tensor, gt_img_tensor) |
|
|
|
else: |
|
print(f"No matching ground truth file for prediction file {pred_file}.") |
|
|
|
|
|
dice_score = dice_metric.compute() |
|
return dice_score |
|
|