import matplotlib.pyplot as plt import numpy as np def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT # Generate a colormap with a specified number of colors cmap = plt.cm.get_cmap(palette, num_colors) # Get the RGB values of the colors in the colormap colors_rgb = cmap(np.arange(num_colors)) # Convert the RGB values to hexadecimal color codes colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb] return colors_hex def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True, skip_second_src=False, layer=2, head=6): alignment = [] # threshold = 0.05 for i, tok in enumerate(outputs.cross_attentions[layer][0][head]): alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()]) # for i in alignment: # src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]] # trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]] # print(src_tok, "=>", trg_tok) merged = [] for i in alignment: token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0] # print(token) if token not in ["", "", "", ""]: if merged: tomerge = False # check overlap with previous entry for x in i[1]: if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "▁": tomerge = True break # if first character is not a "▁" if token[0] != "▁": tomerge = True if tomerge: merged[-1][0] += i[0] merged[-1][1] += i[1] else: merged.append(i) else: merged.append(i) # print("=====MERGED=====") # for i in merged: # src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]] # trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]] # print(src_tok, "=>", trg_tok) colordict = {} ncolors = 0 for i in merged: src_tok = [f"src_{x}" for x in i[0]] trg_tok = [f"trg_{x}" for x in i[1]] all_tok = src_tok + trg_tok # see if any tokens in entry already have associated color newcolor = None for t in all_tok: if t in colordict: newcolor = colordict[t] break if not newcolor: newcolor = ncolors ncolors += 1 for t in all_tok: if t not in colordict: colordict[t] = newcolor colors = generate_diverging_colors(ncolors, palette="Set2") id_to_color = {i: c for i, c in enumerate(colors)} for k, v in colordict.items(): colordict[k] = id_to_color[v] tgthtml = [] for i, token in enumerate(decoder_input_ids[0]): if f"src_{i}" in colordict: label = f"src_{i}" tgthtml.append(f"{tokenizer.convert_ids_to_tokens([token])[0]}") else: tgthtml.append(f"{tokenizer.convert_ids_to_tokens([token])[0]}") tgthtml = "".join(tgthtml) tgthtml = tgthtml.replace("▁", " ") tgthtml = f"{tgthtml}" srchtml = [] for i, token in enumerate(encoder_input_ids[0]): if (i == 0 and skip_first_src) or (i == 1 and skip_second_src): continue if f"trg_{i}" in colordict: label = f"trg_{i}" srchtml.append(f"{tokenizer.convert_ids_to_tokens([token])[0]}") else: srchtml.append(f"{tokenizer.convert_ids_to_tokens([token])[0]}") srchtml = "".join(srchtml) srchtml = srchtml.replace("▁", " ") srchtml = f"{srchtml}" return srchtml, tgthtml