anonymous8/RPD-Demo
initial commit
4943752
"""
HotFlip
===========
(HotFlip: White-Box Adversarial Examples for Text Classification)
"""
from textattack import Attack
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import BeamSearch
from textattack.transformations import WordSwapGradientBased
from .attack_recipe import AttackRecipe
class HotFlipEbrahimi2017(AttackRecipe):
"""Ebrahimi, J. et al. (2017)
HotFlip: White-Box Adversarial Examples for Text Classification
https://arxiv.org/abs/1712.06751
This is a reproduction of the HotFlip word-level attack (section 5 of the
paper).
"""
@staticmethod
def build(model_wrapper):
#
# "HotFlip ... uses the gradient with respect to a one-hot input
# representation to efficiently estimate which individual change has the
# highest estimated loss."
transformation = WordSwapGradientBased(model_wrapper, top_n=1)
#
# Don't modify the same word twice or stopwords
#
constraints = [RepeatModification(), StopwordModification()]
#
# 0. "We were able to create only 41 examples (2% of the correctly-
# classified instances of the SST test set) with one or two flips."
#
constraints.append(MaxWordsPerturbed(max_num_words=2))
#
# 1. "The cosine similarity between the embedding of words is bigger than a
# threshold (0.8)."
#
constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))
#
# 2. "The two words have the same part-of-speech."
#
constraints.append(PartOfSpeech())
#
# Goal is untargeted classification
#
goal_function = UntargetedClassification(model_wrapper)
#
# "HotFlip ... uses a beam search to find a set of manipulations that work
# well together to confuse a classifier ... The adversary uses a beam size
# of 10."
#
search_method = BeamSearch(beam_width=10)
return Attack(goal_function, constraints, transformation, search_method)