from typing import Optional from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs from inseq.commands.attribute_context.attribute_context_helpers import ( AttributeContextOutput, filter_rank_tokens, get_filtered_tokens, ) from inseq.models import HuggingfaceModel def get_formatted_attribute_context_results( model: HuggingfaceModel, args: AttributeContextArgs, output: AttributeContextOutput, ) -> str: """Format the results of the context attribution process.""" def format_context_comment( model: HuggingfaceModel, has_other_context: bool, special_tokens_to_keep: list[str], context: str, context_scores: list[float], other_context_scores: Optional[list[float]] = None, is_target: bool = False, ) -> str: context_tokens = get_filtered_tokens( context, model, special_tokens_to_keep, replace_special_characters=True, is_target=is_target, ) context_token_tuples = [(t, None) for t in context_tokens] scores = context_scores if has_other_context: scores += other_context_scores context_ranked_tokens, _ = filter_rank_tokens( tokens=context_tokens, scores=scores, std_threshold=args.attribution_std_threshold, topk=args.attribution_topk, ) for idx, _, tok in context_ranked_tokens: context_token_tuples[idx] = (tok, "Influential context") return context_token_tuples out = [] output_current_tokens = get_filtered_tokens( output.output_current, model, args.special_tokens_to_keep, replace_special_characters=True, is_target=True, ) for example_idx, cci_out in enumerate(output.cci_scores, start=1): curr_output_tokens = [(t, None) for t in output_current_tokens] cti_idx = cci_out.cti_idx curr_output_tokens[cti_idx] = ( curr_output_tokens[cti_idx][0], "Context sensitive", ) if args.has_input_context: input_context_tokens = format_context_comment( model, args.has_output_context, args.special_tokens_to_keep, output.input_context, cci_out.input_context_scores, cci_out.output_context_scores, ) if args.has_output_context: output_context_tokens = format_context_comment( model, args.has_input_context, args.special_tokens_to_keep, output.output_context, cci_out.output_context_scores, cci_out.input_context_scores, is_target=True, ) out += [ ("\n\n" if example_idx > 1 else "", None), ( f"#{example_idx}.\nGenerated output:\t", None, ), ] out += curr_output_tokens if args.has_input_context: out += [("\nInput context:\t", None)] out += input_context_tokens if args.has_output_context: out += [("\nOutput context:\t", None)] out += output_context_tokens return out