""" | |
BERT-Attack: | |
============================================================ | |
(BERT-Attack: Adversarial Attack Against BERT Using BERT) | |
.. warning:: | |
This attack is super slow | |
(see https://github.com/QData/TextAttack/issues/586) | |
Consider using smaller values for "max_candidates". | |
""" | |
from textattack import Attack | |
from textattack.constraints.overlap import MaxWordsPerturbed | |
from textattack.constraints.pre_transformation import ( | |
RepeatModification, | |
StopwordModification, | |
) | |
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder | |
from textattack.goal_functions import UntargetedClassification | |
from textattack.search_methods import GreedyWordSwapWIR | |
from textattack.transformations import WordSwapMaskedLM | |
from .attack_recipe import AttackRecipe | |
class BERTAttackLi2020(AttackRecipe): | |
"""Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020). | |
BERT-ATTACK: Adversarial Attack Against BERT Using BERT | |
https://arxiv.org/abs/2004.09984 | |
This is "attack mode" 1 from the paper, BAE-R, word replacement. | |
""" | |
def build(model_wrapper): | |
# [from correspondence with the author] | |
# Candidate size K is set to 48 for all data-sets. | |
transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48) | |
# | |
# Don't modify the same word twice or stopwords. | |
# | |
constraints = [RepeatModification(), StopwordModification()] | |
# "We only take ε percent of the most important words since we tend to keep | |
# perturbations minimum." | |
# | |
# [from correspondence with the author] | |
# "Word percentage allowed to change is set to 0.4 for most data-sets, this | |
# parameter is trivial since most attacks only need a few changes. This | |
# epsilon is only used to avoid too much queries on those very hard samples." | |
constraints.append(MaxWordsPerturbed(max_percent=0.4)) | |
# "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence | |
# Encoder (Cer et al., 2018) to measure the semantic consistency between the | |
# adversarial sample and the original sequence. To balance between semantic | |
# preservation and attack success rate, we set up a threshold of semantic | |
# similarity score to filter the less similar examples." | |
# | |
# [from correspondence with author] | |
# "Over the full texts, after generating all the adversarial samples, we filter | |
# out low USE score samples. Thus the success rate is lower but the USE score | |
# can be higher. (actually USE score is not a golden metric, so we simply | |
# measure the USE score over the final texts for a comparison with TextFooler). | |
# For datasets like IMDB, we set a higher threshold between 0.4-0.7; for | |
# datasets like MNLI, we set threshold between 0-0.2." | |
# | |
# Since the threshold in the real world can't be determined from the training | |
# data, the TextAttack implementation uses a fixed threshold - determined to | |
# be 0.2 to be most fair. | |
use_constraint = UniversalSentenceEncoder( | |
threshold=0.2, | |
metric="cosine", | |
compare_against_original=True, | |
window_size=None, | |
) | |
constraints.append(use_constraint) | |
# | |
# Goal is untargeted classification. | |
# | |
goal_function = UntargetedClassification(model_wrapper) | |
# | |
# "We first select the words in the sequence which have a high significance | |
# influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote | |
# the input sentence, and oy(S) denote the logit output by the target model | |
# for correct label y, the importance score Iwi is defined as | |
# Iwi = oy(S) − oy(S\wi), where S\wi = [w0, ··· , wi−1, [MASK], wi+1, ···] | |
# is the sentence after replacing wi with [MASK]. Then we rank all the words | |
# according to the ranking score Iwi in descending order to create word list | |
# L." | |
search_method = GreedyWordSwapWIR(wir_method="unk") | |
return Attack(goal_function, constraints, transformation, search_method) | |