File size: 4,260 Bytes
46f657a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import matplotlib.pyplot as plt
import numpy as np


def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT
    # Generate a colormap with a specified number of colors
    cmap = plt.cm.get_cmap(palette, num_colors)

    # Get the RGB values of the colors in the colormap
    colors_rgb = cmap(np.arange(num_colors))

    # Convert the RGB values to hexadecimal color codes
    colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb]

    return colors_hex


def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids,
                threshold=0.4, skip_first_src=True, skip_second_src=False,
                layer=2, head=6):

    alignment = []
    # threshold = 0.05
    for i, tok in enumerate(outputs.cross_attentions[layer][0][head]):
        alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])

    # for i in alignment:
    #     src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]]
    #     trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]]
    #     print(src_tok, "=>", trg_tok)

    merged = []
    for i in alignment:
        token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0]
        # print(token)
        if token not in ["</s>", "<pad>", "<unk>", "<s>"]:
            if merged:
                tomerge = False
                # check overlap with previous entry
                for x in i[1]:
                    if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "▁": 
                        tomerge = True
                        break
                # if first character is not a "▁"
                if token[0] != "▁":
                    tomerge = True
                if tomerge:
                    merged[-1][0] += i[0]
                    merged[-1][1] += i[1]
                else:
                    merged.append(i)
            else:
                merged.append(i)
                
    # print("=====MERGED=====")
    # for i in merged:
    #     src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]]
    #     trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]]
    #     print(src_tok, "=>", trg_tok)

    colordict = {}
    ncolors = 0
    for i in merged:
        src_tok = [f"src_{x}" for x in i[0]]
        trg_tok = [f"trg_{x}" for x in i[1]]
        all_tok = src_tok + trg_tok
        # see if any tokens in entry already have associated color
        newcolor = None
        for t in all_tok:
            if t in colordict:
                newcolor = colordict[t]
                break
        if not newcolor:
            newcolor = ncolors
            ncolors += 1
        for t in all_tok:
            if t not in colordict:
                colordict[t] = newcolor

    colors = generate_diverging_colors(ncolors, palette="Set2")
    id_to_color = {i: c for i, c in enumerate(colors)}
    for k, v in colordict.items():
        colordict[k] = id_to_color[v]

    tgthtml = []
    for i, token in enumerate(decoder_input_ids[0]):
        if f"src_{i}" in colordict:
            label = f"src_{i}"
            tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
        else:
            tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
    tgthtml = "".join(tgthtml)
    tgthtml = tgthtml.replace("▁", " ")
    tgthtml = f"<span style='font-size: 25px'>{tgthtml}</span>"

    srchtml = []
    for i, token in enumerate(encoder_input_ids[0]):
        if  (i == 0 and skip_first_src) or (i == 1 and skip_second_src):
            continue

        if f"trg_{i}" in colordict:
            label = f"trg_{i}"
            srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
        else:
            srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
    srchtml = "".join(srchtml)
    srchtml = srchtml.replace("▁", " ")
    srchtml = f"<span style='font-size: 25px'>{srchtml}</span>"
    return srchtml, tgthtml