anonymous8/RPD-Demo
initial commit
4943752
raw
history blame
4.24 kB
"""
BERT-Attack:
============================================================
(BERT-Attack: Adversarial Attack Against BERT Using BERT)
.. warning::
This attack is super slow
(see https://github.com/QData/TextAttack/issues/586)
Consider using smaller values for "max_candidates".
"""
from textattack import Attack
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import WordSwapMaskedLM
from .attack_recipe import AttackRecipe
class BERTAttackLi2020(AttackRecipe):
"""Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020).
BERT-ATTACK: Adversarial Attack Against BERT Using BERT
https://arxiv.org/abs/2004.09984
This is "attack mode" 1 from the paper, BAE-R, word replacement.
"""
@staticmethod
def build(model_wrapper):
# [from correspondence with the author]
# Candidate size K is set to 48 for all data-sets.
transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48)
#
# Don't modify the same word twice or stopwords.
#
constraints = [RepeatModification(), StopwordModification()]
# "We only take ε percent of the most important words since we tend to keep
# perturbations minimum."
#
# [from correspondence with the author]
# "Word percentage allowed to change is set to 0.4 for most data-sets, this
# parameter is trivial since most attacks only need a few changes. This
# epsilon is only used to avoid too much queries on those very hard samples."
constraints.append(MaxWordsPerturbed(max_percent=0.4))
# "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence
# Encoder (Cer et al., 2018) to measure the semantic consistency between the
# adversarial sample and the original sequence. To balance between semantic
# preservation and attack success rate, we set up a threshold of semantic
# similarity score to filter the less similar examples."
#
# [from correspondence with author]
# "Over the full texts, after generating all the adversarial samples, we filter
# out low USE score samples. Thus the success rate is lower but the USE score
# can be higher. (actually USE score is not a golden metric, so we simply
# measure the USE score over the final texts for a comparison with TextFooler).
# For datasets like IMDB, we set a higher threshold between 0.4-0.7; for
# datasets like MNLI, we set threshold between 0-0.2."
#
# Since the threshold in the real world can't be determined from the training
# data, the TextAttack implementation uses a fixed threshold - determined to
# be 0.2 to be most fair.
use_constraint = UniversalSentenceEncoder(
threshold=0.2,
metric="cosine",
compare_against_original=True,
window_size=None,
)
constraints.append(use_constraint)
#
# Goal is untargeted classification.
#
goal_function = UntargetedClassification(model_wrapper)
#
# "We first select the words in the sequence which have a high significance
# influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote
# the input sentence, and oy(S) denote the logit output by the target model
# for correct label y, the importance score Iwi is defined as
# Iwi = oy(S) − oy(S\wi), where S\wi = [w0, ··· , wi−1, [MASK], wi+1, ···]
# is the sentence after replacing wi with [MASK]. Then we rank all the words
# according to the ranking score Iwi in descending order to create word list
# L."
search_method = GreedyWordSwapWIR(wir_method="unk")
return Attack(goal_function, constraints, transformation, search_method)