exbert / server /transformer_formatter.py
bhoov's picture
First commit
63858e7
raw
history blame
7.65 kB
from typing import List, Iterable, Tuple
from functools import partial
import numpy as np
import torch
import json
from spacyface.simple_spacy_token import SimpleSpacyToken
from utils.token_processing import fix_byte_spaces
from utils.gen_utils import map_nlist
def round_return_value(attentions, ndigits=5):
"""Rounding must happen right before it's passed back to the frontend because there is a little numerical error that's introduced converting back to lists
attentions: {
'aa': {
left.embeddings & contexts
right.embeddings & contexts
att
}
}
"""
rounder = partial(round, ndigits=ndigits)
nested_rounder = partial(map_nlist, rounder)
new_out = attentions # Modify values to save memory
new_out["aa"]["att"] = nested_rounder(attentions["aa"]["att"])
new_out["aa"]["left"]["embeddings"] = nested_rounder(
attentions["aa"]["left"]["embeddings"]
)
new_out["aa"]["left"]["contexts"] = nested_rounder(
attentions["aa"]["left"]["contexts"]
)
new_out["aa"]["right"]["embeddings"] = nested_rounder(
attentions["aa"]["right"]["embeddings"]
)
new_out["aa"]["right"]["contexts"] = nested_rounder(
attentions["aa"]["right"]["contexts"]
)
return new_out
def flatten_batch(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
"""Remove the batch dimension of every tensor inside the Iterable container `x`"""
return tuple([x_.squeeze(0) for x_ in x])
def squeeze_contexts(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
"""Combine the last two dimensions of the context."""
shape = x[0].shape
new_shape = shape[:-2] + (-1,)
return tuple([x_.view(new_shape) for x_ in x])
def add_blank(xs: Tuple[torch.tensor]) -> Tuple[torch.Tensor]:
"""The embeddings have n_layers + 1, indicating the final output embedding."""
return (torch.zeros_like(xs[0]),) + xs
class TransformerOutputFormatter:
def __init__(
self,
sentence: str,
tokens: List[SimpleSpacyToken],
special_tokens_mask: List[int],
att: Tuple[torch.Tensor],
embeddings: Tuple[torch.Tensor],
contexts: Tuple[torch.Tensor],
topk_words: List[List[str]],
topk_probs: List[List[float]]
):
assert len(tokens) > 0, "Cannot have an empty token output!"
modified_embeddings = flatten_batch(embeddings)
modified_att = flatten_batch(att)
modified_contexts = flatten_batch(contexts)
self.sentence = sentence
self.tokens = tokens
self.special_tokens_mask = special_tokens_mask
self.embeddings = modified_embeddings
self.attentions = modified_att
self.raw_contexts = modified_contexts
self.topk_words = topk_words
self.topk_probs = topk_probs
self.n_layers = len(contexts) # With +1 for buffer layer at the beginning
_, self.__len, self.n_heads, self.hidden_dim = contexts[0].shape
@property
def contexts(self):
"""Combine the head and the context dimension as it is passed forward in the model"""
return squeeze_contexts(self.raw_contexts)
@property
def normed_embeddings(self):
ens = tuple([torch.norm(e, dim=-1) for e in self.embeddings])
normed_es = tuple([e / en.unsqueeze(-1) for e, en in zip(self.embeddings, ens)])
return normed_es
@property
def normed_contexts(self):
"""Normalize each by head"""
cs = self.raw_contexts
cns = tuple([torch.norm(c, dim=-1) for c in cs])
normed_cs = tuple([c / cn.unsqueeze(-1) for c, cn in zip(cs, cns)])
squeezed_normed_cs = squeeze_contexts(normed_cs)
return squeezed_normed_cs
def to_json(self, layer:int, ndigits=5):
"""The original API expects the following response:
aa: {
att: number[][][]
left: <FullSingleTokenInfo[]>
right: <FullSingleTokenInfo[]>
}
FullSingleTokenInfo:
{
text: string
embeddings: number[]
contexts: number[]
bpe_token: string
bpe_pos: string
bpe_dep: string
bpe_is_ent: boolean
}
"""
# Convert the embeddings, attentions, and contexts into list. Perform rounding
rounder = partial(round, ndigits=ndigits)
nested_rounder = partial(map_nlist, rounder)
def tolist(tens): return [t.tolist() for t in tens]
def to_resp(tok: SimpleSpacyToken, embeddings: List[float], contexts: List[float], topk_words, topk_probs):
return {
"text": tok.token,
"bpe_token": tok.token,
"bpe_pos": tok.pos,
"bpe_dep": tok.dep,
"bpe_is_ent": tok.is_ent,
"embeddings": nested_rounder(embeddings),
"contexts": nested_rounder(contexts),
"topk_words": topk_words,
"topk_probs": nested_rounder(topk_probs)
}
side_info = [to_resp(t, e, c, w, p) for t,e,c,w,p in zip(
self.tokens,
tolist(self.embeddings[layer]),
tolist(self.contexts[layer]),
self.topk_words,
self.topk_probs)]
out = {"aa": {
"att": nested_rounder(tolist(self.attentions[layer])),
"left": side_info,
"right": side_info
}}
return out
def display_tokens(self, tokens):
return fix_byte_spaces(tokens)
def to_hdf5_meta(self):
"""Output metadata information to store as hdf5 metadata for a group"""
token_dtype = self.tokens[0].hdf5_token_dtype
out = {k: np.array([t[k] for t in self.tokens], dtype=np.dtype(dtype)) for k, dtype in token_dtype}
out['sentence'] = self.sentence
return out
def to_hdf5_content(self, do_norm=True):
"""Return dictionary of {attentions, embeddings, contexts} formatted as array for hdf5 file"""
def get_embeds(c):
if do_norm: return c.normed_embeddings
return c.embeddings
def get_contexts(c):
if do_norm: return c.normed_contexts
return c.contexts
embeddings = to_numpy(get_embeds(self))
contexts = to_numpy(get_contexts(self))
atts = to_numpy(self.attentions)
return {
"embeddings": embeddings,
"contexts": contexts,
"attentions": atts
}
@property
def searchable_embeddings(self):
return np.array(list(map(to_searchable, self.embeddings)))
@property
def searchable_contexts(self):
return np.array(list(map(to_searchable, self.contexts)))
def __repr__(self):
lim = 50
if len(self.sentence) > lim: s = self.sentence[:lim - 3] + "..."
else: s = self.sentence[:lim]
return f"TransformerOutput({s})"
def __len__(self):
return self.__len
def to_numpy(x):
"""Embeddings, contexts, and attentions are stored as torch.Tensors in a tuple. Convert this to a numpy array
for storage in hdf5"""
return np.array([x_.detach().numpy() for x_ in x])
def to_searchable(t: Tuple[torch.Tensor]):
return t.detach().numpy().astype(np.float32)