Spaces:
Runtime error
Runtime error
# formatting util module providing formatting functions for the model input and output | |
# external imports | |
import re | |
import torch | |
import numpy as np | |
from numpy import ndarray | |
# globally defined tokens that are removed from the output | |
SPECIAL_TOKENS = [ | |
"[CLS]", | |
"[SEP]", | |
"[PAD]", | |
"[UNK]", | |
"[MASK]", | |
"▁", | |
"Ġ", | |
"</w>", | |
"<0x0A>", | |
"<0x0D>", | |
"<0x09>", | |
"<s>", | |
"</s>", | |
] | |
# function to format the model repose nicely | |
# takes a list of strings and returning a combined string | |
def format_output_text(output: list): | |
# remove special tokens from list using other function | |
formatted_output = format_tokens(output) | |
# start string with first list item if it is not empty | |
if formatted_output[0] != "": | |
output_str = formatted_output[0] | |
else: | |
# alternatively start with second list item | |
output_str = formatted_output[1] | |
# add all other list items with a space in between | |
for txt in formatted_output[1:]: | |
# check if the token is a punctuation mark or other special character | |
if txt in [ | |
".", | |
",", | |
"!", | |
"?", | |
":", | |
";", | |
")", | |
"]", | |
"}", | |
"'", | |
'"', | |
"[", | |
"{", | |
"(", | |
"<", | |
]: | |
# add punctuation mark without space | |
output_str += txt | |
# add token with space if not empty | |
elif txt != "": | |
output_str += " " + txt | |
# return the combined string with multiple spaces removed | |
return re.sub(r"\s+", " ", output_str) | |
# format the tokens by removing special tokens and special characters | |
def format_tokens(tokens: list): | |
# initialize empty list | |
updated_tokens = [] | |
# loop through tokens | |
for t in tokens: | |
# remove special token from start of token if found | |
if t.startswith("▁"): | |
t = t.lstrip("▁") | |
# loop through special tokens list and remove from current token if matched | |
for s in SPECIAL_TOKENS: | |
t = t.replace(s, "") | |
# add token to list | |
updated_tokens.append(t) | |
# return the list of tokens | |
return updated_tokens | |
# function to flatten shap values into a 2d list by summing them up | |
def flatten_attribution(values: ndarray, axis: int = 0): | |
return np.sum(values, axis=axis) | |
# function to flatten values into a 2d list by averaging the attention values | |
def flatten_attention(values: ndarray, axis: int = 0): | |
return np.mean(values, axis=axis) | |
# function to get averaged decoder attention from attention values | |
def avg_attention(attention_values, model: str): | |
# check if model is godel | |
if model == "godel": | |
# get attention values for the input and output vectors | |
attention = attention_values.encoder_attentions[0][0].detach().numpy() | |
return np.mean(attention, axis=1) | |
# extracting attention values for mistral | |
attention = attention_values.to(torch.device("cpu")).detach().numpy() | |
# removing the last dimension and transposing to get the correct shape | |
attention = attention[:, :, :, 0] | |
# return the averaged attention values | |
return np.mean(attention, axis=1) | |