LennardZuendorf commited on
Commit
bfa9c36
1 Parent(s): c7e16d0

fix: fixing bug

Browse files
Files changed (1) hide show
  1. explanation/attention.py +2 -2
explanation/attention.py CHANGED
@@ -18,7 +18,7 @@ def chat_explained(model, prompt):
18
  prompt, return_tensors="pt", add_special_tokens=True
19
  ).input_ids
20
  # generate output together with attentions of the model
21
- decoder_input_ids = model.MODEL(
22
  encoder_input_ids, output_attentions=True, generation_config=model.CONFIG
23
  )
24
 
@@ -36,7 +36,7 @@ def chat_explained(model, prompt):
36
 
37
  # get attention values for the input and output vectors
38
  # using already generated input and output
39
- attention_output = model.MODEL(
40
  input_ids=encoder_input_ids,
41
  decoder_input_ids=decoder_input_ids,
42
  output_attentions=True,
 
18
  prompt, return_tensors="pt", add_special_tokens=True
19
  ).input_ids
20
  # generate output together with attentions of the model
21
+ decoder_input_ids = model.MODEL.generate(
22
  encoder_input_ids, output_attentions=True, generation_config=model.CONFIG
23
  )
24
 
 
36
 
37
  # get attention values for the input and output vectors
38
  # using already generated input and output
39
+ attention_output = model.MODEL.generate(
40
  input_ids=encoder_input_ids,
41
  decoder_input_ids=decoder_input_ids,
42
  output_attentions=True,