Spaces:
Runtime error
Runtime error
import torch | |
class FGM(object): | |
""" | |
refer to the paper: FGM(Fast Gradient Method) | |
Adversarial training methods for semi-supervised text classification | |
""" | |
def __init__(self, model): | |
self.model = model | |
self.backup = {} | |
def attack(self, epsilon=1e-6, emd_name="embedding"): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and emd_name in name: | |
self.backup[name] = param.data.clone() | |
norm = torch.norm(param.grad) | |
if norm != 0 and not torch.isnan(norm): | |
r_at = epsilon * param.grad / norm | |
param.data.add_(r_at) | |
def restore(self, emd_name="embedding"): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and emd_name in name: | |
assert name in self.backup | |
param.data = self.backup[name] | |
self.backup = {} | |
class PGD(object): | |
""" | |
refer to the paper: PGD(Projected Gradient Descent) | |
Towards Deep Learning Models Resistant to Adversarial Attacks | |
""" | |
def __init__(self, model): | |
self.model = model | |
self.emb_backup = {} | |
self.grad_backup = {} | |
def attack(self, epsilon=1., alpha=0.3, emb_name="embedding", is_first_attack=False): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and emb_name in name: | |
if is_first_attack: | |
self.emb_backup[name] = param.data.clone() | |
norm = torch.norm(param.grad) | |
if norm != 0 and not torch.isnan(norm): | |
r_at = alpha * param.grad / norm | |
param.data.add_(r_at) | |
param.data = self.project(name, param.data, epsilon) | |
def restore(self, emb_name="embedding"): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and emb_name in name: | |
assert name in self.emb_backup | |
param.data = self.emb_backup[name] | |
self.emb_backup = {} | |
def project(self, param_name, param_data, epsilon): | |
r = param_data - self.emb_backup[param_name] | |
if torch.norm(r) > epsilon: | |
r = epsilon * r / torch.norm(r) | |
return self.emb_backup[param_name] + r | |
def backup_grad(self): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad: | |
self.grad_backup[name] = param.grad.clone() | |
def restore_grad(self): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad: | |
param.grad = self.grad_backup[name] | |