# Copyright (c) EPFL VILAB. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # Based on BEiT, timm, DINO DeiT and MAE-priv code bases # https://github.com/microsoft/unilm/tree/master/beit # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/facebookresearch/deit # https://github.com/facebookresearch/dino # https://github.com/BUPT-PRIV/MAE-priv # -------------------------------------------------------- import os import os.path import random from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, cast import numpy as np import torch from PIL import Image from torchvision.datasets.vision import VisionDataset def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: """Checks if a file is an allowed extension. Args: filename (string): path to a file extensions (tuple of strings): extensions to consider (lowercase) Returns: bool: True if the filename ends with one of given extensions """ return filename.lower().endswith(extensions) def is_image_file(filename: str) -> bool: """Checks if a file is an allowed image extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ return has_file_allowed_extension(filename, IMG_EXTENSIONS) def make_dataset( directory: str, class_to_idx: Dict[str, int], extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: instances = [] directory = os.path.expanduser(directory) both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") if extensions is not None: def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) is_valid_file = cast(Callable[[str], bool], is_valid_file) for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) if not os.path.isdir(target_dir): continue for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) if is_valid_file(path): item = path, class_index instances.append(item) return instances class DatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt logs) both extensions and is_valid_file should not be passed. Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """ def __init__( self, root: str, loader: Callable[[str], Any], extensions: Optional[Tuple[str, ...]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) classes, class_to_idx = self._find_classes(self.root) samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) if len(samples) == 0: msg = "Found 0 logs in subfolders of: {}\n".format(self.root) if extensions is not None: msg += "Supported extensions are: {}".format(",".join(extensions)) raise RuntimeError(msg) self.loader = loader self.extensions = extensions self.classes = classes self.class_to_idx = class_to_idx self.samples = samples self.targets = [s[1] for s in samples] def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: """ Finds the class folders in a dataset. Args: dir (string): Root directory path. Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. Ensures: No class is a subdirectory of another. """ classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ while True: try: path, target = self.samples[index] sample = self.loader(path) break except Exception as e: print(e) index = random.randint(0, len(self.samples) - 1) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self) -> int: return len(self.samples) class MultiTaskDatasetFolder(VisionDataset): """A generic multi-task dataset loader where the samples are arranged in this way: :: root/task_a/class_x/xxx.ext root/task_a/class_y/xxy.ext root/task_a/class_z/xxz.ext root/task_b/class_x/xxx.ext root/task_b/class_y/xxy.ext root/task_b/class_z/xxz.ext Args: root (string): Root directory path. tasks (list): List of tasks as strings loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt logs) both extensions and is_valid_file should not be passed. Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """ def __init__( self, root: str, tasks: List[str], loader: Callable[[str], Any], extensions: Optional[Tuple[str, ...]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable[[str], bool]] = None, prefixes: Optional[Dict[str,str]] = None, max_images: Optional[int] = None ) -> None: super(MultiTaskDatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) self.tasks = tasks classes, class_to_idx = self._find_classes(os.path.join(self.root, self.tasks[0])) prefixes = {} if prefixes is None else prefixes prefixes.update({task: '' for task in tasks if task not in prefixes}) samples = { task: make_dataset(os.path.join(self.root, f'{prefixes[task]}{task}'), class_to_idx, extensions, is_valid_file) for task in self.tasks } for task, task_samples in samples.items(): if len(task_samples) == 0: msg = "Found 0 logs in subfolders of: {}\n".format(os.path.join(self.root, task)) if extensions is not None: msg += "Supported extensions are: {}".format(",".join(extensions)) raise RuntimeError(msg) self.loader = loader self.extensions = extensions self.classes = classes self.class_to_idx = class_to_idx self.samples = samples # self.targets = [s[1] for s in list(samples.values())[0]] # Select random subset of dataset if so specified if isinstance(max_images, int): total_samples = len(list(self.samples.values())[0]) np.random.seed(0) permutation = np.random.permutation(total_samples) for task in samples: self.samples[task] = [self.samples[task][i] for i in permutation][:max_images] self.cache = {} def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: """ Finds the class folders in a dataset. Args: dir (string): Root directory path. Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. Ensures: No class is a subdirectory of another. """ classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ if index in self.cache: sample_dict, target = deepcopy(self.cache[index]) else: sample_dict = {} for task in self.tasks: path, target = self.samples[task][index] sample = pil_loader(path, convert_rgb=(task=='rgb')) sample_dict[task] = sample # self.cache[index] = deepcopy((sample_dict, target)) if self.transform is not None: sample_dict = self.transform(sample_dict) if self.target_transform is not None: target = self.target_transform(target) return sample_dict, target def __len__(self) -> int: return len(list(self.samples.values())[0]) IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp', '.jpx') def pil_loader(path: str, convert_rgb=True) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) # with open(path, 'rb') as f: # img = Image.open(f) img = Image.open(path) return img.convert('RGB') if convert_rgb else img # TODO: specify the return type def accimage_loader(path: str) -> Any: import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path: str) -> Any: from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt logs) Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ): super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file) self.imgs = self.samples class MultiTaskImageFolder(MultiTaskDatasetFolder): """A generic multi-task dataset loader where the images are arranged in this way: :: root/task_a/class_x/xxx.ext root/task_a/class_y/xxy.ext root/task_a/class_z/xxz.ext root/task_b/class_x/xxx.ext root/task_b/class_y/xxy.ext root/task_b/class_z/xxz.ext Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt logs) Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, tasks: List[str], transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = pil_loader, is_valid_file: Optional[Callable[[str], bool]] = None, prefixes: Optional[Dict[str,str]] = None, max_images: Optional[int] = None ): super(MultiTaskImageFolder, self).__init__(root, tasks, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file, prefixes=prefixes, max_images=max_images) self.imgs = self.samples