Spaces:
Runtime error
Runtime error
File size: 1,807 Bytes
2492536 fe1089d d2116db fe1089d 2492536 fe1089d 2492536 fe1089d 2492536 4e18c39 2492536 fe1089d 2492536 fe1089d 2492536 f5ebee7 d2116db 2492536 fe1089d 2492536 1c4497c 2492536 d2116db fe1089d 2492536 1c4497c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# visualization module that creates an attention visualization
# internal imports
from utils import formatting as fmt
from .markup import markup_text
# chat function that returns an answer
# and marked text based on attention
def chat_explained(model, prompt):
# get encoded input
encoder_input_ids = model.TOKENIZER(
prompt, return_tensors="pt", add_special_tokens=True
).input_ids
# generate output together with attentions of the model
decoder_input_ids = model.MODEL.generate(
encoder_input_ids, output_attentions=True, **model.CONFIG
)
# get input and output text as list of strings
encoder_text = fmt.format_tokens(
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
)
decoder_text = fmt.format_tokens(
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
)
# get attention values for the input and output vectors
# using already generated input and output
attention_output = model.MODEL(
input_ids=encoder_input_ids,
decoder_input_ids=decoder_input_ids,
output_attentions=True,
)
# averaging attention across layers
averaged_attention = fmt.avg_attention(attention_output)
# format response text for clean output
response_text = fmt.format_output_text(decoder_text)
# setting placeholder for iFrame graphic
graphic = (
"<div style='text-align: center; font-family:arial;'><h4>Attention"
" Visualization doesn't support an interactive graphic.</h4></div>"
)
# creating marked text using markup_text function and attention
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
# returning response, graphic and marked text array
return response_text, graphic, marked_text
|