Spaces:
Running
Running
import matplotlib.pyplot as plt | |
import numpy as np | |
def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT | |
# Generate a colormap with a specified number of colors | |
cmap = plt.cm.get_cmap(palette, num_colors) | |
# Get the RGB values of the colors in the colormap | |
colors_rgb = cmap(np.arange(num_colors)) | |
# Convert the RGB values to hexadecimal color codes | |
colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb] | |
return colors_hex | |
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, | |
threshold=0.4, skip_first_src=True, skip_second_src=False, | |
layer=2, head=6): | |
alignment = [] | |
# threshold = 0.05 | |
for i, tok in enumerate(outputs.cross_attentions[layer][0][head]): | |
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()]) | |
# for i in alignment: | |
# src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]] | |
# trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]] | |
# print(src_tok, "=>", trg_tok) | |
merged = [] | |
for i in alignment: | |
token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0] | |
# print(token) | |
if token not in ["</s>", "<pad>", "<unk>", "<s>"]: | |
if merged: | |
tomerge = False | |
# check overlap with previous entry | |
for x in i[1]: | |
if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "▁": | |
tomerge = True | |
break | |
# if first character is not a "▁" | |
if token[0] != "▁": | |
tomerge = True | |
if tomerge: | |
merged[-1][0] += i[0] | |
merged[-1][1] += i[1] | |
else: | |
merged.append(i) | |
else: | |
merged.append(i) | |
# print("=====MERGED=====") | |
# for i in merged: | |
# src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]] | |
# trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]] | |
# print(src_tok, "=>", trg_tok) | |
colordict = {} | |
ncolors = 0 | |
for i in merged: | |
src_tok = [f"src_{x}" for x in i[0]] | |
trg_tok = [f"trg_{x}" for x in i[1]] | |
all_tok = src_tok + trg_tok | |
# see if any tokens in entry already have associated color | |
newcolor = None | |
for t in all_tok: | |
if t in colordict: | |
newcolor = colordict[t] | |
break | |
if not newcolor: | |
newcolor = ncolors | |
ncolors += 1 | |
for t in all_tok: | |
if t not in colordict: | |
colordict[t] = newcolor | |
colors = generate_diverging_colors(ncolors, palette="Set2") | |
id_to_color = {i: c for i, c in enumerate(colors)} | |
for k, v in colordict.items(): | |
colordict[k] = id_to_color[v] | |
tgthtml = [] | |
for i, token in enumerate(decoder_input_ids[0]): | |
if f"src_{i}" in colordict: | |
label = f"src_{i}" | |
tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
else: | |
tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
tgthtml = "".join(tgthtml) | |
tgthtml = tgthtml.replace("▁", " ") | |
tgthtml = f"<span style='font-size: 25px'>{tgthtml}</span>" | |
srchtml = [] | |
for i, token in enumerate(encoder_input_ids[0]): | |
if (i == 0 and skip_first_src) or (i == 1 and skip_second_src): | |
continue | |
if f"trg_{i}" in colordict: | |
label = f"trg_{i}" | |
srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
else: | |
srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
srchtml = "".join(srchtml) | |
srchtml = srchtml.replace("▁", " ") | |
srchtml = f"<span style='font-size: 25px'>{srchtml}</span>" | |
return srchtml, tgthtml | |