# plotting functions # external imports import numpy as np import matplotlib.pyplot as plt def plot_seq(seq_values: list, method: str = ""): # separate the tokens and their corresponding importance values tokens, importance = zip(*seq_values) # convert importance values to numpy array for conditional coloring importance = np.array(importance) # determine the colors based on the sign of the importance values colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance] # create a bar plot plt.figure(figsize=(len(tokens) * 0.9, np.max(importance))) x_positions = range(len(tokens)) # Positions for the bars # creating vertical bar plot bar_width = 0.8 plt.bar(x_positions, importance, color=colors, align="center", width=bar_width) # annotating each bar with its value padding = 0.1 # Padding for text annotation for x, (y, color) in enumerate(zip(importance, colors)): sign = "+" if y > 0 else "" plt.annotate( f"{sign}{y:.2f}", # Format the value with sign xy=(x, y + padding if y > 0 else y - padding), ha="center", color=color, va="bottom" if y > 0 else "top", # Vertical alignment fontweight="bold", # Bold text bbox={ "facecolor": "white", "edgecolor": "none", "boxstyle": "round,pad=0.1", }, # White background ) # setting plot properties, labels, and title plt.axhline(0, color="black", linewidth=1) plt.title(f"Input Token Attribution with {method}") plt.xlabel("Input Tokens", labelpad=0.5) plt.ylabel("Attribution") plt.xticks(x_positions, tokens, rotation=45) # adjusting y-axis limits to ensure there's enough space for labels y_min, y_max = plt.ylim() y_range = y_max - y_min plt.ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range) return plt