gsarti commited on
Commit
2b66ced
β€’
1 Parent(s): 45e7dbd

Add presets and model preloading

Browse files
Files changed (4) hide show
  1. app.py +241 -58
  2. contents.py +1 -1
  3. presets.py +58 -0
  4. utils.py +1 -14
app.py CHANGED
@@ -13,14 +13,25 @@ from contents import (
13
  title,
14
  )
15
  from gradio_highlightedtextbox import HighlightedTextbox
 
 
 
 
 
 
 
 
16
  from style import custom_css
17
- from utils import get_tuples_from_output
18
 
19
- from inseq import list_feature_attribution_methods, list_step_functions
20
  from inseq.commands.attribute_context.attribute_context import (
21
  AttributeContextArgs,
22
- attribute_context,
23
  )
 
 
 
24
 
25
 
26
  @spaces.GPU()
@@ -38,17 +49,41 @@ def pecore(
38
  attribution_std_threshold: float,
39
  attribution_topk: int,
40
  input_template: str,
41
- input_current_text_template: str,
42
  output_template: str,
43
  special_tokens_to_keep: str | list[str] | None,
 
44
  model_kwargs: str,
45
  tokenizer_kwargs: str,
46
  generation_kwargs: str,
47
  attribution_kwargs: str,
48
  ):
49
- formatted_input_current_text = input_current_text_template.format(
50
- current=input_current_text
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  pecore_args = AttributeContextArgs(
53
  show_intermediate_outputs=False,
54
  save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
@@ -66,24 +101,41 @@ def pecore(
66
  generation_kwargs=json.loads(generation_kwargs),
67
  attribution_kwargs=json.loads(attribution_kwargs),
68
  context_sensitivity_metric=context_sensitivity_metric,
69
- align_output_context_auto=False,
70
  prompt_user_for_contextless_output_next_tokens=False,
71
  special_tokens_to_keep=special_tokens_to_keep,
72
  context_sensitivity_std_threshold=context_sensitivity_std_threshold,
73
- context_sensitivity_topk=context_sensitivity_topk
74
- if context_sensitivity_topk > 0
75
- else None,
76
  attribution_std_threshold=attribution_std_threshold,
77
- attribution_topk=attribution_topk if attribution_topk > 0 else None,
78
- input_current_text=formatted_input_current_text,
79
- input_context_text=input_context_text if input_context_text else None,
80
  input_template=input_template,
81
- output_current_text=output_current_text if output_current_text else None,
82
- output_context_text=output_context_text if output_context_text else None,
83
  output_template=output_template,
 
 
 
84
  )
85
- out = attribute_context(pecore_args)
86
- return get_tuples_from_output(out), gr.Button(visible=True), gr.Button(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  with gr.Blocks(css=custom_css) as demo:
@@ -93,12 +145,12 @@ with gr.Blocks(css=custom_css) as demo:
93
  with gr.Tab("πŸ‘ Attributing Context"):
94
  with gr.Row():
95
  with gr.Column():
96
- input_current_text = gr.Textbox(
97
- label="Input query", placeholder="Your input query..."
98
- )
99
  input_context_text = gr.Textbox(
100
  label="Input context", lines=4, placeholder="Your input context..."
101
  )
 
 
 
102
  attribute_input_button = gr.Button("Submit", variant="primary")
103
  with gr.Column():
104
  pecore_output_highlights = HighlightedTextbox(
@@ -139,15 +191,57 @@ with gr.Blocks(css=custom_css) as demo:
139
  inputs=[input_current_text, input_context_text],
140
  outputs=pecore_output_highlights,
141
  )
142
- with gr.Tab("βš™οΈ Parameters"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  gr.Markdown("## βš™οΈ PECoRe Parameters")
144
  with gr.Row(equal_height=True):
145
- model_name_or_path = gr.Textbox(
146
- value="gsarti/cora_mgen",
147
- label="Model",
148
- info="Hugging Face Hub identifier of the model to analyze with PECoRe.",
149
- interactive=True,
150
- )
 
 
 
 
 
151
  context_sensitivity_metric = gr.Dropdown(
152
  value="kl_divergence",
153
  label="Context sensitivity metric",
@@ -224,12 +318,13 @@ with gr.Blocks(css=custom_css) as demo:
224
  info="Template to format the output from the model. Use {current} and {context} placeholders.",
225
  interactive=True,
226
  )
227
- input_current_text_template = gr.Textbox(
228
  value="<Q>:{current}",
229
  label="Input current text template",
230
  info="Template to format the input query for the model. Use {current} placeholder.",
231
  interactive=True,
232
  )
 
233
  special_tokens_to_keep = gr.Dropdown(
234
  label="Special tokens to keep",
235
  info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.",
@@ -237,8 +332,28 @@ with gr.Blocks(css=custom_css) as demo:
237
  multiselect=True,
238
  allow_custom_value=True,
239
  )
 
 
 
 
 
 
 
240
 
241
  gr.Markdown("## βš™οΈ Generation Parameters")
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  with gr.Row(equal_height=True):
243
  output_current_text = gr.Textbox(
244
  label="Generation output",
@@ -250,36 +365,37 @@ with gr.Blocks(css=custom_css) as demo:
250
  info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.",
251
  interactive=True,
252
  )
253
- generation_kwargs = gr.Code(
254
- value="{}",
255
- language="json",
256
- label="Generation kwargs",
257
- interactive=True,
258
- lines=1,
259
- )
260
  gr.Markdown("## βš™οΈ Other Parameters")
261
  with gr.Row(equal_height=True):
262
- model_kwargs = gr.Code(
263
- value="{}",
264
- language="json",
265
- label="Model kwargs",
266
- interactive=True,
267
- lines=1,
268
- )
269
- tokenizer_kwargs = gr.Code(
270
- value="{}",
271
- language="json",
272
- label="Tokenizer kwargs",
273
- interactive=True,
274
- lines=1,
275
- )
276
- attribution_kwargs = gr.Code(
277
- value="{}",
278
- language="json",
279
- label="Attribution kwargs",
280
- interactive=True,
281
- lines=1,
282
- )
 
 
 
 
 
 
 
 
283
 
284
  gr.Markdown(how_it_works)
285
  gr.Markdown(how_to_use)
@@ -301,9 +417,10 @@ with gr.Blocks(css=custom_css) as demo:
301
  attribution_std_threshold,
302
  attribution_topk,
303
  input_template,
304
- input_current_text_template,
305
  output_template,
306
  special_tokens_to_keep,
 
307
  model_kwargs,
308
  tokenizer_kwargs,
309
  generation_kwargs,
@@ -316,4 +433,70 @@ with gr.Blocks(css=custom_css) as demo:
316
  ],
317
  )
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  demo.launch(allowed_paths=["outputs/"])
 
13
  title,
14
  )
15
  from gradio_highlightedtextbox import HighlightedTextbox
16
+ from presets import (
17
+ set_chatml_preset,
18
+ set_cora_preset,
19
+ set_default_preset,
20
+ set_mmt_preset,
21
+ set_towerinstruct_preset,
22
+ set_zephyr_preset,
23
+ )
24
  from style import custom_css
25
+ from utils import get_formatted_attribute_context_results
26
 
27
+ from inseq import list_feature_attribution_methods, list_step_functions, load_model
28
  from inseq.commands.attribute_context.attribute_context import (
29
  AttributeContextArgs,
30
+ attribute_context_with_model,
31
  )
32
+ from inseq.models import HuggingfaceModel
33
+
34
+ loaded_model: HuggingfaceModel = None
35
 
36
 
37
  @spaces.GPU()
 
49
  attribution_std_threshold: float,
50
  attribution_topk: int,
51
  input_template: str,
52
+ contextless_input_current_text: str,
53
  output_template: str,
54
  special_tokens_to_keep: str | list[str] | None,
55
+ decoder_input_output_separator: str,
56
  model_kwargs: str,
57
  tokenizer_kwargs: str,
58
  generation_kwargs: str,
59
  attribution_kwargs: str,
60
  ):
61
+ global loaded_model
62
+ if "{context}" in output_template and not output_context_text:
63
+ raise gr.Error(
64
+ "Parameter 'Generated context' is required when using {context} in the output template."
65
+ )
66
+ if loaded_model is None or model_name_or_path != loaded_model.model_name:
67
+ gr.Info("Loading model...")
68
+ loaded_model = load_model(
69
+ model_name_or_path,
70
+ attribution_method,
71
+ model_kwargs=json.loads(model_kwargs),
72
+ tokenizer_kwargs=json.loads(tokenizer_kwargs),
73
+ )
74
+ kwargs = {}
75
+ if context_sensitivity_topk > 0:
76
+ kwargs["context_sensitivity_topk"] = context_sensitivity_topk
77
+ if attribution_topk > 0:
78
+ kwargs["attribution_topk"] = attribution_topk
79
+ if input_context_text:
80
+ kwargs["input_context_text"] = input_context_text
81
+ if output_context_text:
82
+ kwargs["output_context_text"] = output_context_text
83
+ if output_current_text:
84
+ kwargs["output_current_text"] = output_current_text
85
+ if decoder_input_output_separator:
86
+ kwargs["decoder_input_output_separator"] = decoder_input_output_separator
87
  pecore_args = AttributeContextArgs(
88
  show_intermediate_outputs=False,
89
  save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
 
101
  generation_kwargs=json.loads(generation_kwargs),
102
  attribution_kwargs=json.loads(attribution_kwargs),
103
  context_sensitivity_metric=context_sensitivity_metric,
 
104
  prompt_user_for_contextless_output_next_tokens=False,
105
  special_tokens_to_keep=special_tokens_to_keep,
106
  context_sensitivity_std_threshold=context_sensitivity_std_threshold,
 
 
 
107
  attribution_std_threshold=attribution_std_threshold,
108
+ input_current_text=input_current_text,
 
 
109
  input_template=input_template,
 
 
110
  output_template=output_template,
111
+ contextless_input_current_text=contextless_input_current_text,
112
+ handle_output_context_strategy="pre",
113
+ **kwargs,
114
  )
115
+ out = attribute_context_with_model(pecore_args, loaded_model)
116
+ tuples = get_formatted_attribute_context_results(loaded_model, out.info, out)
117
+ if not tuples:
118
+ msg = "Warning: No pairs were found by PECoRe. Try adjusting Results Selection parameters."
119
+ tuples = [(msg, None)]
120
+ return tuples, gr.Button(visible=True), gr.Button(visible=True)
121
+
122
+
123
+ @spaces.GPU()
124
+ def preload_model(
125
+ model_name_or_path: str,
126
+ attribution_method: str,
127
+ model_kwargs: str,
128
+ tokenizer_kwargs: str,
129
+ ):
130
+ global loaded_model
131
+ if loaded_model is None or model_name_or_path != loaded_model.model_name:
132
+ gr.Info("Loading model...")
133
+ loaded_model = load_model(
134
+ model_name_or_path,
135
+ attribution_method,
136
+ model_kwargs=json.loads(model_kwargs),
137
+ tokenizer_kwargs=json.loads(tokenizer_kwargs),
138
+ )
139
 
140
 
141
  with gr.Blocks(css=custom_css) as demo:
 
145
  with gr.Tab("πŸ‘ Attributing Context"):
146
  with gr.Row():
147
  with gr.Column():
 
 
 
148
  input_context_text = gr.Textbox(
149
  label="Input context", lines=4, placeholder="Your input context..."
150
  )
151
+ input_current_text = gr.Textbox(
152
+ label="Input query", placeholder="Your input query..."
153
+ )
154
  attribute_input_button = gr.Button("Submit", variant="primary")
155
  with gr.Column():
156
  pecore_output_highlights = HighlightedTextbox(
 
191
  inputs=[input_current_text, input_context_text],
192
  outputs=pecore_output_highlights,
193
  )
194
+ with gr.Tab("βš™οΈ Parameters") as params_tab:
195
+ gr.Markdown("## ✨ Presets")
196
+ with gr.Row(equal_height=True):
197
+ with gr.Column():
198
+ default_preset = gr.Button("Default", variant="secondary")
199
+ gr.Markdown(
200
+ "Default preset using templates without special tokens or parameters.\nCan be used with most decoder-only and encoder-decoder models."
201
+ )
202
+ with gr.Column():
203
+ cora_preset = gr.Button("CORA mQA", variant="secondary")
204
+ gr.Markdown(
205
+ "Preset for the <a href='https://huggingface.co/gsarti/cora_mgen' target='_blank'>CORA Multilingual QA</a> model.\nUses special templates for inputs."
206
+ )
207
+ with gr.Column():
208
+ zephyr_preset = gr.Button("Zephyr Template", variant="secondary")
209
+ gr.Markdown(
210
+ "Preset for models using the <a href='https://huggingface.co/HuggingFaceH4/zephyr-7b-beta' target='_blank'>Zephyr conversational template</a>.\nUses <code><|system|></code>, <code><|user|></code> and <code><|assistant|></code> special tokens."
211
+ )
212
+ with gr.Row(equal_height=True):
213
+ with gr.Column(scale=1):
214
+ multilingual_mt_template = gr.Button(
215
+ "Multilingual MT", variant="secondary"
216
+ )
217
+ gr.Markdown(
218
+ "Present for multilingual MT models such as <a href='https://huggingface.co/facebook/nllb-200-distilled-600M' target='_blank'>NLLB</a> and <a href='https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt' target='_blank'>mBART</a> using language tags."
219
+ )
220
+ with gr.Column(scale=1):
221
+ chatml_template = gr.Button("ChatML Template", variant="secondary")
222
+ gr.Markdown(
223
+ "Preset for models using the <a href='https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md' target='_blank'>ChatML conversational template</a>.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens."
224
+ )
225
+ with gr.Column(scale=1):
226
+ towerinstruct_template = gr.Button(
227
+ "Unbabel TowerInstruct", variant="secondary"
228
+ )
229
+ gr.Markdown(
230
+ "Preset for models using the <a href='https://huggingface.co/Unbabel/TowerInstruct-7B-v0.1' target='_blank'>Unbabel TowerInstruct</a> conversational template.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens."
231
+ )
232
  gr.Markdown("## βš™οΈ PECoRe Parameters")
233
  with gr.Row(equal_height=True):
234
+ with gr.Column():
235
+ model_name_or_path = gr.Textbox(
236
+ value="gpt2",
237
+ label="Model",
238
+ info="Hugging Face Hub identifier of the model to analyze with PECoRe.",
239
+ interactive=True,
240
+ )
241
+ load_model_button = gr.Button(
242
+ "Load model",
243
+ variant="secondary",
244
+ )
245
  context_sensitivity_metric = gr.Dropdown(
246
  value="kl_divergence",
247
  label="Context sensitivity metric",
 
318
  info="Template to format the output from the model. Use {current} and {context} placeholders.",
319
  interactive=True,
320
  )
321
+ contextless_input_current_text = gr.Textbox(
322
  value="<Q>:{current}",
323
  label="Input current text template",
324
  info="Template to format the input query for the model. Use {current} placeholder.",
325
  interactive=True,
326
  )
327
+ with gr.Row(equal_height=True):
328
  special_tokens_to_keep = gr.Dropdown(
329
  label="Special tokens to keep",
330
  info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.",
 
332
  multiselect=True,
333
  allow_custom_value=True,
334
  )
335
+ decoder_input_output_separator = gr.Textbox(
336
+ label="Decoder input/output separator",
337
+ info="Separator to use between input and output in the decoder input.",
338
+ value="",
339
+ interactive=True,
340
+ lines=1,
341
+ )
342
 
343
  gr.Markdown("## βš™οΈ Generation Parameters")
344
+ with gr.Row(equal_height=True):
345
+ with gr.Column(scale=0.5):
346
+ gr.Markdown(
347
+ "The following arguments can be used to control generation parameters and force specific model outputs."
348
+ )
349
+ with gr.Column(scale=1):
350
+ generation_kwargs = gr.Code(
351
+ value="{}",
352
+ language="json",
353
+ label="Generation kwargs (JSON)",
354
+ interactive=True,
355
+ lines=1,
356
+ )
357
  with gr.Row(equal_height=True):
358
  output_current_text = gr.Textbox(
359
  label="Generation output",
 
365
  info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.",
366
  interactive=True,
367
  )
 
 
 
 
 
 
 
368
  gr.Markdown("## βš™οΈ Other Parameters")
369
  with gr.Row(equal_height=True):
370
+ with gr.Column():
371
+ gr.Markdown(
372
+ "The following arguments will be passed to initialize the Hugging Face model and tokenizer, and to the `inseq_model.attribute` method."
373
+ )
374
+ with gr.Column():
375
+ model_kwargs = gr.Code(
376
+ value="{}",
377
+ language="json",
378
+ label="Model kwargs (JSON)",
379
+ interactive=True,
380
+ lines=1,
381
+ min_width=160,
382
+ )
383
+ with gr.Column():
384
+ tokenizer_kwargs = gr.Code(
385
+ value="{}",
386
+ language="json",
387
+ label="Tokenizer kwargs (JSON)",
388
+ interactive=True,
389
+ lines=1,
390
+ )
391
+ with gr.Column():
392
+ attribution_kwargs = gr.Code(
393
+ value="{}",
394
+ language="json",
395
+ label="Attribution kwargs (JSON)",
396
+ interactive=True,
397
+ lines=1,
398
+ )
399
 
400
  gr.Markdown(how_it_works)
401
  gr.Markdown(how_to_use)
 
417
  attribution_std_threshold,
418
  attribution_topk,
419
  input_template,
420
+ contextless_input_current_text,
421
  output_template,
422
  special_tokens_to_keep,
423
+ decoder_input_output_separator,
424
  model_kwargs,
425
  tokenizer_kwargs,
426
  generation_kwargs,
 
433
  ],
434
  )
435
 
436
+ load_model_button.click(
437
+ preload_model,
438
+ inputs=[model_name_or_path, attribution_method, model_kwargs, tokenizer_kwargs],
439
+ outputs=[],
440
+ )
441
+
442
+ # Preset params
443
+
444
+ outputs_to_reset = [
445
+ model_name_or_path,
446
+ input_template,
447
+ contextless_input_current_text,
448
+ output_template,
449
+ special_tokens_to_keep,
450
+ decoder_input_output_separator,
451
+ model_kwargs,
452
+ tokenizer_kwargs,
453
+ generation_kwargs,
454
+ attribution_kwargs,
455
+ ]
456
+ reset_kwargs = {
457
+ "fn": set_default_preset,
458
+ "inputs": None,
459
+ "outputs": outputs_to_reset,
460
+ }
461
+
462
+ # Presets
463
+
464
+ default_preset.click(**reset_kwargs)
465
+ cora_preset.click(**reset_kwargs).then(
466
+ set_cora_preset,
467
+ outputs=[model_name_or_path, input_template, contextless_input_current_text],
468
+ )
469
+ zephyr_preset.click(**reset_kwargs).then(
470
+ set_zephyr_preset,
471
+ outputs=[
472
+ model_name_or_path,
473
+ input_template,
474
+ contextless_input_current_text,
475
+ decoder_input_output_separator,
476
+ ],
477
+ )
478
+ multilingual_mt_template.click(**reset_kwargs).then(
479
+ set_mmt_preset,
480
+ outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs],
481
+ )
482
+ chatml_template.click(**reset_kwargs).then(
483
+ set_chatml_preset,
484
+ outputs=[
485
+ model_name_or_path,
486
+ input_template,
487
+ contextless_input_current_text,
488
+ decoder_input_output_separator,
489
+ special_tokens_to_keep,
490
+ ],
491
+ )
492
+ towerinstruct_template.click(**reset_kwargs).then(
493
+ set_towerinstruct_preset,
494
+ outputs=[
495
+ model_name_or_path,
496
+ input_template,
497
+ contextless_input_current_text,
498
+ decoder_input_output_separator,
499
+ ],
500
+ )
501
+
502
  demo.launch(allowed_paths=["outputs/"])
contents.py CHANGED
@@ -48,6 +48,6 @@ citation = r"""
48
  examples = [
49
  [
50
  "When was Banff National Park established?",
51
- "Banff National Park is Canada's oldest national park, established in 1885 as Rocky Mountains Park. Located in Alberta's Rocky Mountains, 110–180 kilometres (68–112 mi) west of Calgary, Banff encompasses 6,641 square kilometres (2,564 sq mi) of mountainous terrain.",
52
  ]
53
  ]
 
48
  examples = [
49
  [
50
  "When was Banff National Park established?",
51
+ "Banff National Park is Canada's oldest national park, established in 1885 as Rocky Mountains Park. Located in Alberta's Rocky Mountains, 110-180 kilometres (68-112 mi) west of Calgary, Banff encompasses 6,641 square kilometres (2,564 sq mi) of mountainous terrain.",
52
  ]
53
  ]
presets.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def set_cora_preset():
2
+ return (
3
+ "gsarti/cora_mgen", # model_name_or_path
4
+ "<Q>:{current} <P>:{context}", # input_template
5
+ "<Q>:{current}", # input_current_text_template
6
+ )
7
+
8
+
9
+ def set_default_preset():
10
+ return (
11
+ "gpt2", # model_name_or_path
12
+ "{current} {context}", # input_template
13
+ "{current}", # input_current_template
14
+ "{current}", # output_template
15
+ [], # special_tokens_to_keep
16
+ "", # decoder_input_output_separator
17
+ "{}", # model_kwargs
18
+ "{}", # tokenizer_kwargs
19
+ "{}", # generation_kwargs
20
+ "{}", # attribution_kwargs
21
+ )
22
+
23
+
24
+ def set_zephyr_preset():
25
+ return (
26
+ "stabilityai/stablelm-2-zephyr-1_6b", # model_name_or_path
27
+ "<|system|>\n{context}</s>\n<|user|>\n{current}</s>\n<|assistant|>\n", # input_template
28
+ "<|user|>\n{current}</s>\n<|assistant|>\n", # input_current_text_template
29
+ "\n", # decoder_input_output_separator
30
+ )
31
+
32
+
33
+ def set_chatml_preset():
34
+ return (
35
+ "Qwen/Qwen1.5-0.5B-Chat", # model_name_or_path
36
+ "<|im_start|>system\n{context}<|im_end|>\n<|im_start|>user\n{current}<|im_end|>\n<|im_start|>assistant\n", # input_template
37
+ "<|im_start|>user\n{current}<|im_end|>\n<|im_start|>assistant\n", # input_current_text_template
38
+ "", # decoder_input_output_separator
39
+ ["<|im_start|>", "<|im_end|>"], # special_tokens_to_keep
40
+ )
41
+
42
+
43
+ def set_mmt_preset():
44
+ return (
45
+ "facebook/mbart-large-50-one-to-many-mmt", # model_name_or_path
46
+ "{context} {current}", # input_template
47
+ "{context} {current}", # output_template
48
+ '{\n\t"src_lang": "en_XX",\n\t"tgt_lang": "fr_XX"\n}', # tokenizer_kwargs
49
+ )
50
+
51
+
52
+ def set_towerinstruct_preset():
53
+ return (
54
+ "Unbabel/TowerInstruct-7B-v0.1", # model_name_or_path
55
+ "<|im_start|>user\nSource: {current}\nContext: {context}\nTranslate the above text into French. Use the context to guide your answer.\nTarget:<|im_end|>\n<|im_start|>assistant\n", # input_template
56
+ "<|im_start|>user\nSource: {current}\nTranslate the above text into French.\nTarget:<|im_end|>\n<|im_start|>assistant\n", # input_current_text_template
57
+ "", # decoder_input_output_separator
58
+ )
utils.py CHANGED
@@ -1,7 +1,5 @@
1
- from copy import deepcopy
2
  from typing import Optional
3
 
4
- from inseq import load_model
5
  from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs
6
  from inseq.commands.attribute_context.attribute_context_helpers import (
7
  AttributeContextOutput,
@@ -81,7 +79,6 @@ def get_formatted_attribute_context_results(
81
  cci_out.output_context_scores,
82
  cci_out.input_context_scores,
83
  is_target=True,
84
- context_type="Output",
85
  )
86
  out += [
87
  ("\n\n" if example_idx > 1 else "", None),
@@ -95,16 +92,6 @@ def get_formatted_attribute_context_results(
95
  out += [("\nInput context:\t", None)]
96
  out += input_context_tokens
97
  if args.has_output_context:
98
- out += [("\\Output context:\t", None)]
99
  out += output_context_tokens
100
  return out
101
-
102
-
103
- def get_tuples_from_output(output: AttributeContextOutput):
104
- model = load_model(
105
- output.info.model_name_or_path,
106
- output.info.attribution_method,
107
- model_kwargs=deepcopy(output.info.model_kwargs),
108
- tokenizer_kwargs=deepcopy(output.info.tokenizer_kwargs),
109
- )
110
- return get_formatted_attribute_context_results(model, output.info, output)
 
 
1
  from typing import Optional
2
 
 
3
  from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs
4
  from inseq.commands.attribute_context.attribute_context_helpers import (
5
  AttributeContextOutput,
 
79
  cci_out.output_context_scores,
80
  cci_out.input_context_scores,
81
  is_target=True,
 
82
  )
83
  out += [
84
  ("\n\n" if example_idx > 1 else "", None),
 
92
  out += [("\nInput context:\t", None)]
93
  out += input_context_tokens
94
  if args.has_output_context:
95
+ out += [("\nOutput context:\t", None)]
96
  out += output_context_tokens
97
  return out