Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# @Time : 2022/1/7 11:02 上午 | |
# @Author : JianingWang | |
# @File : adversarial.py | |
import torch | |
class FGM: | |
def __init__(self, model): | |
self.model = model | |
self.backup = {} | |
def attack(self, epsilon=1., emb_name="word_embeddings"): | |
# emb_name这个参数要换成你模型中embedding的参数名 | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and emb_name in name: | |
self.backup[name] = param.data.clone() | |
norm = torch.norm(param.grad) | |
if norm != 0: | |
r_at = epsilon * param.grad / norm | |
param.data.add_(r_at) | |
def restore(self, emb_name="word_embeddings"): | |
# emb_name这个参数要换成你模型中embedding的参数名 | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and emb_name in name: | |
assert name in self.backup | |
param.data = self.backup[name] | |
self.backup = {} | |