""" | |
BAE (BAE: BERT-Based Adversarial Examples) | |
============================================ | |
""" | |
from textattack.constraints.grammaticality import PartOfSpeech | |
from textattack.constraints.pre_transformation import ( | |
RepeatModification, | |
StopwordModification, | |
) | |
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder | |
from textattack.goal_functions import UntargetedClassification | |
from textattack.search_methods import GreedyWordSwapWIR | |
from textattack.transformations import WordSwapMaskedLM | |
from .attack_recipe import AttackRecipe | |
class BAEGarg2019(AttackRecipe): | |
"""Siddhant Garg and Goutham Ramakrishnan, 2019. | |
BAE: BERT-based Adversarial Examples for Text Classification. | |
https://arxiv.org/pdf/2004.01970 | |
This is "attack mode" 1 from the paper, BAE-R, word replacement. | |
We present 4 attack modes for BAE based on the | |
R and I operations, where for each token t in S: | |
• BAE-R: Replace token t (See Algorithm 1) | |
• BAE-I: Insert a token to the left or right of t | |
• BAE-R/I: Either replace token t or insert a | |
token to the left or right of t | |
• BAE-R+I: First replace token t, then insert a | |
token to the left or right of t | |
""" | |
def build(model_wrapper): | |
# "In this paper, we present a simple yet novel technique: BAE (BERT-based | |
# Adversarial Examples), which uses a language model (LM) for token | |
# replacement to best fit the overall context. We perturb an input sentence | |
# by either replacing a token or inserting a new token in the sentence, by | |
# means of masking a part of the input and using a LM to fill in the mask." | |
# | |
# We only consider the top K=50 synonyms from the MLM predictions. | |
# | |
# [from email correspondance with the author] | |
# "When choosing the top-K candidates from the BERT masked LM, we filter out | |
# the sub-words and only retain the whole words (by checking if they are | |
# present in the GloVE vocabulary)" | |
# | |
transformation = WordSwapMaskedLM( | |
method="bae", max_candidates=50, min_confidence=0.0 | |
) | |
# | |
# Don't modify the same word twice or stopwords. | |
# | |
constraints = [RepeatModification(), StopwordModification()] | |
# For the R operations we add an additional check for | |
# grammatical correctness of the generated adversarial example by filtering | |
# out predicted tokens that do not form the same part of speech (POS) as the | |
# original token t_i in the sentence. | |
constraints.append(PartOfSpeech(allow_verb_noun_swap=True)) | |
# "To ensure semantic similarity on introducing perturbations in the input | |
# text, we filter the set of top-K masked tokens (K is a pre-defined | |
# constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE) | |
# (Cer et al., 2018)-based sentence similarity scorer." | |
# | |
# "[We] set a threshold of 0.8 for the cosine similarity between USE-based | |
# embeddings of the adversarial and input text." | |
# | |
# [from email correspondence with the author] | |
# "For a fair comparison of the benefits of using a BERT-MLM in our paper, | |
# we retained the majority of TextFooler's specifications. Thus we: | |
# 1. Use the USE for comparison within a window of size 15 around the word | |
# being replaced/inserted. | |
# 2. Set the similarity score threshold to 0.1 for inputs shorter than the | |
# window size (this translates roughly to almost always accepting the new text). | |
# 3. Perform the USE similarity thresholding of 0.8 with respect to the text | |
# just before the replacement/insertion and not the original text (For | |
# example: at the 3rd R/I operation, we compute the USE score on a window | |
# of size 15 of the text obtained after the first 2 R/I operations and not | |
# the original text). | |
# ... | |
# To address point (3) from above, compare the USE with the original text | |
# at each iteration instead of the current one (While doing this change | |
# for the R-operation is trivial, doing it for the I-operation with the | |
# window based USE comparison might be more involved)." | |
# | |
# Finally, since the BAE code is based on the TextFooler code, we need to | |
# adjust the threshold to account for the missing / pi in the cosine | |
# similarity comparison. So the final threshold is 1 - (1 - 0.8) / pi | |
# = 1 - (0.2 / pi) = 0.936338023. | |
use_constraint = UniversalSentenceEncoder( | |
threshold=0.936338023, | |
metric="cosine", | |
compare_against_original=True, | |
window_size=15, | |
skip_text_shorter_than_window=True, | |
) | |
constraints.append(use_constraint) | |
# | |
# Goal is untargeted classification. | |
# | |
goal_function = UntargetedClassification(model_wrapper) | |
# | |
# "We estimate the token importance Ii of each token | |
# t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the | |
# decrease in probability of predicting the correct label y, similar | |
# to (Jin et al., 2019). | |
# | |
# • "If there are multiple tokens can cause C to misclassify S when they | |
# replace the mask, we choose the token which makes Sadv most similar to | |
# the original S based on the USE score." | |
# • "If no token causes misclassification, we choose the perturbation that | |
# decreases the prediction probability P(C(Sadv)=y) the most." | |
# | |
search_method = GreedyWordSwapWIR(wir_method="delete") | |
return BAEGarg2019(goal_function, constraints, transformation, search_method) | |