PFEemp2024's picture
solving GPU error for previous version
4a1df2e
"""
Word Embedding Distance
--------------------------
"""
from textattack.constraints import Constraint
from textattack.shared import AbstractWordEmbedding, WordEmbedding
from textattack.shared.validators import transformation_consists_of_word_swaps
class WordEmbeddingDistance(Constraint):
"""A constraint on word substitutions which places a maximum distance
between the embedding of the word being deleted and the word being
inserted.
Args:
embedding (obj): Wrapper for word embedding.
include_unknown_words (bool): Whether or not the constraint is fulfilled if the embedding of x or x_adv is unknown.
min_cos_sim (:obj:`float`, optional): The minimum cosine similarity between word embeddings.
max_mse_dist (:obj:`float`, optional): The maximum euclidean distance between word embeddings.
cased (bool): Whether embedding supports uppercase & lowercase (defaults to False, or just lowercase).
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. Otherwise, compare it against the previous `x_adv`.
"""
def __init__(
self,
embedding=None,
include_unknown_words=True,
min_cos_sim=None,
max_mse_dist=None,
cased=False,
compare_against_original=True,
):
super().__init__(compare_against_original)
if embedding is None:
embedding = WordEmbedding.counterfitted_GLOVE_embedding()
self.include_unknown_words = include_unknown_words
self.cased = cased
if bool(min_cos_sim) == bool(max_mse_dist):
raise ValueError("You must choose either `min_cos_sim` or `max_mse_dist`.")
self.min_cos_sim = min_cos_sim
self.max_mse_dist = max_mse_dist
if not isinstance(embedding, AbstractWordEmbedding):
raise ValueError(
"`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`."
)
self.embedding = embedding
def get_cos_sim(self, a, b):
"""Returns the cosine similarity of words with IDs a and b."""
return self.embedding.get_cos_sim(a, b)
def get_mse_dist(self, a, b):
"""Returns the MSE distance of words with IDs a and b."""
return self.embedding.get_mse_dist(a, b)
def _check_constraint(self, transformed_text, reference_text):
"""Returns true if (``transformed_text`` and ``reference_text``) are
closer than ``self.min_cos_sim`` or ``self.max_mse_dist``."""
try:
indices = transformed_text.attack_attrs["newly_modified_indices"]
except KeyError:
raise KeyError(
"Cannot apply part-of-speech constraint without `newly_modified_indices`"
)
# FIXME The index i is sometimes larger than the number of tokens - 1
if any(
i >= len(reference_text.words) or i >= len(transformed_text.words)
for i in indices
):
return False
for i in indices:
ref_word = reference_text.words[i]
transformed_word = transformed_text.words[i]
if not self.cased:
# If embedding vocabulary is all lowercase, lowercase words.
ref_word = ref_word.lower()
transformed_word = transformed_word.lower()
try:
ref_id = self.embedding.word2index(ref_word)
transformed_id = self.embedding.word2index(transformed_word)
except KeyError:
# This error is thrown if x or x_adv has no corresponding ID.
if self.include_unknown_words:
continue
return False
# Check cosine distance.
if self.min_cos_sim:
cos_sim = self.get_cos_sim(ref_id, transformed_id)
if cos_sim < self.min_cos_sim:
return False
# Check MSE distance.
if self.max_mse_dist:
mse_dist = self.get_mse_dist(ref_id, transformed_id)
if mse_dist > self.max_mse_dist:
return False
return True
def check_compatibility(self, transformation):
"""WordEmbeddingDistance requires a word being both deleted and
inserted at the same index in order to compare their embeddings,
therefore it's restricted to word swaps."""
return transformation_consists_of_word_swaps(transformation)
def extra_repr_keys(self):
"""Set the extra representation of the constraint using these keys.
To print customized extra information, you should reimplement
this method in your own constraint. Both single-line and multi-
line strings are acceptable.
"""
if self.min_cos_sim is None:
metric = "max_mse_dist"
else:
metric = "min_cos_sim"
return [
"embedding",
metric,
"cased",
"include_unknown_words",
] + super().extra_repr_keys()