import os import random import torch import imageio.v3 as imageio import numpy as np import skimage.morphology as morph import torchvision.transforms.v2.functional as T_F from skimage.filters import sato from pathlib import Path from scipy.ndimage import zoom from torchvision.datasets.folder import has_file_allowed_extension def make_dataset_t(image_dir, extensions=(".tif", ".tiff")): image_dir = Path(image_dir) images = [ (path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}') for path in sorted(image_dir.iterdir()) if (has_file_allowed_extension(path.name, extensions) and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_'))) ] return images def make_dataset_t_v(image_dir, extensions=(".tif", ".tiff")): image_dir = Path(image_dir) # Use list comprehension for faster filtering images = [ (path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}') for path in sorted(image_dir.iterdir()) if (has_file_allowed_extension(path.name, extensions) and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_'))) ] # Shuffle in place random.shuffle(images) # Calculate split index once split_idx = int(0.95 * len(images)) return images[:split_idx], images[split_idx:] def augmentations(image, label1, label2): if random.random() < 0.5: image, label1, label2 = T_F.vflip(image), T_F.vflip(label1), T_F.vflip(label2) if random.random() < 0.5: image, label1, label2 = T_F.hflip(image), T_F.hflip(label1), T_F.vflip(label2) angles = [90, 180, 270] angle = random.choice(angles) if random.random() < 0.75: image, label1, label2 = T_F.rotate(image, angle), T_F.rotate(label1, angle), T_F.rotate(label2, angle) return image, label1, label2 mean, std = (149.95293407563648, 330.8314960521203) target_water_level_range = [-100, 300] class TrainDataset(torch.utils.data.Dataset): def __init__(self, train_split): self.train_split = train_split def __len__(self): return len(self.train_split) def __getitem__(self, index): pair = self.train_split[index] img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :] img = (img - mean) / std ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16) basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :] water_level = random.randint(*target_water_level_range) basins = (basins >= water_level).to(torch.float16) img, ridge, basins = augmentations(img, ridge, basins) return img, ridge, basins, torch.tensor(water_level, dtype=torch.float16) class ValDataset(torch.utils.data.Dataset): def __init__(self, val_split): self.val_split = val_split def __len__(self): return len(self.val_split) def __getitem__(self, index): pair = self.val_split[index] img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :] img = (img - mean) / std ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16) basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :] target_level = random.randint(*target_water_level_range) basins = (basins >= target_level).to(torch.float16) return img, ridge, basins, torch.tensor(target_level, dtype=torch.float16) if __name__ == '__main__': train_split, val_split = make_dataset_t_v('dataset') train_dataset = TrainDataset(train_split) val_dataset = ValDataset(val_split) print(train_dataset.__getitem__(0)) print(val_dataset.__getitem__(0))