Spaces:
Sleeping
Sleeping
File size: 10,351 Bytes
caab23d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
"""
Utilities for interfacing with the attentions from the front end.
"""
import torch
from typing import List, Union
from abc import ABC, abstractmethod
from transformer_formatter import TransformerOutputFormatter
from utils.token_processing import reshape
from spacyface import (
BertAligner,
GPT2Aligner,
RobertaAligner,
DistilBertAligner,
auto_aligner
)
from transformers import (
BertForMaskedLM,
GPT2LMHeadModel,
RobertaForMaskedLM,
DistilBertForMaskedLM,
)
from utils.f import delegates, pick, memoize
def get_cls(class_name):
cls_type = {
'bert-base-uncased': BertDetails,
'bert-base-cased': BertDetails,
'bert-large-uncased': BertDetails,
'bert-large-cased': BertDetails,
'gpt2': GPT2Details,
'gpt2-medium': GPT2Details,
'gpt2-large': GPT2Details,
'roberta-base': RobertaDetails,
'roberta-large': RobertaDetails,
'roberta-large-mnli': RobertaDetails,
'roberta-base-openai-detector': RobertaDetails,
'roberta-large-openai-detector': RobertaDetails,
'distilbert-base-uncased': DistilBertDetails,
'distilbert-base-uncased-distilled-squad': DistilBertDetails,
'distilgpt2': GPT2Details,
'distilroberta-base': RobertaDetails,
}
return cls_type[class_name]
@memoize
def from_pretrained(model_name):
"""Convert model name into appropriate transformer details"""
try: out = get_cls(model_name).from_pretrained(model_name)
except KeyError: raise KeyError(f"The model name of '{model_name}' either does not exist or is currently not supported")
return out
class TransformerBaseDetails(ABC):
""" All API calls will interact with this class to get the hidden states and attentions for any input sentence."""
def __init__(self, model, aligner):
self.model = model
self.aligner = aligner
self.model.eval()
self.forward_inputs = ['input_ids', 'attention_mask']
@classmethod
def from_pretrained(cls, model_name: str):
raise NotImplementedError(
"""Inherit from this class and specify the Model and Aligner to use"""
)
def att_from_sentence(self, s: str, mask_attentions=False) -> TransformerOutputFormatter:
"""Get formatted attention from a single sentence input"""
tokens = self.aligner.tokenize(s)
return self.att_from_tokens(tokens, s, add_special_tokens=True, mask_attentions=mask_attentions)
def att_from_tokens(
self, tokens: List[str], orig_sentence, add_special_tokens=False, mask_attentions=False
) -> TransformerOutputFormatter:
"""Get formatted attention from a list of tokens, using the original sentence for getting Spacy Metadata"""
ids = self.aligner.convert_tokens_to_ids(tokens)
# For GPT2, add the beginning of sentence token to the input. Note that this will work on all models but XLM
bost = self.aligner.bos_token_id
clst = self.aligner.cls_token_id
if (bost is not None) and (bost != clst) and add_special_tokens:
ids.insert(0, bost)
inputs = self.aligner.prepare_for_model(ids, add_special_tokens=add_special_tokens, return_tensors="pt")
parsed_input = self.format_model_input(inputs, mask_attentions=mask_attentions)
output = self.model(parsed_input['input_ids'], attention_mask=parsed_input['attention_mask'])
return self.format_model_output(inputs, orig_sentence, output)
def format_model_output(self, inputs, sentence:str, output, topk=5):
"""Convert model output to the desired format.
Formatter additionally needs access to the tokens and the original sentence
"""
hidden_state, attentions, contexts, logits = self.select_outputs(output)
words, probs = self.logits2words(logits, topk)
tokens = self.view_ids(inputs["input_ids"])
toks = self.aligner.meta_from_tokens(sentence, tokens, perform_check=False)
formatted_output = TransformerOutputFormatter(
sentence,
toks,
inputs["special_tokens_mask"],
attentions,
hidden_state,
contexts,
words,
probs.tolist()
)
return formatted_output
def select_outputs(self, output):
"""Extract the desired hidden states as passed by a particular model through the output
In all cases, we care for:
- hidden state embeddings (tuple of n_layers + 1)
- attentions (tuple of n_layers)
- contexts (tuple of n_layers)
- Top predicted words
- Probabilities of top predicted words
"""
logits, hidden_state, attentions, contexts = output
return hidden_state, attentions, contexts, logits
def format_model_input(self, inputs, mask_attentions=False):
"""Parse the input for the model according to what is expected in the forward pass.
If not otherwise defined, outputs a dict containing the keys:
{'input_ids', 'attention_mask'}
"""
return pick(self.forward_inputs, self.parse_inputs(inputs, mask_attentions=mask_attentions))
def logits2words(self, logits, topk=5):
probs, idxs = torch.topk(torch.softmax(logits.squeeze(0), 1), topk)
words = [self.aligner.convert_ids_to_tokens(i) for i in idxs]
return words, probs
def view_ids(self, ids: Union[List[int], torch.Tensor]) -> List[str]:
"""View what the tokenizer thinks certain ids are"""
if type(ids) == torch.Tensor:
# Remove batch dimension
ids = ids.squeeze(0).tolist()
out = self.aligner.convert_ids_to_tokens(ids)
return out
def parse_inputs(self, inputs, mask_attentions=False):
"""Parse the output from `tokenizer.prepare_for_model` to the desired attention mask from special tokens
Args:
- inputs: The output of `tokenizer.prepare_for_model`.
A dict with keys: {'special_token_mask', 'token_type_ids', 'input_ids'}
- mask_attentions: Flag indicating whether to mask the attentions or not
Returns:
Dict with keys: {'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'}
Usage:
```
s = "test sentence"
# from raw sentence to tokens
tokens = tokenizer.tokenize(s)
# From tokens to ids
ids = tokenizer.convert_tokens_to_ids(tokens)
# From ids to input
inputs = tokenizer.prepare_for_model(ids, return_tensors='pt')
# Parse the input. Optionally mask the special tokens from the analysis.
parsed_input = parse_inputs(inputs)
# Run the model, pick from this output whatever inputs you want
from utils.f import pick
out = model(**pick(['input_ids'], parse_inputs(inputs)))
```
"""
out = inputs.copy()
# DEFINE SPECIAL TOKENS MASK
if "special_tokens_mask" not in inputs.keys():
special_tokens = set([self.aligner.unk_token_id, self.aligner.cls_token_id, self.aligner.sep_token_id, self.aligner.bos_token_id, self.aligner.eos_token_id, self.aligner.pad_token_id])
in_ids = inputs['input_ids'][0]
special_tok_mask = [1 if int(i) in special_tokens else 0 for i in in_ids]
inputs['special_tokens_mask'] = special_tok_mask
if mask_attentions:
out["attention_mask"] = torch.tensor(
[int(not i) for i in inputs.get("special_tokens_mask")]
).unsqueeze(0)
else:
out["attention_mask"] = torch.tensor(
[1 for i in inputs.get("special_tokens_mask")]
).unsqueeze(0)
return out
class BertDetails(TransformerBaseDetails):
@classmethod
def from_pretrained(cls, model_name: str):
return cls(
BertForMaskedLM.from_pretrained(
model_name,
output_attentions=True,
output_hidden_states=True,
output_additional_info=True,
),
BertAligner.from_pretrained(model_name),
)
class GPT2Details(TransformerBaseDetails):
@classmethod
def from_pretrained(cls, model_name: str):
return cls(
GPT2LMHeadModel.from_pretrained(
model_name,
output_attentions=True,
output_hidden_states=True,
output_additional_info=True,
),
GPT2Aligner.from_pretrained(model_name),
)
def select_outputs(self, output):
logits, _ , hidden_states, att, contexts = output
return hidden_states, att, contexts, logits
class RobertaDetails(TransformerBaseDetails):
@classmethod
def from_pretrained(cls, model_name: str):
return cls(
RobertaForMaskedLM.from_pretrained(
model_name,
output_attentions=True,
output_hidden_states=True,
output_additional_info=True,
),
RobertaAligner.from_pretrained(model_name),
)
class DistilBertDetails(TransformerBaseDetails):
def __init__(self, model, aligner):
super().__init__(model, aligner)
self.forward_inputs = ['input_ids', 'attention_mask']
@classmethod
def from_pretrained(cls, model_name: str):
return cls(
DistilBertForMaskedLM.from_pretrained(
model_name,
output_attentions=True,
output_hidden_states=True,
output_additional_info=True,
),
DistilBertAligner.from_pretrained(model_name),
)
def select_outputs(self, output):
"""Extract the desired hidden states as passed by a particular model through the output
In all cases, we care for:
- hidden state embeddings (tuple of n_layers + 1)
- attentions (tuple of n_layers)
- contexts (tuple of n_layers)
"""
logits, hidden_states, attentions, contexts = output
contexts = tuple([c.permute(0, 2, 1, 3).contiguous() for c in contexts])
return hidden_states, attentions, contexts, logits |