Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------------------ | |
| # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport | |
| # Copyright (c) 2024 Borui Zhang. All Rights Reserved. | |
| # Licensed under the MIT License [see LICENSE for details] | |
| # ------------------------------------------------------------------------------ | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from .preprocessor import normalize_params | |
| class ImageNetDataset(Dataset): | |
| def __init__(self, root, transform=None, convert_to_numpy: bool = True, post_normalize: str = "plain"): | |
| self.root = root | |
| self.transform = transform | |
| self.convert_to_numpy = convert_to_numpy | |
| self.post_normalize = transforms.Normalize( | |
| **normalize_params[post_normalize] | |
| ) | |
| # find classes | |
| classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) | |
| class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} | |
| # make dataset | |
| self.samples = [] | |
| self.extensions = [] | |
| for target_class in sorted(class_to_idx.keys()): | |
| class_idx = class_to_idx[target_class] | |
| target_dir = os.path.join(root, target_class) | |
| if not os.path.isdir(target_dir): | |
| continue | |
| for fname in sorted(os.listdir(target_dir)): | |
| path = os.path.join(target_dir, fname) | |
| item = (path, class_idx) | |
| self.samples.append(item) | |
| ext = path.split(".")[-1] | |
| if ext not in self.extensions: | |
| self.extensions.append(ext) | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, index): | |
| path, label = self.samples[index] | |
| image = Image.open(path) | |
| if not image.mode == "RGB": | |
| image = image.convert("RGB") | |
| if self.convert_to_numpy: | |
| image = np.array(image).astype("uint8") | |
| # image augmentation | |
| image = self.transform(image=image)["image"] | |
| # to tensor and normalize | |
| image = (image / 255).astype(np.float32) | |
| image = torch.from_numpy(image).permute(2, 0, 1) | |
| image = self.post_normalize(image) | |
| return image, label | |