LennardZuendorf commited on
Commit
d4dd3c5
1 Parent(s): 229e14c

feat/fix: several minor fixes and additions

Browse files
explanation/interpret_captum.py CHANGED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # external imports
2
+ from captum.attr import LLMAttribution, TextTokenInput, KernelShap
3
+ import torch
4
+
5
+ # internal imports
6
+ from utils import formatting as fmt
7
+ from .markup import markup_text
8
+
9
+
10
+ # function to extract sequence attribution
11
+ def cpt_extract_seq_att(attr):
12
+
13
+ # getting values from captum
14
+ values = attr.seq_attr.to(torch.device("cpu")).numpy()
15
+
16
+ # format the input tokens nicely and check for mismatch
17
+ input_tokens = fmt.format_tokens(attr.input_tokens)
18
+ if len(attr.input_tokens) != len(values):
19
+ raise RuntimeError("values and input len mismatch")
20
+
21
+ # return a list of tuples with token and value
22
+ return list(zip(input_tokens, values))
23
+
24
+
25
+ # main explain function that returns a chat with explanations
26
+ def chat_explained(model, prompt):
27
+ model.set_config({})
28
+
29
+ # creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
30
+ llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
31
+
32
+ # generation attribution
33
+ attribution_input = TextTokenInput(prompt, model.TOKENIZER)
34
+ attribution_result = llm_attribution.attribute(
35
+ attribution_input, gen_args=model.CONFIG.to_dict()
36
+ )
37
+
38
+ # extracting values and input tokens
39
+ values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
40
+ input_tokens = fmt.format_tokens(attribution_result.input_tokens)
41
+
42
+ # raising error if mismatch occurs
43
+ if len(attribution_result.input_tokens) != len(values):
44
+ raise RuntimeError("values and input len mismatch")
45
+
46
+ # getting response text, graphic placeholder and marked text object
47
+ response_text = fmt.format_output_text(attribution_result.output_tokens)
48
+ graphic = (
49
+ "<div style='text-align: center; font-family:arial;'><h4>Attention"
50
+ "Intepretation with Captum doesn't support an interactive graphic.</h4></div>"
51
+ )
52
+ marked_text = markup_text(input_tokens, values, variant="captum")
53
+
54
+ # return response, graphic and marked_text array
55
+ return response_text, graphic, marked_text
explanation/interpret_shap.py CHANGED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # interpret module that implements the interpretability method
2
+
3
+ # external imports
4
+ from shap import models, maskers, plots, PartitionExplainer
5
+ import torch
6
+
7
+ # internal imports
8
+ from utils import formatting as fmt
9
+ from .markup import markup_text
10
+
11
+ # global variables
12
+ TEACHER_FORCING = None
13
+ TEXT_MASKER = None
14
+
15
+
16
+ # function to extract summarized sequence wise attribution
17
+ def extract_seq_att(shap_values):
18
+
19
+ # extracting summed up shap values
20
+ values = fmt.flatten_attribution(shap_values.values[0], 1)
21
+
22
+ # returning list of tuples of token and value
23
+ return list(zip(shap_values.data[0], values))
24
+
25
+
26
+ # main explain function that returns a chat with explanations
27
+ def chat_explained(model, prompt):
28
+ model.set_config({})
29
+
30
+ # create the shap explainer
31
+ shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
32
+
33
+ # get the shap values for the prompt
34
+ shap_values = shap_explainer([prompt])
35
+
36
+ # create the explanation graphic and marked text array
37
+ graphic = create_graphic(shap_values)
38
+ marked_text = markup_text(
39
+ shap_values.data[0], shap_values.values[0], variant="shap"
40
+ )
41
+
42
+ # create the response text
43
+ response_text = fmt.format_output_text(shap_values.output_names)
44
+
45
+ # return response, graphic and marked_text array
46
+ return response_text, graphic, marked_text
47
+
48
+
49
+ # function used to wrap the model with a shap model
50
+ def wrap_shap(model):
51
+ # calling global variants
52
+ global TEXT_MASKER, TEACHER_FORCING
53
+
54
+ # set the device to cuda if gpu is available
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+
57
+ # updating the model settings
58
+ model.set_config()
59
+
60
+ # (re)initialize the shap models and masker
61
+ # creating a shap text_generation model
62
+ text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
63
+ # wrapping the text generation model in a teacher forcing model
64
+ TEACHER_FORCING = models.TeacherForcing(
65
+ text_generation,
66
+ model.TOKENIZER,
67
+ device=str(device),
68
+ similarity_model=model.MODEL,
69
+ similarity_tokenizer=model.TOKENIZER,
70
+ )
71
+ # setting the text masker as an empty string
72
+ TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
73
+
74
+
75
+ # graphic plotting function that creates a html graphic (as string) for the explanation
76
+ def create_graphic(shap_values):
77
+
78
+ # create the html graphic using shap text plot function
79
+ graphic_html = plots.text(shap_values, display=False)
80
+
81
+ # return the html graphic as string to display in iFrame
82
+ return str(graphic_html)
explanation/plotting.py CHANGED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # plotting functions
2
+
3
+ # external imports
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+
8
+ def plot_seq(seq_values: list, method_model: tuple = ("", "")):
9
+
10
+ # Separate the tokens and their corresponding importance values
11
+ tokens, importance = zip(*seq_values)
12
+
13
+ # Convert importance values to numpy array for conditional coloring
14
+ importance = np.array(importance)
15
+ importance = importance.log
16
+
17
+ # Determine the colors based on the sign of the importance values
18
+ colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]
19
+
20
+ # Create a bar plot
21
+ plt.figure(figsize=(len(tokens) * 0.9, np.max(importance)))
22
+ x_positions = range(len(tokens)) # Positions for the bars
23
+
24
+ # Creating vertical bar plot
25
+ bar_width = 0.8 # Increase this value to make the bars wider
26
+ plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)
27
+ plt.yscale("symlog")
28
+
29
+ # Annotating each bar with its value
30
+ padding = 0.1 # Padding for text annotation
31
+ for x, (y, color) in enumerate(zip(importance, colors)):
32
+ sign = "+" if y > 0 else ""
33
+ plt.annotate(
34
+ f"{sign}{y:.2f}", # Format the value with sign
35
+ xy=(x, y + padding if y > 0 else y - padding),
36
+ ha="center",
37
+ color=color,
38
+ va="bottom" if y > 0 else "top", # Vertical alignment
39
+ fontweight="bold", # Bold text
40
+ bbox={
41
+ "facecolor": "white",
42
+ "edgecolor": "none",
43
+ "boxstyle": "round,pad=0.1",
44
+ }, # White background
45
+ )
46
+
47
+ plt.axhline(0, color="black", linewidth=1)
48
+ plt.title(f"Input Token Attribution with {method_model[0]} on {method_model[1]}")
49
+ plt.xlabel("Input Tokens", labelpad=0.5)
50
+ plt.ylabel("Attribution")
51
+ plt.xticks(x_positions, tokens, rotation=45)
52
+
53
+ # Adjust y-axis limits to ensure there's enough space for labels
54
+ y_min, y_max = plt.ylim()
55
+ y_range = y_max - y_min
56
+ plt.ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range)
57
+
58
+ return plt
main.py CHANGED
@@ -102,45 +102,46 @@ with gr.Blocks(
102
  """)
103
  # row with columns for the different settings
104
  with gr.Row(equal_height=True):
105
- # column that takes up 3/4 of the row
106
- with gr.Column(scale=2):
107
- # textbox to enter the system prompt
108
- system_prompt = gr.Textbox(
109
- label="System Prompt",
110
- info="Set the models system prompt, dictating how it answers.",
111
- # default system prompt is set to this in the backend
112
- placeholder=(
113
- "You are a helpful, respectful and honest assistant. Always"
114
- " answer as helpfully as possible, while being safe."
115
- ),
116
- )
117
- # column that takes up 1/4 of the row
118
- with gr.Column(scale=1):
119
- # checkbox group to select the xai method
120
- xai_selection = gr.Radio(
121
- ["None", "SHAP", "Attention"],
122
- label="Interpretability Settings",
123
- info="Select a Interpretability Implementation to use.",
124
- value="None",
125
- interactive=True,
126
- show_label=True,
127
- )
128
- # column that takes up 1/4 of the row
129
- with gr.Column(scale=1):
130
- # checkbox group to select the xai method
131
- model_selection = gr.Radio(
132
- ["GODEL", "Mistral"],
133
- label="Model Settings",
134
- info="Select a Model to use.",
135
- value="GODEL",
136
- interactive=True,
137
- show_label=True,
138
- )
 
139
 
140
- # calling info functions on inputs/submits for different settings
141
- system_prompt.submit(system_prompt_info, [system_prompt])
142
- xai_selection.input(xai_info, [xai_selection])
143
- model_selection.input(model_info, [model_selection])
144
 
145
  # row with chatbot ui displaying "conversation" with the model
146
  with gr.Row(equal_height=True):
@@ -251,6 +252,11 @@ with gr.Blocks(
251
  show_label=True,
252
  height="400px",
253
  )
 
 
 
 
 
254
 
255
  # functions to trigger the controller
256
  ## takes information for the chat and the xai selection
 
102
  """)
103
  # row with columns for the different settings
104
  with gr.Row(equal_height=True):
105
+ with gr.Accordion("Application Settings", open=False):
106
+ # column that takes up 3/4 of the row
107
+ with gr.Column(scale=2):
108
+ # textbox to enter the system prompt
109
+ system_prompt = gr.Textbox(
110
+ label="System Prompt",
111
+ info="Set the models system prompt, dictating how it answers.",
112
+ # default system prompt is set to this in the backend
113
+ placeholder=(
114
+ "You are a helpful, respectful and honest assistant. Always"
115
+ " answer as helpfully as possible, while being safe."
116
+ ),
117
+ )
118
+ # column that takes up 1/4 of the row
119
+ with gr.Column(scale=1):
120
+ # checkbox group to select the xai method
121
+ xai_selection = gr.Radio(
122
+ ["None", "SHAP", "Attention"],
123
+ label="Interpretability Settings",
124
+ info="Select a Interpretability Implementation to use.",
125
+ value="None",
126
+ interactive=True,
127
+ show_label=True,
128
+ )
129
+ # column that takes up 1/4 of the row
130
+ with gr.Column(scale=1):
131
+ # checkbox group to select the xai method
132
+ model_selection = gr.Radio(
133
+ ["GODEL", "Mistral"],
134
+ label="Model Settings",
135
+ info="Select a Model to use.",
136
+ value="GODEL",
137
+ interactive=True,
138
+ show_label=True,
139
+ )
140
 
141
+ # calling info functions on inputs/submits for different settings
142
+ system_prompt.submit(system_prompt_info, [system_prompt])
143
+ xai_selection.input(xai_info, [xai_selection])
144
+ model_selection.input(model_info, [model_selection])
145
 
146
  # row with chatbot ui displaying "conversation" with the model
147
  with gr.Row(equal_height=True):
 
252
  show_label=True,
253
  height="400px",
254
  )
255
+ with gr.Row():
256
+ with gr.Accordion("Explanation Plot", open=False):
257
+ xai_plot = gr.Plot(
258
+ label="Input Sequence Attribution Plot", show_label=True
259
+ )
260
 
261
  # functions to trigger the controller
262
  ## takes information for the chat and the xai selection
model/mistral.py CHANGED
@@ -25,7 +25,6 @@ else:
25
  MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
26
  MODEL.to(device)
27
  TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
28
- TOKENIZER.pad_token = TOKENIZER.eos_token
29
 
30
  # default model config
31
  CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
@@ -103,6 +102,7 @@ def format_answer(answer: str):
103
  # Return an empty string if there are fewer than two occurrences of [/INST]
104
  formatted_answer = ""
105
 
 
106
  return formatted_answer
107
 
108
 
 
25
  MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
26
  MODEL.to(device)
27
  TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
 
28
 
29
  # default model config
30
  CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
 
102
  # Return an empty string if there are fewer than two occurrences of [/INST]
103
  formatted_answer = ""
104
 
105
+ print(f"Cut {answer} into {formatted_answer}.")
106
  return formatted_answer
107
 
108
 
requirements.txt CHANGED
@@ -10,7 +10,6 @@ markdown~=3.5.1
10
  huggingface_hub~=0.19.4
11
  fastapi~=0.104.1
12
  uvicorn~=0.24.0
13
- tinydb~=4.8.0
14
  black~=23.12.0
15
  pylint~=3.0.0
16
  numpy
 
10
  huggingface_hub~=0.19.4
11
  fastapi~=0.104.1
12
  uvicorn~=0.24.0
 
13
  black~=23.12.0
14
  pylint~=3.0.0
15
  numpy