import json import os import gradio as gr import spaces from contents import ( pecore_citation, inseq_citation, description, examples, how_it_works_intro, cti_explanation, cci_explanation, how_to_use, example_explanation, show_code_modal, subtitle, title, powered_by, faq, ) from gradio_highlightedtextbox import HighlightedTextbox from gradio_modal import Modal from presets import ( set_phi3_preset, set_cora_preset, set_default_preset, set_mbart_mmt_preset, set_nllb_mmt_preset, set_towerinstruct_preset, set_zephyr_preset, set_gemma_preset, set_mistral_instruct_preset, update_code_snippets_fn, ) from style import custom_css from utils import get_formatted_attribute_context_results from inseq.commands.attribute_context.attribute_context import ( AttributeContextArgs, attribute_context_with_model, ) from inseq.models import HuggingfaceModel loaded_model: HuggingfaceModel = None @spaces.GPU() def pecore( input_current_text: str, input_context_text: str, output_current_text: str, output_context_text: str, model_name_or_path: str, attribution_method: str, attributed_fn: str | None, context_sensitivity_metric: str, context_sensitivity_std_threshold: float, context_sensitivity_topk: int, attribution_std_threshold: float, attribution_topk: int, input_template: str, output_template: str, contextless_input_template: str, contextless_output_template: str, special_tokens_to_keep: str | list[str] | None, decoder_input_output_separator: str, model_kwargs: str, tokenizer_kwargs: str, generation_kwargs: str, attribution_kwargs: str, ): global loaded_model if "{context}" in output_template and not output_context_text: raise gr.Error( "Parameter 'Generation context' must be set when including {context} in the output template." ) if loaded_model is None or model_name_or_path != loaded_model.model_name: gr.Info("Loading model...") loaded_model = HuggingfaceModel.load( model_name_or_path, attribution_method, model_kwargs={**json.loads(model_kwargs), **{"token": os.environ["HF_TOKEN"], "trust_remote_code": True}}, tokenizer_kwargs={**json.loads(tokenizer_kwargs), **{"token": os.environ["HF_TOKEN"], "trust_remote_code": True}}, ) if loaded_model.tokenizer.pad_token is None: loaded_model.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) kwargs = {} if context_sensitivity_topk > 0: kwargs["context_sensitivity_topk"] = context_sensitivity_topk if attribution_topk > 0: kwargs["attribution_topk"] = attribution_topk if input_context_text: kwargs["input_context_text"] = input_context_text if output_context_text: kwargs["output_context_text"] = output_context_text if output_current_text: kwargs["output_current_text"] = output_current_text if decoder_input_output_separator: kwargs["decoder_input_output_separator"] = decoder_input_output_separator pecore_args = AttributeContextArgs( show_intermediate_outputs=False, save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"), add_output_info=True, viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"), show_viz=False, model_name_or_path=model_name_or_path, attribution_method=attribution_method, attributed_fn=attributed_fn, attribution_selectors=None, attribution_aggregators=None, normalize_attributions=True, model_kwargs=json.loads(model_kwargs), tokenizer_kwargs=json.loads(tokenizer_kwargs), generation_kwargs=json.loads(generation_kwargs), attribution_kwargs=json.loads(attribution_kwargs), context_sensitivity_metric=context_sensitivity_metric, prompt_user_for_contextless_output_next_tokens=False, special_tokens_to_keep=special_tokens_to_keep, context_sensitivity_std_threshold=context_sensitivity_std_threshold, attribution_std_threshold=attribution_std_threshold, input_current_text=input_current_text, input_template=input_template, output_template=output_template, contextless_input_current_text=contextless_input_template, contextless_output_current_text=contextless_output_template, handle_output_context_strategy="pre", **kwargs, ) out = attribute_context_with_model(pecore_args, loaded_model) tuples = get_formatted_attribute_context_results(loaded_model, out.info, out) if not tuples: msg = f"Output: {out.output_current}\nWarning: No pairs were found by PECoRe.\nTry adjusting Results Selection parameters to soften selection constraints (e.g. setting Context sensitivity threshold to 0)." tuples = [(msg, None)] return [ tuples, gr.DownloadButton( label="📂 Download output", value=os.path.join(os.path.dirname(__file__), "outputs/output.json"), visible=True, ), gr.DownloadButton( label="🔍 Download HTML", value=os.path.join(os.path.dirname(__file__), "outputs/output.html"), visible=True, ) ] @spaces.GPU() def preload_model( model_name_or_path: str, attribution_method: str, model_kwargs: str, tokenizer_kwargs: str, ): global loaded_model if loaded_model is None or model_name_or_path != loaded_model.model_name: gr.Info("Loading model...") loaded_model = HuggingfaceModel.load( model_name_or_path, attribution_method, model_kwargs={**json.loads(model_kwargs), **{"token": os.environ["HF_TOKEN"], "trust_remote_code": True}}, tokenizer_kwargs={**json.loads(tokenizer_kwargs), **{"token": os.environ["HF_TOKEN"], "trust_remote_code": True}}, ) if loaded_model.tokenizer.pad_token is None: loaded_model.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) with gr.Blocks(css=custom_css) as demo: with gr.Row(): with gr.Column(scale=0.1, min_width=100): gr.HTML(f'') with gr.Column(scale=0.8): gr.Markdown(title) gr.Markdown(subtitle) with gr.Column(scale=0.1, min_width=100): gr.HTML(f'') gr.Markdown(description) with gr.Tab("🐑 Demo"): with gr.Row(): with gr.Column(): input_context_text = gr.Textbox( label="Input context", lines=3, placeholder="Your input context..." ) input_current_text = gr.Textbox( label="Input query", placeholder="Your input query..." ) with gr.Row(equal_height=True): show_code_btn = gr.Button("Show code", variant="secondary") attribute_input_button = gr.Button("Run PECoRe", variant="primary") with gr.Column(): pecore_output_highlights = HighlightedTextbox( value=[ ("This output will contain ", None), ("context sensitive", "Context sensitive"), (" generated tokens and ", None), ("influential context", "Influential context"), (" tokens.", None), ], color_map={ "Context sensitive": "#5fb77d", "Influential context": "#80ace8", }, show_legend=True, label="PECoRe Output", combine_adjacent=True, interactive=False, ) with gr.Row(equal_height=True): download_output_file_button = gr.DownloadButton( "📂 Download output", visible=False, ) download_output_html_button = gr.DownloadButton( "🔍 Download HTML", visible=False, value=os.path.join( os.path.dirname(__file__), "outputs/output.html" ), ) preset_comment = gr.Markdown( "The CORA Multilingual QA model by Asai et al. (2021) is set as default and can be used with the examples below. Explore other presets in the ⚙️ Parameters tab." ) attribute_input_examples = gr.Examples( examples, inputs=[input_current_text, input_context_text], examples_per_page=1, ) with gr.Tab("⚙️ Parameters") as params_tab: gr.Markdown( "## ✨ Presets\nSelect a preset to load the selected model and its default parameters (e.g. prompt template, special tokens, etc.) into the fields below.
⚠️ **This will overwrite existing parameters. If you intend to use large models that could crash the demo, please clone this Space and allocate appropriate resources for them to run comfortably.**" ) check_enable_large_models = gr.Checkbox(False, label = "I understand, enable large models presets") with gr.Row(equal_height=True): with gr.Column(): default_preset = gr.Button("Default", variant="secondary") gr.Markdown( "Default preset using templates without special tokens or parameters.\nCan be used with most decoder-only and encoder-decoder models." ) with gr.Column(): cora_preset = gr.Button("CORA mQA", variant="secondary") gr.Markdown( "Preset for the CORA Multilingual QA model.\nUses special templates for inputs." ) with gr.Column(): phi3_preset = gr.Button("Phi-3", variant="secondary", interactive=False) gr.Markdown( "Preset for the Phi-3 conversational model.\nUses <|user|>, <|system|> and <|assistant|> and <|end|> special tokens." ) #chatml_template = gr.Button("Qwen ChatML", variant="secondary") #gr.Markdown( # "Preset for models using the ChatML conversational template.\nUses <|im_start|>, <|im_end|> special tokens." #) with gr.Row(equal_height=True): with gr.Column(scale=1): mbart_mmt_template = gr.Button( "mBART Multilingual MT", variant="secondary" ) gr.Markdown( "Preset for the mBART Many-to-Many multilingual MT model using language tags (default: English to French)." ) with gr.Column(scale=1): nllb_mmt_template = gr.Button( "NLLB Multilingual MT", variant="secondary" ) gr.Markdown( "Preset for the NLLB 600M multilingual MT model using language tags (default: English to French)." ) with gr.Column(scale=1): towerinstruct_template = gr.Button( "Unbabel TowerInstruct", variant="secondary", interactive=False ) gr.Markdown( "Preset for models using the Unbabel TowerInstruct conversational template.\nUses <|im_start|>, <|im_end|> special tokens." ) with gr.Row(equal_height=True): with gr.Column(): zephyr_preset = gr.Button("Zephyr Template", variant="secondary", interactive=False) gr.Markdown( "Preset for models using the StableLM 2 Zephyr conversational template.\nUses <|system|>, <|user|> and <|assistant|> special tokens." ) with gr.Column(scale=1): gemma_template = gr.Button( "Gemma Chat Template", variant="secondary", interactive=False ) gr.Markdown( "Preset for Gemma instruction-tuned models." ) with gr.Column(scale=1): mistral_instruct_template = gr.Button( "Mistral Instruct", variant="secondary", interactive=False ) gr.Markdown( "Preset for models using the Mistral Instruct template.\nUses [INST]...[/INST] special tokens." ) gr.Markdown("## ⚙️ PECoRe Parameters") with gr.Row(equal_height=True): with gr.Column(): model_name_or_path = gr.Textbox( value="gsarti/cora_mgen", label="Model", info="Hugging Face Hub identifier of the model to analyze with PECoRe.", interactive=True, ) load_model_button = gr.Button( "Load model", variant="secondary", ) context_sensitivity_metric = gr.Dropdown( value="kl_divergence", label="Context sensitivity metric", info="Metric to use to measure context sensitivity of generated tokens.", choices=[ "probability", "logit", "kl_divergence", "contrast_logits_diff", "contrast_prob_diff", "pcxmi" ], interactive=True, ) attribution_method = gr.Dropdown( value="saliency", label="Attribution method", info="Attribution method identifier to identify relevant context tokens.", choices=[ "saliency", "input_x_gradient", "value_zeroing", ], interactive=True, ) attributed_fn = gr.Dropdown( value="contrast_prob_diff", label="Attributed function", info="Function of model logits to use as target for the attribution method.", choices=[ "probability", "logit", "contrast_logits_diff", "contrast_prob_diff", ], interactive=True, ) gr.Markdown("#### Results Selection Parameters") with gr.Row(equal_height=True): context_sensitivity_std_threshold = gr.Number( value=0.0, label="Context sensitivity threshold", info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.", precision=1, minimum=0.0, maximum=5.0, step=0.5, interactive=True, ) context_sensitivity_topk = gr.Number( value=0, label="Context sensitivity top-k", info="Select N to keep top N context sensitive tokens. 0 = keep all.", interactive=True, precision=0, minimum=0, maximum=10, ) attribution_std_threshold = gr.Number( value=2.0, label="Attribution threshold", info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.", precision=1, minimum=0.0, maximum=5.0, step=0.5, interactive=True, ) attribution_topk = gr.Number( value=5, label="Attribution top-k", info="Select N to keep top N attributed tokens in the context. 0 = keep all.", interactive=True, precision=0, minimum=0, maximum=100, ) gr.Markdown("#### Text Format Parameters") with gr.Row(equal_height=True): input_template = gr.Textbox( value=": {current}

: {context}", label="Contextual input template", info="Template to format the input for the model. Use {current} and {context} placeholders for Input Query and Input Context, respectively.", interactive=True, ) output_template = gr.Textbox( value="{current}", label="Contextual output template", info="Template to format the output from the model. Use {current} and {context} placeholders for Generation Output and Generation Context, respectively.", interactive=True, ) contextless_input_template = gr.Textbox( value=": {current}", label="Contextless input template", info="Template to format the input query in the non-contextual setting. Use {current} placeholder for Input Query.", interactive=True, ) contextless_output_template = gr.Textbox( value="{current}", label="Contextless output template", info="Template to format the output from the model. Use {current} placeholder for Generation Output.", interactive=True, ) with gr.Row(equal_height=True): special_tokens_to_keep = gr.Dropdown( label="Special tokens to keep", info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.", value=None, multiselect=True, allow_custom_value=True, ) decoder_input_output_separator = gr.Textbox( label="Decoder input/output separator", info="Separator to use between input and output in the decoder input.", value="", interactive=True, lines=1, ) gr.Markdown("## ⚙️ Generation Parameters") with gr.Row(equal_height=True): with gr.Column(scale=0.5): gr.Markdown( "The following arguments can be used to control generation parameters and force specific model outputs." ) with gr.Column(scale=1): generation_kwargs = gr.Code( value="{}", language="json", label="Generation kwargs (JSON)", interactive=True, lines=1, ) with gr.Row(equal_height=True): output_current_text = gr.Textbox( label="Generation output", info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.", interactive=True, ) output_context_text = gr.Textbox( label="Generation context", info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.", interactive=True, ) gr.Markdown("## ⚙️ Other Parameters") with gr.Row(equal_height=True): with gr.Column(): gr.Markdown( "The following arguments will be passed to initialize the Hugging Face model and tokenizer, and to the `inseq_model.attribute` method." ) with gr.Column(): model_kwargs = gr.Code( value="{}", language="json", label="Model kwargs (JSON)", interactive=True, lines=1, min_width=160, ) with gr.Column(): tokenizer_kwargs = gr.Code( value="{}", language="json", label="Tokenizer kwargs (JSON)", interactive=True, lines=1, ) with gr.Column(): attribution_kwargs = gr.Code( value='{\n\t"logprob": true\n}', language="json", label="Attribution kwargs (JSON)", interactive=True, lines=1, ) with gr.Tab("🔍 How Does It Work?"): gr.Markdown(how_it_works_intro) with gr.Row(equal_height=True): with gr.Column(scale=0.60): gr.Markdown(cti_explanation) with gr.Column(scale=0.30): gr.HTML('') with gr.Row(equal_height=True): with gr.Column(scale=0.35): gr.HTML('') with gr.Column(scale=0.65): gr.Markdown(cci_explanation) with gr.Tab("🔧 Usage Guide"): gr.Markdown(how_to_use) gr.Markdown(example_explanation) with gr.Tab("❓ FAQ"): gr.Markdown(faq) with gr.Tab("📚 Citing PECoRe"): gr.Markdown("To refer to the PECoRe framework for context usage detection, cite:") gr.Code(pecore_citation, interactive=False, label="PECoRe (Sarti et al., 2024)") gr.Markdown("If you use the Inseq implementation of PECoRe (inseq attribute-context, including this demo), please also cite:") gr.Code(inseq_citation, interactive=False, label="Inseq (Sarti et al., 2023)") with gr.Row(elem_classes="footer-container"): with gr.Column(): gr.Markdown(powered_by) with gr.Column(): with gr.Row(elem_classes="footer-custom-block"): with gr.Column(scale=0.25, min_width=150): gr.Markdown("""Built by Gabriele Sarti
with the support of
""") with gr.Column(scale=0.25, min_width=120): gr.Markdown("""""") with gr.Column(scale=0.25, min_width=120): gr.Markdown("""""") with gr.Column(scale=0.25, min_width=120): gr.Markdown("""""") with Modal(visible=False) as code_modal: gr.Markdown(show_code_modal) with gr.Row(equal_height=True): with gr.Column(scale=0.5): python_code_snippet = gr.Code( value="""Generate Python code snippet by pressing the button.""", language="python", label="Python", interactive=False, show_label=True, ) with gr.Column(scale=0.5): shell_code_snippet = gr.Code( value="""Generate Shell code snippet by pressing the button.""", language="shell", label="Shell", interactive=False, show_label=True, ) # Main logic load_model_args = [ model_name_or_path, attribution_method, model_kwargs, tokenizer_kwargs, ] pecore_args = [ input_current_text, input_context_text, output_current_text, output_context_text, model_name_or_path, attribution_method, attributed_fn, context_sensitivity_metric, context_sensitivity_std_threshold, context_sensitivity_topk, attribution_std_threshold, attribution_topk, input_template, output_template, contextless_input_template, contextless_output_template, special_tokens_to_keep, decoder_input_output_separator, model_kwargs, tokenizer_kwargs, generation_kwargs, attribution_kwargs, ] attribute_input_button.click( lambda *args: [gr.DownloadButton(visible=False), gr.DownloadButton(visible=False)], inputs=[], outputs=[download_output_file_button, download_output_html_button], ).then( pecore, inputs=pecore_args, outputs=[ pecore_output_highlights, download_output_file_button, download_output_html_button, ], ) load_model_event = load_model_button.click( preload_model, inputs=load_model_args, outputs=[], ) # Preset params check_enable_large_models.input( lambda checkbox, *buttons: [gr.Button(interactive=checkbox) for _ in buttons], inputs=[check_enable_large_models, phi3_preset, zephyr_preset, towerinstruct_template, gemma_template, mistral_instruct_template], outputs=[phi3_preset, zephyr_preset, towerinstruct_template, gemma_template, mistral_instruct_template], ) outputs_to_reset = [ model_name_or_path, input_template, output_template, contextless_input_template, contextless_output_template, special_tokens_to_keep, decoder_input_output_separator, model_kwargs, tokenizer_kwargs, generation_kwargs, attribution_kwargs, ] reset_kwargs = { "fn": set_default_preset, "inputs": None, "outputs": outputs_to_reset, } # Presets default_preset.click(**reset_kwargs).success(preload_model, inputs=load_model_args, cancels=load_model_event) cora_preset.click(**reset_kwargs).then( set_cora_preset, outputs=[model_name_or_path, input_template, contextless_input_template], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) zephyr_preset.click(**reset_kwargs).then( set_zephyr_preset, outputs=[ model_name_or_path, input_template, decoder_input_output_separator, contextless_input_template, special_tokens_to_keep, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) mbart_mmt_template.click(**reset_kwargs).then( set_mbart_mmt_preset, outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) nllb_mmt_template.click(**reset_kwargs).then( set_nllb_mmt_preset, outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) # chatml_template.click(**reset_kwargs).then( # set_chatml_preset, # outputs=[ # model_name_or_path, # input_template, # contextless_input_template, # special_tokens_to_keep, # generation_kwargs, # ], # ).success(preload_model, inputs=load_model_args, cancels=load_model_event) phi3_preset.click(**reset_kwargs).then( set_phi3_preset, outputs=[ model_name_or_path, input_template, decoder_input_output_separator, contextless_input_template, special_tokens_to_keep, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) towerinstruct_template.click(**reset_kwargs).then( set_towerinstruct_preset, outputs=[ model_name_or_path, input_template, contextless_input_template, special_tokens_to_keep, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) gemma_template.click(**reset_kwargs).then( set_gemma_preset, outputs=[ model_name_or_path, input_template, decoder_input_output_separator, contextless_input_template, special_tokens_to_keep, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) mistral_instruct_template.click(**reset_kwargs).then( set_mistral_instruct_preset, outputs=[ model_name_or_path, input_template, contextless_input_template, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) show_code_btn.click( update_code_snippets_fn, inputs=pecore_args, outputs=[python_code_snippet, shell_code_snippet], ).then(lambda: Modal(visible=True), None, code_modal) demo.queue(api_open=False, max_size=20).launch(allowed_paths=["outputs/", "img/"], show_api=False)