File size: 8,669 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
"""
Sentence Encoder Class
------------------------
"""
from abc import ABC
import math
import numpy as np
import torch
from textattack.constraints import Constraint
class SentenceEncoder(Constraint, ABC):
"""Constraint using cosine similarity between sentence encodings of x and
x_adv.
Args:
threshold (:obj:`float`, optional): The threshold for the constraint to be met.
Defaults to 0.8
metric (:obj:`str`, optional): The similarity metric to use. Defaults to
cosine. Options: ['cosine, 'angular']
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`.
Otherwise, compare it against the previous `x_adv`.
window_size (int): The number of words to use in the similarity
comparison. `None` indicates no windowing (encoding is based on the
full input).
"""
def __init__(
self,
threshold=0.8,
metric="cosine",
compare_against_original=True,
window_size=None,
skip_text_shorter_than_window=False,
):
super().__init__(compare_against_original)
self.metric = metric
self.threshold = threshold
self.window_size = window_size
self.skip_text_shorter_than_window = skip_text_shorter_than_window
if not self.window_size:
self.window_size = float("inf")
if metric == "cosine":
self.sim_metric = torch.nn.CosineSimilarity(dim=1)
elif metric == "angular":
self.sim_metric = get_angular_sim
elif metric == "max_euclidean":
# If the threshold requires embedding similarity measurement
# be less than or equal to a certain value, just negate it,
# so that we can still compare to the threshold using >=.
self.threshold = -threshold
self.sim_metric = get_neg_euclidean_dist
else:
raise ValueError(f"Unsupported metric {metric}.")
def encode(self, sentences):
"""Encodes a list of sentences.
To be implemented by subclasses.
"""
raise NotImplementedError()
def _sim_score(self, starting_text, transformed_text):
"""Returns the metric similarity between the embedding of the starting
text and the transformed text.
Args:
starting_text: The ``AttackedText``to use as a starting point.
transformed_text: A transformed ``AttackedText``
Returns:
The similarity between the starting and transformed text using the metric.
"""
try:
modified_index = next(
iter(transformed_text.attack_attrs["newly_modified_indices"])
)
except KeyError:
raise KeyError(
"Cannot apply sentence encoder constraint without `newly_modified_indices`"
)
starting_text_window = starting_text.text_window_around_index(
modified_index, self.window_size
)
transformed_text_window = transformed_text.text_window_around_index(
modified_index, self.window_size
)
starting_embedding, transformed_embedding = self.model.encode(
[starting_text_window, transformed_text_window]
)
if not isinstance(starting_embedding, torch.Tensor):
starting_embedding = torch.tensor(starting_embedding)
if not isinstance(transformed_embedding, torch.Tensor):
transformed_embedding = torch.tensor(transformed_embedding)
starting_embedding = torch.unsqueeze(starting_embedding, dim=0)
transformed_embedding = torch.unsqueeze(transformed_embedding, dim=0)
return self.sim_metric(starting_embedding, transformed_embedding)
def _score_list(self, starting_text, transformed_texts):
"""Returns the metric similarity between the embedding of the starting
text and a list of transformed texts.
Args:
starting_text: The ``AttackedText``to use as a starting point.
transformed_texts: A list of transformed ``AttackedText``
Returns:
A list with the similarity between the ``starting_text`` and each of
``transformed_texts``. If ``transformed_texts`` is empty,
an empty tensor is returned
"""
# Return an empty tensor if transformed_texts is empty.
# This prevents us from calling .repeat(x, 0), which throws an
# error on machines with multiple GPUs (pytorch 1.2).
if len(transformed_texts) == 0:
return torch.tensor([])
if self.window_size:
starting_text_windows = []
transformed_text_windows = []
for transformed_text in transformed_texts:
# @TODO make this work when multiple indices have been modified
try:
modified_index = next(
iter(transformed_text.attack_attrs["newly_modified_indices"])
)
except KeyError:
raise KeyError(
"Cannot apply sentence encoder constraint without `newly_modified_indices`"
)
starting_text_windows.append(
starting_text.text_window_around_index(
modified_index, self.window_size
)
)
transformed_text_windows.append(
transformed_text.text_window_around_index(
modified_index, self.window_size
)
)
embeddings = self.encode(starting_text_windows + transformed_text_windows)
if not isinstance(embeddings, torch.Tensor):
embeddings = torch.tensor(embeddings)
starting_embeddings = embeddings[: len(transformed_texts)]
transformed_embeddings = embeddings[len(transformed_texts) :]
else:
starting_raw_text = starting_text.text
transformed_raw_texts = [t.text for t in transformed_texts]
embeddings = self.encode([starting_raw_text] + transformed_raw_texts)
if not isinstance(embeddings, torch.Tensor):
embeddings = torch.tensor(embeddings)
starting_embedding = embeddings[0]
transformed_embeddings = embeddings[1:]
# Repeat original embedding to size of perturbed embedding.
starting_embeddings = starting_embedding.unsqueeze(dim=0).repeat(
len(transformed_embeddings), 1
)
return self.sim_metric(starting_embeddings, transformed_embeddings)
def _check_constraint_many(self, transformed_texts, reference_text):
"""Filters the list ``transformed_texts`` so that the similarity
between the ``reference_text`` and the transformed text is greater than
the ``self.threshold``."""
scores = self._score_list(reference_text, transformed_texts)
for i, transformed_text in enumerate(transformed_texts):
# Optionally ignore similarity score for sentences shorter than the
# window size.
if (
self.skip_text_shorter_than_window
and len(transformed_text.words) < self.window_size
):
scores[i] = 1
transformed_text.attack_attrs["similarity_score"] = scores[i].item()
mask = (scores >= self.threshold).cpu().numpy().nonzero()
return np.array(transformed_texts)[mask]
def _check_constraint(self, transformed_text, reference_text):
if (
self.skip_text_shorter_than_window
and len(transformed_text.words) < self.window_size
):
score = 1
else:
score = self._sim_score(reference_text, transformed_text)
transformed_text.attack_attrs["similarity_score"] = score
return score >= self.threshold
def extra_repr_keys(self):
return [
"metric",
"threshold",
"window_size",
"skip_text_shorter_than_window",
] + super().extra_repr_keys()
def get_angular_sim(emb1, emb2):
"""Returns the _angular_ similarity between a batch of vector and a batch
of vectors."""
cos_sim = torch.nn.CosineSimilarity(dim=1)(emb1, emb2)
return 1 - (torch.acos(cos_sim) / math.pi)
def get_neg_euclidean_dist(emb1, emb2):
"""Returns the Euclidean distance between a batch of vectors and a batch of
vectors."""
return -torch.sum((emb1 - emb2) ** 2, dim=1)
|