Spaces:
Runtime error
Runtime error
LennardZuendorf
commited on
Commit
•
d4dd3c5
1
Parent(s):
229e14c
feat/fix: several minor fixes and additions
Browse files- explanation/interpret_captum.py +55 -0
- explanation/interpret_shap.py +82 -0
- explanation/plotting.py +58 -0
- main.py +44 -38
- model/mistral.py +1 -1
- requirements.txt +0 -1
explanation/interpret_captum.py
CHANGED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# external imports
|
2 |
+
from captum.attr import LLMAttribution, TextTokenInput, KernelShap
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# internal imports
|
6 |
+
from utils import formatting as fmt
|
7 |
+
from .markup import markup_text
|
8 |
+
|
9 |
+
|
10 |
+
# function to extract sequence attribution
|
11 |
+
def cpt_extract_seq_att(attr):
|
12 |
+
|
13 |
+
# getting values from captum
|
14 |
+
values = attr.seq_attr.to(torch.device("cpu")).numpy()
|
15 |
+
|
16 |
+
# format the input tokens nicely and check for mismatch
|
17 |
+
input_tokens = fmt.format_tokens(attr.input_tokens)
|
18 |
+
if len(attr.input_tokens) != len(values):
|
19 |
+
raise RuntimeError("values and input len mismatch")
|
20 |
+
|
21 |
+
# return a list of tuples with token and value
|
22 |
+
return list(zip(input_tokens, values))
|
23 |
+
|
24 |
+
|
25 |
+
# main explain function that returns a chat with explanations
|
26 |
+
def chat_explained(model, prompt):
|
27 |
+
model.set_config({})
|
28 |
+
|
29 |
+
# creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
|
30 |
+
llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
|
31 |
+
|
32 |
+
# generation attribution
|
33 |
+
attribution_input = TextTokenInput(prompt, model.TOKENIZER)
|
34 |
+
attribution_result = llm_attribution.attribute(
|
35 |
+
attribution_input, gen_args=model.CONFIG.to_dict()
|
36 |
+
)
|
37 |
+
|
38 |
+
# extracting values and input tokens
|
39 |
+
values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
|
40 |
+
input_tokens = fmt.format_tokens(attribution_result.input_tokens)
|
41 |
+
|
42 |
+
# raising error if mismatch occurs
|
43 |
+
if len(attribution_result.input_tokens) != len(values):
|
44 |
+
raise RuntimeError("values and input len mismatch")
|
45 |
+
|
46 |
+
# getting response text, graphic placeholder and marked text object
|
47 |
+
response_text = fmt.format_output_text(attribution_result.output_tokens)
|
48 |
+
graphic = (
|
49 |
+
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
50 |
+
"Intepretation with Captum doesn't support an interactive graphic.</h4></div>"
|
51 |
+
)
|
52 |
+
marked_text = markup_text(input_tokens, values, variant="captum")
|
53 |
+
|
54 |
+
# return response, graphic and marked_text array
|
55 |
+
return response_text, graphic, marked_text
|
explanation/interpret_shap.py
CHANGED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# interpret module that implements the interpretability method
|
2 |
+
|
3 |
+
# external imports
|
4 |
+
from shap import models, maskers, plots, PartitionExplainer
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# internal imports
|
8 |
+
from utils import formatting as fmt
|
9 |
+
from .markup import markup_text
|
10 |
+
|
11 |
+
# global variables
|
12 |
+
TEACHER_FORCING = None
|
13 |
+
TEXT_MASKER = None
|
14 |
+
|
15 |
+
|
16 |
+
# function to extract summarized sequence wise attribution
|
17 |
+
def extract_seq_att(shap_values):
|
18 |
+
|
19 |
+
# extracting summed up shap values
|
20 |
+
values = fmt.flatten_attribution(shap_values.values[0], 1)
|
21 |
+
|
22 |
+
# returning list of tuples of token and value
|
23 |
+
return list(zip(shap_values.data[0], values))
|
24 |
+
|
25 |
+
|
26 |
+
# main explain function that returns a chat with explanations
|
27 |
+
def chat_explained(model, prompt):
|
28 |
+
model.set_config({})
|
29 |
+
|
30 |
+
# create the shap explainer
|
31 |
+
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
32 |
+
|
33 |
+
# get the shap values for the prompt
|
34 |
+
shap_values = shap_explainer([prompt])
|
35 |
+
|
36 |
+
# create the explanation graphic and marked text array
|
37 |
+
graphic = create_graphic(shap_values)
|
38 |
+
marked_text = markup_text(
|
39 |
+
shap_values.data[0], shap_values.values[0], variant="shap"
|
40 |
+
)
|
41 |
+
|
42 |
+
# create the response text
|
43 |
+
response_text = fmt.format_output_text(shap_values.output_names)
|
44 |
+
|
45 |
+
# return response, graphic and marked_text array
|
46 |
+
return response_text, graphic, marked_text
|
47 |
+
|
48 |
+
|
49 |
+
# function used to wrap the model with a shap model
|
50 |
+
def wrap_shap(model):
|
51 |
+
# calling global variants
|
52 |
+
global TEXT_MASKER, TEACHER_FORCING
|
53 |
+
|
54 |
+
# set the device to cuda if gpu is available
|
55 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
+
|
57 |
+
# updating the model settings
|
58 |
+
model.set_config()
|
59 |
+
|
60 |
+
# (re)initialize the shap models and masker
|
61 |
+
# creating a shap text_generation model
|
62 |
+
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
|
63 |
+
# wrapping the text generation model in a teacher forcing model
|
64 |
+
TEACHER_FORCING = models.TeacherForcing(
|
65 |
+
text_generation,
|
66 |
+
model.TOKENIZER,
|
67 |
+
device=str(device),
|
68 |
+
similarity_model=model.MODEL,
|
69 |
+
similarity_tokenizer=model.TOKENIZER,
|
70 |
+
)
|
71 |
+
# setting the text masker as an empty string
|
72 |
+
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
|
73 |
+
|
74 |
+
|
75 |
+
# graphic plotting function that creates a html graphic (as string) for the explanation
|
76 |
+
def create_graphic(shap_values):
|
77 |
+
|
78 |
+
# create the html graphic using shap text plot function
|
79 |
+
graphic_html = plots.text(shap_values, display=False)
|
80 |
+
|
81 |
+
# return the html graphic as string to display in iFrame
|
82 |
+
return str(graphic_html)
|
explanation/plotting.py
CHANGED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# plotting functions
|
2 |
+
|
3 |
+
# external imports
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
|
8 |
+
def plot_seq(seq_values: list, method_model: tuple = ("", "")):
|
9 |
+
|
10 |
+
# Separate the tokens and their corresponding importance values
|
11 |
+
tokens, importance = zip(*seq_values)
|
12 |
+
|
13 |
+
# Convert importance values to numpy array for conditional coloring
|
14 |
+
importance = np.array(importance)
|
15 |
+
importance = importance.log
|
16 |
+
|
17 |
+
# Determine the colors based on the sign of the importance values
|
18 |
+
colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]
|
19 |
+
|
20 |
+
# Create a bar plot
|
21 |
+
plt.figure(figsize=(len(tokens) * 0.9, np.max(importance)))
|
22 |
+
x_positions = range(len(tokens)) # Positions for the bars
|
23 |
+
|
24 |
+
# Creating vertical bar plot
|
25 |
+
bar_width = 0.8 # Increase this value to make the bars wider
|
26 |
+
plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)
|
27 |
+
plt.yscale("symlog")
|
28 |
+
|
29 |
+
# Annotating each bar with its value
|
30 |
+
padding = 0.1 # Padding for text annotation
|
31 |
+
for x, (y, color) in enumerate(zip(importance, colors)):
|
32 |
+
sign = "+" if y > 0 else ""
|
33 |
+
plt.annotate(
|
34 |
+
f"{sign}{y:.2f}", # Format the value with sign
|
35 |
+
xy=(x, y + padding if y > 0 else y - padding),
|
36 |
+
ha="center",
|
37 |
+
color=color,
|
38 |
+
va="bottom" if y > 0 else "top", # Vertical alignment
|
39 |
+
fontweight="bold", # Bold text
|
40 |
+
bbox={
|
41 |
+
"facecolor": "white",
|
42 |
+
"edgecolor": "none",
|
43 |
+
"boxstyle": "round,pad=0.1",
|
44 |
+
}, # White background
|
45 |
+
)
|
46 |
+
|
47 |
+
plt.axhline(0, color="black", linewidth=1)
|
48 |
+
plt.title(f"Input Token Attribution with {method_model[0]} on {method_model[1]}")
|
49 |
+
plt.xlabel("Input Tokens", labelpad=0.5)
|
50 |
+
plt.ylabel("Attribution")
|
51 |
+
plt.xticks(x_positions, tokens, rotation=45)
|
52 |
+
|
53 |
+
# Adjust y-axis limits to ensure there's enough space for labels
|
54 |
+
y_min, y_max = plt.ylim()
|
55 |
+
y_range = y_max - y_min
|
56 |
+
plt.ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range)
|
57 |
+
|
58 |
+
return plt
|
main.py
CHANGED
@@ -102,45 +102,46 @@ with gr.Blocks(
|
|
102 |
""")
|
103 |
# row with columns for the different settings
|
104 |
with gr.Row(equal_height=True):
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
|
145 |
# row with chatbot ui displaying "conversation" with the model
|
146 |
with gr.Row(equal_height=True):
|
@@ -251,6 +252,11 @@ with gr.Blocks(
|
|
251 |
show_label=True,
|
252 |
height="400px",
|
253 |
)
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
# functions to trigger the controller
|
256 |
## takes information for the chat and the xai selection
|
|
|
102 |
""")
|
103 |
# row with columns for the different settings
|
104 |
with gr.Row(equal_height=True):
|
105 |
+
with gr.Accordion("Application Settings", open=False):
|
106 |
+
# column that takes up 3/4 of the row
|
107 |
+
with gr.Column(scale=2):
|
108 |
+
# textbox to enter the system prompt
|
109 |
+
system_prompt = gr.Textbox(
|
110 |
+
label="System Prompt",
|
111 |
+
info="Set the models system prompt, dictating how it answers.",
|
112 |
+
# default system prompt is set to this in the backend
|
113 |
+
placeholder=(
|
114 |
+
"You are a helpful, respectful and honest assistant. Always"
|
115 |
+
" answer as helpfully as possible, while being safe."
|
116 |
+
),
|
117 |
+
)
|
118 |
+
# column that takes up 1/4 of the row
|
119 |
+
with gr.Column(scale=1):
|
120 |
+
# checkbox group to select the xai method
|
121 |
+
xai_selection = gr.Radio(
|
122 |
+
["None", "SHAP", "Attention"],
|
123 |
+
label="Interpretability Settings",
|
124 |
+
info="Select a Interpretability Implementation to use.",
|
125 |
+
value="None",
|
126 |
+
interactive=True,
|
127 |
+
show_label=True,
|
128 |
+
)
|
129 |
+
# column that takes up 1/4 of the row
|
130 |
+
with gr.Column(scale=1):
|
131 |
+
# checkbox group to select the xai method
|
132 |
+
model_selection = gr.Radio(
|
133 |
+
["GODEL", "Mistral"],
|
134 |
+
label="Model Settings",
|
135 |
+
info="Select a Model to use.",
|
136 |
+
value="GODEL",
|
137 |
+
interactive=True,
|
138 |
+
show_label=True,
|
139 |
+
)
|
140 |
|
141 |
+
# calling info functions on inputs/submits for different settings
|
142 |
+
system_prompt.submit(system_prompt_info, [system_prompt])
|
143 |
+
xai_selection.input(xai_info, [xai_selection])
|
144 |
+
model_selection.input(model_info, [model_selection])
|
145 |
|
146 |
# row with chatbot ui displaying "conversation" with the model
|
147 |
with gr.Row(equal_height=True):
|
|
|
252 |
show_label=True,
|
253 |
height="400px",
|
254 |
)
|
255 |
+
with gr.Row():
|
256 |
+
with gr.Accordion("Explanation Plot", open=False):
|
257 |
+
xai_plot = gr.Plot(
|
258 |
+
label="Input Sequence Attribution Plot", show_label=True
|
259 |
+
)
|
260 |
|
261 |
# functions to trigger the controller
|
262 |
## takes information for the chat and the xai selection
|
model/mistral.py
CHANGED
@@ -25,7 +25,6 @@ else:
|
|
25 |
MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
26 |
MODEL.to(device)
|
27 |
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
28 |
-
TOKENIZER.pad_token = TOKENIZER.eos_token
|
29 |
|
30 |
# default model config
|
31 |
CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
@@ -103,6 +102,7 @@ def format_answer(answer: str):
|
|
103 |
# Return an empty string if there are fewer than two occurrences of [/INST]
|
104 |
formatted_answer = ""
|
105 |
|
|
|
106 |
return formatted_answer
|
107 |
|
108 |
|
|
|
25 |
MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
26 |
MODEL.to(device)
|
27 |
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
|
|
28 |
|
29 |
# default model config
|
30 |
CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
|
|
102 |
# Return an empty string if there are fewer than two occurrences of [/INST]
|
103 |
formatted_answer = ""
|
104 |
|
105 |
+
print(f"Cut {answer} into {formatted_answer}.")
|
106 |
return formatted_answer
|
107 |
|
108 |
|
requirements.txt
CHANGED
@@ -10,7 +10,6 @@ markdown~=3.5.1
|
|
10 |
huggingface_hub~=0.19.4
|
11 |
fastapi~=0.104.1
|
12 |
uvicorn~=0.24.0
|
13 |
-
tinydb~=4.8.0
|
14 |
black~=23.12.0
|
15 |
pylint~=3.0.0
|
16 |
numpy
|
|
|
10 |
huggingface_hub~=0.19.4
|
11 |
fastapi~=0.104.1
|
12 |
uvicorn~=0.24.0
|
|
|
13 |
black~=23.12.0
|
14 |
pylint~=3.0.0
|
15 |
numpy
|