LennardZuendorf commited on
Commit
1c4497c
1 Parent(s): 49066ce

fix: removing attention viz graphic again

Browse files
Files changed (2) hide show
  1. explanation/visualize.py +5 -29
  2. utils/formatting.py +1 -11
explanation/visualize.py CHANGED
@@ -1,10 +1,6 @@
1
  # visualization module that creates an attention visualization using BERTViz
2
 
3
 
4
- # external imports
5
- from bertviz import neuron_view as nv
6
-
7
-
8
  # internal imports
9
  from utils import formatting as fmt
10
  from .markup import markup_text
@@ -38,30 +34,10 @@ def chat_explained(model, prompt):
38
 
39
  # create the response text and marked text for ui
40
  response_text = fmt.format_output_text(decoder_text)
41
- xai_graphic = attention_graphic(encoder_text, decoder_text, model)
 
 
 
42
  marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
43
 
44
- return response_text, xai_graphic, marked_text
45
-
46
-
47
- def attention_graphic(encoder_text, decoder_text, model):
48
-
49
- # set model type to BERT (to fake out BERTViz)
50
- model_type = "bert"
51
-
52
- # create sentence a and b from list of strings
53
- sentence_a = " ".join(encoder_text)
54
- sentence_b = " ".join(decoder_text)
55
-
56
- # display neuron view
57
- return nv.show(
58
- model.MODEL,
59
- model_type,
60
- model.TOKENIZER,
61
- sentence_a,
62
- sentence_b,
63
- display_mode="light",
64
- layer=2,
65
- head=0,
66
- html_action="return",
67
- )
 
1
  # visualization module that creates an attention visualization using BERTViz
2
 
3
 
 
 
 
 
4
  # internal imports
5
  from utils import formatting as fmt
6
  from .markup import markup_text
 
34
 
35
  # create the response text and marked text for ui
36
  response_text = fmt.format_output_text(decoder_text)
37
+ graphic = (
38
+ "<div style='text-align: center; font-family:arial;'><h4>Attention"
39
+ " Visualization doesn't support an interactive graphic.</h4></div>"
40
+ )
41
  marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
42
 
43
+ return response_text, graphic, marked_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/formatting.py CHANGED
@@ -35,17 +35,7 @@ def format_output_text(output: list):
35
  # format the tokens by removing special tokens and special characters
36
  def format_tokens(tokens: list):
37
  # define special tokens to remove and initialize empty list
38
- special_tokens = [
39
- "[CLS]",
40
- "[SEP]",
41
- "[PAD]",
42
- "[UNK]",
43
- "[MASK]",
44
- "▁",
45
- "Ġ",
46
- "</w>",
47
- "/n",
48
- ]
49
  updated_tokens = []
50
 
51
  # loop through tokens
 
35
  # format the tokens by removing special tokens and special characters
36
  def format_tokens(tokens: list):
37
  # define special tokens to remove and initialize empty list
38
+ special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "▁", "Ġ", "</w>"]
 
 
 
 
 
 
 
 
 
 
39
  updated_tokens = []
40
 
41
  # loop through tokens