# This file is copied from https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/data.py # MIT License # Copyright (c) 2022 Lorenzo Breschi # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. from typing import Callable, Dict import torch from torch.utils.data import Dataset import torchvision.transforms.functional as F from torchvision import transforms import pytorch_lightning as pl from collections.abc import Iterable # image reader writer from pathlib import Path from PIL import Image from typing import Tuple def read_image(filepath: Path, mode: str = None) -> Image: with open(filepath, 'rb') as file: image = Image.open(file) return image.convert(mode) image2tensor = transforms.ToTensor() tensor2image = transforms.ToPILImage() def write_image(image: Image, filepath: Path): filepath.parent.mkdir(parents=True, exist_ok=True) image.save(str(filepath)) def read_image_tensor(filepath: Path, mode: str = 'RGB') -> torch.Tensor: return image2tensor(read_image(filepath, mode)) def write_image_tensor(input: torch.Tensor, filepath: Path): write_image(tensor2image(input), filepath) def get_valid_indices(H: int, W: int, patch_size: int, random_overlap: int = 0): vih = torch.arange(random_overlap, H-patch_size - random_overlap+1, patch_size) viw = torch.arange(random_overlap, W-patch_size - random_overlap+1, patch_size) if random_overlap > 0: rih = torch.randint_like(vih, -random_overlap, random_overlap) riw = torch.randint_like(viw, -random_overlap, random_overlap) vih += rih viw += riw vi = torch.stack(torch.meshgrid(vih, viw)).view(2, -1).t() return vi def cut_patches(input: torch.Tensor, indices: Tuple[Tuple[int, int]], patch_size: int, padding: int = 0): # TODO use slices to get all patches at the same time ? patches_l = [] for n in range(len(indices)): patch = F.crop(input, *(indices[n]-padding), *(patch_size+padding*2,)*2) patches_l.append(patch) patches = torch.cat(patches_l, dim=0) return patches def prepare_data(data_path: Path, read_func: Callable = read_image_tensor) -> Dict: """ Takes a data_path of a folder which contains subfolders with input, target, etc. lablelled by the same names. :param data_path: Path of the folder containing data :param read_func: function that reads data and returns a tensor """ data_dict = {} subdir_names = ["target", "input", "mask"] # ,"helper" # checks only files for which there is an target # TODO check for images name_ls = [file.name for file in ( data_path / "target").iterdir() if file.is_file()] subdirs = [data_path / sdn for sdn in subdir_names] for sd in subdirs: if sd.is_dir(): data_ls = [] files = [sd / name for name in name_ls] for file in files: tensor = read_func(file) H, W = tensor.shape[-2:] data_ls.append(tensor) # TODO check that all sizes match data_dict[sd.name] = torch.stack(data_ls, dim=0) data_dict['name'] = name_ls data_dict['len'] = len(data_dict['name']) data_dict['H'] = H data_dict['W'] = W return data_dict # TODO an image is loaded whenever a patch is needed, this may be a bottleneck class DataDictLoader(): def __init__(self, data_dict: Dict, batch_size: int = 16, max_length: int = 128, shuffle: bool = False): """ """ self.batch_size = batch_size self.shuffle = shuffle self.batch_size = batch_size self.data_dict = data_dict self.dataset_len = data_dict['len'] self.len = self.dataset_len if max_length is None else min( self.dataset_len, max_length) # Calculate # batches num_batches, remainder = divmod(self.len, self.batch_size) if remainder > 0: num_batches += 1 self.num_batches = num_batches def __iter__(self): if self.shuffle: r = torch.randperm(self.dataset_len) self.data_dict = {k: v[r] if isinstance( v, Iterable) else v for k, v in self.data_dict.items()} self.i = 0 return self def __next__(self): if self.i >= self.len: raise StopIteration batch = {k: v[self.i:self.i+self.batch_size] if isinstance(v, Iterable) else v for k, v in self.data_dict.items()} self.i += self.batch_size return batch def __len__(self): return self.num_batches class PatchDataModule(pl.LightningDataModule): def __init__(self, data_dict, patch_size: int = 2**5, batch_size: int = 2**4, patch_num: int = 2**6): super().__init__() self.data_dict = data_dict self.H, self.W = data_dict['H'], data_dict['W'] self.len = data_dict['len'] self.batch_size = batch_size self.patch_size = patch_size self.patch_num = patch_num def dataloader(self, data_dict, **kwargs): return DataDictLoader(data_dict, **kwargs) def train_dataloader(self): patches = self.cut_patches() return self.dataloader(patches, batch_size=self.batch_size, shuffle=True, max_length=self.patch_num) def val_dataloader(self): return self.dataloader(self.data_dict, batch_size=1) def test_dataloader(self): return self.dataloader(self.data_dict) # TODO batch size def cut_patches(self): # TODO cycle once patch_indices = get_valid_indices( self.H, self.W, self.patch_size, self.patch_size//4) dd = {k: cut_patches( v, patch_indices, self.patch_size) for k, v in self.data_dict.items() if isinstance(v, torch.Tensor) } threshold = 0.1 mask_p = torch.mean( dd.get('mask', torch.ones_like(dd['input'])), dim=(-1, -2, -3)) masked_idx = (mask_p > threshold).nonzero(as_tuple=True)[0] dd = {k: v[masked_idx] for k, v in dd.items()} dd['len'] = len(masked_idx) dd['H'], dd['W'] = (self.patch_size,)*2 return dd class ImageDataset(Dataset): def __init__(self, file_paths: Iterable, read_func: Callable = read_image_tensor): self.file_paths = file_paths def __getitem__(self, idx: int) -> dict: file = self.file_paths[idx] return read_image_tensor(file), file.name def __len__(self) -> int: return len(self.file_paths)