taka-yamakoshi commited on
Commit
d1e605d
1 Parent(s): 340640b
Files changed (1) hide show
  1. skeleton_modeling_albert.py +17 -20
skeleton_modeling_albert.py CHANGED
@@ -21,27 +21,24 @@ def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
21
  assert val.shape == hidden.shape
22
 
23
  # swap representations
24
- interv_layer = interventions.pop(layer_id,None)
25
- if interv_layer is not None:
26
- reps = {
27
- 'lay': hidden,
28
- 'qry': qry,
29
- 'key': key,
30
- 'val': val,
31
- }
32
- for rep_type in ['lay','qry','key','val']:
33
- interv_rep = interv_layer.pop(rep_type,None)
34
- if interv_rep is not None:
35
- new_state = reps[rep_type].clone()
36
- for head_id, pos, swap_ids in interv_rep:
37
- new_state[swap_ids[0],pos,head_dim*head_id:head_dim*(head_id+1)] = reps[rep_type][swap_ids[1],pos,head_dim*head_id:head_dim*(head_id+1)]
38
- new_state[swap_ids[1],pos,head_dim*head_id:head_dim*(head_id+1)] = reps[rep_type][swap_ids[0],pos,head_dim*head_id:head_dim*(head_id+1)]
39
- reps[rep_type] = new_state.clone()
40
 
41
- hidden = reps['lay'].clone()
42
- qry = reps['qry'].clone()
43
- key = reps['key'].clone()
44
- val = reps['val'].clone()
45
 
46
 
47
  #split into multiple heads
 
21
  assert val.shape == hidden.shape
22
 
23
  # swap representations
24
+ reps = {
25
+ 'lay': hidden,
26
+ 'qry': qry,
27
+ 'key': key,
28
+ 'val': val,
29
+ }
30
+ for rep_type in ['lay','qry','key','val']:
31
+ interv_rep = interventions[layer_id][rep_type]
32
+ new_state = reps[rep_type].clone()
33
+ for head_id, pos, swap_ids in interv_rep:
34
+ new_state[swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
35
+ new_state[swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
36
+ reps[rep_type] = new_state.clone()
 
 
 
37
 
38
+ hidden = reps['lay'].clone()
39
+ qry = reps['qry'].clone()
40
+ key = reps['key'].clone()
41
+ val = reps['val'].clone()
42
 
43
 
44
  #split into multiple heads