LennardZuendorf commited on
Commit
2b7cdd8
1 Parent(s): 7ad098c

fix: fixing gen args

Browse files
explanation/interpret_captum.py CHANGED
@@ -16,7 +16,7 @@ def chat_explained(model, prompt):
16
 
17
  # generation attribution
18
  attribution_input = TextTokenInput(prompt, model.TOKENIZER)
19
- attribution_result = llm_attribution.attribute(attribution_input)
20
 
21
  # extracting values and input tokens
22
  values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
 
16
 
17
  # generation attribution
18
  attribution_input = TextTokenInput(prompt, model.TOKENIZER)
19
+ attribution_result = llm_attribution.attribute(attribution_input, gen_args=model.CONFIG.to_dict())
20
 
21
  # extracting values and input tokens
22
  values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
model/mistral.py CHANGED
@@ -1,7 +1,7 @@
1
  # Mistral model module for chat interaction and model instance control
2
 
3
  # external imports
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import gradio as gr
7
 
@@ -26,7 +26,15 @@ else:
26
  TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
27
 
28
  # default model config
29
- CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
 
 
 
 
 
 
 
 
30
 
31
 
32
  # function to (re) set config
@@ -35,14 +43,16 @@ def set_config(config: dict):
35
 
36
  # if config dict is given, update it
37
  if config != {}:
38
- CONFIG = config
39
  else:
40
- # hard setting model config to default
41
- # needed for shap
42
- MODEL.config.max_new_tokens = 50
43
- MODEL.config.min_length = 8
44
- MODEL.config.top_p = 0.9
45
- MODEL.config.do_sample = True
 
 
46
 
47
 
48
  # advanced formatting function that takes into a account a conversation history
 
1
  # Mistral model module for chat interaction and model instance control
2
 
3
  # external imports
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5
  import torch
6
  import gradio as gr
7
 
 
26
  TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
27
 
28
  # default model config
29
+ CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
30
+ CONFIG.update(**{
31
+ "temperature": 0.7,
32
+ "max_new_tokens": 50,
33
+ "top_p": 0.9,
34
+ "repetition_penalty": 1.2,
35
+ "do_sample": True,
36
+ "seed": 42
37
+ })
38
 
39
 
40
  # function to (re) set config
 
43
 
44
  # if config dict is given, update it
45
  if config != {}:
46
+ CONFIG.update(**dict)
47
  else:
48
+ CONFIG.update(**{
49
+ "temperature": 0.7,
50
+ "max_new_tokens": 50,
51
+ "top_p": 0.9,
52
+ "repetition_penalty": 1.2,
53
+ "do_sample": True,
54
+ "seed": 42
55
+ })
56
 
57
 
58
  # advanced formatting function that takes into a account a conversation history