taka-yamakoshi commited on
Commit
ebfe870
1 Parent(s): 5958ae4

add model options

Browse files
Files changed (3) hide show
  1. app.py +26 -15
  2. skeleton_modeling_bert.py +73 -0
  3. skeleton_modeling_roberta.py +73 -0
app.py CHANGED
@@ -10,10 +10,7 @@ import seaborn as sns
10
  import torch
11
  import torch.nn.functional as F
12
 
13
- from transformers import AlbertTokenizer, AlbertForMaskedLM
14
-
15
  #from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
16
- from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
17
 
18
  def wide_setup():
19
  max_width = 1500
@@ -48,10 +45,23 @@ def load_css(file_name):
48
 
49
  @st.cache(show_spinner=True,allow_output_mutation=True)
50
  def load_model(model_name):
51
- tokenizer = AlbertTokenizer.from_pretrained(model_name)
52
- #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
53
- model = AlbertForMaskedLM.from_pretrained(model_name)
54
- return tokenizer,model
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def clear_data():
57
  for key in st.session_state:
@@ -147,14 +157,14 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
147
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
148
 
149
 
150
- def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
151
  probs = []
152
  for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
153
  input_ids = torch.tensor([
154
  *[masked_ids['sent_1'] for _ in range(batch_size)],
155
  *[masked_ids['sent_2'] for _ in range(batch_size)]
156
  ])
157
- outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
158
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
159
  logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
160
  evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
@@ -181,9 +191,10 @@ if __name__=='__main__':
181
  st.session_state['page_status'] = 'type_in'
182
  st.experimental_rerun()
183
 
184
- tokenizer,model = load_model(st.session_state['model_name'])
185
- num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
186
- mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
 
187
 
188
  if st.session_state['page_status']=='type_in':
189
  show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
@@ -263,7 +274,7 @@ if __name__=='__main__':
263
  option_2_tokens = option_2_tokens_1
264
 
265
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
266
- probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
267
  df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
268
  [probs_original[0,1][0],probs_original[1,1][0]]],
269
  columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
@@ -292,9 +303,9 @@ if __name__=='__main__':
292
  for layer_id in range(num_layers):
293
  interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
294
  if multihead:
295
- probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
296
  else:
297
- probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
298
  effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
299
  effect_list.append(effect)
300
  effect_array.append(effect_list)
 
10
  import torch
11
  import torch.nn.functional as F
12
 
 
 
13
  #from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
 
14
 
15
  def wide_setup():
16
  max_width = 1500
 
45
 
46
  @st.cache(show_spinner=True,allow_output_mutation=True)
47
  def load_model(model_name):
48
+ if model_name.startswith('albert'):
49
+ from transformers import AlbertTokenizer, AlbertForMaskedLM
50
+ from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
51
+ tokenizer = AlbertTokenizer.from_pretrained(model_name)
52
+ model = AlbertForMaskedLM.from_pretrained(model_name)
53
+ skeleton_model = SkeletonAlbertForMaskedLM
54
+ elif model_name.startswith('bert'):
55
+ from transformers import BertTokenizer, BertForMaskedLM
56
+ from skeleton_modeling_bert import SkeletonBertForMaskedLM
57
+ tokenizer = BertTokenizer.from_pretrained(model_name)
58
+ model = BertForMaskedLM.from_pretrained(model_name)
59
+ elif model_name.startswith('roberta'):
60
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
61
+ from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM
62
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
63
+ model = RobertaForMaskedLM.from_pretrained(model_name)
64
+ return tokenizer,model,skeleton_model
65
 
66
  def clear_data():
67
  for key in st.session_state:
 
157
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
158
 
159
 
160
+ def run_intervention(interventions,batch_size,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
161
  probs = []
162
  for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
163
  input_ids = torch.tensor([
164
  *[masked_ids['sent_1'] for _ in range(batch_size)],
165
  *[masked_ids['sent_2'] for _ in range(batch_size)]
166
  ])
167
+ outputs = skeleton_model(model,input_ids,interventions=interventions)
168
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
169
  logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
170
  evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
 
191
  st.session_state['page_status'] = 'type_in'
192
  st.experimental_rerun()
193
 
194
+ if st.session_state['page_status']!='model_selection':
195
+ tokenizer,model,skeleton_model = load_model(st.session_state['model_name'])
196
+ num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
197
+ mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
198
 
199
  if st.session_state['page_status']=='type_in':
200
  show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
 
274
  option_2_tokens = option_2_tokens_1
275
 
276
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
277
+ probs_original = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
278
  df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
279
  [probs_original[0,1][0],probs_original[1,1][0]]],
280
  columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
 
303
  for layer_id in range(num_layers):
304
  interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
305
  if multihead:
306
+ probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
307
  else:
308
+ probs = run_intervention(interventions,num_heads,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
309
  effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
310
  effect_list.append(effect)
311
  effect_array.append(effect_list)
skeleton_modeling_bert.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ @torch.no_grad()
7
+ def SkeletonBertLayer(layer_id,layer,hidden,interventions):
8
+ attention_layer = layer.attention.self
9
+ num_heads = attention_layer.num_attention_heads
10
+ head_dim = attention_layer.attention_head_size
11
+ assert num_heads*head_dim == hidden.shape[2]
12
+
13
+ qry = attention_layer.query(hidden)
14
+ key = attention_layer.key(hidden)
15
+ val = attention_layer.value(hidden)
16
+
17
+ assert qry.shape == hidden.shape
18
+ assert key.shape == hidden.shape
19
+ assert val.shape == hidden.shape
20
+
21
+ # swap representations
22
+ reps = {
23
+ 'lay': hidden,
24
+ 'qry': qry,
25
+ 'key': key,
26
+ 'val': val,
27
+ }
28
+ for rep_type in ['lay','qry','key','val']:
29
+ interv_rep = interventions[layer_id][rep_type]
30
+ new_state = reps[rep_type].clone()
31
+ for head_id, pos, swap_ids in interv_rep:
32
+ new_state[swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
33
+ new_state[swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
34
+ reps[rep_type] = new_state.clone()
35
+
36
+ hidden = reps['lay'].clone()
37
+ qry = reps['qry'].clone()
38
+ key = reps['key'].clone()
39
+ val = reps['val'].clone()
40
+
41
+
42
+ #split into multiple heads
43
+ split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
44
+ split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
45
+ split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
46
+
47
+ #calculate the attention matrix
48
+ attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
49
+
50
+ z_rep_indiv = attn_mat@split_val
51
+ z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
52
+
53
+ hidden_post_attn_res = layer.attention.output.dense(z_rep)+hidden # residual connection
54
+ hidden_post_attn = layer.attention.output.LayerNorm(hidden_post_attn_res) # layer_norm
55
+
56
+ hidden_post_interm = layer.intermediate(hidden_post_attn) # massive feed forward
57
+ hidden_post_interm_res = layer.output.dense(hidden_post_interm)+hidden_post_attn # residual connection
58
+ new_hidden = layer.output.LayerNorm(hidden_post_interm_res) # layer_norm
59
+ return new_hidden
60
+
61
+ def SkeletonBertForMaskedLM(model,input_ids,interventions):
62
+ core_model = model.bert
63
+ lm_head = model.cls
64
+ output_hidden = []
65
+ with torch.no_grad():
66
+ hidden = core_model.embeddings(input_ids)
67
+ output_hidden.append(hidden)
68
+ for layer_id in range(model.config.num_hidden_layers):
69
+ layer = core_model.encoder.layer[layer_id]
70
+ hidden = SkeletonBertLayer(layer_id,layer,hidden,interventions)
71
+ output_hidden.append(hidden)
72
+ logits = lm_head(hidden)
73
+ return {'logits':logits,'hidden_states':output_hidden}
skeleton_modeling_roberta.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ @torch.no_grad()
7
+ def SkeletonRobertaLayer(layer_id,layer,hidden,interventions):
8
+ attention_layer = layer.attention.self
9
+ num_heads = attention_layer.num_attention_heads
10
+ head_dim = attention_layer.attention_head_size
11
+ assert num_heads*head_dim == hidden.shape[2]
12
+
13
+ qry = attention_layer.query(hidden)
14
+ key = attention_layer.key(hidden)
15
+ val = attention_layer.value(hidden)
16
+
17
+ assert qry.shape == hidden.shape
18
+ assert key.shape == hidden.shape
19
+ assert val.shape == hidden.shape
20
+
21
+ # swap representations
22
+ reps = {
23
+ 'lay': hidden,
24
+ 'qry': qry,
25
+ 'key': key,
26
+ 'val': val,
27
+ }
28
+ for rep_type in ['lay','qry','key','val']:
29
+ interv_rep = interventions[layer_id][rep_type]
30
+ new_state = reps[rep_type].clone()
31
+ for head_id, pos, swap_ids in interv_rep:
32
+ new_state[swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
33
+ new_state[swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
34
+ reps[rep_type] = new_state.clone()
35
+
36
+ hidden = reps['lay'].clone()
37
+ qry = reps['qry'].clone()
38
+ key = reps['key'].clone()
39
+ val = reps['val'].clone()
40
+
41
+
42
+ #split into multiple heads
43
+ split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
44
+ split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
45
+ split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
46
+
47
+ #calculate the attention matrix
48
+ attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
49
+
50
+ z_rep_indiv = attn_mat@split_val
51
+ z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
52
+
53
+ hidden_post_attn_res = layer.attention.output.dense(z_rep)+hidden # residual connection
54
+ hidden_post_attn = layer.attention.output.LayerNorm(hidden_post_attn_res) # layer_norm
55
+
56
+ hidden_post_interm = layer.intermediate(hidden_post_attn) # massive feed forward
57
+ hidden_post_interm_res = layer.output.dense(hidden_post_interm)+hidden_post_attn # residual connection
58
+ new_hidden = layer.output.LayerNorm(hidden_post_interm_res) # layer_norm
59
+ return new_hidden
60
+
61
+ def SkeletonBertForMaskedLM(model,input_ids,interventions):
62
+ core_model = model.roberta
63
+ lm_head = model.lm_head
64
+ output_hidden = []
65
+ with torch.no_grad():
66
+ hidden = core_model.embeddings(input_ids)
67
+ output_hidden.append(hidden)
68
+ for layer_id in range(model.config.num_hidden_layers):
69
+ layer = core_model.encoder.layer[layer_id]
70
+ hidden = SkeletonRobertaLayer(layer_id,layer,hidden,interventions)
71
+ output_hidden.append(hidden)
72
+ logits = lm_head(hidden)
73
+ return {'logits':logits,'hidden_states':output_hidden}