levanti_en_ar / colorize.py
Guy Mor-Lan
add files
e35836c
raw
history blame
No virus
4.26 kB
import matplotlib.pyplot as plt
import numpy as np
def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT
# Generate a colormap with a specified number of colors
cmap = plt.cm.get_cmap(palette, num_colors)
# Get the RGB values of the colors in the colormap
colors_rgb = cmap(np.arange(num_colors))
# Convert the RGB values to hexadecimal color codes
colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb]
return colors_hex
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids,
threshold=0.4, skip_first_src=True, skip_second_src=False,
layer=2, head=6):
alignment = []
# threshold = 0.05
for i, tok in enumerate(outputs.cross_attentions[layer][0][head]):
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
# for i in alignment:
# src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]]
# trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]]
# print(src_tok, "=>", trg_tok)
merged = []
for i in alignment:
token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0]
# print(token)
if token not in ["</s>", "<pad>", "<unk>", "<s>"]:
if merged:
tomerge = False
# check overlap with previous entry
for x in i[1]:
if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "▁":
tomerge = True
break
# if first character is not a "▁"
if token[0] != "▁":
tomerge = True
if tomerge:
merged[-1][0] += i[0]
merged[-1][1] += i[1]
else:
merged.append(i)
else:
merged.append(i)
# print("=====MERGED=====")
# for i in merged:
# src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]]
# trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]]
# print(src_tok, "=>", trg_tok)
colordict = {}
ncolors = 0
for i in merged:
src_tok = [f"src_{x}" for x in i[0]]
trg_tok = [f"trg_{x}" for x in i[1]]
all_tok = src_tok + trg_tok
# see if any tokens in entry already have associated color
newcolor = None
for t in all_tok:
if t in colordict:
newcolor = colordict[t]
break
if not newcolor:
newcolor = ncolors
ncolors += 1
for t in all_tok:
if t not in colordict:
colordict[t] = newcolor
colors = generate_diverging_colors(ncolors, palette="Set2")
id_to_color = {i: c for i, c in enumerate(colors)}
for k, v in colordict.items():
colordict[k] = id_to_color[v]
tgthtml = []
for i, token in enumerate(decoder_input_ids[0]):
if f"src_{i}" in colordict:
label = f"src_{i}"
tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
else:
tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
tgthtml = "".join(tgthtml)
tgthtml = tgthtml.replace("▁", " ")
tgthtml = f"<span style='font-size: 25px'>{tgthtml}</span>"
srchtml = []
for i, token in enumerate(encoder_input_ids[0]):
if (i == 0 and skip_first_src) or (i == 1 and skip_second_src):
continue
if f"trg_{i}" in colordict:
label = f"trg_{i}"
srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
else:
srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
srchtml = "".join(srchtml)
srchtml = srchtml.replace("▁", " ")
srchtml = f"<span style='font-size: 25px'>{srchtml}</span>"
return srchtml, tgthtml