import gradio as gr from transformers import pipeline import numpy as np import pandas as pd import re import torch import altair as alt alt.data_transformers.disable_max_rows() number_re = re.compile(r"\.[0-9]*\.") STATE_DICT = {} PIPE = None DATA = pd.DataFrame() def scatter_plot_fn(group_name): global DATA df = DATA[DATA.group_name == group_name] return gr.LinePlot.update( value=df, x="rank", y="val", color="layer", tooltip=["val", "rank", "layer"], caption="", ) def find_choices(state_dict): if not state_dict: return [], [] global DATA layered_tensors = [ k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2 ] choices = set() data = [] max_layer = 0 for name in layered_tensors: group_name = number_re.sub(".{N}.", name) choices.add(group_name) layer = int(number_re.search(name).group()[1:-1]) if layer > max_layer: max_layer = layer svdvals = torch.linalg.svdvals(state_dict[name]) svdvals /= svdvals.sum() for rank, val in enumerate(svdvals.tolist()): data.append((name, layer, group_name, rank, val)) data = np.array(data) DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"]) DATA["val"] = DATA["val"].astype("float") DATA["layer"] = DATA["layer"].astype("category") DATA["rank"] = DATA["rank"].astype("int32") return choices, list(range(max_layer + 1)) def weights_fn(model_id): global STATE_DICT, PIPE try: pipe = pipeline(model=model_id) PIPE = pipe STATE_DICT = pipe.model.state_dict() except Exception as e: print(e) STATE_DICT = {} choices, layers = find_choices(STATE_DICT) return [gr.Dropdown.update(choices=choices), gr.Dropdown.update(choices=layers)] def layer_fn(weights, layer): k = 5 directions = 10 embeddings = PIPE.model.get_input_embeddings().weight weight_name = weights.replace("{N}", str(layer)) weight = STATE_DICT[weight_name] U, S, Vh = torch.linalg.svd(weight) D = U if U.shape[0] == embeddings.shape[0] else Vh # words = D[:directions].matmul(embeddings.T).topk(k=k) # words_t = D[:, :directions].T.matmul(embeddings.T).topk(k=k) # Cosine similarity words = ( (D[:directions] / D[:directions].norm(dim=0)) .matmul(embeddings.T / embeddings.T.norm(dim=0)) .topk(k=k) ) words_t = ( (D[:, :directions].T / D[:, :directions].norm(dim=1)) .matmul(embeddings.T / embeddings.T.norm(dim=0)) .topk(k=k) ) data = [[PIPE.tokenizer.decode(w) for w in indices] for indices in words.indices] data = np.array(data) data = pd.DataFrame(data) data_t = [ [PIPE.tokenizer.decode(w) for w in indices] for indices in words_t.indices ] data_t = np.array(data_t) data_t = pd.DataFrame(data_t) return ( gr.Dataframe.update(value=data, interactive=False), gr.Dataframe.update(value=data_t, interactive=False), ) with gr.Blocks() as scatter_plot: with gr.Row(): with gr.Column(): model_id = gr.Textbox(label="model_id") weights = gr.Dropdown(label="weights") layer = gr.Dropdown(label="layer") with gr.Column(): plot = gr.LinePlot(show_label=False).style(container=True) directions = gr.Dataframe(interactive=False) directions_t = gr.Dataframe(interactive=False) model_id.change(weights_fn, inputs=model_id, outputs=[weights, layer]) weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot) layer.change( fn=layer_fn, inputs=[weights, layer], outputs=[directions, directions_t] ) if __name__ == "__main__": scatter_plot.launch()