IE101TW / models /adversarial.py
DeepLearning101's picture
Upload 2 files
a2fef5f
raw
history blame
No virus
1.06 kB
# -*- coding: utf-8 -*-
# @Time : 2022/1/7 11:02 上午
# @Author : JianingWang
# @File : adversarial.py
import torch
class FGM:
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name="word_embeddings"):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name="word_embeddings"):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}