stefan-it commited on
Commit
932cb24
1 Parent(s): cc010f0

tools: update conversion script

Browse files
convert_token_dropping_bert_original_tf2_checkpoint_to_pytorch.py CHANGED
@@ -46,8 +46,8 @@ def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pyt
46
  full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
47
  array = tf.train.load_variable(tf_checkpoint_path, full_name)
48
 
49
- #if "kernel" in name:
50
- # array = array.transpose()
51
 
52
  return torch.from_numpy(array)
53
 
46
  full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
47
  array = tf.train.load_variable(tf_checkpoint_path, full_name)
48
 
49
+ if "kernel" in name:
50
+ array = array.transpose()
51
 
52
  return torch.from_numpy(array)
53