|
import json |
|
|
|
SYSTEM_PROMPT = "You are a helpful assistant that provide concise and accurate answers." |
|
|
|
def set_cora_preset(): |
|
return ( |
|
"gsarti/cora_mgen", |
|
"<Q>:{current} <P>:{context}", |
|
"<Q>:{current}", |
|
) |
|
|
|
|
|
def set_default_preset(): |
|
return ( |
|
"gpt2", |
|
"{current} {context}", |
|
"{current}", |
|
"{current}", |
|
"{current}", |
|
[], |
|
"", |
|
"{}", |
|
"{}", |
|
"{}", |
|
"{}", |
|
) |
|
|
|
|
|
def set_zephyr_preset(): |
|
return ( |
|
"stabilityai/stablelm-2-zephyr-1_6b", |
|
"<|system|>{system_prompt}<|endoftext|>\n<|user|>\n{context}\n\n{current}<|endoftext|>\n<|assistant|>".format(system_prompt=SYSTEM_PROMPT), |
|
"<|system|>{system_prompt}<|endoftext|>\n<|user|>\n{current}<|endoftext|>\n<|assistant|>".format(system_prompt=SYSTEM_PROMPT), |
|
"\n", |
|
["<|im_start|>", "<|im_end|>", "<|endoftext|>"], |
|
) |
|
|
|
|
|
def set_chatml_preset(): |
|
return ( |
|
"Qwen/Qwen1.5-0.5B-Chat", |
|
"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{context}\n\n{current}<|im_end|>\n<|im_start|>assistant".format(system_prompt=SYSTEM_PROMPT), |
|
"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{current}<|im_end|>\n<|im_start|>assistant".format(system_prompt=SYSTEM_PROMPT), |
|
"\n", |
|
["<|im_start|>", "<|im_end|>"], |
|
) |
|
|
|
|
|
def set_mmt_preset(): |
|
return ( |
|
"facebook/mbart-large-50-one-to-many-mmt", |
|
"{context} {current}", |
|
"{context} {current}", |
|
'{\n\t"src_lang": "en_XX",\n\t"tgt_lang": "fr_XX"\n}', |
|
) |
|
|
|
|
|
def set_towerinstruct_preset(): |
|
return ( |
|
"Unbabel/TowerInstruct-7B-v0.1", |
|
"<|im_start|>user\nSource: {current}\nContext: {context}\nTranslate the above text into French. Use the context to guide your answer.\nTarget:<|im_end|>\n<|im_start|>assistant", |
|
"<|im_start|>user\nSource: {current}\nTranslate the above text into French.\nTarget:<|im_end|>\n<|im_start|>assistant", |
|
"\n", |
|
["<|im_start|>", "<|im_end|>"], |
|
) |
|
|
|
def set_gemma_preset(): |
|
return ( |
|
"google/gemma-2b-it", |
|
"<start_of_turn>user\n{context}\n{current}<end_of_turn>\n<start_of_turn>model", |
|
"<start_of_turn>user\n{current}<end_of_turn>\n<start_of_turn>model", |
|
"\n", |
|
["<start_of_turn>", "<end_of_turn>"], |
|
) |
|
|
|
def set_mistral_instruct_preset(): |
|
return ( |
|
"mistralai/Mistral-7B-Instruct-v0.2" |
|
"[INST]{context}\n{current}[/INST]" |
|
"[INST]{current}[/INST]" |
|
"\n" |
|
) |
|
|
|
def update_code_snippets_fn( |
|
input_current_text: str, |
|
input_context_text: str, |
|
output_current_text: str, |
|
output_context_text: str, |
|
model_name_or_path: str, |
|
attribution_method: str, |
|
attributed_fn: str | None, |
|
context_sensitivity_metric: str, |
|
context_sensitivity_std_threshold: float, |
|
context_sensitivity_topk: int, |
|
attribution_std_threshold: float, |
|
attribution_topk: int, |
|
input_template: str, |
|
output_template: str, |
|
contextless_input_template: str, |
|
contextless_output_template: str, |
|
special_tokens_to_keep: str | list[str] | None, |
|
decoder_input_output_separator: str, |
|
model_kwargs: str, |
|
tokenizer_kwargs: str, |
|
generation_kwargs: str, |
|
attribution_kwargs: str, |
|
) -> tuple[str, str]: |
|
if not input_current_text: |
|
input_current_text = "<MISSING INPUT CURRENT TEXT, REQUIRED>" |
|
def py_get_kwargs_str(kwargs: str, name: str, pad: str = " " * 4) -> str: |
|
kwargs_dict = json.loads(kwargs) |
|
return nl + pad + name + '=' + str(kwargs_dict) + ',' if kwargs_dict else '' |
|
def py_get_if_specified(arg: str | int | float | list | None, name: str, pad: str = " " * 4) -> str: |
|
if arg is None or (isinstance(arg, (str, list)) and not arg) or (isinstance(arg, (int, float)) and arg <= 0): |
|
return "" |
|
elif isinstance(arg, str): |
|
return nl + pad + name + "=" + tq + arg + tq + "," |
|
elif isinstance(arg, list): |
|
return nl + pad + name + "=" + str(arg) + "," |
|
else: |
|
return nl + pad + name + "=" + str(arg) + "," |
|
def sh_get_kwargs_str(kwargs: str, name: str, pad: str = " " * 4) -> str: |
|
return nl + pad + f"--{name} " + '"' + str(kwargs).replace("\n", "").replace('"', '\\"') + '"' + " \\\\" if json.loads(kwargs) else '' |
|
def sh_get_if_specified(arg: str | int | float | list | None, name: str, pad: str = " " * 4) -> str: |
|
if arg is None or (isinstance(arg, (str, list)) and not arg) or (isinstance(arg, (int, float)) and arg <= 0): |
|
return "" |
|
elif isinstance(arg, str): |
|
return nl + pad + f"--{name} " + '"' + arg.replace('"', '\\"') + '"' + " \\\\" |
|
elif isinstance(arg, list): |
|
return nl + pad + f"--{name} " + " ".join(str(arg)) + " \\\\" |
|
else: |
|
return nl + pad + f"--{name} " + str(arg) + " \\\\" |
|
nl = "\n" |
|
tq = "\"\"\"" |
|
|
|
python = f"""#!pip install inseq |
|
import inseq |
|
from inseq.commands.attribute_context import attribute_context_with_model |
|
|
|
inseq_model = inseq.load_model( |
|
"{model_name_or_path}", |
|
"{attribution_method}",{py_get_kwargs_str(model_kwargs, "model_kwargs")}{py_get_kwargs_str(tokenizer_kwargs, "tokenizer_kwargs")} |
|
) |
|
|
|
pecore_args = AttributeContextArgs( |
|
model_name_or_path="{model_name_or_path}", |
|
attribution_method="{attribution_method}", |
|
attributed_fn="{attributed_fn}", |
|
context_sensitivity_metric="{context_sensitivity_metric}", |
|
context_sensitivity_std_threshold={context_sensitivity_std_threshold},{py_get_if_specified(context_sensitivity_topk, "context_sensitivity_topk")} |
|
attribution_std_threshold={attribution_std_threshold},{py_get_if_specified(attribution_topk, "attribution_topk")} |
|
input_current_text=\"\"\"{input_current_text}\"\"\",{py_get_if_specified(input_context_text, "input_context_text")} |
|
contextless_input_current_text=\"\"\"{contextless_input_template}\"\"\", |
|
input_template=\"\"\"{input_template}\"\"\",{py_get_if_specified(output_current_text, "output_current_text")}{py_get_if_specified(output_context_text, "output_context_text")} |
|
contextless_output_current_text=\"\"\"{contextless_output_template}\"\"\", |
|
output_template="{output_template}",{py_get_if_specified(special_tokens_to_keep, "special_tokens_to_keep")}{py_get_if_specified(decoder_input_output_separator, "decoder_input_output_separator")} |
|
save_path="pecore_output.json", |
|
viz_path="pecore_output.html",{py_get_kwargs_str(model_kwargs, "model_kwargs")}{py_get_kwargs_str(tokenizer_kwargs, "tokenizer_kwargs")}{py_get_kwargs_str(generation_kwargs, "generation_kwargs")}{py_get_kwargs_str(attribution_kwargs, "attribution_kwargs")} |
|
) |
|
|
|
out = attribute_context_with_model(pecore_args, loaded_model)""" |
|
|
|
bash = f"""# pip install inseq |
|
inseq attribute-context \\\\ |
|
--model-name-or-path "{model_name_or_path}" \\\\ |
|
--attribution-method "{attribution_method}" \\\\ |
|
--attributed-fn "{attributed_fn}" \\\\ |
|
--context-sensitivity-metric "{context_sensitivity_metric}" \\\\ |
|
--context-sensitivity-std-threshold {context_sensitivity_std_threshold} \\\\{sh_get_if_specified(context_sensitivity_topk, "context-sensitivity-topk")} |
|
--attribution-std-threshold {attribution_std_threshold} \\\\{sh_get_if_specified(attribution_topk, "attribution-topk")} |
|
--input-current-text "{input_current_text}" \\\\{sh_get_if_specified(input_context_text, "input-context-text")} |
|
--contextless-input-current-text "{contextless_input_template}" \\\\ |
|
--input-template "{input_template}" \\\\{sh_get_if_specified(output_current_text, "output-current-text")}{sh_get_if_specified(output_context_text, "output-context-text")} |
|
--contextless-output-current-text "{contextless_output_template}" \\\\ |
|
--output-template "{output_template}" \\\\{sh_get_if_specified(special_tokens_to_keep, "special_tokens_to_keep")}{sh_get_if_specified(decoder_input_output_separator, "decoder-input-output-separator")} |
|
--save-path pecore_output.json \\\\ |
|
--viz-path pecore_output.html \\\\{sh_get_kwargs_str(model_kwargs, "model-kwargs")}{sh_get_kwargs_str(tokenizer_kwargs, "tokenizer-kwargs")}{sh_get_kwargs_str(generation_kwargs, "generation-kwargs")}{sh_get_kwargs_str(attribution_kwargs, "attribution-kwargs")} |
|
""" |
|
return python, bash |
|
|