LennardZuendorf commited on
Commit
c7e16d0
1 Parent(s): b0721f8

fix: fixing attention visualization

Browse files
explanation/attention.py CHANGED
@@ -3,18 +3,22 @@
3
 
4
  # internal imports
5
  from utils import formatting as fmt
 
6
  from .markup import markup_text
7
 
 
8
  # chat function that returns an answer
9
  # and marked text based on attention
10
  def chat_explained(model, prompt):
11
 
 
 
12
  # get encoded input
13
  encoder_input_ids = model.TOKENIZER(
14
  prompt, return_tensors="pt", add_special_tokens=True
15
  ).input_ids
16
  # generate output together with attentions of the model
17
- decoder_input_ids = model.MODEL.generate(
18
  encoder_input_ids, output_attentions=True, generation_config=model.CONFIG
19
  )
20
 
@@ -26,16 +30,24 @@ def chat_explained(model, prompt):
26
  model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
27
  )
28
 
29
- # get attention values for the input and output vectors
30
- # using already generated input and output
31
- attention_output = model.MODEL(
32
- input_ids=encoder_input_ids,
33
- decoder_input_ids=decoder_input_ids,
34
- output_attentions=True,
35
- )
 
 
 
 
 
 
 
36
 
37
- # averaging attention across layers
38
- averaged_attention = fmt.avg_attention(attention_output)
 
39
 
40
  # format response text for clean output
41
  response_text = fmt.format_output_text(decoder_text)
 
3
 
4
  # internal imports
5
  from utils import formatting as fmt
6
+ from model import godel
7
  from .markup import markup_text
8
 
9
+
10
  # chat function that returns an answer
11
  # and marked text based on attention
12
  def chat_explained(model, prompt):
13
 
14
+ model.set_config({"return_dict": True})
15
+
16
  # get encoded input
17
  encoder_input_ids = model.TOKENIZER(
18
  prompt, return_tensors="pt", add_special_tokens=True
19
  ).input_ids
20
  # generate output together with attentions of the model
21
+ decoder_input_ids = model.MODEL(
22
  encoder_input_ids, output_attentions=True, generation_config=model.CONFIG
23
  )
24
 
 
30
  model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
31
  )
32
 
33
+ # getting attention if model is godel
34
+ if isinstance(model, godel):
35
+ print("attention.py: Model detected to be GODEL")
36
+
37
+ # get attention values for the input and output vectors
38
+ # using already generated input and output
39
+ attention_output = model.MODEL(
40
+ input_ids=encoder_input_ids,
41
+ decoder_input_ids=decoder_input_ids,
42
+ output_attentions=True,
43
+ )
44
+
45
+ # averaging attention across layers
46
+ averaged_attention = fmt.avg_attention(attention_output)
47
 
48
+ # getting attention is model is mistral
49
+ else:
50
+ averaged_attention = fmt.avg_attention(decoder_input_ids)
51
 
52
  # format response text for clean output
53
  response_text = fmt.format_output_text(decoder_text)
explanation/interpret_captum.py CHANGED
@@ -45,11 +45,9 @@ def chat_explained(model, prompt):
45
 
46
  # getting response text, graphic placeholder and marked text object
47
  response_text = fmt.format_output_text(attribution_result.output_tokens)
48
- graphic = (
49
- """<div style='text-align: center; font-family:arial;'><h4>
50
  Intepretation with Captum doesn't support an interactive graphic.</h4></div>
51
  """
52
- )
53
  marked_text = markup_text(input_tokens, values, variant="captum")
54
 
55
  # return response, graphic and marked_text array
 
45
 
46
  # getting response text, graphic placeholder and marked text object
47
  response_text = fmt.format_output_text(attribution_result.output_tokens)
48
+ graphic = """<div style='text-align: center; font-family:arial;'><h4>
 
49
  Intepretation with Captum doesn't support an interactive graphic.</h4></div>
50
  """
 
51
  marked_text = markup_text(input_tokens, values, variant="captum")
52
 
53
  # return response, graphic and marked_text array
explanation/interpret_shap.py CHANGED
@@ -32,7 +32,7 @@ def wrap_shap(model):
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
 
34
  # updating the model settings
35
- model.set_config()
36
 
37
  # (re)initialize the shap models and masker
38
  # creating a shap text_generation model
 
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
 
34
  # updating the model settings
35
+ model.set_config({})
36
 
37
  # (re)initialize the shap models and masker
38
  # creating a shap text_generation model
main.py CHANGED
@@ -110,11 +110,10 @@ with gr.Blocks(
110
  label="System Prompt",
111
  info="Set the models system prompt, dictating how it answers.",
112
  # default system prompt is set to this in the backend
113
- placeholder=("""
114
  You are a helpful, respectful and honest assistant. Always
115
  answer as helpfully as possible, while being safe.
116
- """
117
- ),
118
  )
119
  # column that takes up 1/4 of the row
120
  with gr.Column(scale=1):
@@ -122,7 +121,9 @@ with gr.Blocks(
122
  xai_selection = gr.Radio(
123
  ["None", "SHAP", "Attention"],
124
  label="Interpretability Settings",
125
- info="Select a Interpretability Approach Implementation to use.",
 
 
126
  value="None",
127
  interactive=True,
128
  show_label=True,
@@ -209,10 +210,15 @@ with gr.Blocks(
209
  gr.Examples(
210
  label="Example Questions",
211
  examples=[
212
- ["Does money buy happiness?", "Mistral", "SHAP"],
213
- ["Does money buy happiness?", "Mistral", "Attention"],
 
 
 
 
 
 
214
  ],
215
- inputs=[user_prompt, model_selection, xai_selection],
216
  )
217
  with gr.Accordion("GODEL Model Examples", open=False):
218
  # examples util component
 
110
  label="System Prompt",
111
  info="Set the models system prompt, dictating how it answers.",
112
  # default system prompt is set to this in the backend
113
+ placeholder="""
114
  You are a helpful, respectful and honest assistant. Always
115
  answer as helpfully as possible, while being safe.
116
+ """,
 
117
  )
118
  # column that takes up 1/4 of the row
119
  with gr.Column(scale=1):
 
121
  xai_selection = gr.Radio(
122
  ["None", "SHAP", "Attention"],
123
  label="Interpretability Settings",
124
+ info=(
125
+ "Select a Interpretability Approach Implementation to use."
126
+ ),
127
  value="None",
128
  interactive=True,
129
  show_label=True,
 
210
  gr.Examples(
211
  label="Example Questions",
212
  examples=[
213
+ ["Does money buy happiness?", "", "Mistral", "SHAP"],
214
+ ["Does money buy happiness?", "", "Mistral", "Attention"],
215
+ ],
216
+ inputs=[
217
+ user_prompt,
218
+ knowledge_input,
219
+ model_selection,
220
+ xai_selection,
221
  ],
 
222
  )
223
  with gr.Accordion("GODEL Model Examples", open=False):
224
  # examples util component
model/godel.py CHANGED
@@ -13,7 +13,12 @@ MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2se
13
 
14
  # model config definition
15
  CONFIG = GenerationConfig.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
16
- base_config_dict = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
 
 
 
 
 
17
  CONFIG.update(**base_config_dict)
18
 
19
 
@@ -59,11 +64,13 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
59
  # CREDIT: Copied from official interference example on Huggingface
60
  ## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
61
  def respond(prompt):
 
 
62
  # tokenizing input string
63
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
64
 
65
  # generating using config and decoding output
66
- outputs = MODEL.generate(input_ids,generation_config=CONFIG)
67
  output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
68
 
69
  # returns the model output string
 
13
 
14
  # model config definition
15
  CONFIG = GenerationConfig.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
16
+ base_config_dict = {
17
+ "max_new_tokens": 50,
18
+ "min_length": 8,
19
+ "top_p": 0.9,
20
+ "do_sample": True,
21
+ }
22
  CONFIG.update(**base_config_dict)
23
 
24
 
 
64
  # CREDIT: Copied from official interference example on Huggingface
65
  ## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
66
  def respond(prompt):
67
+ set_config({})
68
+
69
  # tokenizing input string
70
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
71
 
72
  # generating using config and decoding output
73
+ outputs = MODEL.generate(input_ids, generation_config=CONFIG)
74
  output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
75
 
76
  # returns the model output string
model/mistral.py CHANGED
@@ -110,6 +110,7 @@ def format_answer(answer: str):
110
 
111
 
112
  def respond(prompt: str):
 
113
 
114
  # tokenizing inputs and configuring model
115
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"].to(device)
 
110
 
111
 
112
  def respond(prompt: str):
113
+ set_config({})
114
 
115
  # tokenizing inputs and configuring model
116
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"].to(device)