File size: 7,266 Bytes
c597257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# utils.py

import numpy as np
import glob
import rasterio
from torchvision import transforms
import torch
import re
from torchmetrics import Dice
import os

def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]):
    """
    Calculate the mean and standard deviation of each band in a folder of GeoTIFF files.

    Args:
        image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training.
        image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats.
        bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5].

    Raises:
        Exception: If no images are found in the given directory.

    Returns:
        tuple: Two lists containing the means and standard deviations of each band.
    """
    # Initialize lists to store the means and standard deviations
    all_means = []
    all_stds = []

    # Use glob to get a list of all .tif images in the directory
    all_images = glob.glob(f"{image_directory}/{image_pattern}.tif")

    # Make sure there are images to process
    if not all_images:
        raise Exception("No images found")

    # Get the number of bands
    num_bands = len(bands)

    # Initialize arrays to hold sums and sum of squares for each band
    band_sums = np.zeros(num_bands)
    band_sq_sums = np.zeros(num_bands)
    pixel_counts = np.zeros(num_bands)

    # Iterate over each image
    for image_file in all_images:
        with rasterio.open(image_file) as src:
            # For each band, calculate the sum, square sum, and pixel count
            for band in bands:
                data = src.read(band + 1)  # rasterio band index starts from 1
                band_sums[band] += np.nansum(data)
                band_sq_sums[band] += np.nansum(data**2)
                pixel_counts[band] += np.count_nonzero(~np.isnan(data))

    # Calculate means and standard deviations for each band
    for i in bands:
        mean = band_sums[i] / pixel_counts[i]
        std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2))
        all_means.append(mean)
        all_stds.append(std)

    return all_means, all_stds


def split_and_pad(array, target_shape):
    """
    Splits the input array into smaller arrays of the target shape, padding if necessary.

    Args:
        array (numpy.ndarray): The input array. Must be shape (batch, band, time, height, width)
        target_shape (tuple): The target shape of the smaller arrays. Must be of shape
        (batch, band, time, height, width)

    Raises:
        ValueError: If target shape is larger than the array shape.

    Returns:
        list[numpy.ndarray]: A list of the smaller arrays.
    """
    # Check if the target shape is smaller or equal to the array shape
    if target_shape[-2:] > array.shape[-2:]:
        raise ValueError('Target shape must be smaller or equal to the array shape.')

    # Calculate how much padding is needed
    pad_h = (target_shape[-2] - array.shape[-2] % target_shape[-2]) % target_shape[-2]
    pad_w = (target_shape[-1] - array.shape[-1] % target_shape[-1]) % target_shape[-1]

    # Apply padding to the array
    padded_array = np.pad(array, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)))

    # Split the array into smaller arrays of the target shape
    result = []
    for i in range(0, padded_array.shape[-2], target_shape[-2]):
        for j in range(0, padded_array.shape[-1], target_shape[-1]):
            result.append(padded_array[..., i:i+target_shape[-2], j:j+target_shape[-1]])

    return result

def merge_and_unpad(np_array_list, original_shape, target_shape):
    """
    Assembles smaller numpy arrays back into the original larger numpy array, removing padding if necessary.

    Args:
        np_array_list (list[numpy.ndarray]): The list of smaller numpy arrays derived from split_and_pad.
        original_shape (tuple): The original shape of the larger numpy array. Must be shape (Height, Width).
        target_shape (tuple): The target shape of the smaller numpy arrays. Must be shape (Height, Width).

    Returns:
        numpy.ndarray: The original larger numpy array.
    """
    # Calculate how much padding was added
    pad_h = (target_shape[0] - original_shape[0] % target_shape[0]) % target_shape[0]
    pad_w = (target_shape[1] - original_shape[1] % target_shape[1]) % target_shape[1]

    # Calculate the shape of the padded larger array
    padded_shape = (original_shape[0] + pad_h, original_shape[1] + pad_w)

    # Calculate the number of smaller arrays in each dimension
    num_arrays_h = padded_shape[0] // target_shape[0]
    num_arrays_w = padded_shape[1] // target_shape[1]

    # Reshape the list of smaller arrays back into the shape of the padded larger array
    merged_array = np.stack(np_array_list).reshape(num_arrays_h, num_arrays_w, *target_shape)

    # Rearrange the array dimensions
    merged_array = merged_array.transpose(0, 2, 1, 3).reshape(*padded_shape)

    # Remove the padding
    unpadded_array = merged_array[:original_shape[0], :original_shape[1]]

    return unpadded_array

def compute_metrics(gt_dir, pred_dir):
    """
    Compute the Dice similarity coefficient between the predicted and ground truth images.

    Args:
        gt_dir (str): Directory where the ground truth images are stored.
        pred_dir (str): Directory where the predicted images are stored.

    Returns:
        Tensor: Dice similarity coefficient score.
    """
    dice_metric = Dice()

    # find all .tif files in the prediction directory
    pred_files = glob.glob(os.path.join(pred_dir, "*.tif"))

    # iterate over each prediction file
    for pred_file in pred_files:
        # extract the unique_id from the file name
        unique_id = re.search('HLS\..*\.v1\.4', os.path.basename(pred_file))

        if unique_id is not None:
            unique_id = unique_id.group()

            # create the unique pattern for the gt directory
            gt_file_pattern = os.path.join(gt_dir, f"*{unique_id}*mask.tif")

            # glob the file pattern
            gt_files = glob.glob(gt_file_pattern)

            # if we found a matching gt file
            if len(gt_files) == 1:
                gt_file = gt_files[0]

                # read the .tif files
                with rasterio.open(gt_file) as src:
                    gt_img = src.read(1) # ground truth image

                with rasterio.open(pred_file) as src:
                    pred_img = src.read(1) # predicted image

                # make sure the images are binary (values are 0 or 1)
                gt_img = (gt_img > 0).astype(np.uint8)
                pred_img = (pred_img > 0).astype(np.uint8)

                # convert numpy arrays to PyTorch tensors
                gt_img_tensor = torch.from_numpy(gt_img).long().flatten()
                pred_img_tensor = torch.from_numpy(pred_img).long().flatten()

                # update dice_metric
                dice_metric.update(pred_img_tensor, gt_img_tensor)

            else:
                print(f"No matching ground truth file for prediction file {pred_file}.")

    # compute the dice score
    dice_score = dice_metric.compute()
    return dice_score