import json import os import gradio as gr import spaces from contents import ( citation, description, examples, how_it_works, how_to_use, subtitle, title, ) from gradio_highlightedtextbox import HighlightedTextbox from style import custom_css from utils import get_tuples_from_output from inseq import list_feature_attribution_methods, list_step_functions from inseq.commands.attribute_context.attribute_context import ( AttributeContextArgs, attribute_context, ) @spaces.GPU() def pecore( input_current_text: str, input_context_text: str, output_current_text: str, output_context_text: str, model_name_or_path: str, attribution_method: str, attributed_fn: str | None, context_sensitivity_metric: str, context_sensitivity_std_threshold: float, context_sensitivity_topk: int, attribution_std_threshold: float, attribution_topk: int, input_template: str, input_current_text_template: str, output_template: str, special_tokens_to_keep: str | list[str] | None, model_kwargs: str, tokenizer_kwargs: str, generation_kwargs: str, attribution_kwargs: str, ): formatted_input_current_text = input_current_text_template.format( current=input_current_text ) pecore_args = AttributeContextArgs( show_intermediate_outputs=False, save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"), add_output_info=True, viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"), show_viz=False, model_name_or_path=model_name_or_path, attribution_method=attribution_method, attributed_fn=attributed_fn, attribution_selectors=None, attribution_aggregators=None, normalize_attributions=True, model_kwargs=json.loads(model_kwargs), tokenizer_kwargs=json.loads(tokenizer_kwargs), generation_kwargs=json.loads(generation_kwargs), attribution_kwargs=json.loads(attribution_kwargs), context_sensitivity_metric=context_sensitivity_metric, align_output_context_auto=False, prompt_user_for_contextless_output_next_tokens=False, special_tokens_to_keep=special_tokens_to_keep, context_sensitivity_std_threshold=context_sensitivity_std_threshold, context_sensitivity_topk=context_sensitivity_topk if context_sensitivity_topk > 0 else None, attribution_std_threshold=attribution_std_threshold, attribution_topk=attribution_topk if attribution_topk > 0 else None, input_current_text=formatted_input_current_text, input_context_text=input_context_text if input_context_text else None, input_template=input_template, output_current_text=output_current_text if output_current_text else None, output_context_text=output_context_text if output_context_text else None, output_template=output_template, ) out = attribute_context(pecore_args) return get_tuples_from_output(out), gr.Button(visible=True), gr.Button(visible=True) with gr.Blocks(css=custom_css) as demo: gr.Markdown(title) gr.Markdown(subtitle) gr.Markdown(description) with gr.Tab("🐑 Attributing Context"): with gr.Row(): with gr.Column(): input_current_text = gr.Textbox( label="Input query", placeholder="Your input query..." ) input_context_text = gr.Textbox( label="Input context", lines=4, placeholder="Your input context..." ) attribute_input_button = gr.Button("Submit", variant="primary") with gr.Column(): pecore_output_highlights = HighlightedTextbox( value=[ ("This output will contain ", None), ("context sensitive", "Context sensitive"), (" generated tokens and ", None), ("influential context", "Influential context"), (" tokens.", None), ], color_map={ "Context sensitive": "green", "Influential context": "blue", }, show_legend=True, label="PECoRe Output", combine_adjacent=True, interactive=False, ) with gr.Row(equal_height=True): download_output_file_button = gr.Button( "⇓ Download output", visible=False, link=os.path.join( os.path.dirname(__file__), "/file=outputs/output.json" ), ) download_output_html_button = gr.Button( "🔍 Download HTML", visible=False, link=os.path.join( os.path.dirname(__file__), "/file=outputs/output.html" ), ) attribute_input_examples = gr.Examples( examples, inputs=[input_current_text, input_context_text], outputs=pecore_output_highlights, ) with gr.Tab("⚙️ Parameters"): gr.Markdown("## ⚙️ PECoRe Parameters") with gr.Row(equal_height=True): model_name_or_path = gr.Textbox( value="gsarti/cora_mgen", label="Model", info="Hugging Face Hub identifier of the model to analyze with PECoRe.", interactive=True, ) context_sensitivity_metric = gr.Dropdown( value="kl_divergence", label="Context sensitivity metric", info="Metric to use to measure context sensitivity of generated tokens.", choices=list_step_functions(), interactive=True, ) attribution_method = gr.Dropdown( value="saliency", label="Attribution method", info="Attribution method identifier to identify relevant context tokens.", choices=list_feature_attribution_methods(), interactive=True, ) attributed_fn = gr.Dropdown( value="contrast_prob_diff", label="Attributed function", info="Function of model logits to use as target for the attribution method.", choices=list_step_functions(), interactive=True, ) gr.Markdown("#### Results Selection Parameters") with gr.Row(equal_height=True): context_sensitivity_std_threshold = gr.Number( value=1.0, label="Context sensitivity threshold", info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.", precision=1, minimum=0.0, maximum=5.0, step=0.5, interactive=True, ) context_sensitivity_topk = gr.Number( value=0, label="Context sensitivity top-k", info="Select N to keep top N context sensitive tokens. 0 = keep all.", interactive=True, precision=0, minimum=0, maximum=10, ) attribution_std_threshold = gr.Number( value=1.0, label="Attribution threshold", info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.", precision=1, minimum=0.0, maximum=5.0, step=0.5, interactive=True, ) attribution_topk = gr.Number( value=0, label="Attribution top-k", info="Select N to keep top N attributed tokens in the context. 0 = keep all.", interactive=True, precision=0, minimum=0, maximum=50, ) gr.Markdown("#### Text Format Parameters") with gr.Row(equal_height=True): input_template = gr.Textbox( value="{current}
:{context}",
label="Input template",
info="Template to format the input for the model. Use {current} and {context} placeholders.",
interactive=True,
)
output_template = gr.Textbox(
value="{current}",
label="Output template",
info="Template to format the output from the model. Use {current} and {context} placeholders.",
interactive=True,
)
input_current_text_template = gr.Textbox(
value=":{current}",
label="Input current text template",
info="Template to format the input query for the model. Use {current} placeholder.",
interactive=True,
)
special_tokens_to_keep = gr.Dropdown(
label="Special tokens to keep",
info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.",
value=None,
multiselect=True,
allow_custom_value=True,
)
gr.Markdown("## ⚙️ Generation Parameters")
with gr.Row(equal_height=True):
output_current_text = gr.Textbox(
label="Generation output",
info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.",
interactive=True,
)
output_context_text = gr.Textbox(
label="Generation context",
info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.",
interactive=True,
)
generation_kwargs = gr.Code(
value="{}",
language="json",
label="Generation kwargs",
interactive=True,
lines=1,
)
gr.Markdown("## ⚙️ Other Parameters")
with gr.Row(equal_height=True):
model_kwargs = gr.Code(
value="{}",
language="json",
label="Model kwargs",
interactive=True,
lines=1,
)
tokenizer_kwargs = gr.Code(
value="{}",
language="json",
label="Tokenizer kwargs",
interactive=True,
lines=1,
)
attribution_kwargs = gr.Code(
value="{}",
language="json",
label="Attribution kwargs",
interactive=True,
lines=1,
)
gr.Markdown(how_it_works)
gr.Markdown(how_to_use)
gr.Markdown(citation)
attribute_input_button.click(
pecore,
inputs=[
input_current_text,
input_context_text,
output_current_text,
output_context_text,
model_name_or_path,
attribution_method,
attributed_fn,
context_sensitivity_metric,
context_sensitivity_std_threshold,
context_sensitivity_topk,
attribution_std_threshold,
attribution_topk,
input_template,
input_current_text_template,
output_template,
special_tokens_to_keep,
model_kwargs,
tokenizer_kwargs,
generation_kwargs,
attribution_kwargs,
],
outputs=[
pecore_output_highlights,
download_output_file_button,
download_output_html_button,
],
)
demo.launch(allowed_paths=["outputs/"])