File size: 5,858 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
Part of Speech Constraint
--------------------------
"""
import flair
from flair.data import Sentence
from flair.models import SequenceTagger
import lru
import nltk
import textattack
from textattack.constraints import Constraint
from textattack.shared.utils import LazyLoader, device
from textattack.shared.validators import transformation_consists_of_word_swaps
# Set global flair device to be TextAttack's current device
flair.device = device
stanza = LazyLoader("stanza", globals(), "stanza")
class PartOfSpeech(Constraint):
"""Constraints word swaps to only swap words with the same part of speech.
Uses the NLTK universal part-of-speech tagger by default. An implementation
of `<https://arxiv.org/abs/1907.11932>`_ adapted from
`<https://github.com/jind11/TextFooler>`_.
POS taggers from Flair `<https://github.com/flairNLP/flair>`_ and
Stanza `<https://github.com/stanfordnlp/stanza>`_ are also available
Args:
tagger_type (str): Name of the tagger to use (available choices: "nltk", "flair", "stanza").
tagset (str): tagset to use for POS tagging (e.g. "universal")
allow_verb_noun_swap (bool): If `True`, allow verbs to be swapped with nouns and vice versa.
compare_against_original (bool): If `True`, compare against the original text.
Otherwise, compare against the most recent text.
language_nltk: Language to be used for nltk POS-Tagger
(available choices: "eng", "rus")
language_stanza: Language to be used for stanza POS-Tagger
(available choices: https://stanfordnlp.github.io/stanza/available_models.html)
"""
def __init__(
self,
tagger_type="nltk",
tagset="universal",
allow_verb_noun_swap=True,
compare_against_original=True,
language_nltk="eng",
language_stanza="en",
):
super().__init__(compare_against_original)
self.tagger_type = tagger_type
self.tagset = tagset
self.allow_verb_noun_swap = allow_verb_noun_swap
self.language_nltk = language_nltk
self.language_stanza = language_stanza
self._pos_tag_cache = lru.LRU(2**14)
if tagger_type == "flair":
if tagset == "universal":
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
else:
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
if tagger_type == "stanza":
self._stanza_pos_tagger = stanza.Pipeline(
lang=self.language_stanza,
processors="tokenize, pos",
tokenize_pretokenized=True,
)
def clear_cache(self):
self._pos_tag_cache.clear()
def _can_replace_pos(self, pos_a, pos_b):
return (pos_a == pos_b) or (
self.allow_verb_noun_swap and set([pos_a, pos_b]) <= set(["NOUN", "VERB"])
)
def _get_pos(self, before_ctx, word, after_ctx):
context_words = before_ctx + [word] + after_ctx
context_key = " ".join(context_words)
if context_key in self._pos_tag_cache:
word_list, pos_list = self._pos_tag_cache[context_key]
else:
if self.tagger_type == "nltk":
word_list, pos_list = zip(
*nltk.pos_tag(
context_words, tagset=self.tagset, lang=self.language_nltk
)
)
if self.tagger_type == "flair":
context_key_sentence = Sentence(
context_key,
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
)
self._flair_pos_tagger.predict(context_key_sentence)
word_list, pos_list = textattack.shared.utils.zip_flair_result(
context_key_sentence
)
if self.tagger_type == "stanza":
word_list, pos_list = textattack.shared.utils.zip_stanza_result(
self._stanza_pos_tagger(context_key), tagset=self.tagset
)
self._pos_tag_cache[context_key] = (word_list, pos_list)
# idx of `word` in `context_words`
assert word in word_list, "POS list not matched with original word list."
word_idx = word_list.index(word)
return pos_list[word_idx]
def _check_constraint(self, transformed_text, reference_text):
try:
indices = transformed_text.attack_attrs["newly_modified_indices"]
except KeyError:
raise KeyError(
"Cannot apply part-of-speech constraint without `newly_modified_indices`"
)
for i in indices:
reference_word = reference_text.words[i]
transformed_word = transformed_text.words[i]
before_ctx = reference_text.words[max(i - 4, 0) : i]
after_ctx = reference_text.words[
i + 1 : min(i + 4, len(reference_text.words))
]
ref_pos = self._get_pos(before_ctx, reference_word, after_ctx)
replace_pos = self._get_pos(before_ctx, transformed_word, after_ctx)
if not self._can_replace_pos(ref_pos, replace_pos):
return False
return True
def check_compatibility(self, transformation):
return transformation_consists_of_word_swaps(transformation)
def extra_repr_keys(self):
return [
"tagger_type",
"tagset",
"allow_verb_noun_swap",
] + super().extra_repr_keys()
def __getstate__(self):
state = self.__dict__.copy()
state["_pos_tag_cache"] = self._pos_tag_cache.get_size()
return state
def __setstate__(self, state):
self.__dict__ = state
self._pos_tag_cache = lru.LRU(state["_pos_tag_cache"])
|