Timerns's picture
Upload folder using huggingface_hub
984cdba verified
import os
import random
import numpy as np
from glob import glob
import rasterio
from rasterio.windows import Window
import torch
from torch.utils.data import Dataset
# -----------------------------
# Utility Functions
# -----------------------------
def normalize_band(img, mean, std):
"""Min-max normalization using mean ± 2sigma ."""
min_v = mean - 2 * std
max_v = mean + 2 * std
img = (img - min_v) / (max_v - min_v + 1e-6)
return np.clip(img, 0, 1).astype(np.float32)
class GeoAugment:
"""Random flips + 90° rotations."""
def __init__(self, rotate=True, flip=True):
self.rotate = rotate
self.flip = flip
def __call__(self, x, y):
# Horizontal flip
if self.flip and random.random() < 0.5:
x = np.flip(x, axis=2).copy()
y = np.flip(y, axis=2).copy()
# Vertical flip
if self.flip and random.random() < 0.5:
x = np.flip(x, axis=1).copy()
y = np.flip(y, axis=1).copy()
# Rotations
if self.rotate:
k = random.choice([0, 1, 2, 3])
if k > 0:
x = np.rot90(x, k, axes=(1, 2)).copy()
y = np.rot90(y, k, axes=(1, 2)).copy()
return x, y
# -----------------------------
# Dataset Class
# -----------------------------
class SatellitePatchDataset(Dataset):
"""
Multi-modal satellite dataset loader (S1, S2, DEM).
Train/val/test split must be performed by selecting `locations`.
"""
def __init__(
self,
root,
locations,
patch_size=256,
stride=None,
skip_empty=True,
empty_tile_ratio=0.0,
task='segmentation',
dates=None,
masking_ratio=0.5,
transform=None,
band_stats=None, # { 'S1': {mean:[], std:[]}, 'S2': {...}, 'DEM': {...} }
ch_s1=[0, 1], # chanell 0 is VV, channel 1 is VH
ch_s2=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # channel 0-11 are B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, NDWI, NDSI
ch_dem=[0], # channel 0 is elevation
ch_hillshade=[0], # channel 0 is hillshade
ch_cloudmask=[0],
):
self.root = root
self.locations = locations
self.patch_size = patch_size
self.stride = stride or patch_size
self.skip_empty = skip_empty
self.empty_tile_ratio = empty_tile_ratio
self.task = task
self.masking_ratio = masking_ratio
self.transform = transform
self.dates = dates
self.ch_s1 = ch_s1
self.ch_s2 = ch_s2
self.ch_dem = ch_dem
self.ch_hillshade = ch_hillshade
self.ch_cloudmask = ch_cloudmask
self.band_stats = band_stats
self.samples = [] # (date, location)
self.patch_index = [] # (sample_id, x, y)
if task not in ['segmentation', 'mae']:
raise ValueError(f"Unsupported task: {task}")
self._discover_samples()
self._index_patches()
# ---------------------------------------------------------
# Scan dataset and find all valid (date, location) pairs
# ---------------------------------------------------------
def _discover_samples(self):
for loc in self.locations:
loc_dir = os.path.join(self.root, loc)
mask_paths = sorted(glob(os.path.join(loc_dir, "*_lake_mask.tif")))
mask_paths = [p for p in mask_paths if os.path.basename(p).split("_")[0] in self.dates] if self.dates is not None else mask_paths
for path in mask_paths:
basename = os.path.basename(path)
date = basename.split("_")[0]
if self.ch_s1 is not None and len(self.ch_s1) > 0:
s1_path = os.path.join(loc_dir, f"{date}_{loc}_s1.tif")
if not os.path.exists(s1_path):
continue
if self.ch_s2 is not None and len(self.ch_s2) > 0:
s2_path = os.path.join(loc_dir, f"{date}_{loc}_s2.tif")
if not os.path.exists(s2_path):
continue
if self.ch_dem is not None and len(self.ch_dem) > 0:
dem_path = os.path.join(loc_dir, f"{loc}_dem.tif")
if not os.path.exists(dem_path):
continue
if self.ch_hillshade is not None and len(self.ch_hillshade) > 0:
hillshade_path = os.path.join(loc_dir, f"{date}_{loc}_hillshade.tif")
if not os.path.exists(hillshade_path):
continue
if self.ch_cloudmask is not None and len(self.ch_cloudmask) > 0:
cloudmask_path = os.path.join(loc_dir, f"{date}_{loc}_cloud_mask.tif")
if not os.path.exists(cloudmask_path):
continue
self.samples.append((date, loc))
# ---------------------------------------------------------
# Build patch index (optionally skip empty mask)
# ---------------------------------------------------------
def _index_patches(self):
for i, (date, loc) in enumerate(self.samples):
mask_path = os.path.join(self.root, loc, f"{date}_{loc}_lake_mask.tif")
ordered_patches = []
empty_indices = []
with rasterio.open(mask_path) as msk:
H, W = msk.height, msk.width
for y in range(0, H, self.stride):
for x in range(0, W, self.stride):
patch = msk.read(
1,
window=Window(x, y, self.patch_size, self.patch_size),
boundless=True,
fill_value=0,
)
is_empty = np.all(patch == 0)
ordered_patches.append((i, x, y, is_empty))
if is_empty:
empty_indices.append(len(ordered_patches) - 1)
# Decide which empty patches to keep
keep_empty = set()
if not self.skip_empty:
if self.empty_tile_ratio > 0:
k = int((len(ordered_patches) - len(empty_indices)) * self.empty_tile_ratio)
k = min(k, len(empty_indices))
keep_empty = set(empty_indices[:k]) # deterministic
else:
keep_empty = set(empty_indices)
for i, (idx, x, y, is_empty) in enumerate(ordered_patches):
if is_empty and i not in keep_empty:
continue
self.patch_index.append((idx, x, y))
def reconstruct_image(self, patches, sample_id):
"""Reconstruct full image from patches for a given sample_id."""
date, loc = self.samples[sample_id]
loc_dir = os.path.join(self.root, loc)
# Load one band to get image size
with rasterio.open(os.path.join(loc_dir, f"{date}_{loc}_lake_mask.tif")) as src:
H, W = src.height + self.patch_size - 1, src.width + self.patch_size - 1
full_image = np.zeros((patches.shape[1], H, W), dtype=patches.dtype)
count_image = np.zeros((H, W), dtype=np.float32)
patch_idx = 0
for y in range(0, H - self.patch_size + 1, self.stride):
for x in range(0, W - self.patch_size + 1, self.stride):
if patch_idx >= patches.shape[0]:
break
full_image[:, y:y+self.patch_size, x:x+self.patch_size] += patches[patch_idx]
count_image[y:y+self.patch_size, x:x+self.patch_size] += 1.0
patch_idx += 1
count_image[count_image == 0] = 1.0 # avoid division by zero
full_image /= count_image[None, :, :]
return full_image[:, :src.height, :src.width]
# ---------------------------------------------------------
# PyTorch Dataset API
# ---------------------------------------------------------
def __len__(self):
return len(self.patch_index)
def __getitem__(self, idx):
sample_id, x0, y0 = self.patch_index[idx]
date, loc = self.samples[sample_id]
loc_dir = os.path.join(self.root, loc)
window = Window(x0, y0, self.patch_size, self.patch_size)
# ---------------------------
# Load modalities
# ---------------------------
channels = []
if self.ch_s1 is not None and len(self.ch_s1) > 0:
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_s1.tif"), "S1", self.ch_s1, window))
if self.ch_s2 is not None and len(self.ch_s2) > 0:
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_s2.tif"), "S2", self.ch_s2, window))
if self.ch_dem is not None and len(self.ch_dem) > 0:
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{loc}_dem.tif"), "DEM", self.ch_dem, window))
if self.ch_hillshade is not None and len(self.ch_hillshade) > 0:
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_hillshade.tif"), "Hillshade", self.ch_hillshade, window))
if self.ch_cloudmask is not None and len(self.ch_cloudmask) > 0:
channels.append(self._load_and_normalize(os.path.join(loc_dir, f"{date}_{loc}_cloud_mask.tif"), "Cloudmask", self.ch_cloudmask, window))
x = np.concatenate(channels, axis=0)
# Mask
mask_path = os.path.join(loc_dir, f"{date}_{loc}_lake_mask.tif")
with rasterio.open(mask_path) as src:
y = src.read(1, window=window, boundless=True, fill_value=0).astype(np.float32)[None, ...]
y = (y > 0).astype(np.float32)
# Apply augmentation
if self.transform:
x, y = self.transform(x, y)
if self.task == 'segmentation':
return torch.from_numpy(x), torch.from_numpy(y)
if self.task == 'mae':
B, H, W = x.shape
# mask the image with the given ratio
mask_size = 8
num_patches = (H // mask_size) * (W // mask_size)
num_masked = int(num_patches * self.masking_ratio)
mask = np.hstack([
np.ones(num_masked, dtype=np.float32),
np.zeros(num_patches - num_masked, dtype=np.float32),
])
np.random.shuffle(mask)
mask = mask.reshape(H // mask_size, W // mask_size)
mask = np.kron(mask, np.ones((mask_size, mask_size), dtype=np.float32)) # Upsample to pixel level
masked_image = x * (1 - mask[None, :, :])
return torch.from_numpy(masked_image), torch.from_numpy(x)
return torch.from_numpy(x), torch.from_numpy(y)
# ---------------------------------------------------------
# Loading with optional per-band normalization
# ---------------------------------------------------------
def _load_and_normalize(self, path, key, channels, window):
with rasterio.open(path) as src:
arr = src.read([c + 1 for c in channels], window=window, boundless=True, fill_value=0).astype(np.float32)
arr[~np.isfinite(arr)] = 0
if self.band_stats and key in self.band_stats:
means = [self.band_stats[key]["mean"][c] for c in channels]
stds = [self.band_stats[key]["std"][c] for c in channels]
for i in range(arr.shape[0]):
arr[i] = normalize_band(arr[i], means[i], stds[i])
return arr