szukevin's picture
upload
7900c16
raw
history blame
2.7 kB
import torch
class FGM(object):
"""
refer to the paper: FGM(Fast Gradient Method)
Adversarial training methods for semi-supervised text classification
"""
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1e-6, emd_name="embedding"):
for name, param in self.model.named_parameters():
if param.requires_grad and emd_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emd_name="embedding"):
for name, param in self.model.named_parameters():
if param.requires_grad and emd_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
class PGD(object):
"""
refer to the paper: PGD(Projected Gradient Descent)
Towards Deep Learning Models Resistant to Adversarial Attacks
"""
def __init__(self, model):
self.model = model
self.emb_backup = {}
self.grad_backup = {}
def attack(self, epsilon=1., alpha=0.3, emb_name="embedding", is_first_attack=False):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
if is_first_attack:
self.emb_backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = alpha * param.grad / norm
param.data.add_(r_at)
param.data = self.project(name, param.data, epsilon)
def restore(self, emb_name="embedding"):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.emb_backup
param.data = self.emb_backup[name]
self.emb_backup = {}
def project(self, param_name, param_data, epsilon):
r = param_data - self.emb_backup[param_name]
if torch.norm(r) > epsilon:
r = epsilon * r / torch.norm(r)
return self.emb_backup[param_name] + r
def backup_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.grad_backup[name] = param.grad.clone()
def restore_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
param.grad = self.grad_backup[name]