|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from anti_kd_backdoor.data.dataset.base import ( |
|
|
DatasetInterface, |
|
|
IndexDataset, |
|
|
IndexRatioDataset, |
|
|
PoisonLabelDataset, |
|
|
RangeRatioDataset, |
|
|
RangeRatioPoisonLabelDataset, |
|
|
RatioDataset, |
|
|
RatioPoisonLabelDataset, |
|
|
) |
|
|
from anti_kd_backdoor.data.dataset.types import XY_TYPE |
|
|
|
|
|
|
|
|
class FakeDataset(DatasetInterface, Dataset): |
|
|
cache: dict[tuple, tuple[torch.Tensor, list[int]]] = dict() |
|
|
|
|
|
def __init__(self, |
|
|
*, |
|
|
x_shape: tuple[int, int, int] = (3, 32, 32), |
|
|
y_range: tuple[int, int] = (0, 9), |
|
|
nums: int = 10000, |
|
|
cache_xy: bool = False) -> None: |
|
|
self._nums = nums |
|
|
self._x_shape = x_shape |
|
|
self._y_range = y_range |
|
|
self._raw_num_classes = y_range[1] - y_range[0] + 1 |
|
|
|
|
|
if cache_xy: |
|
|
cache_key = (x_shape, y_range, nums) |
|
|
if cache_key not in FakeDataset.cache: |
|
|
x, y = self._prepare_xy() |
|
|
FakeDataset.cache[cache_key] = (x, y) |
|
|
x, y = FakeDataset.cache[cache_key] |
|
|
else: |
|
|
x, y = self._prepare_xy() |
|
|
|
|
|
self.data, self.targets = x, y |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.targets) |
|
|
|
|
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
x = self.data[index] |
|
|
y = self.targets[index] |
|
|
|
|
|
return x, y |
|
|
|
|
|
def get_xy(self) -> XY_TYPE: |
|
|
return list(self.data), self.targets.copy() |
|
|
|
|
|
def set_xy(self, xy: XY_TYPE) -> None: |
|
|
x, y = xy |
|
|
assert len(x) == len(y) |
|
|
|
|
|
self.data = np.stack(x, axis=0) |
|
|
self.targets = y.copy() |
|
|
|
|
|
@property |
|
|
def num_classes(self) -> int: |
|
|
return len(set(self.targets)) |
|
|
|
|
|
@property |
|
|
def raw_num_classes(self) -> int: |
|
|
return self._raw_num_classes |
|
|
|
|
|
def _prepare_xy(self) -> tuple[torch.Tensor, list[int]]: |
|
|
x = torch.rand((self._nums, *self._x_shape)) |
|
|
num_per_class = self._nums // self._raw_num_classes |
|
|
y = [ |
|
|
i for _ in range(num_per_class) |
|
|
for i in range(self._raw_num_classes) |
|
|
] |
|
|
y.extend([self._y_range[0]] * (self._nums - len(y))) |
|
|
|
|
|
return x, y |
|
|
|
|
|
|
|
|
class IndexFakeDataset(FakeDataset, IndexDataset): |
|
|
|
|
|
def __init__(self, *, start_idx: int, end_idx: int, **kwargs) -> None: |
|
|
FakeDataset.__init__(self, **kwargs) |
|
|
IndexDataset.__init__(self, start_idx=start_idx, end_idx=end_idx) |
|
|
|
|
|
|
|
|
class RatioFakeDataset(FakeDataset, RatioDataset): |
|
|
|
|
|
def __init__(self, *, ratio: float, **kwargs) -> None: |
|
|
FakeDataset.__init__(self, **kwargs) |
|
|
RatioDataset.__init__(self, ratio=ratio) |
|
|
|
|
|
|
|
|
class RangeRatioFakeDataset(FakeDataset, RangeRatioDataset): |
|
|
|
|
|
def __init__(self, *, range_ratio: tuple[float, float], **kwargs) -> None: |
|
|
FakeDataset.__init__(self, **kwargs) |
|
|
RangeRatioDataset.__init__(self, range_ratio=range_ratio) |
|
|
|
|
|
|
|
|
class IndexRatioFakeDataset(FakeDataset, IndexRatioDataset): |
|
|
|
|
|
def __init__(self, *, start_idx: int, end_idx: int, ratio: float, |
|
|
**kwargs) -> None: |
|
|
FakeDataset.__init__(self, **kwargs) |
|
|
IndexRatioDataset.__init__(self, |
|
|
start_idx=start_idx, |
|
|
end_idx=end_idx, |
|
|
ratio=ratio) |
|
|
|
|
|
|
|
|
class PoisonLabelFakeDataset(FakeDataset, PoisonLabelDataset): |
|
|
|
|
|
def __init__(self, *, poison_label: int, **kwargs) -> None: |
|
|
FakeDataset.__init__(self, **kwargs) |
|
|
PoisonLabelDataset.__init__(self, poison_label=poison_label) |
|
|
|
|
|
|
|
|
class RatioPoisonLabelFakeDataset(FakeDataset, RatioPoisonLabelDataset): |
|
|
|
|
|
def __init__(self, *, ratio: float, poison_label: int, **kwargs) -> None: |
|
|
FakeDataset.__init__(self, **kwargs) |
|
|
RatioPoisonLabelDataset.__init__(self, |
|
|
ratio=ratio, |
|
|
poison_label=poison_label) |
|
|
|
|
|
|
|
|
class RangeRatioPoisonLabelFakeDataset(FakeDataset, |
|
|
RangeRatioPoisonLabelDataset): |
|
|
|
|
|
def __init__(self, *, range_ratio: tuple[float, float], poison_label: int, |
|
|
**kwargs) -> None: |
|
|
FakeDataset.__init__(self, **kwargs) |
|
|
RangeRatioPoisonLabelDataset.__init__(self, |
|
|
range_ratio=range_ratio, |
|
|
poison_label=poison_label) |
|
|
|
|
|
|
|
|
FAKE_DATASETS_MAPPING = { |
|
|
'FakeDataset': FakeDataset, |
|
|
'IndexFakeDataset': IndexFakeDataset, |
|
|
'RatioFakeDataset': RatioFakeDataset, |
|
|
'RangeRatioFakeDataset': RangeRatioFakeDataset, |
|
|
'IndexRatioFakeDataset': IndexRatioFakeDataset, |
|
|
'PoisonLabelFakeDataset': PoisonLabelFakeDataset, |
|
|
'RatioPoisonLabelFakeDataset': RatioPoisonLabelFakeDataset, |
|
|
'RangeRatioPoisonLabelFakeDataset': RangeRatioPoisonLabelFakeDataset |
|
|
} |
|
|
|
|
|
|
|
|
def build_fake_dataset(dataset_cfg: dict) -> FakeDataset: |
|
|
if 'type' not in dataset_cfg: |
|
|
raise ValueError('Dataset config must have `type` field') |
|
|
dataset_type = dataset_cfg.pop('type') |
|
|
if dataset_type not in FAKE_DATASETS_MAPPING: |
|
|
raise ValueError( |
|
|
f'Dataset `{dataset_type}` is not support, ' |
|
|
f'available datasets: {list(FAKE_DATASETS_MAPPING.keys())}') |
|
|
dataset = FAKE_DATASETS_MAPPING[dataset_type] |
|
|
|
|
|
return dataset(**dataset_cfg) |
|
|
|