Spaces:
Runtime error
Runtime error
from typing import List, Iterable, Tuple | |
from functools import partial | |
import numpy as np | |
import torch | |
import json | |
from spacyface.simple_spacy_token import SimpleSpacyToken | |
from utils.token_processing import fix_byte_spaces | |
from utils.gen_utils import map_nlist | |
def round_return_value(attentions, ndigits=5): | |
"""Rounding must happen right before it's passed back to the frontend because there is a little numerical error that's introduced converting back to lists | |
attentions: { | |
'aa': { | |
left.embeddings & contexts | |
right.embeddings & contexts | |
att | |
} | |
} | |
""" | |
rounder = partial(round, ndigits=ndigits) | |
nested_rounder = partial(map_nlist, rounder) | |
new_out = attentions # Modify values to save memory | |
new_out["aa"]["att"] = nested_rounder(attentions["aa"]["att"]) | |
new_out["aa"]["left"]["embeddings"] = nested_rounder( | |
attentions["aa"]["left"]["embeddings"] | |
) | |
new_out["aa"]["left"]["contexts"] = nested_rounder( | |
attentions["aa"]["left"]["contexts"] | |
) | |
new_out["aa"]["right"]["embeddings"] = nested_rounder( | |
attentions["aa"]["right"]["embeddings"] | |
) | |
new_out["aa"]["right"]["contexts"] = nested_rounder( | |
attentions["aa"]["right"]["contexts"] | |
) | |
return new_out | |
def flatten_batch(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: | |
"""Remove the batch dimension of every tensor inside the Iterable container `x`""" | |
return tuple([x_.squeeze(0) for x_ in x]) | |
def squeeze_contexts(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: | |
"""Combine the last two dimensions of the context.""" | |
shape = x[0].shape | |
new_shape = shape[:-2] + (-1,) | |
return tuple([x_.view(new_shape) for x_ in x]) | |
def add_blank(xs: Tuple[torch.tensor]) -> Tuple[torch.Tensor]: | |
"""The embeddings have n_layers + 1, indicating the final output embedding.""" | |
return (torch.zeros_like(xs[0]),) + xs | |
class TransformerOutputFormatter: | |
def __init__( | |
self, | |
sentence: str, | |
tokens: List[SimpleSpacyToken], | |
special_tokens_mask: List[int], | |
att: Tuple[torch.Tensor], | |
embeddings: Tuple[torch.Tensor], | |
contexts: Tuple[torch.Tensor], | |
topk_words: List[List[str]], | |
topk_probs: List[List[float]] | |
): | |
assert len(tokens) > 0, "Cannot have an empty token output!" | |
modified_embeddings = flatten_batch(embeddings) | |
modified_att = flatten_batch(att) | |
modified_contexts = flatten_batch(contexts) | |
self.sentence = sentence | |
self.tokens = tokens | |
self.special_tokens_mask = special_tokens_mask | |
self.embeddings = modified_embeddings | |
self.attentions = modified_att | |
self.raw_contexts = modified_contexts | |
self.topk_words = topk_words | |
self.topk_probs = topk_probs | |
self.n_layers = len(contexts) # With +1 for buffer layer at the beginning | |
_, self.__len, self.n_heads, self.hidden_dim = contexts[0].shape | |
def contexts(self): | |
"""Combine the head and the context dimension as it is passed forward in the model""" | |
return squeeze_contexts(self.raw_contexts) | |
def normed_embeddings(self): | |
ens = tuple([torch.norm(e, dim=-1) for e in self.embeddings]) | |
normed_es = tuple([e / en.unsqueeze(-1) for e, en in zip(self.embeddings, ens)]) | |
return normed_es | |
def normed_contexts(self): | |
"""Normalize each by head""" | |
cs = self.raw_contexts | |
cns = tuple([torch.norm(c, dim=-1) for c in cs]) | |
normed_cs = tuple([c / cn.unsqueeze(-1) for c, cn in zip(cs, cns)]) | |
squeezed_normed_cs = squeeze_contexts(normed_cs) | |
return squeezed_normed_cs | |
def to_json(self, layer:int, ndigits=5): | |
"""The original API expects the following response: | |
aa: { | |
att: number[][][] | |
left: <FullSingleTokenInfo[]> | |
right: <FullSingleTokenInfo[]> | |
} | |
FullSingleTokenInfo: | |
{ | |
text: string | |
embeddings: number[] | |
contexts: number[] | |
bpe_token: string | |
bpe_pos: string | |
bpe_dep: string | |
bpe_is_ent: boolean | |
} | |
""" | |
# Convert the embeddings, attentions, and contexts into list. Perform rounding | |
rounder = partial(round, ndigits=ndigits) | |
nested_rounder = partial(map_nlist, rounder) | |
def tolist(tens): return [t.tolist() for t in tens] | |
def to_resp(tok: SimpleSpacyToken, embeddings: List[float], contexts: List[float], topk_words, topk_probs): | |
return { | |
"text": tok.token, | |
"bpe_token": tok.token, | |
"bpe_pos": tok.pos, | |
"bpe_dep": tok.dep, | |
"bpe_is_ent": tok.is_ent, | |
"embeddings": nested_rounder(embeddings), | |
"contexts": nested_rounder(contexts), | |
"topk_words": topk_words, | |
"topk_probs": nested_rounder(topk_probs) | |
} | |
side_info = [to_resp(t, e, c, w, p) for t,e,c,w,p in zip( | |
self.tokens, | |
tolist(self.embeddings[layer]), | |
tolist(self.contexts[layer]), | |
self.topk_words, | |
self.topk_probs)] | |
out = {"aa": { | |
"att": nested_rounder(tolist(self.attentions[layer])), | |
"left": side_info, | |
"right": side_info | |
}} | |
return out | |
def display_tokens(self, tokens): | |
return fix_byte_spaces(tokens) | |
def to_hdf5_meta(self): | |
"""Output metadata information to store as hdf5 metadata for a group""" | |
token_dtype = self.tokens[0].hdf5_token_dtype | |
out = {k: np.array([t[k] for t in self.tokens], dtype=np.dtype(dtype)) for k, dtype in token_dtype} | |
out['sentence'] = self.sentence | |
return out | |
def to_hdf5_content(self, do_norm=True): | |
"""Return dictionary of {attentions, embeddings, contexts} formatted as array for hdf5 file""" | |
def get_embeds(c): | |
if do_norm: return c.normed_embeddings | |
return c.embeddings | |
def get_contexts(c): | |
if do_norm: return c.normed_contexts | |
return c.contexts | |
embeddings = to_numpy(get_embeds(self)) | |
contexts = to_numpy(get_contexts(self)) | |
atts = to_numpy(self.attentions) | |
return { | |
"embeddings": embeddings, | |
"contexts": contexts, | |
"attentions": atts | |
} | |
def searchable_embeddings(self): | |
return np.array(list(map(to_searchable, self.embeddings))) | |
def searchable_contexts(self): | |
return np.array(list(map(to_searchable, self.contexts))) | |
def __repr__(self): | |
lim = 50 | |
if len(self.sentence) > lim: s = self.sentence[:lim - 3] + "..." | |
else: s = self.sentence[:lim] | |
return f"TransformerOutput({s})" | |
def __len__(self): | |
return self.__len | |
def to_numpy(x): | |
"""Embeddings, contexts, and attentions are stored as torch.Tensors in a tuple. Convert this to a numpy array | |
for storage in hdf5""" | |
return np.array([x_.detach().numpy() for x_ in x]) | |
def to_searchable(t: Tuple[torch.Tensor]): | |
return t.detach().numpy().astype(np.float32) |