stefan-it commited on
Commit
cdc6dd4
1 Parent(s): 7f03040

tools: add initial version of conversion script

Browse files
convert_token_dropping_bert_original_tf2_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This script converts a lm-head checkpoint from the "Token Dropping" implementation
17
+ into a PyTorch-compatible BERT model. The official implementation of "Token Dropping"
18
+ can be found in the TensorFlow Models repository:
19
+
20
+ https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
21
+ """
22
+ import argparse
23
+ import os
24
+ import re
25
+
26
+ import tensorflow as tf
27
+ import torch
28
+
29
+ from transformers import BertConfig, BertForMaskedLM
30
+ from transformers.models.bert.modeling_bert import (
31
+ BertIntermediate,
32
+ BertLayer,
33
+ BertOutput,
34
+ BertPooler,
35
+ BertSelfAttention,
36
+ BertSelfOutput,
37
+ )
38
+ from transformers.utils import logging
39
+
40
+
41
+ logging.set_verbosity_info()
42
+
43
+
44
+ def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
45
+ def get_masked_lm_array(name: str):
46
+ full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
47
+ array = tf.train.load_variable(tf_checkpoint_path, full_name)
48
+
49
+ #if "kernel" in name:
50
+ # array = array.transpose()
51
+
52
+ return torch.from_numpy(array)
53
+
54
+ def get_encoder_array(name: str):
55
+ full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
56
+ array = tf.train.load_variable(tf_checkpoint_path, full_name)
57
+
58
+ if "kernel" in name:
59
+ array = array.transpose()
60
+
61
+ return torch.from_numpy(array)
62
+
63
+ def get_encoder_layer_array(layer_index: int, name: str):
64
+ full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
65
+ array = tf.train.load_variable(tf_checkpoint_path, full_name)
66
+
67
+ if "kernel" in name:
68
+ array = array.transpose()
69
+
70
+ return torch.from_numpy(array)
71
+
72
+ def get_encoder_attention_layer_array(layer_index: int, name: str, orginal_shape):
73
+ full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
74
+ array = tf.train.load_variable(tf_checkpoint_path, full_name)
75
+ array = array.reshape(orginal_shape)
76
+
77
+ if "kernel" in name:
78
+ array = array.transpose()
79
+
80
+ return torch.from_numpy(array)
81
+
82
+
83
+ print(f"Loading model based on config from {config_path}...")
84
+ config = BertConfig.from_json_file(config_path)
85
+ model = BertForMaskedLM(config)
86
+
87
+ # Layers
88
+ for layer_index in range(0, config.num_hidden_layers):
89
+ layer: BertLayer = model.bert.encoder.layer[layer_index]
90
+
91
+ # Self-attention
92
+ self_attn: BertSelfAttention = layer.attention.self
93
+
94
+ self_attn.query.weight.data = get_encoder_attention_layer_array(layer_index, "_query_dense/kernel",
95
+ self_attn.query.weight.data.shape)
96
+ self_attn.query.bias.data = get_encoder_attention_layer_array(layer_index, "_query_dense/bias",
97
+ self_attn.query.bias.data.shape)
98
+ self_attn.key.weight.data = get_encoder_attention_layer_array(layer_index, "_key_dense/kernel",
99
+ self_attn.key.weight.data.shape)
100
+ self_attn.key.bias.data = get_encoder_attention_layer_array(layer_index, "_key_dense/bias",
101
+ self_attn.key.bias.data.shape)
102
+ self_attn.value.weight.data = get_encoder_attention_layer_array(layer_index, "_value_dense/kernel",
103
+ self_attn.value.weight.data.shape)
104
+ self_attn.value.bias.data = get_encoder_attention_layer_array(layer_index, "_value_dense/bias",
105
+ self_attn.value.bias.data.shape)
106
+
107
+ # Self-attention Output
108
+ self_output: BertSelfOutput = layer.attention.output
109
+
110
+ self_output.dense.weight.data = get_encoder_attention_layer_array(layer_index, "_output_dense/kernel",
111
+ self_output.dense.weight.data.shape)
112
+ self_output.dense.bias.data = get_encoder_attention_layer_array(layer_index, "_output_dense/bias",
113
+ self_output.dense.bias.data.shape)
114
+
115
+ self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
116
+ self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
117
+
118
+ # Intermediate
119
+ intermediate: BertIntermediate = layer.intermediate
120
+
121
+ intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
122
+ intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
123
+
124
+ # Output
125
+ bert_output: BertOutput = layer.output
126
+
127
+ bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
128
+ bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
129
+
130
+ bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
131
+ bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
132
+
133
+ # Embeddings
134
+ model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
135
+ model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
136
+ model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
137
+ model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
138
+
139
+ # LM Head
140
+ lm_head = model.cls.predictions.transform
141
+
142
+ lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
143
+ lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
144
+
145
+ lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
146
+ lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
147
+
148
+ # It's in the masked-lm?!
149
+ model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
150
+
151
+ # Pooling
152
+ model.bert.pooler = BertPooler(config=config)
153
+ model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
154
+ model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
155
+
156
+ # Export final model
157
+ model.save_pretrained("./")
158
+
159
+ # Integration test - should load without any errors ;)
160
+ new_model = BertForMaskedLM.from_pretrained("./")
161
+ print(new_model.eval())
162
+
163
+ print("Model conversion was done sucessfully!")
164
+
165
+
166
+ if __name__ == "__main__":
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument(
169
+ "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
170
+ )
171
+ parser.add_argument(
172
+ "--bert_config_file",
173
+ type=str,
174
+ required=True,
175
+ help="The config json file corresponding to the BERT model. This specifies the model architecture.",
176
+ )
177
+ parser.add_argument(
178
+ "--pytorch_dump_path",
179
+ type=str,
180
+ required=True,
181
+ help="Path to the output PyTorch model (must include filename).",
182
+ )
183
+ args = parser.parse_args()
184
+ convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)