|
|
import os
|
|
|
import random
|
|
|
import numpy as np
|
|
|
from glob import glob
|
|
|
import rasterio
|
|
|
from rasterio.windows import Window
|
|
|
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_band(img, mean, std):
|
|
|
"""Min-max normalization using mean ± 2sigma ."""
|
|
|
min_v = mean - 2 * std
|
|
|
max_v = mean + 2 * std
|
|
|
img = (img - min_v) / (max_v - min_v + 1e-6)
|
|
|
return np.clip(img, 0, 1).astype(np.float32)
|
|
|
|
|
|
|
|
|
class GeoAugment:
|
|
|
"""Random flips + 90° rotations."""
|
|
|
def __init__(self, rotate=True, flip=True):
|
|
|
self.rotate = rotate
|
|
|
self.flip = flip
|
|
|
|
|
|
def __call__(self, x, y):
|
|
|
|
|
|
if self.flip and random.random() < 0.5:
|
|
|
x = np.flip(x, axis=2).copy()
|
|
|
y = np.flip(y, axis=2).copy()
|
|
|
|
|
|
|
|
|
if self.flip and random.random() < 0.5:
|
|
|
x = np.flip(x, axis=1).copy()
|
|
|
y = np.flip(y, axis=1).copy()
|
|
|
|
|
|
|
|
|
if self.rotate:
|
|
|
k = random.choice([0, 1, 2, 3])
|
|
|
if k > 0:
|
|
|
x = np.rot90(x, k, axes=(1, 2)).copy()
|
|
|
y = np.rot90(y, k, axes=(1, 2)).copy()
|
|
|
|
|
|
return x, y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SatellitePatchDataset(Dataset):
|
|
|
"""
|
|
|
Multi-modal satellite dataset loader (S1, S2, DEM).
|
|
|
Train/val/test split must be performed by selecting `locations`.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
root,
|
|
|
locations,
|
|
|
patch_size=256,
|
|
|
stride=None,
|
|
|
skip_empty=True,
|
|
|
empty_tile_ratio=0.0,
|
|
|
task='segmentation',
|
|
|
dates=None,
|
|
|
masking_ratio=0.5,
|
|
|
transform=None,
|
|
|
band_stats=None,
|
|
|
ch_s1=[0, 1],
|
|
|
ch_s2=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
|
|
ch_dem=[0],
|
|
|
ch_hillshade=[0],
|
|
|
ch_cloudmask=[0],
|
|
|
):
|
|
|
self.root = root
|
|
|
self.locations = locations
|
|
|
self.patch_size = patch_size
|
|
|
self.stride = stride or patch_size
|
|
|
self.skip_empty = skip_empty
|
|
|
self.empty_tile_ratio = empty_tile_ratio
|
|
|
self.task = task
|
|
|
self.masking_ratio = masking_ratio
|
|
|
self.transform = transform
|
|
|
self.dates = dates
|
|
|
|
|
|
self.ch_s1 = ch_s1
|
|
|
self.ch_s2 = ch_s2
|
|
|
self.ch_dem = ch_dem
|
|
|
self.ch_hillshade = ch_hillshade
|
|
|
self.ch_cloudmask = ch_cloudmask
|
|
|
self.band_stats = band_stats
|
|
|
|
|
|
self.samples = []
|
|
|
self.patch_index = []
|
|
|
|
|
|
if task not in ['segmentation', 'mae']:
|
|
|
raise ValueError(f"Unsupported task: {task}")
|
|
|
|
|
|
self._discover_samples()
|
|
|
self._index_patches()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _discover_samples(self):
|
|
|
for loc in self.locations:
|
|
|
loc_dir = os.path.join(self.root, loc)
|
|
|
|
|
|
mask_paths = sorted(glob(os.path.join(loc_dir, "*_lake_mask.tif")))
|
|
|
|
|
|
mask_paths = [p for p in mask_paths if os.path.basename(p).split("_")[0] in self.dates] if self.dates is not None else mask_paths
|
|
|
for path in mask_paths:
|
|
|
basename = os.path.basename(path)
|
|
|
date = basename.split("_")[0]
|
|
|
|
|
|
if self.ch_s1 is not None and len(self.ch_s1) > 0:
|
|
|
s1_path = os.path.join(loc_dir, f"{date}_{loc}_s1.tif")
|
|
|
if not os.path.exists(s1_path):
|
|
|
continue
|
|
|
if self.ch_s2 is not None and len(self.ch_s2) > 0:
|
|
|
s2_path = os.path.join(loc_dir, f"{date}_{loc}_s2.tif")
|
|
|
if not os.path.exists(s2_path):
|
|
|
continue
|
|
|
if self.ch_dem is not None and len(self.ch_dem) > 0:
|
|
|
dem_path = os.path.join(loc_dir, f"{loc}_dem.tif")
|
|
|
if not os.path.exists(dem_path):
|
|
|
continue
|
|
|
if self.ch_hillshade is not None and len(self.ch_hillshade) > 0:
|
|
|
hillshade_path = os.path.join(loc_dir, f"{date}_{loc}_hillshade.tif")
|
|
|
if not os.path.exists(hillshade_path):
|
|
|
continue
|
|
|
if self.ch_cloudmask is not None and len(self.ch_cloudmask) > 0:
|
|
|
cloudmask_path = os.path.join(loc_dir, f"{date}_{loc}_cloud_mask.tif")
|
|
|
if not os.path.exists(cloudmask_path):
|
|
|
continue
|
|
|
self.samples.append((date, loc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _index_patches(self):
|
|
|
for i, (date, loc) in enumerate(self.samples):
|
|
|
mask_path = os.path.join(self.root, loc, f"{date}_{loc}_lake_mask.tif")
|
|
|
|
|
|
ordered_patches = []
|
|
|
empty_indices = []
|
|
|
|
|
|
with rasterio.open(mask_path) as msk:
|
|
|
H, W = msk.height, msk.width
|
|
|
|
|
|
for y in range(0, H, self.stride):
|
|
|
for x in range(0, W, self.stride):
|
|
|
patch = msk.read(
|
|
|
1,
|
|
|
window=Window(x, y, self.patch_size, self.patch_size),
|
|
|
boundless=True,
|
|
|
fill_value=0,
|
|
|
)
|
|
|
|
|
|
is_empty = np.all(patch == 0)
|
|
|
ordered_patches.append((i, x, y, is_empty))
|
|
|
|
|
|
if is_empty:
|
|
|
empty_indices.append(len(ordered_patches) - 1)
|
|
|
|
|
|
|
|
|
keep_empty = set()
|
|
|
|
|
|
if not self.skip_empty:
|
|
|
if self.empty_tile_ratio > 0:
|
|
|
k = int((len(ordered_patches) - len(empty_indices)) * self.empty_tile_ratio)
|
|
|
k = min(k, len(empty_indices))
|
|
|
keep_empty = set(empty_indices[:k])
|
|
|
else:
|
|
|
keep_empty = set(empty_indices)
|
|
|
|
|
|
for i, (idx, x, y, is_empty) in enumerate(ordered_patches):
|
|
|
if is_empty and i not in keep_empty:
|
|
|
continue
|
|
|
self.patch_index.append((idx, x, y))
|
|
|
|
|
|
def reconstruct_image(self, patches, sample_id):
|
|
|
"""Reconstruct full image from patches for a given sample_id."""
|
|
|
date, loc = self.samples[sample_id]
|
|
|
loc_dir = os.path.join(self.root, loc)
|
|
|
|
|
|
|
|
|
with rasterio.open(os.path.join(loc_dir, f"{date}_{loc}_lake_mask.tif")) as src:
|
|
|
H, W = src.height + self.patch_size - 1, src.width + self.patch_size - 1
|
|
|
|
|
|
full_image = np.zeros((patches.shape[1], H, W), dtype=patches.dtype)
|
|
|
count_image = np.zeros((H, W), dtype=np.float32)
|
|
|
|
|
|
patch_idx = 0
|
|
|
for y in range(0, H - self.patch_size + 1, self.stride):
|
|
|
for x in range(0, W - self.patch_size + 1, self.stride):
|
|
|
if patch_idx >= patches.shape[0]:
|
|
|
break
|
|
|
full_image[:, y:y+self.patch_size, x:x+self.patch_size] += patches[patch_idx]
|
|
|
count_image[y:y+self.patch_size, x:x+self.patch_size] += 1.0
|
|
|
patch_idx += 1
|
|
|
|
|
|
count_image[count_image == 0] = 1.0
|
|
|
full_image /= count_image[None, :, :]
|
|
|
|
|
|
return full_image[:, :src.height, :src.width]
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.patch_index)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
sample_id, x0, y0 = self.patch_index[idx]
|
|
|
date, loc = self.samples[sample_id]
|
|
|
loc_dir = os.path.join(self.root, loc)
|
|
|
|
|
|
window = Window(x0, y0, self.patch_size, self.patch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
channels = []
|
|
|
if self.ch_s1 is not None and len(self.ch_s1) > 0:
|
|
|
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_s1.tif"), "S1", self.ch_s1, window))
|
|
|
|
|
|
if self.ch_s2 is not None and len(self.ch_s2) > 0:
|
|
|
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_s2.tif"), "S2", self.ch_s2, window))
|
|
|
|
|
|
if self.ch_dem is not None and len(self.ch_dem) > 0:
|
|
|
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{loc}_dem.tif"), "DEM", self.ch_dem, window))
|
|
|
|
|
|
if self.ch_hillshade is not None and len(self.ch_hillshade) > 0:
|
|
|
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_hillshade.tif"), "Hillshade", self.ch_hillshade, window))
|
|
|
|
|
|
if self.ch_cloudmask is not None and len(self.ch_cloudmask) > 0:
|
|
|
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_cloud_mask.tif"), "Cloudmask", self.ch_cloudmask, window))
|
|
|
|
|
|
x = np.concatenate(channels, axis=0)
|
|
|
|
|
|
|
|
|
mask_path = os.path.join(loc_dir, f"{date}_{loc}_lake_mask.tif")
|
|
|
with rasterio.open(mask_path) as src:
|
|
|
y = src.read(1, window=window, boundless=True, fill_value=0).astype(np.float32)[None, ...]
|
|
|
y = (y > 0).astype(np.float32)
|
|
|
|
|
|
|
|
|
if self.transform:
|
|
|
x, y = self.transform(x, y)
|
|
|
|
|
|
if self.task == 'segmentation':
|
|
|
return torch.from_numpy(x), torch.from_numpy(y)
|
|
|
|
|
|
if self.task == 'mae':
|
|
|
B, H, W = x.shape
|
|
|
|
|
|
mask_size = 8
|
|
|
num_patches = (H // mask_size) * (W // mask_size)
|
|
|
num_masked = int(num_patches * self.masking_ratio)
|
|
|
mask = np.hstack([
|
|
|
np.ones(num_masked, dtype=np.float32),
|
|
|
np.zeros(num_patches - num_masked, dtype=np.float32),
|
|
|
])
|
|
|
np.random.shuffle(mask)
|
|
|
mask = mask.reshape(H // mask_size, W // mask_size)
|
|
|
mask = np.kron(mask, np.ones((mask_size, mask_size), dtype=np.float32))
|
|
|
|
|
|
masked_image = x * (1 - mask[None, :, :])
|
|
|
|
|
|
return torch.from_numpy(masked_image), torch.from_numpy(x)
|
|
|
|
|
|
return torch.from_numpy(x), torch.from_numpy(y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_and_normalize(self, path, key, channels, window):
|
|
|
with rasterio.open(path) as src:
|
|
|
arr = src.read([c + 1 for c in channels], window=window, boundless=True, fill_value=0).astype(np.float32)
|
|
|
|
|
|
arr[~np.isfinite(arr)] = 0
|
|
|
|
|
|
if self.band_stats and key in self.band_stats:
|
|
|
means = [self.band_stats[key]["mean"][c] for c in channels]
|
|
|
stds = [self.band_stats[key]["std"][c] for c in channels]
|
|
|
|
|
|
for i in range(arr.shape[0]):
|
|
|
arr[i] = normalize_band(arr[i], means[i], stds[i])
|
|
|
|
|
|
return arr
|
|
|
|