LennardZuendorf commited on
Commit
f5ebee7
1 Parent(s): 940d70a

feat: removing plots, updating iframe height, minor changes

Browse files
backend/controller.py CHANGED
@@ -40,7 +40,7 @@ def interference(
40
  raise RuntimeError("There was an error in the selected XAI approach.")
41
 
42
  # call the explained chat function
43
- prompt_output, history_output, xai_graphic, xai_plot, xai_markup = (
44
  explained_chat(
45
  model=godel,
46
  xai=xai,
@@ -61,17 +61,16 @@ def interference(
61
  knowledge=knowledge,
62
  )
63
  # set XAI outputs to disclaimer html/none
64
- xai_graphic, xai_plot, xai_markup = (
65
  """
66
  <div style="text-align: center"><h4>Without Selected XAI Approach,
67
  no graphic will be displayed</h4></div>
68
  """,
69
- None,
70
  [("", "")],
71
  )
72
 
73
  # return the outputs
74
- return prompt_output, history_output, xai_graphic, xai_plot, xai_markup
75
 
76
 
77
  # simple chat function that calls the model
@@ -98,10 +97,10 @@ def explained_chat(
98
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
99
 
100
  # generating an answer using the xai methods explain and respond function
101
- answer, xai_graphic, xai_plot, xai_markup = xai.chat_explained(model, prompt)
102
 
103
  # updating the chat history with the new answer
104
  history.append((message, answer))
105
 
106
  # returning the updated history, xai graphic and xai plot elements
107
- return "", history, xai_graphic, xai_plot, xai_markup
 
40
  raise RuntimeError("There was an error in the selected XAI approach.")
41
 
42
  # call the explained chat function
43
+ prompt_output, history_output, xai_graphic, xai_markup = (
44
  explained_chat(
45
  model=godel,
46
  xai=xai,
 
61
  knowledge=knowledge,
62
  )
63
  # set XAI outputs to disclaimer html/none
64
+ xai_graphic, xai_markup = (
65
  """
66
  <div style="text-align: center"><h4>Without Selected XAI Approach,
67
  no graphic will be displayed</h4></div>
68
  """,
 
69
  [("", "")],
70
  )
71
 
72
  # return the outputs
73
+ return prompt_output, history_output, xai_graphic, xai_markup
74
 
75
 
76
  # simple chat function that calls the model
 
97
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
98
 
99
  # generating an answer using the xai methods explain and respond function
100
+ answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
101
 
102
  # updating the chat history with the new answer
103
  history.append((message, answer))
104
 
105
  # returning the updated history, xai graphic and xai plot elements
106
+ return "", history, xai_graphic, xai_markup
explanation/interpret_shap.py CHANGED
@@ -26,18 +26,13 @@ def chat_explained(model, prompt):
26
 
27
  # create the explanation graphic and plot
28
  graphic = create_graphic(shap_values)
29
- plot = create_plot(
30
- values=shap_values.values[0],
31
- output_names=shap_values.output_names,
32
- input_names=shap_values.data[0],
33
- )
34
  marked_text = markup_text(
35
  shap_values.data[0], shap_values.values[0], variant="shap"
36
  )
37
 
38
  # create the response text
39
  response_text = fmt.format_output_text(shap_values.output_names)
40
- return response_text, graphic, plot, marked_text
41
 
42
 
43
  def wrap_shap(model):
@@ -67,55 +62,4 @@ def create_graphic(shap_values):
67
  graphic_html = plots.text(shap_values, display=False)
68
 
69
  # return the html graphic as string
70
- return str(graphic_html)
71
-
72
-
73
- # creating an attention heatmap plot using matplotlib/seaborn
74
- # CREDIT: adopted from official Matplotlib documentation
75
- ## see https://matplotlib.org/stable/
76
- def create_plot(values, output_names, input_names):
77
-
78
- # Set seaborn style to dark
79
- sns.set(style="white")
80
- fig, ax = plt.subplots()
81
-
82
- # Setting figure size
83
- fig.set_size_inches(
84
- max(values.shape[1] * 2, 10),
85
- max(values.shape[0] * 1, 5),
86
- )
87
-
88
- # Plotting the heatmap with Seaborn's color palette
89
- im = ax.imshow(
90
- values,
91
- vmax=values.max(),
92
- vmin=values.min(),
93
- cmap=sns.color_palette("vlag_r", as_cmap=True),
94
- aspect="auto",
95
- )
96
-
97
- # Creating colorbar
98
- cbar = ax.figure.colorbar(im, ax=ax)
99
- cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
100
- cbar.ax.yaxis.set_tick_params(color="black")
101
- plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
102
-
103
- # Setting ticks and labels with white color for visibility
104
- ax.set_yticks(np.arange(len(input_names)), labels=input_names)
105
- ax.set_xticks(np.arange(len(output_names)), labels=output_names)
106
- plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
107
- plt.setp(ax.get_yticklabels(), color="black")
108
-
109
- # Adjusting tick labels
110
- ax.tick_params(
111
- top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
112
- )
113
-
114
- # Adding text annotations with appropriate contrast
115
- for i in range(values.shape[0]):
116
- for j in range(values.shape[1]):
117
- val = values[i, j]
118
- color = "white" if im.norm(values.max()) / 2 > im.norm(val) else "black"
119
- ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
120
-
121
- return plt
 
26
 
27
  # create the explanation graphic and plot
28
  graphic = create_graphic(shap_values)
 
 
 
 
 
29
  marked_text = markup_text(
30
  shap_values.data[0], shap_values.values[0], variant="shap"
31
  )
32
 
33
  # create the response text
34
  response_text = fmt.format_output_text(shap_values.output_names)
35
+ return response_text, graphic, marked_text
36
 
37
 
38
  def wrap_shap(model):
 
62
  graphic_html = plots.text(shap_values, display=False)
63
 
64
  # return the html graphic as string
65
+ return str(graphic_html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explanation/markup.py CHANGED
@@ -11,10 +11,13 @@ from utils import formatting as fmt
11
  def markup_text(input_text: list, text_values: ndarray, variant: str):
12
  bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
13
 
14
- # Flatten the explanations values
 
15
  if variant == "shap":
16
  text_values = np.transpose(text_values)
17
- text_values = fmt.flatten_values(text_values)
 
 
18
 
19
  # Determine the minimum and maximum values
20
  min_val, max_val = np.min(text_values), np.max(text_values)
 
11
  def markup_text(input_text: list, text_values: ndarray, variant: str):
12
  bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
13
 
14
+ # Flatten the values depending on the source
15
+ # attention is averaged, SHAP summed up
16
  if variant == "shap":
17
  text_values = np.transpose(text_values)
18
+ text_values = fmt.flatten_attribution(text_values)
19
+ else:
20
+ text_values = fmt.flatten_attention(text_values)
21
 
22
  # Determine the minimum and maximum values
23
  min_val, max_val = np.min(text_values), np.max(text_values)
explanation/visualize.py CHANGED
@@ -34,74 +34,11 @@ def chat_explained(model, prompt):
34
  output_attentions=True,
35
  )
36
 
37
- averaged_attention = avg_attention(attention_output)
38
 
39
- # create the response text, graphic and plot
40
  response_text = fmt.format_output_text(decoder_text)
41
- plot = create_plot(averaged_attention, (encoder_text, decoder_text))
42
  marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
43
 
44
- return response_text, "", plot, marked_text
45
 
46
-
47
- # creating an attention heatmap plot using matplotlib/seaborn
48
- # CREDIT: adopted from official Matplotlib documentation
49
- ## see https://matplotlib.org/stable/
50
- def create_plot(averaged_attention_weights, enc_dec_texts: tuple):
51
- # transpose the attention weights
52
- averaged_attention_weights = np.transpose(averaged_attention_weights)
53
-
54
- # get the encoder and decoder tokens in text form
55
- encoder_tokens = enc_dec_texts[0]
56
- decoder_tokens = enc_dec_texts[1]
57
-
58
- # set seaborn style to dark and initialize figure and axis
59
- sns.set(style="white")
60
- fig, ax = plt.subplots()
61
-
62
- # Setting figure size
63
- fig.set_size_inches(
64
- max(averaged_attention_weights.shape[1] * 2, 10),
65
- max(averaged_attention_weights.shape[0] * 1, 5),
66
- )
67
-
68
- # Plotting the heatmap with seaborn's color palette
69
- im = ax.imshow(
70
- averaged_attention_weights,
71
- vmax=averaged_attention_weights.max(),
72
- vmin=-averaged_attention_weights.min(),
73
- cmap=sns.color_palette("rocket", as_cmap=True),
74
- aspect="auto",
75
- )
76
-
77
- # Creating colorbar
78
- cbar = ax.figure.colorbar(im, ax=ax)
79
- cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom")
80
- cbar.ax.yaxis.set_tick_params(color="black")
81
- plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
82
-
83
- # Setting ticks and labels with black color for visibility
84
- ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
85
- ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
86
- ax.set_title("Attention Weights by Token")
87
- plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
88
- plt.setp(ax.get_yticklabels(), color="black")
89
-
90
- # Adding text annotations with appropriate contrast
91
- for i in range(averaged_attention_weights.shape[0]):
92
- for j in range(averaged_attention_weights.shape[1]):
93
- val = averaged_attention_weights[i, j]
94
- color = (
95
- "white"
96
- if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val)
97
- else "black"
98
- )
99
- ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
100
-
101
- # return the plot
102
- return plt
103
-
104
-
105
- def avg_attention(attention_values):
106
- attention = attention_values.cross_attentions[0][0].detach().numpy()
107
- return np.mean(attention, axis=0)
 
34
  output_attentions=True,
35
  )
36
 
37
+ averaged_attention = fmt.avg_attention(attention_output)
38
 
39
+ # create the response text and marked text for ui
40
  response_text = fmt.format_output_text(decoder_text)
 
41
  marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
42
 
43
+ return response_text, "", marked_text
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -180,6 +180,10 @@ with gr.Blocks(
180
  " scripts: hieroglyphs, Demotic, and Greek."
181
  ),
182
  ],
 
 
 
 
183
  ],
184
  inputs=[user_prompt, knowledge_input],
185
  )
@@ -197,22 +201,14 @@ with gr.Blocks(
197
  with gr.Row(variant="panel"):
198
  # wraps the explanation html to display it statically
199
  xai_interactive = iFrame(
200
- label="Static Explanation",
201
  value=(
202
  '<div style="text-align: center"><h4>No Graphic to Display'
203
  " (Yet)</h4></div>"
204
  ),
 
205
  show_label=True,
206
  )
207
- # row and accordion to display an explanation plot (if applicable)
208
- with gr.Row():
209
- with gr.Accordion("Token Wise Explanation Plot", open=False):
210
- gr.Markdown("""
211
- #### Plotted Values
212
- Values have been excluded for readability. See colorbar for value indication.
213
- """)
214
- # plot component that takes a matplotlib figure as input
215
- xai_plot = gr.Plot(label="Token Level Explanation")
216
 
217
  # functions to trigger the controller
218
  ## takes information for the chat and the xai selection
@@ -221,13 +217,13 @@ with gr.Blocks(
221
  submit_btn.click(
222
  interference,
223
  [user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
224
- [user_prompt, chatbot, xai_interactive, xai_plot, xai_text],
225
  )
226
  # function triggered by the enter key
227
  user_prompt.submit(
228
  interference,
229
  [user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
230
- [user_prompt, chatbot, xai_interactive, xai_plot, xai_text],
231
  )
232
 
233
  # final row to show legal information
 
180
  " scripts: hieroglyphs, Demotic, and Greek."
181
  ),
182
  ],
183
+ [
184
+ "Does money buy happiness?",
185
+ ""
186
+ ],
187
  ],
188
  inputs=[user_prompt, knowledge_input],
189
  )
 
201
  with gr.Row(variant="panel"):
202
  # wraps the explanation html to display it statically
203
  xai_interactive = iFrame(
204
+ label="Interactive Explanation",
205
  value=(
206
  '<div style="text-align: center"><h4>No Graphic to Display'
207
  " (Yet)</h4></div>"
208
  ),
209
+ height="600px",
210
  show_label=True,
211
  )
 
 
 
 
 
 
 
 
 
212
 
213
  # functions to trigger the controller
214
  ## takes information for the chat and the xai selection
 
217
  submit_btn.click(
218
  interference,
219
  [user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
220
+ [user_prompt, chatbot, xai_interactive, xai_text],
221
  )
222
  # function triggered by the enter key
223
  user_prompt.submit(
224
  interference,
225
  [user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
226
+ [user_prompt, chatbot, xai_interactive, xai_text],
227
  )
228
 
229
  # final row to show legal information
public/about.md CHANGED
@@ -7,7 +7,7 @@ This research tackles the rise of LLM based applications such a chatbots and exp
7
  ## Links
8
 
9
  - [GitHub Repository](https://github.com/LennardZuendorf/thesis-webapp) - The GitHub repository of this project.
10
- - [HTW Berlin](https://www.htw-berlin.de/) - The University I have built this project for, as part of my thesis.
11
 
12
 
13
  ## Implementation
 
7
  ## Links
8
 
9
  - [GitHub Repository](https://github.com/LennardZuendorf/thesis-webapp) - The GitHub repository of this project.
10
+ - [HTW Berlin](https://www.htw-berlin.de/en/) - The University I have built this project for, as part of my thesis.
11
 
12
 
13
  ## Implementation
utils/formatting.py CHANGED
@@ -66,5 +66,12 @@ def format_tokens(tokens: list):
66
 
67
 
68
  # function to flatten values into a 2d list by averaging the explanation values
69
- def flatten_values(values: ndarray, axis: int = 0):
 
 
 
70
  return np.mean(values, axis=axis)
 
 
 
 
 
66
 
67
 
68
  # function to flatten values into a 2d list by averaging the explanation values
69
+ def flatten_attribution(values: ndarray, axis: int = 0):
70
+ return np.sum(values, axis=axis)
71
+
72
+ def flatten_attention(values: ndarray, axis: int = 0):
73
  return np.mean(values, axis=axis)
74
+
75
+ def avg_attention(attention_values):
76
+ attention = attention_values.cross_attentions[0][0].detach().numpy()
77
+ return np.mean(attention, axis=0)