|
|
|
|
| import os |
| import torch.nn as nn |
|
|
| from continuum import ClassIncremental, InstanceIncremental |
| from continuum.datasets import ( |
| CIFAR100, ImageNet100, TinyImageNet200, ImageFolderDataset, Core50, |
| fgvc_aircraft, Caltech101, DTD, EuroSAT, flowers102, food101, |
| MNIST, OxfordPet, SUN397 |
|
|
| ) |
| from .utils import get_dataset_class_names |
|
|
|
|
| class ImageNet1000(ImageFolderDataset): |
| """Continuum dataset for datasetsss with tree-like structure. |
| :param train_folder: The folder of the train data. |
| :param test_folder: The folder of the test data. |
| :param download: Dummy parameter. |
| """ |
|
|
| def __init__( |
| self, |
| data_path: str, |
| train: bool = True, |
| download: bool = False, |
| ): |
| super().__init__(data_path=data_path, train=train, download=download) |
|
|
| def get_data(self): |
| if self.train: |
| self.data_path = os.path.join(self.data_path, "train") |
| else: |
| self.data_path = os.path.join(self.data_path, "val") |
| return super().get_data() |
|
|
|
|
| def get_dataset(cfg, is_train, transforms=None): |
| if cfg.dataset == "cifar100": |
| data_path = os.path.join(cfg.dataset_root, cfg.dataset) |
| dataset = CIFAR100( |
| data_path=data_path, |
| download=True, |
| train=is_train, |
| |
| ) |
| classes_names = dataset.dataset.classes |
|
|
| |
| elif cfg.dataset == "tinyimagenet": |
| |
| data_path = os.path.join(cfg.dataset_root, cfg.dataset) |
| dataset = TinyImageNet200( |
| data_path, |
| train=is_train, |
| download=True |
| ) |
| classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) |
| |
| elif cfg.dataset == "imagenet100": |
| data_path = cfg.dataset_root |
| |
| dataset = ImageNet100( |
| data_path, |
| train=is_train, |
| data_subset=os.path.join('/home/kangborui/ClProject/MoE-Adapters4CL-cross-guild-fusion/cil/dataset_reqs/imagenet100_splits', "train_100.txt" if is_train else "val_100.txt") |
| ) |
| classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) |
|
|
| elif cfg.dataset == "imagenet1000": |
| data_path = os.path.join(cfg.dataset_root, cfg.dataset) |
| dataset = ImageNet1000( |
| data_path, |
| train=is_train |
| ) |
| classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) |
|
|
| elif cfg.dataset == "core50": |
| data_path = os.path.join(cfg.dataset_root, cfg.dataset) |
| dataset = dataset = Core50( |
| data_path, |
| scenario="domains", |
| classification="category", |
| train=is_train |
| ) |
| classes_names = [ |
| "plug adapters", "mobile phones", "scissors", "light bulbs", "cans", |
| "glasses", "balls", "markers", "cups", "remote controls" |
| ] |
| |
| else: |
| ValueError(f"'{cfg.dataset}' is a invalid dataset.") |
|
|
| return dataset, classes_names |
|
|
|
|
| def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module: |
|
|
| dataset, classes_names = get_dataset(cfg, is_train) |
|
|
| if cfg.scenario == "class": |
| scenario = ClassIncremental( |
| dataset, |
| initial_increment=cfg.initial_increment, |
| increment=cfg.increment, |
| transformations=transforms.transforms, |
| class_order=cfg.class_order, |
| ) |
|
|
| elif cfg.scenario == "domain": |
| scenario = InstanceIncremental( |
| dataset, |
| transformations=transforms.transforms, |
| ) |
|
|
| elif cfg.scenario == "task-agnostic": |
| NotImplementedError("Method has not been implemented. Soon be added.") |
|
|
| else: |
| ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, " |
| "please choose from {{'class', 'domain', 'task-agnostic'}}.") |
|
|
| return scenario, classes_names |