Spaces:
Runtime error
Runtime error
LennardZuendorf
commited on
Commit
•
67a34bd
1
Parent(s):
6ff516d
feat/fix: fixing attention bug, fixing other mistral bugs
Browse files- explanation/attention.py +33 -16
- explanation/interpret_captum.py +7 -2
- explanation/interpret_shap.py +6 -2
- explanation/markup.py +14 -11
- explanation/plotting.py +2 -2
- main.py +28 -17
- pyproject.toml +1 -0
- utils/formatting.py +31 -8
- utils/modelling.py +11 -0
explanation/attention.py
CHANGED
@@ -2,7 +2,8 @@
|
|
2 |
|
3 |
|
4 |
# internal imports
|
5 |
-
from utils import formatting as fmt
|
|
|
6 |
from .markup import markup_text
|
7 |
|
8 |
|
@@ -10,36 +11,52 @@ from .markup import markup_text
|
|
10 |
# and marked text based on attention
|
11 |
def chat_explained(model, prompt):
|
12 |
|
13 |
-
model.set_config({"return_dict": True})
|
14 |
-
|
15 |
# get encoded input
|
16 |
-
|
17 |
prompt, return_tensors="pt", add_special_tokens=True
|
18 |
).input_ids
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
)
|
23 |
|
24 |
# get input and output text as list of strings
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
decoder_text = fmt.format_tokens(
|
29 |
-
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
|
30 |
)
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
# format response text for clean output
|
35 |
-
response_text = fmt.format_output_text(
|
36 |
# setting placeholder for iFrame graphic
|
37 |
graphic = (
|
38 |
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
39 |
" Visualization doesn't support an interactive graphic.</h4></div>"
|
40 |
)
|
41 |
# creating marked text using markup_text function and attention
|
42 |
-
marked_text = markup_text(
|
43 |
|
44 |
# returning response, graphic and marked text array
|
45 |
return response_text, graphic, marked_text, None
|
|
|
2 |
|
3 |
|
4 |
# internal imports
|
5 |
+
from utils import formatting as fmt, modelling as mdl
|
6 |
+
from model import mistral
|
7 |
from .markup import markup_text
|
8 |
|
9 |
|
|
|
11 |
# and marked text based on attention
|
12 |
def chat_explained(model, prompt):
|
13 |
|
|
|
|
|
14 |
# get encoded input
|
15 |
+
input_ids = model.TOKENIZER(
|
16 |
prompt, return_tensors="pt", add_special_tokens=True
|
17 |
).input_ids
|
18 |
+
|
19 |
+
# generate output of the model
|
20 |
+
decoder_ids = model.MODEL.generate(input_ids, generation_config=model.CONFIG)
|
|
|
21 |
|
22 |
# get input and output text as list of strings
|
23 |
+
input_text = fmt.format_tokens(model.TOKENIZER.convert_ids_to_tokens(input_ids[0]))
|
24 |
+
output_text = fmt.format_tokens(
|
25 |
+
model.TOKENIZER.convert_ids_to_tokens(decoder_ids[0])
|
|
|
|
|
26 |
)
|
27 |
|
28 |
+
# checking if model is mistral
|
29 |
+
if type(model.MODEL) == type(mistral.MODEL):
|
30 |
+
|
31 |
+
# get attention values for the input vectors
|
32 |
+
attention_output = model.MODEL(input_ids, output_attentions=True).attentions
|
33 |
+
|
34 |
+
# averaging attention across layers and heads
|
35 |
+
attention_output = mdl.format_mistral_attention(attention_output)
|
36 |
+
averaged_attention = fmt.avg_attention(attention_output, model="mistral")
|
37 |
+
|
38 |
+
# attention visualization for godel
|
39 |
+
else:
|
40 |
+
# get attention values for the input and output vectors
|
41 |
+
# using already generated input and output
|
42 |
+
attention_output = model.MODEL(
|
43 |
+
input_ids=input_ids,
|
44 |
+
decoder_input_ids=decoder_ids,
|
45 |
+
output_attentions=True,
|
46 |
+
)
|
47 |
+
|
48 |
+
# averaging attention across layers
|
49 |
+
averaged_attention = fmt.avg_attention(attention_output, model="godel")
|
50 |
|
51 |
# format response text for clean output
|
52 |
+
response_text = fmt.format_output_text(output_text)
|
53 |
# setting placeholder for iFrame graphic
|
54 |
graphic = (
|
55 |
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
56 |
" Visualization doesn't support an interactive graphic.</h4></div>"
|
57 |
)
|
58 |
# creating marked text using markup_text function and attention
|
59 |
+
marked_text = markup_text(input_text, averaged_attention, variant="visualizer")
|
60 |
|
61 |
# returning response, graphic and marked text array
|
62 |
return response_text, graphic, marked_text, None
|
explanation/interpret_captum.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
|
5 |
# internal imports
|
6 |
from utils import formatting as fmt
|
|
|
7 |
from .markup import markup_text
|
8 |
|
9 |
|
@@ -26,7 +27,7 @@ def cpt_extract_seq_att(attr):
|
|
26 |
def chat_explained(model, prompt):
|
27 |
model.set_config({})
|
28 |
|
29 |
-
# creating llm attribution class with KernelSHAP and
|
30 |
llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
|
31 |
|
32 |
# generation attribution
|
@@ -48,7 +49,11 @@ def chat_explained(model, prompt):
|
|
48 |
graphic = """<div style='text-align: center; font-family:arial;'><h4>
|
49 |
Intepretation with Captum doesn't support an interactive graphic.</h4></div>
|
50 |
"""
|
|
|
51 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
52 |
|
|
|
|
|
|
|
53 |
# return response, graphic and marked_text array
|
54 |
-
return response_text, graphic, marked_text,
|
|
|
4 |
|
5 |
# internal imports
|
6 |
from utils import formatting as fmt
|
7 |
+
from .plotting import plot_seq
|
8 |
from .markup import markup_text
|
9 |
|
10 |
|
|
|
27 |
def chat_explained(model, prompt):
|
28 |
model.set_config({})
|
29 |
|
30 |
+
# creating llm attribution class with KernelSHAP and Mistral Model, Tokenizer
|
31 |
llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
|
32 |
|
33 |
# generation attribution
|
|
|
49 |
graphic = """<div style='text-align: center; font-family:arial;'><h4>
|
50 |
Intepretation with Captum doesn't support an interactive graphic.</h4></div>
|
51 |
"""
|
52 |
+
# create the explanation marked text array
|
53 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
54 |
|
55 |
+
# creating sequence attribution plot
|
56 |
+
plot = plot_seq(cpt_extract_seq_att(attribution_result), "KernelSHAP")
|
57 |
+
|
58 |
# return response, graphic and marked_text array
|
59 |
+
return response_text, graphic, marked_text, plot
|
explanation/interpret_shap.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
|
7 |
# internal imports
|
8 |
from utils import formatting as fmt
|
|
|
9 |
from .markup import markup_text
|
10 |
|
11 |
# global variables
|
@@ -14,7 +15,7 @@ TEXT_MASKER = None
|
|
14 |
|
15 |
|
16 |
# function to extract summarized sequence wise attribution
|
17 |
-
def
|
18 |
|
19 |
# extracting summed up shap values
|
20 |
values = fmt.flatten_attribution(shap_values.values[0], 1)
|
@@ -78,5 +79,8 @@ def chat_explained(model, prompt):
|
|
78 |
# create the response text
|
79 |
response_text = fmt.format_output_text(shap_values.output_names)
|
80 |
|
|
|
|
|
|
|
81 |
# return response, graphic and marked_text array
|
82 |
-
return response_text, graphic, marked_text,
|
|
|
6 |
|
7 |
# internal imports
|
8 |
from utils import formatting as fmt
|
9 |
+
from .plotting import plot_seq
|
10 |
from .markup import markup_text
|
11 |
|
12 |
# global variables
|
|
|
15 |
|
16 |
|
17 |
# function to extract summarized sequence wise attribution
|
18 |
+
def shap_extract_seq_att(shap_values):
|
19 |
|
20 |
# extracting summed up shap values
|
21 |
values = fmt.flatten_attribution(shap_values.values[0], 1)
|
|
|
79 |
# create the response text
|
80 |
response_text = fmt.format_output_text(shap_values.output_names)
|
81 |
|
82 |
+
# creating sequence attribution plot
|
83 |
+
plot = plot_seq(shap_extract_seq_att(shap_values), "PartitionSHAP")
|
84 |
+
|
85 |
# return response, graphic and marked_text array
|
86 |
+
return response_text, graphic, marked_text, plot
|
explanation/markup.py
CHANGED
@@ -25,12 +25,12 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
25 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
26 |
|
27 |
# separate the threshold calculation for negative and positive values
|
28 |
-
# visualization negative thresholds are all 0 since
|
29 |
if variant == "visualizer":
|
30 |
neg_thresholds = np.linspace(
|
31 |
0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
32 |
)[1:]
|
33 |
-
#
|
34 |
else:
|
35 |
neg_thresholds = np.linspace(
|
36 |
min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
@@ -45,16 +45,19 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
45 |
|
46 |
# looping over each text snippet and attribution value
|
47 |
for text, value in zip(input_text, text_values):
|
48 |
-
# setting inital bucket at lowest
|
49 |
-
bucket = "-5"
|
50 |
|
51 |
-
#
|
52 |
-
|
53 |
-
#
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
# returning list of marked text snippets as list of tuples
|
60 |
return marked_text
|
|
|
25 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
26 |
|
27 |
# separate the threshold calculation for negative and positive values
|
28 |
+
# visualization negative thresholds are all 0 since attention always positive
|
29 |
if variant == "visualizer":
|
30 |
neg_thresholds = np.linspace(
|
31 |
0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
32 |
)[1:]
|
33 |
+
# standard config for 5 negative buckets
|
34 |
else:
|
35 |
neg_thresholds = np.linspace(
|
36 |
min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
|
|
45 |
|
46 |
# looping over each text snippet and attribution value
|
47 |
for text, value in zip(input_text, text_values):
|
|
|
|
|
48 |
|
49 |
+
# validating text and skipping empty text/special tokens
|
50 |
+
if text not in ("", fmt.SPECIAL_TOKENS):
|
51 |
+
# setting initial bucket at lowest
|
52 |
+
bucket = "-5"
|
53 |
+
|
54 |
+
# looping over all bucket and their threshold
|
55 |
+
for i, threshold in zip(bucket_tags, thresholds):
|
56 |
+
# updating assigned bucket if value is above threshold
|
57 |
+
if value >= threshold:
|
58 |
+
bucket = i
|
59 |
+
# finally adding text and bucket assignment to list of tuples
|
60 |
+
marked_text.append((text, str(bucket)))
|
61 |
|
62 |
# returning list of marked text snippets as list of tuples
|
63 |
return marked_text
|
explanation/plotting.py
CHANGED
@@ -5,7 +5,7 @@ import numpy as np
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
|
7 |
|
8 |
-
def plot_seq(seq_values: list,
|
9 |
|
10 |
# Separate the tokens and their corresponding importance values
|
11 |
tokens, importance = zip(*seq_values)
|
@@ -45,7 +45,7 @@ def plot_seq(seq_values: list, method_model: tuple = ("", "")):
|
|
45 |
)
|
46 |
|
47 |
plt.axhline(0, color="black", linewidth=1)
|
48 |
-
plt.title(f"Input Token Attribution with {
|
49 |
plt.xlabel("Input Tokens", labelpad=0.5)
|
50 |
plt.ylabel("Attribution")
|
51 |
plt.xticks(x_positions, tokens, rotation=45)
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
|
7 |
|
8 |
+
def plot_seq(seq_values: list, method: str = ""):
|
9 |
|
10 |
# Separate the tokens and their corresponding importance values
|
11 |
tokens, importance = zip(*seq_values)
|
|
|
45 |
)
|
46 |
|
47 |
plt.axhline(0, color="black", linewidth=1)
|
48 |
+
plt.title(f"Input Token Attribution with {method}")
|
49 |
plt.xlabel("Input Tokens", labelpad=0.5)
|
50 |
plt.ylabel("Attribution")
|
51 |
plt.xticks(x_positions, tokens, rotation=45)
|
main.py
CHANGED
@@ -155,7 +155,7 @@ with gr.Blocks(
|
|
155 |
The explanations are based on 10 buckets that range between the
|
156 |
lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
|
157 |
**The legend shows the color for each bucket.**
|
158 |
-
|
159 |
*HINT*: This works best in light mode.
|
160 |
""")
|
161 |
xai_text = gr.HighlightedText(
|
@@ -210,12 +210,34 @@ with gr.Blocks(
|
|
210 |
gr.Examples(
|
211 |
label="Example Questions",
|
212 |
examples=[
|
213 |
-
["Does money buy happiness?", "", "Mistral", "
|
214 |
-
["Does money buy happiness?", "", "Mistral", "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
],
|
216 |
inputs=[
|
217 |
user_prompt,
|
218 |
knowledge_input,
|
|
|
219 |
model_selection,
|
220 |
xai_selection,
|
221 |
],
|
@@ -227,32 +249,21 @@ with gr.Blocks(
|
|
227 |
label="Example Questions",
|
228 |
examples=[
|
229 |
[
|
230 |
-
"
|
231 |
(
|
232 |
"Black holes are created when a massive star's core"
|
233 |
" collapses after a supernova, forming an object with"
|
234 |
" gravity so intense that even light cannot escape."
|
235 |
),
|
|
|
236 |
"GODEL",
|
237 |
"SHAP",
|
238 |
],
|
239 |
-
[
|
240 |
-
(
|
241 |
-
"Explain the importance of the Rosetta Stone in"
|
242 |
-
" understanding ancient languages."
|
243 |
-
),
|
244 |
-
(
|
245 |
-
"The Rosetta Stone, an ancient Egyptian artifact, was"
|
246 |
-
" key in decoding hieroglyphs, featuring the same text"
|
247 |
-
" in three scripts: hieroglyphs, Demotic, and Greek."
|
248 |
-
),
|
249 |
-
"GODEL",
|
250 |
-
"Attention",
|
251 |
-
],
|
252 |
],
|
253 |
inputs=[
|
254 |
user_prompt,
|
255 |
knowledge_input,
|
|
|
256 |
model_selection,
|
257 |
xai_selection,
|
258 |
],
|
|
|
155 |
The explanations are based on 10 buckets that range between the
|
156 |
lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
|
157 |
**The legend shows the color for each bucket.**
|
158 |
+
|
159 |
*HINT*: This works best in light mode.
|
160 |
""")
|
161 |
xai_text = gr.HighlightedText(
|
|
|
210 |
gr.Examples(
|
211 |
label="Example Questions",
|
212 |
examples=[
|
213 |
+
["Does money buy happiness?", "", "", "Mistral", "None"],
|
214 |
+
["Does money buy happiness?", "", "", "Mistral", "SHAP"],
|
215 |
+
["Does money buy happiness?", "", "", "Mistral", "Attention"],
|
216 |
+
[
|
217 |
+
"Does money buy happiness?",
|
218 |
+
"",
|
219 |
+
(
|
220 |
+
"Respond from the perspective of a billionaire enjoying"
|
221 |
+
" life in Dubai"
|
222 |
+
),
|
223 |
+
"Mistral",
|
224 |
+
"None",
|
225 |
+
],
|
226 |
+
[
|
227 |
+
"Does money buy happiness?",
|
228 |
+
"",
|
229 |
+
(
|
230 |
+
"Respond from the perspective of a billionaire enjoying"
|
231 |
+
" life in Dubai"
|
232 |
+
),
|
233 |
+
"Mistral",
|
234 |
+
"SHAP",
|
235 |
+
],
|
236 |
],
|
237 |
inputs=[
|
238 |
user_prompt,
|
239 |
knowledge_input,
|
240 |
+
system_prompt,
|
241 |
model_selection,
|
242 |
xai_selection,
|
243 |
],
|
|
|
249 |
label="Example Questions",
|
250 |
examples=[
|
251 |
[
|
252 |
+
"Does money buy happiness?",
|
253 |
(
|
254 |
"Black holes are created when a massive star's core"
|
255 |
" collapses after a supernova, forming an object with"
|
256 |
" gravity so intense that even light cannot escape."
|
257 |
),
|
258 |
+
"",
|
259 |
"GODEL",
|
260 |
"SHAP",
|
261 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
],
|
263 |
inputs=[
|
264 |
user_prompt,
|
265 |
knowledge_input,
|
266 |
+
system_prompt,
|
267 |
model_selection,
|
268 |
xai_selection,
|
269 |
],
|
pyproject.toml
CHANGED
@@ -21,6 +21,7 @@ exclude = '''
|
|
21 |
|
22 |
[tool.pylint.messages_control]
|
23 |
disable = [
|
|
|
24 |
"not-a-mapping",
|
25 |
"arguments-differ",
|
26 |
"attribute-defined-outside-init",
|
|
|
21 |
|
22 |
[tool.pylint.messages_control]
|
23 |
disable = [
|
24 |
+
"unidiomatic-typecheck",
|
25 |
"not-a-mapping",
|
26 |
"arguments-differ",
|
27 |
"attribute-defined-outside-init",
|
utils/formatting.py
CHANGED
@@ -2,12 +2,31 @@
|
|
2 |
|
3 |
# external imports
|
4 |
import re
|
|
|
5 |
import numpy as np
|
6 |
from numpy import ndarray
|
7 |
|
8 |
|
9 |
-
#
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def format_output_text(output: list):
|
12 |
|
13 |
# remove special tokens from list using other function
|
@@ -36,8 +55,6 @@ def format_output_text(output: list):
|
|
36 |
|
37 |
# format the tokens by removing special tokens and special characters
|
38 |
def format_tokens(tokens: list):
|
39 |
-
# define special tokens to remove
|
40 |
-
special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "▁", "Ġ", "</w>"]
|
41 |
|
42 |
# initialize empty list
|
43 |
updated_tokens = []
|
@@ -49,7 +66,7 @@ def format_tokens(tokens: list):
|
|
49 |
t = t.lstrip("▁")
|
50 |
|
51 |
# loop through special tokens list and remove from current token if matched
|
52 |
-
for s in
|
53 |
t = t.replace(s, "")
|
54 |
|
55 |
# add token to list
|
@@ -70,6 +87,12 @@ def flatten_attention(values: ndarray, axis: int = 0):
|
|
70 |
|
71 |
|
72 |
# function to get averaged decoder attention from attention values
|
73 |
-
def avg_attention(attention_values):
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
# external imports
|
4 |
import re
|
5 |
+
import torch
|
6 |
import numpy as np
|
7 |
from numpy import ndarray
|
8 |
|
9 |
|
10 |
+
# globally defined tokens that are removed from the output
|
11 |
+
SPECIAL_TOKENS = [
|
12 |
+
"[CLS]",
|
13 |
+
"[SEP]",
|
14 |
+
"[PAD]",
|
15 |
+
"[UNK]",
|
16 |
+
"[MASK]",
|
17 |
+
"▁",
|
18 |
+
"Ġ",
|
19 |
+
"</w>",
|
20 |
+
"<0x0A>",
|
21 |
+
"<0x0D>",
|
22 |
+
"<0x09>",
|
23 |
+
"<s>",
|
24 |
+
"</s>",
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
# function to format the model repose nicely
|
29 |
+
# takes a list of strings and returning a combined string
|
30 |
def format_output_text(output: list):
|
31 |
|
32 |
# remove special tokens from list using other function
|
|
|
55 |
|
56 |
# format the tokens by removing special tokens and special characters
|
57 |
def format_tokens(tokens: list):
|
|
|
|
|
58 |
|
59 |
# initialize empty list
|
60 |
updated_tokens = []
|
|
|
66 |
t = t.lstrip("▁")
|
67 |
|
68 |
# loop through special tokens list and remove from current token if matched
|
69 |
+
for s in SPECIAL_TOKENS:
|
70 |
t = t.replace(s, "")
|
71 |
|
72 |
# add token to list
|
|
|
87 |
|
88 |
|
89 |
# function to get averaged decoder attention from attention values
|
90 |
+
def avg_attention(attention_values, model: str):
|
91 |
+
# check if model is godel
|
92 |
+
if model == "godel":
|
93 |
+
# get attention values for the input and output vectors
|
94 |
+
attention = attention_values.decoder_attentions[0][0].detach().numpy()
|
95 |
+
return np.mean(attention, axis=0)
|
96 |
+
# extracting attention values for mistral
|
97 |
+
attention_np = attention_values.to(torch.device("cpu")).detach().numpy()
|
98 |
+
return np.mean(attention_np, axis=(0, 1, 2))
|
utils/modelling.py
CHANGED
@@ -97,3 +97,14 @@ def gpu_loading_config(max_memory: str = "15000MB"):
|
|
97 |
)
|
98 |
|
99 |
return n_gpus, max_memory, bnb_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
)
|
98 |
|
99 |
return n_gpus, max_memory, bnb_config
|
100 |
+
|
101 |
+
|
102 |
+
# formatting mistral attention values
|
103 |
+
# CREDIT: copied and adapted from BERTViz
|
104 |
+
# see https://github.com/jessevig/bertviz
|
105 |
+
def format_mistral_attention(attention_values):
|
106 |
+
squeezed = []
|
107 |
+
for layer_attention in attention_values:
|
108 |
+
layer_attention = layer_attention.squeeze(0)
|
109 |
+
squeezed.append(layer_attention)
|
110 |
+
return torch.stack(squeezed)
|