File size: 1,807 Bytes
2492536
fe1089d
 
 
 
d2116db
fe1089d
 
2492536
 
fe1089d
 
2492536
fe1089d
 
 
2492536
 
4e18c39
2492536
 
 
fe1089d
 
 
 
 
 
 
 
2492536
fe1089d
 
 
 
 
 
2492536
f5ebee7
d2116db
2492536
fe1089d
2492536
1c4497c
 
 
 
2492536
d2116db
fe1089d
2492536
1c4497c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# visualization module that creates an attention visualization


# internal imports
from utils import formatting as fmt
from .markup import markup_text


# chat function that returns an answer
# and marked text based on attention
def chat_explained(model, prompt):

    # get encoded input
    encoder_input_ids = model.TOKENIZER(
        prompt, return_tensors="pt", add_special_tokens=True
    ).input_ids
    # generate output together with attentions of the model
    decoder_input_ids = model.MODEL.generate(
        encoder_input_ids, output_attentions=True, **model.CONFIG
    )

    # get input and output text as list of strings
    encoder_text = fmt.format_tokens(
        model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
    )
    decoder_text = fmt.format_tokens(
        model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
    )

    # get attention values for the input and output vectors
    # using already generated input and output
    attention_output = model.MODEL(
        input_ids=encoder_input_ids,
        decoder_input_ids=decoder_input_ids,
        output_attentions=True,
    )

    # averaging attention across layers
    averaged_attention = fmt.avg_attention(attention_output)

    # format response text for clean output
    response_text = fmt.format_output_text(decoder_text)
    # setting placeholder for iFrame graphic
    graphic = (
        "<div style='text-align: center; font-family:arial;'><h4>Attention"
        " Visualization doesn't support an interactive graphic.</h4></div>"
    )
    # creating marked text using markup_text function and attention
    marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")

    # returning response, graphic and marked text array
    return response_text, graphic, marked_text