Narsil HF staff commited on
Commit
cfc1bbd
1 Parent(s): 4c415fe

Adding directions exploration.

Browse files
Files changed (1) hide show
  1. app.py +64 -7
app.py CHANGED
@@ -12,6 +12,7 @@ alt.data_transformers.disable_max_rows()
12
  number_re = re.compile(r"\.[0-9]*\.")
13
 
14
  STATE_DICT = {}
 
15
  DATA = pd.DataFrame()
16
 
17
 
@@ -30,15 +31,20 @@ def scatter_plot_fn(group_name):
30
 
31
  def find_choices(state_dict):
32
  if not state_dict:
33
- return []
34
  global DATA
35
- layered_tensors = [k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2]
 
 
36
  choices = set()
37
  data = []
 
38
  for name in layered_tensors:
39
  group_name = number_re.sub(".{N}.", name)
40
  choices.add(group_name)
41
  layer = int(number_re.search(name).group()[1:-1])
 
 
42
 
43
  svdvals = torch.linalg.svdvals(state_dict[name])
44
  svdvals /= svdvals.sum()
@@ -49,19 +55,64 @@ def find_choices(state_dict):
49
  DATA["val"] = DATA["val"].astype("float")
50
  DATA["layer"] = DATA["layer"].astype("category")
51
  DATA["rank"] = DATA["rank"].astype("int32")
52
- return choices
53
 
54
 
55
  def weights_fn(model_id):
56
- global STATE_DICT
57
  try:
58
  pipe = pipeline(model=model_id)
 
59
  STATE_DICT = pipe.model.state_dict()
60
  except Exception as e:
61
  print(e)
62
  STATE_DICT = {}
63
- choices = find_choices(STATE_DICT)
64
- return gr.Dropdown.update(choices=choices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  with gr.Blocks() as scatter_plot:
@@ -69,10 +120,16 @@ with gr.Blocks() as scatter_plot:
69
  with gr.Column():
70
  model_id = gr.Textbox(label="model_id")
71
  weights = gr.Dropdown(label="weights")
 
72
  with gr.Column():
73
  plot = gr.LinePlot(show_label=False).style(container=True)
74
- model_id.change(weights_fn, inputs=model_id, outputs=weights)
 
 
75
  weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
 
 
 
76
 
77
  if __name__ == "__main__":
78
  scatter_plot.launch()
 
12
  number_re = re.compile(r"\.[0-9]*\.")
13
 
14
  STATE_DICT = {}
15
+ PIPE = None
16
  DATA = pd.DataFrame()
17
 
18
 
 
31
 
32
  def find_choices(state_dict):
33
  if not state_dict:
34
+ return [], []
35
  global DATA
36
+ layered_tensors = [
37
+ k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2
38
+ ]
39
  choices = set()
40
  data = []
41
+ max_layer = 0
42
  for name in layered_tensors:
43
  group_name = number_re.sub(".{N}.", name)
44
  choices.add(group_name)
45
  layer = int(number_re.search(name).group()[1:-1])
46
+ if layer > max_layer:
47
+ max_layer = layer
48
 
49
  svdvals = torch.linalg.svdvals(state_dict[name])
50
  svdvals /= svdvals.sum()
 
55
  DATA["val"] = DATA["val"].astype("float")
56
  DATA["layer"] = DATA["layer"].astype("category")
57
  DATA["rank"] = DATA["rank"].astype("int32")
58
+ return choices, list(range(max_layer + 1))
59
 
60
 
61
  def weights_fn(model_id):
62
+ global STATE_DICT, PIPE
63
  try:
64
  pipe = pipeline(model=model_id)
65
+ PIPE = pipe
66
  STATE_DICT = pipe.model.state_dict()
67
  except Exception as e:
68
  print(e)
69
  STATE_DICT = {}
70
+ choices, layers = find_choices(STATE_DICT)
71
+ return [gr.Dropdown.update(choices=choices), gr.Dropdown.update(choices=layers)]
72
+
73
+
74
+ def layer_fn(weights, layer):
75
+ k = 5
76
+ directions = 10
77
+
78
+ embeddings = PIPE.model.get_input_embeddings().weight
79
+ weight_name = weights.replace("{N}", str(layer))
80
+
81
+ weight = STATE_DICT[weight_name]
82
+
83
+ U, S, Vh = torch.linalg.svd(weight)
84
+
85
+ D = U if U.shape[0] == embeddings.shape[0] else Vh
86
+
87
+ # words = D[:directions].matmul(embeddings.T).topk(k=k)
88
+ # words_t = D[:, :directions].T.matmul(embeddings.T).topk(k=k)
89
+
90
+ # Cosine similarity
91
+ words = (
92
+ (D[:directions] / D[:directions].norm(dim=0))
93
+ .matmul(embeddings.T / embeddings.T.norm(dim=0))
94
+ .topk(k=k)
95
+ )
96
+ words_t = (
97
+ (D[:, :directions].T / D[:, :directions].norm(dim=1))
98
+ .matmul(embeddings.T / embeddings.T.norm(dim=0))
99
+ .topk(k=k)
100
+ )
101
+
102
+ data = [[PIPE.tokenizer.decode(w) for w in indices] for indices in words.indices]
103
+ data = np.array(data)
104
+ data = pd.DataFrame(data)
105
+
106
+ data_t = [
107
+ [PIPE.tokenizer.decode(w) for w in indices] for indices in words_t.indices
108
+ ]
109
+ data_t = np.array(data_t)
110
+ data_t = pd.DataFrame(data_t)
111
+
112
+ return (
113
+ gr.Dataframe.update(value=data, interactive=False),
114
+ gr.Dataframe.update(value=data_t, interactive=False),
115
+ )
116
 
117
 
118
  with gr.Blocks() as scatter_plot:
 
120
  with gr.Column():
121
  model_id = gr.Textbox(label="model_id")
122
  weights = gr.Dropdown(label="weights")
123
+ layer = gr.Dropdown(label="layer")
124
  with gr.Column():
125
  plot = gr.LinePlot(show_label=False).style(container=True)
126
+ directions = gr.Dataframe(interactive=False)
127
+ directions_t = gr.Dataframe(interactive=False)
128
+ model_id.change(weights_fn, inputs=model_id, outputs=[weights, layer])
129
  weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
130
+ layer.change(
131
+ fn=layer_fn, inputs=[weights, layer], outputs=[directions, directions_t]
132
+ )
133
 
134
  if __name__ == "__main__":
135
  scatter_plot.launch()