# visualization module that creates an attention visualization # internal imports from utils import formatting as fmt from .markup import markup_text # chat function that returns an answer # and marked text based on attention def chat_explained(model, prompt): model.set_config({"return_dict": True}) # get encoded input encoder_input_ids = model.TOKENIZER( prompt, return_tensors="pt", add_special_tokens=True ).input_ids # generate output together with attentions of the model decoder_input_ids = model.MODEL.generate( encoder_input_ids, output_attentions=True, generation_config=model.CONFIG ) # get input and output text as list of strings encoder_text = fmt.format_tokens( model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0]) ) decoder_text = fmt.format_tokens( model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0]) ) averaged_attention = fmt.avg_attention(decoder_input_ids) # format response text for clean output response_text = fmt.format_output_text(decoder_text) # setting placeholder for iFrame graphic graphic = ( "

Attention" " Visualization doesn't support an interactive graphic.

" ) # creating marked text using markup_text function and attention marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer") # returning response, graphic and marked text array return response_text, graphic, marked_text, None