# Copyright 2022 Garena Online Private Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ ImageFold function. Mostly copy-paste from torchvision references """ import os import os.path from typing import Any, Callable, Dict, List, Optional, Tuple, cast from PIL import Image from torchvision.datasets.vision import VisionDataset def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: """Checks if a file is an allowed extension. Args: filename (string): path to a file extensions (tuple of strings): extensions to consider (lowercase) Returns: bool: True if the filename ends with one of given extensions """ return filename.lower().endswith(extensions) def is_image_file(filename: str) -> bool: """Checks if a file is an allowed image extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ return has_file_allowed_extension(filename, IMG_EXTENSIONS) def find_classes(directory: str, class_num: int) -> Tuple[List[str], Dict[str, int]]: """Finds the class folders in a dataset. See :class:`DatasetFolder` for details. """ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) if not classes: raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") classes = classes[:class_num] class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx def make_dataset( directory: str, class_to_idx: Optional[Dict[str, int]] = None, extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, class_num=10, ) -> List[Tuple[str, int]]: """Generates a list of samples of a form (path_to_sample, class). See :class:`DatasetFolder` for details. Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function by default. """ directory = os.path.expanduser(directory) if class_to_idx is None: _, class_to_idx = find_classes(directory, class_num) elif not class_to_idx: raise ValueError( "'class_to_index' must have at least one entry to collect any samples." ) both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError( "Both extensions and is_valid_file cannot be None or not None at the same time" ) if extensions is not None: def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) is_valid_file = cast(Callable[[str], bool], is_valid_file) instances = [] available_classes = set() for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) if not os.path.isdir(target_dir): continue for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) if is_valid_file(path): item = path, class_index instances.append(item) if target_class not in available_classes: available_classes.add(target_class) empty_classes = set(class_to_idx.keys()) - available_classes if empty_classes: msg = ( f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " ) if extensions is not None: msg += f"Supported extensions are: {', '.join(extensions)}" raise FileNotFoundError(msg) return instances class DatasetFolder(VisionDataset): """A generic data loader. This default directory structure can be customized by overriding the :meth:`find_classes` method. Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. class_num: how many classes will be loaded Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """ def __init__( self, root: str, loader: Callable[[str], Any], extensions: Optional[Tuple[str, ...]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable[[str], bool]] = None, class_num=10, ) -> None: super(DatasetFolder, self).__init__( root, transform=transform, target_transform=target_transform ) classes, class_to_idx = self.find_classes(self.root, class_num=class_num) samples = self.make_dataset( self.root, class_to_idx, extensions, is_valid_file, class_num=class_num ) self.loader = loader self.extensions = extensions self.classes = classes self.class_to_idx = class_to_idx self.samples = samples self.targets = [s[1] for s in samples] @staticmethod def make_dataset( directory: str, class_to_idx: Dict[str, int], extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, class_num=10, ) -> List[Tuple[str, int]]: """Generates a list of samples of a form (path_to_sample, class). This can be overridden to e.g. read files from a compressed zip file instead of from the disk. Args: directory (str): root dataset directory, corresponding to ``self.root``. class_to_idx (Dict[str, int]): Dictionary mapping class name to class index. extensions (optional): A list of allowed extensions. Either extensions or is_valid_file should be passed. Defaults to None. is_valid_file (optional): A function that takes path of a file and checks if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Defaults to None. class_num: how many classes will be loaded Raises: ValueError: In case ``class_to_idx`` is empty. ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. FileNotFoundError: In case no valid file was found for any class. Returns: List[Tuple[str, int]]: samples of a form (path_to_sample, class) """ if class_to_idx is None: # prevent potential bug since make_dataset() would use the class_to_idx logic of the # find_classes() function, instead of using that of the find_classes() method, which # is potentially overridden and thus could have a different logic. raise ValueError("The class_to_idx parameter cannot be None.") return make_dataset( directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, class_num=class_num, ) def find_classes( self, directory: str, class_num: int ) -> Tuple[List[str], Dict[str, int]]: """Find the class folders in a dataset structured as follows:: directory/ ├── class_x │ ├── xxx.ext │ ├── xxy.ext │ └── ... │ └── xxz.ext └── class_y ├── 123.ext ├── nsdf3.ext └── ... └── asd932_.ext This method can be overridden to only consider a subset of classes, or to adapt to a different dataset directory structure. Args: directory(str): Root directory path, corresponding to ``self.root`` Raises: FileNotFoundError: If ``dir`` has no class folders. Returns: (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. """ return find_classes(directory, class_num=class_num) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) # if self.target_transform is not None: # target = self.target_transform(target) return sample # , target def __len__(self) -> int: return len(self.samples) IMG_EXTENSIONS = ( ".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp", ) def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") # TODO: specify the return type def accimage_loader(path: str) -> Any: import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path: str) -> Any: from torchvision import get_image_backend if get_image_backend() == "accimage": return accimage_loader(path) else: return pil_loader(path) class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way by default: :: root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png This class inherits from :class:`~torchvision.datasets.DatasetFolder` so the same methods can be overridden to customize the dataset. Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) class_num: how many classes will be loaded Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, class_num=10, ): super(ImageFolder, self).__init__( root, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file, class_num=class_num, ) self.imgs = self.samples