Rapid-Textual-Adversarial-Defense
/
textattack
/constraints
/semantics
/sentence_encoders
/thought_vector.py
""" | |
Thought Vector Class | |
--------------------- | |
""" | |
import functools | |
import torch | |
from textattack.shared import AbstractWordEmbedding, WordEmbedding, utils | |
from .sentence_encoder import SentenceEncoder | |
class ThoughtVector(SentenceEncoder): | |
"""A constraint on the distance between two sentences' thought vectors. | |
Args: | |
word_embedding (textattack.shared.AbstractWordEmbedding): The word embedding to use | |
""" | |
def __init__(self, embedding=None, **kwargs): | |
if embedding is None: | |
embedding = WordEmbedding.counterfitted_GLOVE_embedding() | |
if not isinstance(embedding, AbstractWordEmbedding): | |
raise ValueError( | |
"`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`." | |
) | |
self.word_embedding = embedding | |
super().__init__(**kwargs) | |
def clear_cache(self): | |
self._get_thought_vector.cache_clear() | |
def _get_thought_vector(self, text): | |
"""Sums the embeddings of all the words in ``text`` into a "thought | |
vector".""" | |
embeddings = [] | |
for word in utils.words_from_text(text): | |
embedding = self.word_embedding[word] | |
if embedding is not None: # out-of-vocab words do not have embeddings | |
embeddings.append(embedding) | |
embeddings = torch.tensor(embeddings) | |
return torch.mean(embeddings, dim=0) | |
def encode(self, raw_text_list): | |
return torch.stack([self._get_thought_vector(text) for text in raw_text_list]) | |
def extra_repr_keys(self): | |
"""Set the extra representation of the constraint using these keys.""" | |
return ["word_embedding"] + super().extra_repr_keys() | |