LennardZuendorf commited on
Commit
b324c38
1 Parent(s): 4577044

fix/feat: fixing model config, adding new examples

Browse files
Files changed (3) hide show
  1. backend/controller.py +10 -8
  2. main.py +46 -26
  3. model/mistral.py +8 -16
backend/controller.py CHANGED
@@ -59,13 +59,15 @@ def interference(
59
  raise RuntimeError("There was an error in the selected XAI approach.")
60
 
61
  # call the explained chat function with the model instance
62
- prompt_output, history_output, xai_interactive, xai_markup, xai_plot = explained_chat(
63
- model=model,
64
- xai=xai,
65
- message=prompt,
66
- history=history,
67
- system_prompt=system_prompt,
68
- knowledge=knowledge,
 
 
69
  )
70
  # if no XAI approach is selected call the vanilla chat function
71
  else:
@@ -84,7 +86,7 @@ def interference(
84
  no graphic will be displayed</h4></div>
85
  """,
86
  [("", "")],
87
- None
88
  )
89
 
90
  # return the outputs
 
59
  raise RuntimeError("There was an error in the selected XAI approach.")
60
 
61
  # call the explained chat function with the model instance
62
+ prompt_output, history_output, xai_interactive, xai_markup, xai_plot = (
63
+ explained_chat(
64
+ model=model,
65
+ xai=xai,
66
+ message=prompt,
67
+ history=history,
68
+ system_prompt=system_prompt,
69
+ knowledge=knowledge,
70
+ )
71
  )
72
  # if no XAI approach is selected call the vanilla chat function
73
  else:
 
86
  no graphic will be displayed</h4></div>
87
  """,
88
  [("", "")],
89
+ None,
90
  )
91
 
92
  # return the outputs
main.py CHANGED
@@ -202,34 +202,54 @@ with gr.Blocks(
202
  submit_btn = gr.Button("Submit", variant="primary")
203
  # row with content examples that get autofilled on click
204
  with gr.Row(elem_classes="examples"):
205
- # examples util component
206
- # see: https://www.gradio.app/docs/examples
207
- gr.Examples(
208
- label="Example Questions",
209
- examples=[
210
- [
211
- "How does a black hole form in space?",
212
- (
213
- "Black holes are created when a massive star's core"
214
- " collapses after a supernova, forming an object with"
215
- " gravity so intense that even light cannot escape."
216
- ),
217
  ],
218
- [
219
- (
220
- "Explain the importance of the Rosetta Stone in"
221
- " understanding ancient languages."
222
- ),
223
- (
224
- "The Rosetta Stone, an ancient Egyptian artifact, was key"
225
- " in decoding hieroglyphs, featuring the same text in three"
226
- " scripts: hieroglyphs, Demotic, and Greek."
227
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  ],
229
- ["Does money buy happiness?", ""],
230
- ],
231
- inputs=[user_prompt, knowledge_input],
232
- )
 
 
 
233
 
234
  # explanations tab used to provide explanations for a specific conversation
235
  with gr.Tab("Explanations"):
 
202
  submit_btn = gr.Button("Submit", variant="primary")
203
  # row with content examples that get autofilled on click
204
  with gr.Row(elem_classes="examples"):
205
+ with gr.Accordion("Mistral Model Examples", open=False):
206
+ # examples util component
207
+ # see: https://www.gradio.app/docs/examples
208
+ gr.Examples(
209
+ label="Example Questions",
210
+ examples=[
211
+ ["Does money buy happiness?", "Mistral", "SHAP"],
212
+ ["Does money buy happiness?", "Mistral", "Attention"],
 
 
 
 
213
  ],
214
+ inputs=[user_prompt, model_selection, xai_selection],
215
+ )
216
+ with gr.Accordion("GODEL Model Examples", open=False):
217
+ # examples util component
218
+ # see: https://www.gradio.app/docs/examples
219
+ gr.Examples(
220
+ label="Example Questions",
221
+ examples=[
222
+ [
223
+ "How does a black hole form in space?",
224
+ (
225
+ "Black holes are created when a massive star's core"
226
+ " collapses after a supernova, forming an object with"
227
+ " gravity so intense that even light cannot escape."
228
+ ),
229
+ "GODEL",
230
+ "SHAP",
231
+ ],
232
+ [
233
+ (
234
+ "Explain the importance of the Rosetta Stone in"
235
+ " understanding ancient languages."
236
+ ),
237
+ (
238
+ "The Rosetta Stone, an ancient Egyptian artifact, was"
239
+ " key in decoding hieroglyphs, featuring the same text"
240
+ " in three scripts: hieroglyphs, Demotic, and Greek."
241
+ ),
242
+ "GODEL",
243
+ "Attention",
244
+ ],
245
  ],
246
+ inputs=[
247
+ user_prompt,
248
+ knowledge_input,
249
+ model_selection,
250
+ xai_selection,
251
+ ],
252
+ )
253
 
254
  # explanations tab used to provide explanations for a specific conversation
255
  with gr.Tab("Explanations"):
model/mistral.py CHANGED
@@ -28,15 +28,16 @@ TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
28
 
29
  # default model config
30
  CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
31
- CONFIG.update(**{
32
  "temperature": 0.7,
33
- "max_new_tokens": 50,
34
- "max_length": 50,
35
  "top_p": 0.9,
36
  "repetition_penalty": 1.2,
37
  "do_sample": True,
38
  "seed": 42,
39
- })
 
40
 
41
 
42
  # function to (re) set config
@@ -44,22 +45,13 @@ def set_config(config_dict: dict):
44
 
45
  # if config dict is not given, set to default
46
  if config_dict == {}:
47
- config_dict = {
48
- "temperature": 0.7,
49
- "max_new_tokens": 50,
50
- "max_length": 50,
51
- "top_p": 0.9,
52
- "repetition_penalty": 1.2,
53
- "do_sample": True,
54
- "seed": 42,
55
- }
56
-
57
- CONFIG.update(**dict)
58
 
59
 
60
  # advanced formatting function that takes into a account a conversation history
61
  # CREDIT: adapated from the Mistral AI Instruct chat template
62
- # see https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja
63
  def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
64
  prompt = ""
65
 
 
28
 
29
  # default model config
30
  CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
31
+ base_config_dict = {
32
  "temperature": 0.7,
33
+ "max_new_tokens": 64,
34
+ "max_length": 64,
35
  "top_p": 0.9,
36
  "repetition_penalty": 1.2,
37
  "do_sample": True,
38
  "seed": 42,
39
+ }
40
+ CONFIG.update(**base_config_dict)
41
 
42
 
43
  # function to (re) set config
 
45
 
46
  # if config dict is not given, set to default
47
  if config_dict == {}:
48
+ config_dict = base_config_dict
49
+ CONFIG.update(**config_dict)
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  # advanced formatting function that takes into a account a conversation history
53
  # CREDIT: adapated from the Mistral AI Instruct chat template
54
+ # see https://github.com/chujiezheng/chat_templates/
55
  def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
56
  prompt = ""
57