LennardZuendorf commited on
Commit
517fd4c
1 Parent(s): 1f063be

fix: final fix of attention

Browse files
explanation/attention.py CHANGED
@@ -37,8 +37,8 @@ def chat_explained(model, prompt):
37
  attention_output = mdl.format_mistral_attention(attention_output)
38
  averaged_attention = fmt.avg_attention(attention_output, model="mistral")
39
 
40
- response_text = fmt.format_output_text(output_text)
41
- response_text = mistral.format_answer(response_text)
42
 
43
  # otherwise use attention visualization for godel
44
  else:
 
37
  attention_output = mdl.format_mistral_attention(attention_output)
38
  averaged_attention = fmt.avg_attention(attention_output, model="mistral")
39
 
40
+ output_text = fmt.format_output_text(output_text)
41
+ response_text = mistral.format_answer(output_text)
42
 
43
  # otherwise use attention visualization for godel
44
  else:
explanation/markup.py CHANGED
@@ -10,7 +10,7 @@ from utils import formatting as fmt
10
 
11
  # main function that assigns each text snipped a marked bucket
12
  def markup_text(input_text: list, text_values: ndarray, variant: str):
13
- print(f"Marking up text {input_text} and {text_values} for {variant}.")
14
 
15
  # naming of the 11 buckets
16
  bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
 
10
 
11
  # main function that assigns each text snipped a marked bucket
12
  def markup_text(input_text: list, text_values: ndarray, variant: str):
13
+ print(f"Marking up text {input_text} for {variant}.")
14
 
15
  # naming of the 11 buckets
16
  bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
main.py CHANGED
@@ -252,8 +252,8 @@ with gr.Blocks(
252
  ],
253
  inputs=[
254
  user_prompt,
255
- system_prompt,
256
  xai_selection,
 
257
  model_selection,
258
  knowledge_input,
259
  ],
@@ -266,6 +266,7 @@ with gr.Blocks(
266
  examples=[
267
  [
268
  "Does money buy happiness?",
 
269
  (
270
  "Some studies have found a correlation between income"
271
  " and happiness, but this relationship often has"
@@ -275,10 +276,10 @@ with gr.Blocks(
275
  ),
276
  "",
277
  "GODEL",
278
- "SHAP",
279
  ],
280
  [
281
  "Does money buy happiness?",
 
282
  (
283
  "Some studies have found a correlation between income"
284
  " and happiness, but this relationship often has"
@@ -288,14 +289,13 @@ with gr.Blocks(
288
  ),
289
  "",
290
  "GODEL",
291
- "Attention",
292
  ],
293
  [
294
  "Does money buy happiness?",
 
295
  "",
296
  "",
297
  "GODEL",
298
- "Attention",
299
  ],
300
  ],
301
  inputs=[
 
252
  ],
253
  inputs=[
254
  user_prompt,
 
255
  xai_selection,
256
+ system_prompt,
257
  model_selection,
258
  knowledge_input,
259
  ],
 
266
  examples=[
267
  [
268
  "Does money buy happiness?",
269
+ "SHAP",
270
  (
271
  "Some studies have found a correlation between income"
272
  " and happiness, but this relationship often has"
 
276
  ),
277
  "",
278
  "GODEL",
 
279
  ],
280
  [
281
  "Does money buy happiness?",
282
+ "Attention",
283
  (
284
  "Some studies have found a correlation between income"
285
  " and happiness, but this relationship often has"
 
289
  ),
290
  "",
291
  "GODEL",
 
292
  ],
293
  [
294
  "Does money buy happiness?",
295
+ "Attention",
296
  "",
297
  "",
298
  "GODEL",
 
299
  ],
300
  ],
301
  inputs=[
model/mistral.py CHANGED
@@ -32,12 +32,11 @@ TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
32
  # default model config
33
  CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
34
  base_config_dict = {
35
- "temperature": 0.7,
36
- "max_new_tokens": 64,
37
  "top_p": 0.9,
38
  "repetition_penalty": 1.2,
39
  "do_sample": True,
40
- "seed": 42,
41
  }
42
  CONFIG.update(**base_config_dict)
43
 
 
32
  # default model config
33
  CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
34
  base_config_dict = {
35
+ "temperature": 1,
36
+ "max_new_tokens": 100,
37
  "top_p": 0.9,
38
  "repetition_penalty": 1.2,
39
  "do_sample": True,
 
40
  }
41
  CONFIG.update(**base_config_dict)
42
 
utils/formatting.py CHANGED
@@ -92,15 +92,14 @@ def avg_attention(attention_values, model: str):
92
  # check if model is godel
93
  if model == "godel":
94
  # get attention values for the input and output vectors
95
- attention = attention_values.decoder_attentions[0][0].detach().numpy()
96
- return np.mean(attention, axis=0)
97
 
98
  # extracting attention values for mistral
99
  attention = attention_values.to(torch.device("cpu")).detach().numpy()
100
 
101
  # removing the last dimension and transposing to get the correct shape
102
  attention = attention[:, :, :, 0]
103
- attention = attention.transpose()
104
 
105
  # return the averaged attention values
106
  return np.mean(attention, axis=1)
 
92
  # check if model is godel
93
  if model == "godel":
94
  # get attention values for the input and output vectors
95
+ attention = attention_values.encoder_attentions[0][0].detach().numpy()
96
+ return np.mean(attention, axis=1)
97
 
98
  # extracting attention values for mistral
99
  attention = attention_values.to(torch.device("cpu")).detach().numpy()
100
 
101
  # removing the last dimension and transposing to get the correct shape
102
  attention = attention[:, :, :, 0]
 
103
 
104
  # return the averaged attention values
105
  return np.mean(attention, axis=1)
utils/modelling.py CHANGED
@@ -100,11 +100,15 @@ def gpu_loading_config(max_memory: str = "15000MB"):
100
 
101
 
102
  # formatting mistral attention values
103
- # CREDIT: copied and adapted from BERTViz
104
  # see https://github.com/jessevig/bertviz
105
- def format_mistral_attention(attention_values):
 
 
106
  squeezed = []
107
  for layer_attention in attention_values:
108
  layer_attention = layer_attention.squeeze(0)
 
 
109
  squeezed.append(layer_attention)
110
- return torch.stack(squeezed).to(torch.device("cpu"))
 
100
 
101
 
102
  # formatting mistral attention values
103
+ # CREDIT: copied from BERTViz
104
  # see https://github.com/jessevig/bertviz
105
+ def format_mistral_attention(attention_values, layers=None, heads=None):
106
+ if layers:
107
+ attention_values = [attention_values[layer_index] for layer_index in layers]
108
  squeezed = []
109
  for layer_attention in attention_values:
110
  layer_attention = layer_attention.squeeze(0)
111
+ if heads:
112
+ layer_attention = layer_attention[heads]
113
  squeezed.append(layer_attention)
114
+ return torch.stack(squeezed)