File size: 7,266 Bytes
c597257 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# utils.py
import numpy as np
import glob
import rasterio
from torchvision import transforms
import torch
import re
from torchmetrics import Dice
import os
def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]):
"""
Calculate the mean and standard deviation of each band in a folder of GeoTIFF files.
Args:
image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training.
image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats.
bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5].
Raises:
Exception: If no images are found in the given directory.
Returns:
tuple: Two lists containing the means and standard deviations of each band.
"""
# Initialize lists to store the means and standard deviations
all_means = []
all_stds = []
# Use glob to get a list of all .tif images in the directory
all_images = glob.glob(f"{image_directory}/{image_pattern}.tif")
# Make sure there are images to process
if not all_images:
raise Exception("No images found")
# Get the number of bands
num_bands = len(bands)
# Initialize arrays to hold sums and sum of squares for each band
band_sums = np.zeros(num_bands)
band_sq_sums = np.zeros(num_bands)
pixel_counts = np.zeros(num_bands)
# Iterate over each image
for image_file in all_images:
with rasterio.open(image_file) as src:
# For each band, calculate the sum, square sum, and pixel count
for band in bands:
data = src.read(band + 1) # rasterio band index starts from 1
band_sums[band] += np.nansum(data)
band_sq_sums[band] += np.nansum(data**2)
pixel_counts[band] += np.count_nonzero(~np.isnan(data))
# Calculate means and standard deviations for each band
for i in bands:
mean = band_sums[i] / pixel_counts[i]
std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2))
all_means.append(mean)
all_stds.append(std)
return all_means, all_stds
def split_and_pad(array, target_shape):
"""
Splits the input array into smaller arrays of the target shape, padding if necessary.
Args:
array (numpy.ndarray): The input array. Must be shape (batch, band, time, height, width)
target_shape (tuple): The target shape of the smaller arrays. Must be of shape
(batch, band, time, height, width)
Raises:
ValueError: If target shape is larger than the array shape.
Returns:
list[numpy.ndarray]: A list of the smaller arrays.
"""
# Check if the target shape is smaller or equal to the array shape
if target_shape[-2:] > array.shape[-2:]:
raise ValueError('Target shape must be smaller or equal to the array shape.')
# Calculate how much padding is needed
pad_h = (target_shape[-2] - array.shape[-2] % target_shape[-2]) % target_shape[-2]
pad_w = (target_shape[-1] - array.shape[-1] % target_shape[-1]) % target_shape[-1]
# Apply padding to the array
padded_array = np.pad(array, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)))
# Split the array into smaller arrays of the target shape
result = []
for i in range(0, padded_array.shape[-2], target_shape[-2]):
for j in range(0, padded_array.shape[-1], target_shape[-1]):
result.append(padded_array[..., i:i+target_shape[-2], j:j+target_shape[-1]])
return result
def merge_and_unpad(np_array_list, original_shape, target_shape):
"""
Assembles smaller numpy arrays back into the original larger numpy array, removing padding if necessary.
Args:
np_array_list (list[numpy.ndarray]): The list of smaller numpy arrays derived from split_and_pad.
original_shape (tuple): The original shape of the larger numpy array. Must be shape (Height, Width).
target_shape (tuple): The target shape of the smaller numpy arrays. Must be shape (Height, Width).
Returns:
numpy.ndarray: The original larger numpy array.
"""
# Calculate how much padding was added
pad_h = (target_shape[0] - original_shape[0] % target_shape[0]) % target_shape[0]
pad_w = (target_shape[1] - original_shape[1] % target_shape[1]) % target_shape[1]
# Calculate the shape of the padded larger array
padded_shape = (original_shape[0] + pad_h, original_shape[1] + pad_w)
# Calculate the number of smaller arrays in each dimension
num_arrays_h = padded_shape[0] // target_shape[0]
num_arrays_w = padded_shape[1] // target_shape[1]
# Reshape the list of smaller arrays back into the shape of the padded larger array
merged_array = np.stack(np_array_list).reshape(num_arrays_h, num_arrays_w, *target_shape)
# Rearrange the array dimensions
merged_array = merged_array.transpose(0, 2, 1, 3).reshape(*padded_shape)
# Remove the padding
unpadded_array = merged_array[:original_shape[0], :original_shape[1]]
return unpadded_array
def compute_metrics(gt_dir, pred_dir):
"""
Compute the Dice similarity coefficient between the predicted and ground truth images.
Args:
gt_dir (str): Directory where the ground truth images are stored.
pred_dir (str): Directory where the predicted images are stored.
Returns:
Tensor: Dice similarity coefficient score.
"""
dice_metric = Dice()
# find all .tif files in the prediction directory
pred_files = glob.glob(os.path.join(pred_dir, "*.tif"))
# iterate over each prediction file
for pred_file in pred_files:
# extract the unique_id from the file name
unique_id = re.search('HLS\..*\.v1\.4', os.path.basename(pred_file))
if unique_id is not None:
unique_id = unique_id.group()
# create the unique pattern for the gt directory
gt_file_pattern = os.path.join(gt_dir, f"*{unique_id}*mask.tif")
# glob the file pattern
gt_files = glob.glob(gt_file_pattern)
# if we found a matching gt file
if len(gt_files) == 1:
gt_file = gt_files[0]
# read the .tif files
with rasterio.open(gt_file) as src:
gt_img = src.read(1) # ground truth image
with rasterio.open(pred_file) as src:
pred_img = src.read(1) # predicted image
# make sure the images are binary (values are 0 or 1)
gt_img = (gt_img > 0).astype(np.uint8)
pred_img = (pred_img > 0).astype(np.uint8)
# convert numpy arrays to PyTorch tensors
gt_img_tensor = torch.from_numpy(gt_img).long().flatten()
pred_img_tensor = torch.from_numpy(pred_img).long().flatten()
# update dice_metric
dice_metric.update(pred_img_tensor, gt_img_tensor)
else:
print(f"No matching ground truth file for prediction file {pred_file}.")
# compute the dice score
dice_score = dice_metric.compute()
return dice_score
|