Ttius's picture
Upload 192 files
998bb30 verified
import numpy as np
import torch
from torch.utils.data import Dataset
from anti_kd_backdoor.data.dataset.base import (
DatasetInterface,
IndexDataset,
IndexRatioDataset,
PoisonLabelDataset,
RangeRatioDataset,
RangeRatioPoisonLabelDataset,
RatioDataset,
RatioPoisonLabelDataset,
)
from anti_kd_backdoor.data.dataset.types import XY_TYPE
class FakeDataset(DatasetInterface, Dataset):
cache: dict[tuple, tuple[torch.Tensor, list[int]]] = dict()
def __init__(self,
*,
x_shape: tuple[int, int, int] = (3, 32, 32),
y_range: tuple[int, int] = (0, 9),
nums: int = 10000,
cache_xy: bool = False) -> None:
self._nums = nums
self._x_shape = x_shape
self._y_range = y_range
self._raw_num_classes = y_range[1] - y_range[0] + 1
if cache_xy:
cache_key = (x_shape, y_range, nums)
if cache_key not in FakeDataset.cache:
x, y = self._prepare_xy()
FakeDataset.cache[cache_key] = (x, y)
x, y = FakeDataset.cache[cache_key]
else:
x, y = self._prepare_xy()
self.data, self.targets = x, y
def __len__(self) -> int:
return len(self.targets)
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
x = self.data[index]
y = self.targets[index]
return x, y
def get_xy(self) -> XY_TYPE:
return list(self.data), self.targets.copy()
def set_xy(self, xy: XY_TYPE) -> None:
x, y = xy
assert len(x) == len(y)
self.data = np.stack(x, axis=0)
self.targets = y.copy()
@property
def num_classes(self) -> int:
return len(set(self.targets))
@property
def raw_num_classes(self) -> int:
return self._raw_num_classes
def _prepare_xy(self) -> tuple[torch.Tensor, list[int]]:
x = torch.rand((self._nums, *self._x_shape))
num_per_class = self._nums // self._raw_num_classes
y = [
i for _ in range(num_per_class)
for i in range(self._raw_num_classes)
]
y.extend([self._y_range[0]] * (self._nums - len(y)))
return x, y
class IndexFakeDataset(FakeDataset, IndexDataset):
def __init__(self, *, start_idx: int, end_idx: int, **kwargs) -> None:
FakeDataset.__init__(self, **kwargs)
IndexDataset.__init__(self, start_idx=start_idx, end_idx=end_idx)
class RatioFakeDataset(FakeDataset, RatioDataset):
def __init__(self, *, ratio: float, **kwargs) -> None:
FakeDataset.__init__(self, **kwargs)
RatioDataset.__init__(self, ratio=ratio)
class RangeRatioFakeDataset(FakeDataset, RangeRatioDataset):
def __init__(self, *, range_ratio: tuple[float, float], **kwargs) -> None:
FakeDataset.__init__(self, **kwargs)
RangeRatioDataset.__init__(self, range_ratio=range_ratio)
class IndexRatioFakeDataset(FakeDataset, IndexRatioDataset):
def __init__(self, *, start_idx: int, end_idx: int, ratio: float,
**kwargs) -> None:
FakeDataset.__init__(self, **kwargs)
IndexRatioDataset.__init__(self,
start_idx=start_idx,
end_idx=end_idx,
ratio=ratio)
class PoisonLabelFakeDataset(FakeDataset, PoisonLabelDataset):
def __init__(self, *, poison_label: int, **kwargs) -> None:
FakeDataset.__init__(self, **kwargs)
PoisonLabelDataset.__init__(self, poison_label=poison_label)
class RatioPoisonLabelFakeDataset(FakeDataset, RatioPoisonLabelDataset):
def __init__(self, *, ratio: float, poison_label: int, **kwargs) -> None:
FakeDataset.__init__(self, **kwargs)
RatioPoisonLabelDataset.__init__(self,
ratio=ratio,
poison_label=poison_label)
class RangeRatioPoisonLabelFakeDataset(FakeDataset,
RangeRatioPoisonLabelDataset):
def __init__(self, *, range_ratio: tuple[float, float], poison_label: int,
**kwargs) -> None:
FakeDataset.__init__(self, **kwargs)
RangeRatioPoisonLabelDataset.__init__(self,
range_ratio=range_ratio,
poison_label=poison_label)
FAKE_DATASETS_MAPPING = {
'FakeDataset': FakeDataset,
'IndexFakeDataset': IndexFakeDataset,
'RatioFakeDataset': RatioFakeDataset,
'RangeRatioFakeDataset': RangeRatioFakeDataset,
'IndexRatioFakeDataset': IndexRatioFakeDataset,
'PoisonLabelFakeDataset': PoisonLabelFakeDataset,
'RatioPoisonLabelFakeDataset': RatioPoisonLabelFakeDataset,
'RangeRatioPoisonLabelFakeDataset': RangeRatioPoisonLabelFakeDataset
}
def build_fake_dataset(dataset_cfg: dict) -> FakeDataset:
if 'type' not in dataset_cfg:
raise ValueError('Dataset config must have `type` field')
dataset_type = dataset_cfg.pop('type')
if dataset_type not in FAKE_DATASETS_MAPPING:
raise ValueError(
f'Dataset `{dataset_type}` is not support, '
f'available datasets: {list(FAKE_DATASETS_MAPPING.keys())}')
dataset = FAKE_DATASETS_MAPPING[dataset_type]
return dataset(**dataset_cfg)