Spaces:
Runtime error
Runtime error
File size: 3,061 Bytes
2492536 d2116db 2492536 d2116db 2492536 58a02af d2116db 2492536 f5ebee7 d2116db f5ebee7 5d99c07 f5ebee7 d2116db 2492536 d2116db 58a02af 2492536 d2116db 58a02af 2492536 58a02af 2492536 58a02af d2116db 2492536 d2116db 2492536 d2116db 2492536 58a02af 2492536 58a02af 2492536 58a02af d2116db 2492536 d2116db 2492536 d2116db 2492536 d2116db 2492536 58a02af d2116db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# markup module that provides marked up text as an array
# external imports
import numpy as np
from numpy import ndarray
# internal imports
from utils import formatting as fmt
# main function that assigns each text snipped a marked bucket
def markup_text(input_text: list, text_values: ndarray, variant: str):
# naming of the 11 buckets
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
# flatten the values depending on the source
# attention is averaged, SHAP summed up
if variant == "shap":
text_values = np.transpose(text_values)
text_values = fmt.flatten_attribution(text_values)
elif variant == "visualizer":
text_values = fmt.flatten_attention(text_values)
# Determine the minimum and maximum values
min_val, max_val = np.min(text_values), np.max(text_values)
# separate the threshold calculation for negative and positive values
# visualization negative thresholds are all 0 since attetion always positive
if variant == "visualizer":
neg_thresholds = np.linspace(
0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
)[1:]
# standart config for 5 negative buckets
else:
neg_thresholds = np.linspace(
min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
)[1:]
# creating positive thresholds between 0 and max values
pos_thresholds = np.linspace(0, max_val, num=(len(bucket_tags) - 1) // 2 + 1)[1:]
# combining thresholds
thresholds = np.concatenate([neg_thresholds, [0], pos_thresholds])
# init empty marked text list
marked_text = []
# looping over each text snippet and attribution value
for text, value in zip(input_text, text_values):
# setting inital bucket at lowest
bucket = "-5"
# looping over all bucket and their threshold
for i, threshold in zip(bucket_tags, thresholds):
# updating assigned bucket if value is above threshold
if value >= threshold:
bucket = i
# finally adding text and bucket assignment to list of tuples
marked_text.append((text, str(bucket)))
# returning list of marked text snippets as list of tuples
return marked_text
# function that defines color codes
# coloring along SHAP style coloring for consistency
def color_codes():
return {
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
# 0: white (assuming default light mode)
# +1 to +5 light pink to string magenta
"-5": "#3251a8", # Strong Light Sky Blue
"-4": "#5A7FB2", # Slightly Lighter Sky Blue
"-3": "#8198BC", # Intermediate Sky Blue
"-2": "#A8B1C6", # Light Sky Blue
"-1": "#E6F0FF", # Very Light Sky Blue
"0": "#FFFFFF", # White
"+1": "#FFE6F0", # Lighter Pink
"+2": "#DF8CA3", # Slightly Stronger Pink
"+3": "#D7708E", # Intermediate Pink
"+4": "#CF5480", # Deep Pink
"+5": "#A83273", # Strong Magenta
}
|