LennardZuendorf commited on
Commit
226ad46
1 Parent(s): b324c38

fix: fixing various bugs

Browse files
backend/controller.py CHANGED
@@ -8,9 +8,9 @@ import gradio as gr
8
  from model import godel
9
  from model import mistral
10
  from explanation import (
 
11
  interpret_shap as shap_int,
12
  interpret_captum as cpt_int,
13
- visualize_att as viz,
14
  )
15
 
16
 
@@ -48,7 +48,7 @@ def interference(
48
  else:
49
  xai = shap_int
50
  case "attention":
51
- xai = viz
52
  case _:
53
  # use Gradio warning to display error message
54
  gr.Warning(f"""
 
8
  from model import godel
9
  from model import mistral
10
  from explanation import (
11
+ attention as attention_viz,
12
  interpret_shap as shap_int,
13
  interpret_captum as cpt_int,
 
14
  )
15
 
16
 
 
48
  else:
49
  xai = shap_int
50
  case "attention":
51
+ xai = attention_viz
52
  case _:
53
  # use Gradio warning to display error message
54
  gr.Warning(f"""
explanation/attention.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # visualization module that creates an attention visualization
2
+
3
+
4
+ # internal imports
5
+ from utils import formatting as fmt
6
+ from .markup import markup_text
7
+
8
+ # chat function that returns an answer
9
+ # and marked text based on attention
10
+ def chat_explained(model, prompt):
11
+
12
+ # get encoded input
13
+ encoder_input_ids = model.TOKENIZER(
14
+ prompt, return_tensors="pt", add_special_tokens=True
15
+ ).input_ids
16
+ # generate output together with attentions of the model
17
+ decoder_input_ids = model.MODEL.generate(
18
+ encoder_input_ids, output_attentions=True, **model.CONFIG
19
+ )
20
+
21
+ # get input and output text as list of strings
22
+ encoder_text = fmt.format_tokens(
23
+ model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
24
+ )
25
+ decoder_text = fmt.format_tokens(
26
+ model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
27
+ )
28
+
29
+ # get attention values for the input and output vectors
30
+ # using already generated input and output
31
+ attention_output = model.MODEL(
32
+ input_ids=encoder_input_ids,
33
+ decoder_input_ids=decoder_input_ids,
34
+ output_attentions=True,
35
+ )
36
+
37
+ # averaging attention across layers
38
+ averaged_attention = fmt.avg_attention(attention_output)
39
+
40
+ # format response text for clean output
41
+ response_text = fmt.format_output_text(decoder_text)
42
+ # setting placeholder for iFrame graphic
43
+ graphic = (
44
+ "<div style='text-align: center; font-family:arial;'><h4>Attention"
45
+ " Visualization doesn't support an interactive graphic.</h4></div>"
46
+ )
47
+ # creating marked text using markup_text function and attention
48
+ marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
49
+
50
+ # returning response, graphic and marked text array
51
+ return response_text, graphic, marked_text, None
explanation/interpret_captum.py CHANGED
@@ -46,8 +46,9 @@ def chat_explained(model, prompt):
46
  # getting response text, graphic placeholder and marked text object
47
  response_text = fmt.format_output_text(attribution_result.output_tokens)
48
  graphic = (
49
- "<div style='text-align: center; font-family:arial;'><h4>Attention"
50
- "Intepretation with Captum doesn't support an interactive graphic.</h4></div>"
 
51
  )
52
  marked_text = markup_text(input_tokens, values, variant="captum")
53
 
 
46
  # getting response text, graphic placeholder and marked text object
47
  response_text = fmt.format_output_text(attribution_result.output_tokens)
48
  graphic = (
49
+ """<div style='text-align: center; font-family:arial;'><h4>
50
+ Intepretation with Captum doesn't support an interactive graphic.</h4></div>
51
+ """
52
  )
53
  marked_text = markup_text(input_tokens, values, variant="captum")
54
 
explanation/markup.py CHANGED
@@ -71,11 +71,11 @@ def color_codes():
71
  "-4": "#68a1fd",
72
  "-3": "#96b7fe",
73
  "-2": "#bcceff",
74
- "-1:": "#dee6ff",
75
  "0": "#ffffff",
76
- "1": "#ffd9d9",
77
- "2": "#ffb3b5",
78
- "3": "#ff8b92",
79
- "4": "#ff5c71",
80
- "5": "#ff0051",
81
  }
 
71
  "-4": "#68a1fd",
72
  "-3": "#96b7fe",
73
  "-2": "#bcceff",
74
+ "-1": "#dee6ff",
75
  "0": "#ffffff",
76
+ "+1": "#ffd9d9",
77
+ "+2": "#ffb3b5",
78
+ "+3": "#ff8b92",
79
+ "+4": "#ff5c71",
80
+ "+5": "#ff0051",
81
  }
explanation/visualize_att.py DELETED
File without changes