|
""" |
|
CoLA for Grammaticality |
|
-------------------------- |
|
|
|
""" |
|
import lru |
|
import nltk |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
from textattack.constraints import Constraint |
|
from textattack.models.wrappers import HuggingFaceModelWrapper |
|
|
|
|
|
class COLA(Constraint): |
|
"""Constrains an attack to text that has a similar number of linguistically |
|
accecptable sentences as the original text. Linguistic acceptability is |
|
determined by a model pre-trained on the `CoLA dataset <https://nyu- |
|
mll.github.io/CoLA/>`_. By default a BERT model is used, see the `pre- |
|
trained models README <https://github.com/QData/TextAttack/tree/master/ |
|
textattack/models>`_ for a full list of available models or provide your |
|
own model from the huggingface model hub. |
|
|
|
Args: |
|
max_diff (float or int): The absolute (if int or greater than or equal to 1) or percent (if float and less than 1) |
|
maximum difference allowed between the number of valid sentences in the reference |
|
text and the number of valid sentences in the attacked text. |
|
model_name (str): The name of the pre-trained model to use for classification. The model must be in huggingface model hub. |
|
compare_against_original (bool): If `True`, compare against the original text. |
|
Otherwise, compare against the most recent text. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
max_diff, |
|
model_name="textattack/bert-base-uncased-CoLA", |
|
compare_against_original=True, |
|
): |
|
super().__init__(compare_against_original) |
|
if not isinstance(max_diff, float) and not isinstance(max_diff, int): |
|
raise TypeError("max_diff must be a float or int") |
|
if max_diff < 0.0: |
|
raise ValueError("max_diff must be a value greater or equal to than 0.0") |
|
|
|
self.max_diff = max_diff |
|
self.model_name = model_name |
|
self._reference_score_cache = lru.LRU(2**10) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = HuggingFaceModelWrapper(model, tokenizer) |
|
|
|
def clear_cache(self): |
|
self._reference_score_cache.clear() |
|
|
|
def _check_constraint(self, transformed_text, reference_text): |
|
if reference_text not in self._reference_score_cache: |
|
|
|
reference_sentences = nltk.sent_tokenize(reference_text.text) |
|
|
|
num_valid = self.model(reference_sentences).argmax(axis=1).sum() |
|
self._reference_score_cache[reference_text] = num_valid |
|
|
|
sentences = nltk.sent_tokenize(transformed_text.text) |
|
predictions = self.model(sentences) |
|
num_valid = predictions.argmax(axis=1).sum() |
|
reference_score = self._reference_score_cache[reference_text] |
|
|
|
if isinstance(self.max_diff, int) or self.max_diff >= 1: |
|
threshold = reference_score - self.max_diff |
|
else: |
|
threshold = reference_score - (reference_score * self.max_diff) |
|
|
|
if num_valid < threshold: |
|
return False |
|
return True |
|
|
|
def extra_repr_keys(self): |
|
return [ |
|
"max_diff", |
|
"model_name", |
|
] + super().extra_repr_keys() |
|
|
|
def __getstate__(self): |
|
state = self.__dict__.copy() |
|
state["_reference_score_cache"] = self._reference_score_cache.get_size() |
|
return state |
|
|
|
def __setstate__(self, state): |
|
self.__dict__ = state |
|
self._reference_score_cache = lru.LRU(state["_reference_score_cache"]) |
|
|