LennardZuendorf commited on
Commit
2230009
1 Parent(s): 43cce2a

feat: implementing fixes and updates for version 1.0.1

Browse files
backend/controller.py CHANGED
@@ -10,7 +10,6 @@ from explanation import interpret, visualize
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
13
- # TODO: Limit maximum tokens/model input
14
  def interference(
15
  prompt: str,
16
  history: list,
 
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
 
13
  def interference(
14
  prompt: str,
15
  history: list,
explanation/interpret.py CHANGED
@@ -3,16 +3,11 @@
3
  import seaborn as sns
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- from shap import models, maskers, plots, PartitionExplainer
7
- import torch
8
 
9
  # internal imports
10
  from utils import formatting as fmt
11
 
12
- # global variables
13
- TEACHER_FORCING = None
14
- TEXT_MASKER = None
15
-
16
 
17
  # main explain function that returns a chat with explanations
18
  def chat_explained(model, prompt):
@@ -32,27 +27,6 @@ def chat_explained(model, prompt):
32
  return response_text, graphic, plot
33
 
34
 
35
- def wrap_shap(model):
36
- global TEXT_MASKER, TEACHER_FORCING
37
-
38
- # set the device to cuda if gpu is available
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
-
41
- # updating the model settings again
42
- model.set_config()
43
-
44
- # (re)initialize the shap models and masker
45
- text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
46
- TEACHER_FORCING = models.TeacherForcing(
47
- text_generation,
48
- model.TOKENIZER,
49
- device=str(device),
50
- similarity_model=model.MODEL,
51
- similarity_tokenizer=model.TOKENIZER,
52
- )
53
- TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
54
-
55
-
56
  # graphic plotting function that creates a html graphic (as string) for the explanation
57
  def create_graphic(shap_values):
58
  # create the html graphic using shap text plot function
@@ -68,29 +42,21 @@ def create_plot(shap_values):
68
  output_names = shap_values.output_names
69
  input_names = shap_values.data[0]
70
 
71
- # Transpose the values for horizontal input names
72
- transposed_values = np.transpose(values)
73
-
74
  # Set seaborn style to dark
75
- sns.set(style="dark")
76
-
77
  fig, ax = plt.subplots()
78
 
79
- # Making background transparent
80
- ax.set_alpha(0)
81
- fig.patch.set_alpha(0)
82
-
83
  # Setting figure size
84
  fig.set_size_inches(
85
- max(transposed_values.shape[1] * 2, 10),
86
- max(transposed_values.shape[0] / 1.5, 5),
87
  )
88
 
89
  # Plotting the heatmap with Seaborn's color palette
90
  im = ax.imshow(
91
- transposed_values,
92
- vmax=transposed_values.max(),
93
- vmin=-transposed_values.min(),
94
  cmap=sns.color_palette("vlag_r", as_cmap=True),
95
  aspect="auto",
96
  )
@@ -98,25 +64,25 @@ def create_plot(shap_values):
98
  # Creating colorbar
99
  cbar = ax.figure.colorbar(im, ax=ax)
100
  cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
101
- cbar.ax.yaxis.set_tick_params(color="white")
102
- plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
103
 
104
  # Setting ticks and labels with white color for visibility
105
- ax.set_xticks(np.arange(len(input_names)), labels=input_names)
106
- ax.set_yticks(np.arange(len(output_names)), labels=output_names)
107
- plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
108
- plt.setp(ax.get_yticklabels(), color="white")
109
 
110
  # Adjusting tick labels
111
  ax.tick_params(
112
  top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
113
  )
114
 
115
- # Adding text annotations - not used for readability
116
- # for i in range(transposed_values.shape[0]):
117
- # for j in range(transposed_values.shape[1]):
118
- # val = transposed_values[i, j]
119
- # color = "black" if 0.2 < im.norm(val) < 0.8 else "white"
120
- # ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
121
 
122
  return plt
 
3
  import seaborn as sns
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
+ from shap import plots, PartitionExplainer
 
7
 
8
  # internal imports
9
  from utils import formatting as fmt
10
 
 
 
 
 
11
 
12
  # main explain function that returns a chat with explanations
13
  def chat_explained(model, prompt):
 
27
  return response_text, graphic, plot
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # graphic plotting function that creates a html graphic (as string) for the explanation
31
  def create_graphic(shap_values):
32
  # create the html graphic using shap text plot function
 
42
  output_names = shap_values.output_names
43
  input_names = shap_values.data[0]
44
 
 
 
 
45
  # Set seaborn style to dark
46
+ sns.set(style="white")
 
47
  fig, ax = plt.subplots()
48
 
 
 
 
 
49
  # Setting figure size
50
  fig.set_size_inches(
51
+ max(values.shape[1] * 2, 10),
52
+ max(values.shape[0] * 1, 5),
53
  )
54
 
55
  # Plotting the heatmap with Seaborn's color palette
56
  im = ax.imshow(
57
+ values,
58
+ vmax=values.max(),
59
+ vmin=values.min(),
60
  cmap=sns.color_palette("vlag_r", as_cmap=True),
61
  aspect="auto",
62
  )
 
64
  # Creating colorbar
65
  cbar = ax.figure.colorbar(im, ax=ax)
66
  cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
67
+ cbar.ax.yaxis.set_tick_params(color="black")
68
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
69
 
70
  # Setting ticks and labels with white color for visibility
71
+ ax.set_yticks(np.arange(len(input_names)), labels=input_names)
72
+ ax.set_xticks(np.arange(len(output_names)), labels=output_names)
73
+ plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
74
+ plt.setp(ax.get_yticklabels(), color="black")
75
 
76
  # Adjusting tick labels
77
  ax.tick_params(
78
  top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
79
  )
80
 
81
+ # Adding text annotations with appropriate contrast
82
+ for i in range(values.shape[0]):
83
+ for j in range(values.shape[1]):
84
+ val = values[i, j]
85
+ color = "white" if im.norm(values.max()) / 2 > im.norm(val) else "black"
86
+ ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
87
 
88
  return plt
explanation/visualize.py CHANGED
@@ -13,6 +13,7 @@ from utils import formatting as fmt
13
  # plotting function that plots the attention values in a heatmap
14
  def chat_explained(model, prompt):
15
 
 
16
  model.set_config()
17
 
18
  # get encoded input and output vectors
@@ -20,6 +21,8 @@ def chat_explained(model, prompt):
20
  prompt, return_tensors="pt", add_special_tokens=True
21
  ).input_ids
22
  decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
 
 
23
  encoder_text = fmt.format_tokens(
24
  model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
25
  )
@@ -37,11 +40,20 @@ def chat_explained(model, prompt):
37
  # create the response text, graphic and plot
38
  response_text = fmt.format_output_text(decoder_text)
39
  graphic = create_graphic(attention_output, (encoder_text, decoder_text))
 
 
 
 
40
  plot = create_plot(attention_output, (encoder_text, decoder_text))
41
- return response_text, graphic, plot
 
 
 
 
42
 
43
 
44
  # creating a html graphic using BERTViz
 
45
  def create_graphic(attention_output, enc_dec_texts: tuple):
46
 
47
  # calls the head_view function of BERTViz to return html graphic
@@ -58,27 +70,28 @@ def create_graphic(attention_output, enc_dec_texts: tuple):
58
 
59
 
60
  # creating an attention heatmap plot using seaborn
 
 
 
 
61
  def create_plot(attention_output, enc_dec_texts: tuple):
62
  # get the averaged attention weights
63
  attention = attention_output.cross_attentions[0][0].detach().numpy()
64
  averaged_attention_weights = np.mean(attention, axis=0)
 
65
 
66
- # get the encoder and decoder tokens
67
  encoder_tokens = enc_dec_texts[0]
68
  decoder_tokens = enc_dec_texts[1]
69
 
70
  # set seaborn style to dark and initialize figure and axis
71
- sns.set(style="dark")
72
  fig, ax = plt.subplots()
73
 
74
- # Making background transparent
75
- ax.set_alpha(0)
76
- fig.patch.set_alpha(0)
77
-
78
  # Setting figure size
79
  fig.set_size_inches(
80
  max(averaged_attention_weights.shape[1] * 2, 10),
81
- max(averaged_attention_weights.shape[0] / 1.5, 5),
82
  )
83
 
84
  # Plotting the heatmap with seaborn's color palette
@@ -92,19 +105,27 @@ def create_plot(attention_output, enc_dec_texts: tuple):
92
 
93
  # Creating colorbar
94
  cbar = ax.figure.colorbar(im, ax=ax)
95
- cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
96
- cbar.ax.yaxis.set_tick_params(color="white")
97
- plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
98
-
99
- # Setting ticks and labels with white color for visibility
100
- ax.set_xticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
101
- ax.set_yticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
102
- plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
103
- plt.setp(ax.get_yticklabels(), color="white")
104
-
105
- # Adjusting tick labels
106
- ax.tick_params(
107
- top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
108
- )
109
-
 
 
 
 
 
 
 
 
110
  return plt
 
13
  # plotting function that plots the attention values in a heatmap
14
  def chat_explained(model, prompt):
15
 
16
+ # reset the model config to default
17
  model.set_config()
18
 
19
  # get encoded input and output vectors
 
21
  prompt, return_tensors="pt", add_special_tokens=True
22
  ).input_ids
23
  decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
24
+
25
+ # get the text for the input and output vectors
26
  encoder_text = fmt.format_tokens(
27
  model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
28
  )
 
40
  # create the response text, graphic and plot
41
  response_text = fmt.format_output_text(decoder_text)
42
  graphic = create_graphic(attention_output, (encoder_text, decoder_text))
43
+ graphic = (
44
+ '<div style="text-align: center"><h4>Interactive Graphic for Attention'
45
+ " currently WIP</h4></div>"
46
+ )
47
  plot = create_plot(attention_output, (encoder_text, decoder_text))
48
+ return (
49
+ response_text,
50
+ graphic,
51
+ plot,
52
+ )
53
 
54
 
55
  # creating a html graphic using BERTViz
56
+ # TODO: FIX
57
  def create_graphic(attention_output, enc_dec_texts: tuple):
58
 
59
  # calls the head_view function of BERTViz to return html graphic
 
70
 
71
 
72
  # creating an attention heatmap plot using seaborn
73
+ # CREDIT: adopted from official Matplotlib documentation
74
+ ## see https://matplotlib.org/stable/
75
+
76
+
77
  def create_plot(attention_output, enc_dec_texts: tuple):
78
  # get the averaged attention weights
79
  attention = attention_output.cross_attentions[0][0].detach().numpy()
80
  averaged_attention_weights = np.mean(attention, axis=0)
81
+ averaged_attention_weights = np.transpose(averaged_attention_weights)
82
 
83
+ # get the encoder and decoder tokens in text form
84
  encoder_tokens = enc_dec_texts[0]
85
  decoder_tokens = enc_dec_texts[1]
86
 
87
  # set seaborn style to dark and initialize figure and axis
88
+ sns.set(style="white")
89
  fig, ax = plt.subplots()
90
 
 
 
 
 
91
  # Setting figure size
92
  fig.set_size_inches(
93
  max(averaged_attention_weights.shape[1] * 2, 10),
94
+ max(averaged_attention_weights.shape[0] * 1, 5),
95
  )
96
 
97
  # Plotting the heatmap with seaborn's color palette
 
105
 
106
  # Creating colorbar
107
  cbar = ax.figure.colorbar(im, ax=ax)
108
+ cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom")
109
+ cbar.ax.yaxis.set_tick_params(color="black")
110
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
111
+
112
+ # Setting ticks and labels with black color for visibility
113
+ ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
114
+ ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
115
+ ax.set_title("Attention Weights by Token")
116
+ plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
117
+ plt.setp(ax.get_yticklabels(), color="black")
118
+
119
+ # Adding text annotations with appropriate contrast
120
+ for i in range(averaged_attention_weights.shape[0]):
121
+ for j in range(averaged_attention_weights.shape[1]):
122
+ val = averaged_attention_weights[i, j]
123
+ color = (
124
+ "white"
125
+ if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val)
126
+ else "black"
127
+ )
128
+ ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
129
+
130
+ # return the plot
131
  return plt
main.py CHANGED
@@ -187,7 +187,7 @@ with gr.Blocks(
187
  Values have been excluded for readability. See colorbar for value indication.
188
  """)
189
  # plot component that takes a matplotlib figure as input
190
- xai_plot = gr.Plot(label="Token Level Explanation", scale=3)
191
 
192
  # functions to trigger the controller
193
  ## takes information for the chat and the xai selection
@@ -205,10 +205,12 @@ with gr.Blocks(
205
  [user_prompt, chatbot, xai_interactive, xai_plot],
206
  )
207
 
208
- # final row to show legal information
209
- ## - credits, data protection and link to the License
210
- with gr.Tab(label="Credits, Data Protection and License"):
211
- gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
 
 
212
 
213
  # mount function for fastAPI Application
214
  app = gr.mount_gradio_app(app, ui, path="/")
 
187
  Values have been excluded for readability. See colorbar for value indication.
188
  """)
189
  # plot component that takes a matplotlib figure as input
190
+ xai_plot = gr.Plot(label="Token Level Explanation")
191
 
192
  # functions to trigger the controller
193
  ## takes information for the chat and the xai selection
 
205
  [user_prompt, chatbot, xai_interactive, xai_plot],
206
  )
207
 
208
+ # final row to about information
209
+ ## and credits, data protection and link to the License
210
+ with gr.Tab(label="About"):
211
+ gr.Markdown(value=load_md("public/about.md"))
212
+ with gr.Accordion(label="Credits, Data Protection and License", open=False):
213
+ gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
214
 
215
  # mount function for fastAPI Application
216
  app = gr.mount_gradio_app(app, ui, path="/")
public/about.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # About
2
+
3
+ This is a non-commercial research projects as part of a Bachelor Thesis with the topic **"Building an Interpretable Natural Language AI Tool based on Transformer Models and Approaches of Explainable AI".**
4
+
5
+ This research tackles the rise of LLM based applications such a chatbots and explores the possibilities of model interpretation and explainability. The goal is to build a tool that can be used to explain the predictions of a LLM based chatbot.
6
+
7
+ ## Implementation
8
+
9
+ This project is an implementation of PartitionSHAP and BERTViz into GODEL by Microsoft - [GODEL Model](https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq) which is a generative seq2seq transformer fine-tuned for goal directed dialog. It supports context and knowledge base inputs.
10
+
11
+ The UI is build with Gradio.
12
+
13
+ ## Usage
14
+
15
+ You can chat with the model by entering a message into the input field and pressing enter. The model will then generate a response. You can also enter a context and knowledge base by clicking on the respective buttons and entering the data into the input fields. The model will then generate a response based on the context and knowledge base.
16
+
17
+ To explore explanations, chose one of the explanations methods (HINT: The runtime can increase significantly). Then keep on chatting and explore the explanations in the respective tab.
18
+
19
+ ### Self Hosted
20
+
21
+ You can run this application locally by cloning this repository, setting up a python environment and installing the requirements. Then run the `app.py` file or use "uvicorn main:app --reload" in the *python terminal*.
22
+
23
+ For self-hosting you can use the Dockerfile to build a docker image and run it locally or directly use the provided docker image on the [GitHub page](https://github.com/lennardzuendorf/thesis-webapp/).
24
+
25
+ ## Credit & License
26
+ This Product is licensed under the MIT license. See [LICENSE](https://github.com/LennardZuendorf/thesis-webapp/blob/main/LICENSE.md) at GitHub for more information.
27
+
28
+ Please credit the original authors of this project (Lennard Zündorf) and the credits listed below if you use this project or parts of it in your own work.
29
+
30
+ ## Contact
31
+
32
+ ### Author
33
+
34
+ - Lennard Zündorf
35
+ - lennard.zuendorf@student.htw-berlin.de
36
+ - [GitHub](https://github.com/LennardZuendorf)
37
+ - [LinkedIn](https://www.zuendorf.me/linkd)
38
+
39
+
40
+ #### University
41
+ Hochschule für Technik und Wirtschaft Berlin (HTW Berlin) - University of Applied Sciences for Engineering and Economics Berlin
42
+
43
+ 1. Supervisor: Prof. Dr. Katarina Simbeck
44
+ 2. Supervisor: Prof. Dr. Axel Hochstein
public/credits_dataprotection_license.md CHANGED
@@ -6,7 +6,6 @@
6
 
7
 
8
  # Credits
9
- For full credits, please refer to the [thesis print]()
10
 
11
  ### Models
12
  This implementation is build on GODEL by Microsoft, Inc.
 
6
 
7
 
8
  # Credits
 
9
 
10
  ### Models
11
  This implementation is build on GODEL by Microsoft, Inc.