# utils.py import numpy as np import glob import rasterio from torchvision import transforms import torch import re from torchmetrics import Dice import os def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]): """ Calculate the mean and standard deviation of each band in a folder of GeoTIFF files. Args: image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training. image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats. bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5]. Raises: Exception: If no images are found in the given directory. Returns: tuple: Two lists containing the means and standard deviations of each band. """ # Initialize lists to store the means and standard deviations all_means = [] all_stds = [] # Use glob to get a list of all .tif images in the directory all_images = glob.glob(f"{image_directory}/{image_pattern}.tif") # Make sure there are images to process if not all_images: raise Exception("No images found") # Get the number of bands num_bands = len(bands) # Initialize arrays to hold sums and sum of squares for each band band_sums = np.zeros(num_bands) band_sq_sums = np.zeros(num_bands) pixel_counts = np.zeros(num_bands) # Iterate over each image for image_file in all_images: with rasterio.open(image_file) as src: # For each band, calculate the sum, square sum, and pixel count for band in bands: data = src.read(band + 1) # rasterio band index starts from 1 band_sums[band] += np.nansum(data) band_sq_sums[band] += np.nansum(data**2) pixel_counts[band] += np.count_nonzero(~np.isnan(data)) # Calculate means and standard deviations for each band for i in bands: mean = band_sums[i] / pixel_counts[i] std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2)) all_means.append(mean) all_stds.append(std) return all_means, all_stds def split_and_pad(array, target_shape): """ Splits the input array into smaller arrays of the target shape, padding if necessary. Args: array (numpy.ndarray): The input array. Must be shape (batch, band, time, height, width) target_shape (tuple): The target shape of the smaller arrays. Must be of shape (batch, band, time, height, width) Raises: ValueError: If target shape is larger than the array shape. Returns: list[numpy.ndarray]: A list of the smaller arrays. """ # Check if the target shape is smaller or equal to the array shape if target_shape[-2:] > array.shape[-2:]: raise ValueError('Target shape must be smaller or equal to the array shape.') # Calculate how much padding is needed pad_h = (target_shape[-2] - array.shape[-2] % target_shape[-2]) % target_shape[-2] pad_w = (target_shape[-1] - array.shape[-1] % target_shape[-1]) % target_shape[-1] # Apply padding to the array padded_array = np.pad(array, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w))) # Split the array into smaller arrays of the target shape result = [] for i in range(0, padded_array.shape[-2], target_shape[-2]): for j in range(0, padded_array.shape[-1], target_shape[-1]): result.append(padded_array[..., i:i+target_shape[-2], j:j+target_shape[-1]]) return result def merge_and_unpad(np_array_list, original_shape, target_shape): """ Assembles smaller numpy arrays back into the original larger numpy array, removing padding if necessary. Args: np_array_list (list[numpy.ndarray]): The list of smaller numpy arrays derived from split_and_pad. original_shape (tuple): The original shape of the larger numpy array. Must be shape (Height, Width). target_shape (tuple): The target shape of the smaller numpy arrays. Must be shape (Height, Width). Returns: numpy.ndarray: The original larger numpy array. """ # Calculate how much padding was added pad_h = (target_shape[0] - original_shape[0] % target_shape[0]) % target_shape[0] pad_w = (target_shape[1] - original_shape[1] % target_shape[1]) % target_shape[1] # Calculate the shape of the padded larger array padded_shape = (original_shape[0] + pad_h, original_shape[1] + pad_w) # Calculate the number of smaller arrays in each dimension num_arrays_h = padded_shape[0] // target_shape[0] num_arrays_w = padded_shape[1] // target_shape[1] # Reshape the list of smaller arrays back into the shape of the padded larger array merged_array = np.stack(np_array_list).reshape(num_arrays_h, num_arrays_w, *target_shape) # Rearrange the array dimensions merged_array = merged_array.transpose(0, 2, 1, 3).reshape(*padded_shape) # Remove the padding unpadded_array = merged_array[:original_shape[0], :original_shape[1]] return unpadded_array def compute_metrics(gt_dir, pred_dir): """ Compute the Dice similarity coefficient between the predicted and ground truth images. Args: gt_dir (str): Directory where the ground truth images are stored. pred_dir (str): Directory where the predicted images are stored. Returns: Tensor: Dice similarity coefficient score. """ dice_metric = Dice() # find all .tif files in the prediction directory pred_files = glob.glob(os.path.join(pred_dir, "*.tif")) # iterate over each prediction file for pred_file in pred_files: # extract the unique_id from the file name unique_id = re.search('HLS\..*\.v1\.4', os.path.basename(pred_file)) if unique_id is not None: unique_id = unique_id.group() # create the unique pattern for the gt directory gt_file_pattern = os.path.join(gt_dir, f"*{unique_id}*mask.tif") # glob the file pattern gt_files = glob.glob(gt_file_pattern) # if we found a matching gt file if len(gt_files) == 1: gt_file = gt_files[0] # read the .tif files with rasterio.open(gt_file) as src: gt_img = src.read(1) # ground truth image with rasterio.open(pred_file) as src: pred_img = src.read(1) # predicted image # make sure the images are binary (values are 0 or 1) gt_img = (gt_img > 0).astype(np.uint8) pred_img = (pred_img > 0).astype(np.uint8) # convert numpy arrays to PyTorch tensors gt_img_tensor = torch.from_numpy(gt_img).long().flatten() pred_img_tensor = torch.from_numpy(pred_img).long().flatten() # update dice_metric dice_metric.update(pred_img_tensor, gt_img_tensor) else: print(f"No matching ground truth file for prediction file {pred_file}.") # compute the dice score dice_score = dice_metric.compute() return dice_score