LennardZuendorf commited on
Commit
7409c2e
1 Parent(s): 9c6520d

fix: fixing more bugs

Browse files
explanation/interpret_captum.py CHANGED
@@ -16,7 +16,9 @@ def chat_explained(model, prompt):
16
 
17
  # generation attribution
18
  attribution_input = TextTokenInput(prompt, model.TOKENIZER)
19
- attribution_result = llm_attribution.attribute(attribution_input, gen_args=model.CONFIG.to_dict())
 
 
20
 
21
  # extracting values and input tokens
22
  values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
 
16
 
17
  # generation attribution
18
  attribution_input = TextTokenInput(prompt, model.TOKENIZER)
19
+ attribution_result = llm_attribution.attribute(
20
+ attribution_input, gen_args=model.CONFIG.to_dict()
21
+ )
22
 
23
  # extracting values and input tokens
24
  values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
model/mistral.py CHANGED
@@ -25,35 +25,36 @@ else:
25
  MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
26
  MODEL.to(device)
27
  TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
28
- TOKENIZER.pad_token=TOKENIZER.eos_token
29
 
30
  # default model config
31
  CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
32
  CONFIG.update(**{
33
- "temperature": 0.7,
34
- "max_new_tokens": 50,
35
- "top_p": 0.9,
36
- "repetition_penalty": 1.2,
37
- "do_sample": True,
38
- "seed": 42
 
39
  })
40
 
41
 
42
  # function to (re) set config
43
  def set_config(config: dict):
44
- global CONFIG
45
 
46
  # if config dict is given, update it
47
  if config != {}:
48
  CONFIG.update(**dict)
49
  else:
50
  CONFIG.update(**{
51
- "temperature": 0.7,
52
- "max_new_tokens": 50,
53
- "top_p": 0.9,
54
- "repetition_penalty": 1.2,
55
- "do_sample": True,
56
- "seed": 42
 
57
  })
58
 
59
 
@@ -93,9 +94,6 @@ def format_answer(answer: str):
93
  # empty answer string
94
  formatted_answer = ""
95
 
96
- if type(answer) == list:
97
- answer = fmt.format_output_text(answer)
98
-
99
  # extracting text after INST tokens
100
  parts = answer.split("[/INST]")
101
  if len(parts) >= 3:
@@ -116,5 +114,6 @@ def respond(prompt: str):
116
  # generating text with tokenized input, returning output
117
  output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
118
  output_text = TOKENIZER.batch_decode(output_ids)
 
119
 
120
- return fmt.format_output_text(output_text)
 
25
  MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
26
  MODEL.to(device)
27
  TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
28
+ TOKENIZER.pad_token = TOKENIZER.eos_token
29
 
30
  # default model config
31
  CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
32
  CONFIG.update(**{
33
+ "temperature": 0.7,
34
+ "max_new_tokens": 50,
35
+ "max_length": 50,
36
+ "top_p": 0.9,
37
+ "repetition_penalty": 1.2,
38
+ "do_sample": True,
39
+ "seed": 42,
40
  })
41
 
42
 
43
  # function to (re) set config
44
  def set_config(config: dict):
 
45
 
46
  # if config dict is given, update it
47
  if config != {}:
48
  CONFIG.update(**dict)
49
  else:
50
  CONFIG.update(**{
51
+ "temperature": 0.7,
52
+ "max_new_tokens": 50,
53
+ "max_length": 50,
54
+ "top_p": 0.9,
55
+ "repetition_penalty": 1.2,
56
+ "do_sample": True,
57
+ "seed": 42,
58
  })
59
 
60
 
 
94
  # empty answer string
95
  formatted_answer = ""
96
 
 
 
 
97
  # extracting text after INST tokens
98
  parts = answer.split("[/INST]")
99
  if len(parts) >= 3:
 
114
  # generating text with tokenized input, returning output
115
  output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
116
  output_text = TOKENIZER.batch_decode(output_ids)
117
+ output_text = fmt.format_output_text(output_text)
118
 
119
+ return format_answer(output_text)
pyproject.toml CHANGED
@@ -21,6 +21,7 @@ exclude = '''
21
 
22
  [tool.pylint.messages_control]
23
  disable = [
 
24
  "arguments-differ",
25
  "attribute-defined-outside-init",
26
  "blacklisted-name",
 
21
 
22
  [tool.pylint.messages_control]
23
  disable = [
24
+ "not-a-mapping",
25
  "arguments-differ",
26
  "attribute-defined-outside-init",
27
  "blacklisted-name",