LennardZuendorf commited on
Commit
f301e04
1 Parent(s): 67a34bd

fix: fixing for plotting and attention visualization

Browse files
backend/controller.py CHANGED
@@ -14,8 +14,47 @@ from explanation import (
14
  )
15
 
16
 
17
- # main interference function that that calls chat functions depending on selections
18
- # is getting called on every chat submit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def interference(
20
  prompt: str,
21
  history: list,
@@ -31,6 +70,7 @@ def interference(
31
  Always answer as helpfully as possible, while being safe.
32
  """
33
 
 
34
  if model_selection.lower() == "mistral":
35
  model = mistral
36
  print("Indentified model as Mistral")
@@ -39,6 +79,7 @@ def interference(
39
  print("Indentified model as GODEL")
40
 
41
  # if a XAI approach is selected, grab the XAI module instance
 
42
  if xai_selection in ("SHAP", "Attention"):
43
  # matching selection
44
  match xai_selection.lower():
@@ -71,7 +112,7 @@ def interference(
71
  )
72
  # if no XAI approach is selected call the vanilla chat function
73
  else:
74
- # call the vanilla chat function
75
  prompt_output, history_output = vanilla_chat(
76
  model=model,
77
  message=prompt,
@@ -91,43 +132,3 @@ def interference(
91
 
92
  # return the outputs
93
  return prompt_output, history_output, xai_interactive, xai_markup, xai_plot
94
-
95
-
96
- # simple chat function that calls the model
97
- # formats prompts, calls for an answer and returns updated conversation history
98
- def vanilla_chat(
99
- model, message: str, history: list, system_prompt: str, knowledge: str = ""
100
- ):
101
- print(f"Running normal chat with {model}.")
102
-
103
- # formatting the prompt using the model's format_prompt function
104
- prompt = model.format_prompt(message, history, system_prompt, knowledge)
105
-
106
- # generating an answer using the model's respond function
107
- answer = model.respond(prompt)
108
-
109
- # updating the chat history with the new answer
110
- history.append((message, answer))
111
- # returning the updated history
112
- return "", history
113
-
114
-
115
- def explained_chat(
116
- model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
117
- ):
118
- print(f"Running explained chat with {xai} with {model}.")
119
-
120
- # formatting the prompt using the model's format_prompt function
121
- # message, history, system_prompt, knowledge = mdl.prompt_limiter(
122
- # message, history, system_prompt, knowledge
123
- # )
124
- prompt = model.format_prompt(message, history, system_prompt, knowledge)
125
-
126
- # generating an answer using the methods chat function
127
- answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
128
-
129
- # updating the chat history with the new answer
130
- history.append((message, answer))
131
-
132
- # returning the updated history, xai graphic and xai plot elements
133
- return "", history, xai_graphic, xai_markup, xai_plot
 
14
  )
15
 
16
 
17
+ # simple chat function that calls the model
18
+ # formats prompts, calls for an answer and returns updated conversation history
19
+ def vanilla_chat(
20
+ model, message: str, history: list, system_prompt: str, knowledge: str = ""
21
+ ):
22
+ print(f"Running normal chat with {model}.")
23
+
24
+ # formatting the prompt using the model's format_prompt function
25
+ prompt = model.format_prompt(message, history, system_prompt, knowledge)
26
+
27
+ # generating an answer using the model's respond function
28
+ answer = model.respond(prompt)
29
+
30
+ # updating the chat history with the new answer
31
+ history.append((message, answer))
32
+ # returning the updated history
33
+ return "", history
34
+
35
+
36
+ def explained_chat(
37
+ model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
38
+ ):
39
+ print(f"Running explained chat with {xai} with {model}.")
40
+
41
+ # formatting the prompt using the model's format_prompt function
42
+ # message, history, system_prompt, knowledge = mdl.prompt_limiter(
43
+ # message, history, system_prompt, knowledge
44
+ # )
45
+ prompt = model.format_prompt(message, history, system_prompt, knowledge)
46
+
47
+ # generating an answer using the methods chat function
48
+ answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
49
+
50
+ # updating the chat history with the new answer
51
+ history.append((message, answer))
52
+
53
+ # returning the updated history, xai graphic and xai plot elements
54
+ return "", history, xai_graphic, xai_markup, xai_plot
55
+
56
+
57
+ # main interference function that calls chat functions depending on selections
58
  def interference(
59
  prompt: str,
60
  history: list,
 
70
  Always answer as helpfully as possible, while being safe.
71
  """
72
 
73
+ # if a model is selected, grab the model instance
74
  if model_selection.lower() == "mistral":
75
  model = mistral
76
  print("Indentified model as Mistral")
 
79
  print("Indentified model as GODEL")
80
 
81
  # if a XAI approach is selected, grab the XAI module instance
82
+ # and call the explained chat function
83
  if xai_selection in ("SHAP", "Attention"):
84
  # matching selection
85
  match xai_selection.lower():
 
112
  )
113
  # if no XAI approach is selected call the vanilla chat function
114
  else:
115
+ # calling the vanilla chat function
116
  prompt_output, history_output = vanilla_chat(
117
  model=model,
118
  message=prompt,
 
132
 
133
  # return the outputs
134
  return prompt_output, history_output, xai_interactive, xai_markup, xai_plot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explanation/attention.py CHANGED
@@ -28,14 +28,14 @@ def chat_explained(model, prompt):
28
  # checking if model is mistral
29
  if type(model.MODEL) == type(mistral.MODEL):
30
 
31
- # get attention values for the input vectors
32
  attention_output = model.MODEL(input_ids, output_attentions=True).attentions
33
 
34
  # averaging attention across layers and heads
35
  attention_output = mdl.format_mistral_attention(attention_output)
36
  averaged_attention = fmt.avg_attention(attention_output, model="mistral")
37
 
38
- # attention visualization for godel
39
  else:
40
  # get attention values for the input and output vectors
41
  # using already generated input and output
 
28
  # checking if model is mistral
29
  if type(model.MODEL) == type(mistral.MODEL):
30
 
31
+ # get attention values for the input vectors, specific to mistral
32
  attention_output = model.MODEL(input_ids, output_attentions=True).attentions
33
 
34
  # averaging attention across layers and heads
35
  attention_output = mdl.format_mistral_attention(attention_output)
36
  averaged_attention = fmt.avg_attention(attention_output, model="mistral")
37
 
38
+ # otherwise use attention visualization for godel
39
  else:
40
  # get attention values for the input and output vectors
41
  # using already generated input and output
explanation/plotting.py CHANGED
@@ -12,7 +12,6 @@ def plot_seq(seq_values: list, method: str = ""):
12
 
13
  # Convert importance values to numpy array for conditional coloring
14
  importance = np.array(importance)
15
- importance = importance.log
16
 
17
  # Determine the colors based on the sign of the importance values
18
  colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]
@@ -22,9 +21,8 @@ def plot_seq(seq_values: list, method: str = ""):
22
  x_positions = range(len(tokens)) # Positions for the bars
23
 
24
  # Creating vertical bar plot
25
- bar_width = 0.8 # Increase this value to make the bars wider
26
  plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)
27
- plt.yscale("symlog")
28
 
29
  # Annotating each bar with its value
30
  padding = 0.1 # Padding for text annotation
 
12
 
13
  # Convert importance values to numpy array for conditional coloring
14
  importance = np.array(importance)
 
15
 
16
  # Determine the colors based on the sign of the importance values
17
  colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]
 
21
  x_positions = range(len(tokens)) # Positions for the bars
22
 
23
  # Creating vertical bar plot
24
+ bar_width = 0.8
25
  plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)
 
26
 
27
  # Annotating each bar with its value
28
  padding = 0.1 # Padding for text annotation
model/mistral.py CHANGED
@@ -31,7 +31,6 @@ CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
31
  base_config_dict = {
32
  "temperature": 0.7,
33
  "max_new_tokens": 64,
34
- "max_length": 64,
35
  "top_p": 0.9,
36
  "repetition_penalty": 1.2,
37
  "do_sample": True,
 
31
  base_config_dict = {
32
  "temperature": 0.7,
33
  "max_new_tokens": 64,
 
34
  "top_p": 0.9,
35
  "repetition_penalty": 1.2,
36
  "do_sample": True,
utils/formatting.py CHANGED
@@ -88,11 +88,19 @@ def flatten_attention(values: ndarray, axis: int = 0):
88
 
89
  # function to get averaged decoder attention from attention values
90
  def avg_attention(attention_values, model: str):
 
91
  # check if model is godel
92
  if model == "godel":
93
  # get attention values for the input and output vectors
94
  attention = attention_values.decoder_attentions[0][0].detach().numpy()
95
  return np.mean(attention, axis=0)
 
96
  # extracting attention values for mistral
97
- attention_np = attention_values.to(torch.device("cpu")).detach().numpy()
98
- return np.mean(attention_np, axis=(0, 1, 2))
 
 
 
 
 
 
 
88
 
89
  # function to get averaged decoder attention from attention values
90
  def avg_attention(attention_values, model: str):
91
+
92
  # check if model is godel
93
  if model == "godel":
94
  # get attention values for the input and output vectors
95
  attention = attention_values.decoder_attentions[0][0].detach().numpy()
96
  return np.mean(attention, axis=0)
97
+
98
  # extracting attention values for mistral
99
+ attention = attention_values.to(torch.device("cpu")).detach().numpy()
100
+
101
+ # removing the last dimension and transposing to get the correct shape
102
+ attention = attention[:, :, :, 0]
103
+ attention = attention.transpose
104
+
105
+ # return the averaged attention values
106
+ return np.mean(attention, axis=1)
utils/modelling.py CHANGED
@@ -107,4 +107,4 @@ def format_mistral_attention(attention_values):
107
  for layer_attention in attention_values:
108
  layer_attention = layer_attention.squeeze(0)
109
  squeezed.append(layer_attention)
110
- return torch.stack(squeezed)
 
107
  for layer_attention in attention_values:
108
  layer_attention = layer_attention.squeeze(0)
109
  squeezed.append(layer_attention)
110
+ return torch.stack(squeezed).to(torch.device("cpu"))