LennardZuendorf commited on
Commit
36b45ee
2 Parent(s): 148ef57 fe1089d

merge remote

Browse files
backend/controller.py CHANGED
@@ -10,6 +10,7 @@ from explanation import interpret, visualize
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
 
13
  def interference(
14
  prompt: str,
15
  history: list,
 
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
13
+ # TODO: Limit maximum tokens/model input
14
  def interference(
15
  prompt: str,
16
  history: list,
explanation/interpret.py CHANGED
@@ -3,11 +3,16 @@
3
  import seaborn as sns
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- from shap import plots, PartitionExplainer
 
7
 
8
  # internal imports
9
  from utils import formatting as fmt
10
 
 
 
 
 
11
 
12
  # main explain function that returns a chat with explanations
13
  def chat_explained(model, prompt):
@@ -27,6 +32,27 @@ def chat_explained(model, prompt):
27
  return response_text, graphic, plot
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # graphic plotting function that creates a html graphic (as string) for the explanation
31
  def create_graphic(shap_values):
32
  # create the html graphic using shap text plot function
@@ -42,21 +68,29 @@ def create_plot(shap_values):
42
  output_names = shap_values.output_names
43
  input_names = shap_values.data[0]
44
 
 
 
 
45
  # Set seaborn style to dark
46
- sns.set(style="white")
 
47
  fig, ax = plt.subplots()
48
 
 
 
 
 
49
  # Setting figure size
50
  fig.set_size_inches(
51
- max(values.shape[1] * 2, 10),
52
- max(values.shape[0] * 1, 5),
53
  )
54
 
55
  # Plotting the heatmap with Seaborn's color palette
56
  im = ax.imshow(
57
- values,
58
- vmax=values.max(),
59
- vmin=values.min(),
60
  cmap=sns.color_palette("vlag_r", as_cmap=True),
61
  aspect="auto",
62
  )
@@ -64,25 +98,25 @@ def create_plot(shap_values):
64
  # Creating colorbar
65
  cbar = ax.figure.colorbar(im, ax=ax)
66
  cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
67
- cbar.ax.yaxis.set_tick_params(color="black")
68
- plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
69
 
70
  # Setting ticks and labels with white color for visibility
71
- ax.set_yticks(np.arange(len(input_names)), labels=input_names)
72
- ax.set_xticks(np.arange(len(output_names)), labels=output_names)
73
- plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
74
- plt.setp(ax.get_yticklabels(), color="black")
75
 
76
  # Adjusting tick labels
77
  ax.tick_params(
78
  top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
79
  )
80
 
81
- # Adding text annotations with appropriate contrast
82
- for i in range(values.shape[0]):
83
- for j in range(values.shape[1]):
84
- val = values[i, j]
85
- color = "white" if im.norm(values.max()) / 2 > im.norm(val) else "black"
86
- ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
87
 
88
  return plt
 
3
  import seaborn as sns
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
+ from shap import models, maskers, plots, PartitionExplainer
7
+ import torch
8
 
9
  # internal imports
10
  from utils import formatting as fmt
11
 
12
+ # global variables
13
+ TEACHER_FORCING = None
14
+ TEXT_MASKER = None
15
+
16
 
17
  # main explain function that returns a chat with explanations
18
  def chat_explained(model, prompt):
 
32
  return response_text, graphic, plot
33
 
34
 
35
+ def wrap_shap(model):
36
+ global TEXT_MASKER, TEACHER_FORCING
37
+
38
+ # set the device to cuda if gpu is available
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ # updating the model settings again
42
+ model.set_config()
43
+
44
+ # (re)initialize the shap models and masker
45
+ text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
46
+ TEACHER_FORCING = models.TeacherForcing(
47
+ text_generation,
48
+ model.TOKENIZER,
49
+ device=str(device),
50
+ similarity_model=model.MODEL,
51
+ similarity_tokenizer=model.TOKENIZER,
52
+ )
53
+ TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
54
+
55
+
56
  # graphic plotting function that creates a html graphic (as string) for the explanation
57
  def create_graphic(shap_values):
58
  # create the html graphic using shap text plot function
 
68
  output_names = shap_values.output_names
69
  input_names = shap_values.data[0]
70
 
71
+ # Transpose the values for horizontal input names
72
+ transposed_values = np.transpose(values)
73
+
74
  # Set seaborn style to dark
75
+ sns.set(style="dark")
76
+
77
  fig, ax = plt.subplots()
78
 
79
+ # Making background transparent
80
+ ax.set_alpha(0)
81
+ fig.patch.set_alpha(0)
82
+
83
  # Setting figure size
84
  fig.set_size_inches(
85
+ max(transposed_values.shape[1] * 2, 10),
86
+ max(transposed_values.shape[0] / 1.5, 5),
87
  )
88
 
89
  # Plotting the heatmap with Seaborn's color palette
90
  im = ax.imshow(
91
+ transposed_values,
92
+ vmax=transposed_values.max(),
93
+ vmin=-transposed_values.min(),
94
  cmap=sns.color_palette("vlag_r", as_cmap=True),
95
  aspect="auto",
96
  )
 
98
  # Creating colorbar
99
  cbar = ax.figure.colorbar(im, ax=ax)
100
  cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
101
+ cbar.ax.yaxis.set_tick_params(color="white")
102
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
103
 
104
  # Setting ticks and labels with white color for visibility
105
+ ax.set_xticks(np.arange(len(input_names)), labels=input_names)
106
+ ax.set_yticks(np.arange(len(output_names)), labels=output_names)
107
+ plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
108
+ plt.setp(ax.get_yticklabels(), color="white")
109
 
110
  # Adjusting tick labels
111
  ax.tick_params(
112
  top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
113
  )
114
 
115
+ # Adding text annotations - not used for readability
116
+ # for i in range(transposed_values.shape[0]):
117
+ # for j in range(transposed_values.shape[1]):
118
+ # val = transposed_values[i, j]
119
+ # color = "black" if 0.2 < im.norm(val) < 0.8 else "white"
120
+ # ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
121
 
122
  return plt
explanation/visualize.py CHANGED
@@ -13,7 +13,6 @@ from utils import formatting as fmt
13
  # plotting function that plots the attention values in a heatmap
14
  def chat_explained(model, prompt):
15
 
16
- # reset the model config to default
17
  model.set_config()
18
 
19
  # get encoded input and output vectors
@@ -21,8 +20,6 @@ def chat_explained(model, prompt):
21
  prompt, return_tensors="pt", add_special_tokens=True
22
  ).input_ids
23
  decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
24
-
25
- # get the text for the input and output vectors
26
  encoder_text = fmt.format_tokens(
27
  model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
28
  )
@@ -40,20 +37,11 @@ def chat_explained(model, prompt):
40
  # create the response text, graphic and plot
41
  response_text = fmt.format_output_text(decoder_text)
42
  graphic = create_graphic(attention_output, (encoder_text, decoder_text))
43
- graphic = (
44
- '<div style="text-align: center"><h4>Interactive Graphic for Attention'
45
- " currently WIP</h4></div>"
46
- )
47
  plot = create_plot(attention_output, (encoder_text, decoder_text))
48
- return (
49
- response_text,
50
- graphic,
51
- plot,
52
- )
53
 
54
 
55
  # creating a html graphic using BERTViz
56
- # TODO: FIX
57
  def create_graphic(attention_output, enc_dec_texts: tuple):
58
 
59
  # calls the head_view function of BERTViz to return html graphic
@@ -70,28 +58,27 @@ def create_graphic(attention_output, enc_dec_texts: tuple):
70
 
71
 
72
  # creating an attention heatmap plot using seaborn
73
- # CREDIT: adopted from official Matplotlib documentation
74
- ## see https://matplotlib.org/stable/
75
-
76
-
77
  def create_plot(attention_output, enc_dec_texts: tuple):
78
  # get the averaged attention weights
79
  attention = attention_output.cross_attentions[0][0].detach().numpy()
80
  averaged_attention_weights = np.mean(attention, axis=0)
81
- averaged_attention_weights = np.transpose(averaged_attention_weights)
82
 
83
- # get the encoder and decoder tokens in text form
84
  encoder_tokens = enc_dec_texts[0]
85
  decoder_tokens = enc_dec_texts[1]
86
 
87
  # set seaborn style to dark and initialize figure and axis
88
- sns.set(style="white")
89
  fig, ax = plt.subplots()
90
 
 
 
 
 
91
  # Setting figure size
92
  fig.set_size_inches(
93
  max(averaged_attention_weights.shape[1] * 2, 10),
94
- max(averaged_attention_weights.shape[0] * 1, 5),
95
  )
96
 
97
  # Plotting the heatmap with seaborn's color palette
@@ -105,27 +92,19 @@ def create_plot(attention_output, enc_dec_texts: tuple):
105
 
106
  # Creating colorbar
107
  cbar = ax.figure.colorbar(im, ax=ax)
108
- cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom")
109
- cbar.ax.yaxis.set_tick_params(color="black")
110
- plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
111
-
112
- # Setting ticks and labels with black color for visibility
113
- ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
114
- ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
115
- ax.set_title("Attention Weights by Token")
116
- plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
117
- plt.setp(ax.get_yticklabels(), color="black")
118
-
119
- # Adding text annotations with appropriate contrast
120
- for i in range(averaged_attention_weights.shape[0]):
121
- for j in range(averaged_attention_weights.shape[1]):
122
- val = averaged_attention_weights[i, j]
123
- color = (
124
- "white"
125
- if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val)
126
- else "black"
127
- )
128
- ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
129
-
130
- # return the plot
131
  return plt
 
13
  # plotting function that plots the attention values in a heatmap
14
  def chat_explained(model, prompt):
15
 
 
16
  model.set_config()
17
 
18
  # get encoded input and output vectors
 
20
  prompt, return_tensors="pt", add_special_tokens=True
21
  ).input_ids
22
  decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
 
 
23
  encoder_text = fmt.format_tokens(
24
  model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
25
  )
 
37
  # create the response text, graphic and plot
38
  response_text = fmt.format_output_text(decoder_text)
39
  graphic = create_graphic(attention_output, (encoder_text, decoder_text))
 
 
 
 
40
  plot = create_plot(attention_output, (encoder_text, decoder_text))
41
+ return response_text, graphic, plot
 
 
 
 
42
 
43
 
44
  # creating a html graphic using BERTViz
 
45
  def create_graphic(attention_output, enc_dec_texts: tuple):
46
 
47
  # calls the head_view function of BERTViz to return html graphic
 
58
 
59
 
60
  # creating an attention heatmap plot using seaborn
 
 
 
 
61
  def create_plot(attention_output, enc_dec_texts: tuple):
62
  # get the averaged attention weights
63
  attention = attention_output.cross_attentions[0][0].detach().numpy()
64
  averaged_attention_weights = np.mean(attention, axis=0)
 
65
 
66
+ # get the encoder and decoder tokens
67
  encoder_tokens = enc_dec_texts[0]
68
  decoder_tokens = enc_dec_texts[1]
69
 
70
  # set seaborn style to dark and initialize figure and axis
71
+ sns.set(style="dark")
72
  fig, ax = plt.subplots()
73
 
74
+ # Making background transparent
75
+ ax.set_alpha(0)
76
+ fig.patch.set_alpha(0)
77
+
78
  # Setting figure size
79
  fig.set_size_inches(
80
  max(averaged_attention_weights.shape[1] * 2, 10),
81
+ max(averaged_attention_weights.shape[0] / 1.5, 5),
82
  )
83
 
84
  # Plotting the heatmap with seaborn's color palette
 
92
 
93
  # Creating colorbar
94
  cbar = ax.figure.colorbar(im, ax=ax)
95
+ cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
96
+ cbar.ax.yaxis.set_tick_params(color="white")
97
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
98
+
99
+ # Setting ticks and labels with white color for visibility
100
+ ax.set_xticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
101
+ ax.set_yticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
102
+ plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
103
+ plt.setp(ax.get_yticklabels(), color="white")
104
+
105
+ # Adjusting tick labels
106
+ ax.tick_params(
107
+ top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
108
+ )
109
+
 
 
 
 
 
 
 
 
110
  return plt
main.py CHANGED
@@ -187,7 +187,7 @@ with gr.Blocks(
187
  Values have been excluded for readability. See colorbar for value indication.
188
  """)
189
  # plot component that takes a matplotlib figure as input
190
- xai_plot = gr.Plot(label="Token Level Explanation")
191
 
192
  # functions to trigger the controller
193
  ## takes information for the chat and the xai selection
@@ -205,12 +205,10 @@ with gr.Blocks(
205
  [user_prompt, chatbot, xai_interactive, xai_plot],
206
  )
207
 
208
- # final row to about information
209
- ## and credits, data protection and link to the License
210
- with gr.Tab(label="About"):
211
- gr.Markdown(value=load_md("public/about.md"))
212
- with gr.Accordion(label="Credits, Data Protection and License", open=False):
213
- gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
214
 
215
  # mount function for fastAPI Application
216
  app = gr.mount_gradio_app(app, ui, path="/")
 
187
  Values have been excluded for readability. See colorbar for value indication.
188
  """)
189
  # plot component that takes a matplotlib figure as input
190
+ xai_plot = gr.Plot(label="Token Level Explanation", scale=3)
191
 
192
  # functions to trigger the controller
193
  ## takes information for the chat and the xai selection
 
205
  [user_prompt, chatbot, xai_interactive, xai_plot],
206
  )
207
 
208
+ # final row to show legal information
209
+ ## - credits, data protection and link to the License
210
+ with gr.Tab(label="Credits, Data Protection and License"):
211
+ gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
 
 
212
 
213
  # mount function for fastAPI Application
214
  app = gr.mount_gradio_app(app, ui, path="/")
public/credits_dataprotection_license.md CHANGED
@@ -6,6 +6,7 @@
6
 
7
 
8
  # Credits
 
9
 
10
  ### Models
11
  This implementation is build on GODEL by Microsoft, Inc.
 
6
 
7
 
8
  # Credits
9
+ For full credits, please refer to the [thesis print]()
10
 
11
  ### Models
12
  This implementation is build on GODEL by Microsoft, Inc.