File size: 2,356 Bytes
223d932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------

import os
from PIL import Image
import numpy as np

import torch
from torch.utils.data import Dataset
from torchvision import transforms

from .preprocessor import normalize_params

class ImageNetDataset(Dataset):
    def __init__(self, root, transform=None, convert_to_numpy: bool = True, post_normalize: str = "plain"):
        self.root = root
        self.transform = transform
        self.convert_to_numpy = convert_to_numpy
        self.post_normalize = transforms.Normalize(
            **normalize_params[post_normalize]
        )

        # find classes
        classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}

        # make dataset
        self.samples = []
        self.extensions = []
        for target_class in sorted(class_to_idx.keys()):
            class_idx = class_to_idx[target_class]
            target_dir = os.path.join(root, target_class)
            if not os.path.isdir(target_dir):
                continue
            for fname in sorted(os.listdir(target_dir)):
                path = os.path.join(target_dir, fname)
                item = (path, class_idx)
                self.samples.append(item)
                ext = path.split(".")[-1]
                if ext not in self.extensions:
                    self.extensions.append(ext)
    
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, label = self.samples[index]
        image = Image.open(path)
        if not image.mode == "RGB":
            image = image.convert("RGB")
        if self.convert_to_numpy:
            image = np.array(image).astype("uint8")
        # image augmentation
        image = self.transform(image=image)["image"]
        # to tensor and normalize
        image = (image / 255).astype(np.float32)
        image = torch.from_numpy(image).permute(2, 0, 1)
        image = self.post_normalize(image)
        return image, label