File size: 7,503 Bytes
551ee08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
Cloud Mask Prediction and Visualization Module

This script processes Sentinel-2 satellite imagery bands to predict cloud masks
using the omnicloudmask library. It reads blue, red, green, and near-infrared bands,
resamples them as needed, creates a stacked array for prediction, and visualizes
the cloud mask overlaid on the original RGB image.
"""

import rasterio
import numpy as np
from rasterio.enums import Resampling
from omnicloudmask import predict_from_array
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches

def load_band(file_path, resample=False, target_height=None, target_width=None):
    """
    Load a single band from a raster file with optional resampling.
    
    Args:
        file_path (str): Path to the raster file
        resample (bool): Whether to resample the band
        target_height (int, optional): Target height for resampling
        target_width (int, optional): Target width for resampling
        
    Returns:
        numpy.ndarray: Band data as float32 array
    """
    with rasterio.open(file_path) as src:
        if resample and target_height is not None and target_width is not None:
            band_data = src.read(
                out_shape=(src.count, target_height, target_width),
                resampling=Resampling.bilinear
            )[0].astype(np.float32)
        else:
            band_data = src.read()[0].astype(np.float32)
    
    return band_data

def prepare_input_array(base_path="jp2s/"):
    """
    Prepare a stacked array of satellite bands for cloud mask prediction.
    
    This function loads blue, red, green, and near-infrared bands from Sentinel-2 imagery,
    resamples the NIR band if needed (from 20m to 10m resolution), and stacks the required
    bands for cloud mask prediction in CHW (channel, height, width) format.
    
    Args:
        base_path (str): Base directory containing the JP2 band files
        
    Returns:
        tuple: (stacked_array, rgb_image)
            - stacked_array: numpy.ndarray with bands stacked in CHW format for prediction
            - rgb_image: numpy.ndarray with RGB bands for visualization
    """
    # Define paths to band files
    band_paths = {
        'blue': f"{base_path}B02.jp2",   # Blue band (10m)
        'green': f"{base_path}B03.jp2",  # Green band (10m)
        'red': f"{base_path}B04.jp2",    # Red band (10m)
        'nir': f"{base_path}B8A.jp2"     # Near-infrared band (20m)
    }

    # Get dimensions from red band to use for resampling
    with rasterio.open(band_paths['red']) as src:
        target_height = src.height
        target_width = src.width
    
    # Load bands (resample NIR band to match 10m resolution)
    blue_data = load_band(band_paths['blue'])
    green_data = load_band(band_paths['green'])
    red_data = load_band(band_paths['red'])
    nir_data = load_band(
        band_paths['nir'], 
        resample=True, 
        target_height=target_height, 
        target_width=target_width
    )
    
    # Print band shapes for debugging
    print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}")
    
    # Create RGB image for visualization (scale to 0-1 range)
    # Adjust scaling factor based on your data's bit depth (e.g., 10000 for 16-bit Sentinel-2)
    scale_factor = 10000.0  # Adjust based on your data
    rgb_image = np.stack([
        red_data / scale_factor, 
        green_data / scale_factor, 
        blue_data / scale_factor
    ], axis=-1)
    
    # Clip values to 0-1 range
    rgb_image = np.clip(rgb_image, 0, 1)
    
    # Stack bands in CHW format for cloud mask prediction (red, green, nir)
    prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
    
    return prediction_array, rgb_image

def visualize_cloud_mask(rgb_image, cloud_mask, output_path="cloud_mask_visualization.png"):
    """
    Visualize the cloud mask overlaid on the original RGB image.
    
    Args:
        rgb_image (numpy.ndarray): RGB image array (HWC format)
        cloud_mask (numpy.ndarray): Predicted cloud mask
        output_path (str): Path to save the visualization
    """
    # Fix the cloud mask shape if it has an extra dimension
    if cloud_mask.ndim > 2:
        # Check the shape and squeeze if needed
        print(f"Original cloud mask shape: {cloud_mask.shape}")
        cloud_mask = np.squeeze(cloud_mask)
        print(f"Squeezed cloud mask shape: {cloud_mask.shape}")
    
    # Create figure with two subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot original RGB image
    ax1.imshow(rgb_image)
    ax1.set_title("Original RGB Image")
    ax1.axis('off')
    
    # Define colormap for cloud mask
    # 0=Clear, 1=Thick Cloud, 2=Thin Cloud, 3=Cloud Shadow
    cloud_cmap = ListedColormap(['green', 'red', 'yellow', 'blue'])
    
    # Plot cloud mask
    im = ax2.imshow(cloud_mask, cmap=cloud_cmap, vmin=0, vmax=3)
    ax2.set_title("Cloud Mask")
    ax2.axis('off')
    
    # Create legend patches
    legend_patches = [
        mpatches.Patch(color='green', label='Clear'),
        mpatches.Patch(color='red', label='Thick Cloud'),
        mpatches.Patch(color='yellow', label='Thin Cloud'),
        mpatches.Patch(color='blue', label='Cloud Shadow')
    ]
    ax2.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Plot RGB with semi-transparent cloud mask overlay
    ax3.imshow(rgb_image)
    
    # Create a masked array with transparency
    cloud_mask_rgba = np.zeros((*cloud_mask.shape, 4))
    
    # Set colors with alpha for each class
    cloud_mask_rgba[cloud_mask == 0] = [0, 1, 0, 0.3]    # Clear - green with low opacity
    cloud_mask_rgba[cloud_mask == 1] = [1, 0, 0, 0.5]    # Thick Cloud - red
    cloud_mask_rgba[cloud_mask == 2] = [1, 1, 0, 0.5]    # Thin Cloud - yellow
    cloud_mask_rgba[cloud_mask == 3] = [0, 0, 1, 0.5]    # Cloud Shadow - blue
    
    ax3.imshow(cloud_mask_rgba)
    ax3.set_title("RGB with Cloud Mask Overlay")
    ax3.axis('off')
    
    # Add legend to the overlay plot as well
    ax3.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Visualization saved to {output_path}")

def main():
    """
    Main function to run the cloud mask prediction and visualization workflow.
    """
    # Create input array from satellite bands and get RGB image for visualization
    input_array, rgb_image = prepare_input_array()
    
    # Predict cloud mask using omnicloudmask
    pred_mask = predict_from_array(input_array)
    
    # Print prediction results and shape
    print("Cloud mask prediction results:")
    print(f"Cloud mask shape: {pred_mask.shape}")
    print(f"Unique classes in mask: {np.unique(pred_mask)}")
    
    # Calculate class distribution
    if pred_mask.ndim > 2:
        # Squeeze if needed for counting
        flat_mask = np.squeeze(pred_mask)
    else:
        flat_mask = pred_mask
        
    print(f"Class distribution: Clear: {np.sum(flat_mask == 0)}, Thick Cloud: {np.sum(flat_mask == 1)}, "
          f"Thin Cloud: {np.sum(flat_mask == 2)}, Cloud Shadow: {np.sum(flat_mask == 3)}")
    
    # Visualize the cloud mask on the original image
    visualize_cloud_mask(rgb_image, pred_mask)

if __name__ == "__main__":
    main()