LennardZuendorf commited on
Commit
b0721f8
1 Parent(s): 226ad46

fix: fixing model config settings

Browse files
Files changed (3) hide show
  1. explanation/attention.py +1 -1
  2. main.py +9 -8
  3. model/godel.py +12 -16
explanation/attention.py CHANGED
@@ -15,7 +15,7 @@ def chat_explained(model, prompt):
15
  ).input_ids
16
  # generate output together with attentions of the model
17
  decoder_input_ids = model.MODEL.generate(
18
- encoder_input_ids, output_attentions=True, **model.CONFIG
19
  )
20
 
21
  # get input and output text as list of strings
 
15
  ).input_ids
16
  # generate output together with attentions of the model
17
  decoder_input_ids = model.MODEL.generate(
18
+ encoder_input_ids, output_attentions=True, generation_config=model.CONFIG
19
  )
20
 
21
  # get input and output text as list of strings
main.py CHANGED
@@ -110,9 +110,10 @@ with gr.Blocks(
110
  label="System Prompt",
111
  info="Set the models system prompt, dictating how it answers.",
112
  # default system prompt is set to this in the backend
113
- placeholder=(
114
- "You are a helpful, respectful and honest assistant. Always"
115
- " answer as helpfully as possible, while being safe."
 
116
  ),
117
  )
118
  # column that takes up 1/4 of the row
@@ -121,7 +122,7 @@ with gr.Blocks(
121
  xai_selection = gr.Radio(
122
  ["None", "SHAP", "Attention"],
123
  label="Interpretability Settings",
124
- info="Select a Interpretability Implementation to use.",
125
  value="None",
126
  interactive=True,
127
  show_label=True,
@@ -133,15 +134,15 @@ with gr.Blocks(
133
  ["GODEL", "Mistral"],
134
  label="Model Settings",
135
  info="Select a Model to use.",
136
- value="GODEL",
137
  interactive=True,
138
  show_label=True,
139
  )
140
 
141
  # calling info functions on inputs/submits for different settings
142
- system_prompt.submit(system_prompt_info, [system_prompt])
143
- xai_selection.input(xai_info, [xai_selection])
144
- model_selection.input(model_info, [model_selection])
145
 
146
  # row with chatbot ui displaying "conversation" with the model
147
  with gr.Row(equal_height=True):
 
110
  label="System Prompt",
111
  info="Set the models system prompt, dictating how it answers.",
112
  # default system prompt is set to this in the backend
113
+ placeholder=("""
114
+ You are a helpful, respectful and honest assistant. Always
115
+ answer as helpfully as possible, while being safe.
116
+ """
117
  ),
118
  )
119
  # column that takes up 1/4 of the row
 
122
  xai_selection = gr.Radio(
123
  ["None", "SHAP", "Attention"],
124
  label="Interpretability Settings",
125
+ info="Select a Interpretability Approach Implementation to use.",
126
  value="None",
127
  interactive=True,
128
  show_label=True,
 
134
  ["GODEL", "Mistral"],
135
  label="Model Settings",
136
  info="Select a Model to use.",
137
+ value="Mistral",
138
  interactive=True,
139
  show_label=True,
140
  )
141
 
142
  # calling info functions on inputs/submits for different settings
143
+ system_prompt.change(system_prompt_info, [system_prompt])
144
+ xai_selection.change(xai_info, [xai_selection])
145
+ model_selection.change(model_info, [model_selection])
146
 
147
  # row with chatbot ui displaying "conversation" with the model
148
  with gr.Row(equal_height=True):
model/godel.py CHANGED
@@ -1,7 +1,7 @@
1
  # GODEL model module for chat interaction and model instance control
2
 
3
  # external imports
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  # internal imports
7
  from utils import modelling as mdl
@@ -10,24 +10,20 @@ from utils import modelling as mdl
10
  TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
11
  MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
12
 
13
- # default model config
14
- CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
 
 
 
15
 
16
 
17
  # function to (re) set config
18
- def set_config(config: dict):
19
- global CONFIG
20
 
21
- # if config dict is given, update it
22
- if config != {}:
23
- CONFIG = config
24
- else:
25
- # hard setting model config to default
26
- # needed for shap
27
- MODEL.config.max_new_tokens = 50
28
- MODEL.config.min_length = 8
29
- MODEL.config.top_p = 0.9
30
- MODEL.config.do_sample = True
31
 
32
 
33
  # formatting class to formatting input for the model
@@ -67,7 +63,7 @@ def respond(prompt):
67
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
68
 
69
  # generating using config and decoding output
70
- outputs = MODEL.generate(input_ids, **CONFIG)
71
  output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
72
 
73
  # returns the model output string
 
1
  # GODEL model module for chat interaction and model instance control
2
 
3
  # external imports
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
5
 
6
  # internal imports
7
  from utils import modelling as mdl
 
10
  TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
11
  MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
12
 
13
+
14
+ # model config definition
15
+ CONFIG = GenerationConfig.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
16
+ base_config_dict = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
17
+ CONFIG.update(**base_config_dict)
18
 
19
 
20
  # function to (re) set config
21
+ def set_config(config_dict: dict):
 
22
 
23
+ # if config dict is not given, set to default
24
+ if config_dict == {}:
25
+ config_dict = base_config_dict
26
+ CONFIG.update(**config_dict)
 
 
 
 
 
 
27
 
28
 
29
  # formatting class to formatting input for the model
 
63
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
64
 
65
  # generating using config and decoding output
66
+ outputs = MODEL.generate(input_ids,generation_config=CONFIG)
67
  output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
68
 
69
  # returns the model output string