|
|
import abc
|
|
|
from collections import OrderedDict
|
|
|
from pathlib import Path
|
|
|
from typing import Union
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
class BaseGenerativeAttack(abc.ABC):
|
|
|
|
|
|
def __init__(self,
|
|
|
device: Union[str, torch.device],
|
|
|
epsilon: float = 32 / 255) -> None:
|
|
|
if isinstance(device, str):
|
|
|
device = torch.device(device)
|
|
|
self.device = device
|
|
|
self.set_adv_gen()
|
|
|
self.set_mode('eval')
|
|
|
self.epsilon = epsilon
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
def set_adv_gen(self):
|
|
|
pass
|
|
|
|
|
|
def load_ckpt(self, ckpt: Union[str, Path, OrderedDict]) -> None:
|
|
|
if isinstance(ckpt, str):
|
|
|
ckpt = Path(ckpt)
|
|
|
if isinstance(ckpt, Path):
|
|
|
if not ckpt.exists():
|
|
|
raise FileNotFoundError(f'File not found: {ckpt}')
|
|
|
ckpt = torch.load(ckpt, map_location=self.device)
|
|
|
self.adv_gen.load_state_dict(ckpt)
|
|
|
self.adv_gen.to(self.device)
|
|
|
|
|
|
def save_ckpt(self, ckpt: Union[str, Path]) -> None:
|
|
|
if isinstance(ckpt, str):
|
|
|
ckpt = Path(ckpt)
|
|
|
_adv_gen_cpu = self.adv_gen.to('cpu')
|
|
|
torch.save(_adv_gen_cpu.state_dict(), ckpt)
|
|
|
|
|
|
def get_params(self) -> torch.nn.Parameter:
|
|
|
return self.adv_gen.parameters()
|
|
|
|
|
|
def get_model(self) -> torch.nn.Module:
|
|
|
return self.adv_gen
|
|
|
|
|
|
def set_mode(self, mode: str) -> None:
|
|
|
assert mode in ['train', 'eval']
|
|
|
self.adv_gen.train() if mode == 'train' else self.adv_gen.eval()
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
def attack(self, *args) -> torch.Tensor:
|
|
|
pass
|
|
|
|
|
|
def __call__(self, x_nat: torch.Tensor, *extra_inputs) -> torch.Tensor:
|
|
|
x_adv = self.attack(x_nat, *extra_inputs)
|
|
|
x_adv = torch.min(torch.max(x_adv, x_nat - self.epsilon),
|
|
|
x_nat + self.epsilon)
|
|
|
torch.clamp_(x_adv, 0.0, 1.0)
|
|
|
return x_adv
|
|
|
|