thesis / utils /formatting.py
LennardZuendorf's picture
fix/chore: final fix of attention, preperating for last release
bf15c20 unverified
# formatting util module providing formatting functions for the model input and output
# external imports
import re
import torch
import numpy as np
from numpy import ndarray
# globally defined tokens that are removed from the output
SPECIAL_TOKENS = [
"[CLS]",
"[SEP]",
"[PAD]",
"[UNK]",
"[MASK]",
"▁",
"Ġ",
"</w>",
"<0x0A>",
"<0x0D>",
"<0x09>",
"<s>",
"</s>",
]
# function to format the model repose nicely
# takes a list of strings and returning a combined string
def format_output_text(output: list):
# remove special tokens from list using other function
formatted_output = format_tokens(output)
# start string with first list item if it is not empty
if formatted_output[0] != "":
output_str = formatted_output[0]
else:
# alternatively start with second list item
output_str = formatted_output[1]
# add all other list items with a space in between
for txt in formatted_output[1:]:
# check if the token is a punctuation mark or other special character
if txt in [
".",
",",
"!",
"?",
":",
";",
")",
"]",
"}",
"'",
'"',
"[",
"{",
"(",
"<",
]:
# add punctuation mark without space
output_str += txt
# add token with space if not empty
elif txt != "":
output_str += " " + txt
# return the combined string with multiple spaces removed
return re.sub(r"\s+", " ", output_str)
# format the tokens by removing special tokens and special characters
def format_tokens(tokens: list):
# initialize empty list
updated_tokens = []
# loop through tokens
for t in tokens:
# remove special token from start of token if found
if t.startswith("▁"):
t = t.lstrip("▁")
# loop through special tokens list and remove from current token if matched
for s in SPECIAL_TOKENS:
t = t.replace(s, "")
# add token to list
updated_tokens.append(t)
# return the list of tokens
return updated_tokens
# function to flatten shap values into a 2d list by summing them up
def flatten_attribution(values: ndarray, axis: int = 0):
return np.sum(values, axis=axis)
# function to flatten values into a 2d list by averaging the attention values
def flatten_attention(values: ndarray, axis: int = 0):
return np.mean(values, axis=axis)
# function to get averaged decoder attention from attention values
def avg_attention(attention_values, model: str):
# check if model is godel
if model == "godel":
# get attention values for the input and output vectors
attention = attention_values.encoder_attentions[0][0].detach().numpy()
return np.mean(attention, axis=1)
# extracting attention values for mistral
attention = attention_values.to(torch.device("cpu")).detach().numpy()
# removing the last dimension and transposing to get the correct shape
attention = attention[:, :, :, 0]
# return the averaged attention values
return np.mean(attention, axis=1)