File size: 4,916 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""
HuggingFace Model Wrapper
--------------------------
"""
import torch
import transformers
import textattack
from textattack.models.helpers import T5ForTextToText
from textattack.models.tokenizers import T5Tokenizer
from .pytorch_model_wrapper import PyTorchModelWrapper
torch.cuda.empty_cache()
class HuggingFaceModelWrapper(PyTorchModelWrapper):
"""Loads a HuggingFace ``transformers`` model and tokenizer."""
def __init__(self, model, tokenizer):
assert isinstance(
model, (transformers.PreTrainedModel, T5ForTextToText)
), f"`model` must be of type `transformers.PreTrainedModel`, but got type {type(model)}."
assert isinstance(
tokenizer,
(
transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast,
T5Tokenizer,
),
), f"`tokenizer` must of type `transformers.PreTrainedTokenizer` or `transformers.PreTrainedTokenizerFast`, but got type {type(tokenizer)}."
self.model = model
self.tokenizer = tokenizer
def __call__(self, text_input_list):
"""Passes inputs to HuggingFace models as keyword arguments.
(Regular PyTorch ``nn.Module`` models typically take inputs as
positional arguments.)
"""
# Default max length is set to be int(1e30), so we force 512 to enable batching.
max_length = (
512
if self.tokenizer.model_max_length == int(1e30)
else self.tokenizer.model_max_length
)
inputs_dict = self.tokenizer(
text_input_list,
add_special_tokens=True,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
model_device = next(self.model.parameters()).device
inputs_dict.to(model_device)
with torch.no_grad():
outputs = self.model(**inputs_dict)
if isinstance(outputs[0], str):
# HuggingFace sequence-to-sequence models return a list of
# string predictions as output. In this case, return the full
# list of outputs.
return outputs
else:
# HuggingFace classification models return a tuple as output
# where the first item in the tuple corresponds to the list of
# scores for each input.
return outputs.logits
def get_grad(self, text_input):
"""Get gradient of loss with respect to input tokens.
Args:
text_input (str): input string
Returns:
Dict of ids, tokens, and gradient as numpy array.
"""
if isinstance(self.model, textattack.models.helpers.T5ForTextToText):
raise NotImplementedError(
"`get_grads` for T5FotTextToText has not been implemented yet."
)
self.model.train()
embedding_layer = self.model.get_input_embeddings()
original_state = embedding_layer.weight.requires_grad
embedding_layer.weight.requires_grad = True
emb_grads = []
def grad_hook(module, grad_in, grad_out):
emb_grads.append(grad_out[0])
emb_hook = embedding_layer.register_backward_hook(grad_hook)
self.model.zero_grad()
model_device = next(self.model.parameters()).device
input_dict = self.tokenizer(
[text_input],
add_special_tokens=True,
return_tensors="pt",
padding="max_length",
truncation=True,
)
input_dict.to(model_device)
predictions = self.model(**input_dict).logits
try:
labels = predictions.argmax(dim=1)
loss = self.model(**input_dict, labels=labels)[0]
except TypeError:
raise TypeError(
f"{type(self.model)} class does not take in `labels` to calculate loss. "
"One cause for this might be if you instantiatedyour model using `transformer.AutoModel` "
"(instead of `transformers.AutoModelForSequenceClassification`)."
)
loss.backward()
# grad w.r.t to word embeddings
grad = emb_grads[0][0].cpu().numpy()
embedding_layer.weight.requires_grad = original_state
emb_hook.remove()
self.model.eval()
output = {"ids": input_dict["input_ids"], "gradient": grad}
return output
def _tokenize(self, inputs):
"""Helper method that for `tokenize`
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
return [
self.tokenizer.convert_ids_to_tokens(
self.tokenizer([x], truncation=True)["input_ids"][0]
)
for x in inputs
]
|