Dr. Jorge Abreu Vicente commited on
Commit
c623373
1 Parent(s): 6e8f412

Create convert_biomegatron_checkpoint.py

Browse files
Files changed (1) hide show
  1. convert_biomegatron_checkpoint.py +198 -0
convert_biomegatron_checkpoint.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import zipfile
6
+
7
+ import torch
8
+
9
+ ####################################################################################################
10
+ # This file is a modification of the original
11
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py
12
+
13
+ def recursive_print(name, val, spaces=0):
14
+ # Format the message.
15
+ if name is None:
16
+ msg = None
17
+ else:
18
+ fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
19
+ msg = fmt.format(name)
20
+
21
+ # Print and recurse (if needed).
22
+ if isinstance(val, dict):
23
+ if msg is not None:
24
+ print(msg)
25
+ for k in val.keys():
26
+ recursive_print(k, val[k], spaces + 2)
27
+ elif isinstance(val, torch.Tensor):
28
+ print(msg, ":", val.size())
29
+ else:
30
+ print(msg, ":", val)
31
+
32
+
33
+ def convert_megatron_checkpoint(input_state_dict, head_model=True):
34
+ # The converted output model.
35
+ output_state_dict = {}
36
+
37
+ # The model.
38
+ model = input_state_dict["model"]
39
+ # The language model.
40
+ lm = model["language_model"]
41
+ # The embeddings.
42
+ embeddings = lm["embedding"]
43
+
44
+ # The word embeddings.
45
+ word_embeddings = embeddings["word_embeddings"]["weight"]
46
+ # Store the word embeddings.
47
+ output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings
48
+
49
+ # The position embeddings.
50
+ pos_embeddings = embeddings["position_embeddings"]["weight"]
51
+ # Trained for 512 x 1024.
52
+ assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024
53
+ # Store the position embeddings.
54
+ output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings
55
+
56
+ # The token-type embeddings.
57
+ tokentype_embeddings = embeddings["tokentype_embeddings"]["weight"]
58
+ # Store the position embeddings.
59
+ output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings
60
+
61
+ # The transformer.
62
+ transformer = lm["transformer"]
63
+
64
+ # The regex to extract layer names.
65
+ layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
66
+
67
+ # The simple map of names for "automated" rules.
68
+ megatron_to_transformers = {
69
+ "attention.dense": ".attention.output.dense.",
70
+ "mlp.dense_h_to_4h": ".intermediate.dense.",
71
+ "mlp.dense_4h_to_h": ".output.dense.",
72
+ }
73
+
74
+ # Keep track of the attention/query/value tensor.
75
+ attention_qkv_weight = None
76
+
77
+ # Extract the layers.
78
+ for key, val in transformer.items():
79
+ # Match the name.
80
+ m = layer_re.match(key)
81
+
82
+ # Stop if that's not a layer
83
+ if m is None:
84
+ break
85
+
86
+ # The index of the layer.
87
+ layer_idx = int(m.group(1))
88
+ # The name of the operation.
89
+ op_name = m.group(2)
90
+ # Is it a weight or a bias?
91
+ weight_or_bias = m.group(3)
92
+
93
+ # The name of the layer.
94
+ layer_name = f"bert.encoder.layer.{layer_idx}"
95
+
96
+ # For layernorm(s), simply store the layer norm.
97
+ if op_name.endswith("layernorm"):
98
+
99
+ ln_name = "attention.ln" if op_name.startswith("input") else "ln"
100
+ output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
101
+
102
+ # Transpose the QKV matrix.
103
+ elif op_name == "attention.query_key_value" and weight_or_bias == "weight":
104
+
105
+ # Make sure the QKV pointer is nil.
106
+ assert attention_qkv_weight is None, ""
107
+
108
+ # Store the tensor as we need the bias as well to interleave QKV and biases.
109
+ attention_qkv_weight = val
110
+
111
+ # Transpose the bias.
112
+ elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
113
+
114
+ # Make sure we read the weight tensor.
115
+ assert attention_qkv_weight is not None, ""
116
+
117
+ # Split the QKV matrix into Q, K and V. Megatron stores Q,K,V interleaved.
118
+ q = attention_qkv_weight[0 * 1024 : 1 * 1024, :]
119
+ k = attention_qkv_weight[1 * 1024 : 2 * 1024, :]
120
+ v = attention_qkv_weight[2 * 1024 : 3 * 1024, :]
121
+
122
+ # Split the bias.
123
+ q_bias = val[0 * 1024 : 1 * 1024]
124
+ k_bias = val[1 * 1024 : 2 * 1024]
125
+ v_bias = val[2 * 1024 : 3 * 1024]
126
+
127
+ # Store.
128
+ output_state_dict[f"{layer_name}.attention.self.query.weight"] = q
129
+ output_state_dict[f"{layer_name}.attention.self.query.bias"] = q_bias
130
+ output_state_dict[f"{layer_name}.attention.self.key.weight"] = k
131
+ output_state_dict[f"{layer_name}.attention.self.key.bias"] = k_bias
132
+ output_state_dict[f"{layer_name}.attention.self.value.weight"] = v
133
+ output_state_dict[f"{layer_name}.attention.self.value.bias"] = v_bias
134
+
135
+ # Clear the stored tensor.
136
+ attention_qkv_weight = None
137
+
138
+ # Copy weights and biases as is.
139
+ elif weight_or_bias in ["weight", "bias"]:
140
+
141
+ out_name = megatron_to_transformers[op_name]
142
+ output_state_dict[layer_name + out_name + weight_or_bias] = val
143
+
144
+ # The final layernorm.
145
+ output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
146
+ output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]
147
+
148
+ # The config.
149
+ output_config = {
150
+ "vocab_size": word_embeddings.size(0),
151
+ "hidden_size": 1024,
152
+ "num_hidden_layers": 24,
153
+ "num_attention_heads": 16,
154
+ "hidden_act": "gelu_new",
155
+ "intermediate_size": 4096,
156
+ "hidden_dropout_prob": 0.1,
157
+ "attention_probs_dropout_prob": 0.1,
158
+ "max_position_embeddings": 512,
159
+ "type_vocab_size": 2,
160
+ "initializer_range": 0.2,
161
+ "layer_norm_eps": 1e-12,
162
+ "position_embedding_type": "absolute",
163
+ "use_cache": False,
164
+ "model_type": "megatron-bert",
165
+ }
166
+
167
+ if head_model:
168
+ # The pooler.
169
+ pooler = lm["pooler"]
170
+
171
+ # Store the matrix and the bias.
172
+ output_state_dict["bert.pooler.dense.weight"] = pooler["dense.weight"]
173
+ output_state_dict["bert.pooler.dense.bias"] = pooler["dense.bias"]
174
+
175
+ # The LM head from Megatron (for RACE).
176
+ lm_head = model["lm_head"]
177
+
178
+ # The transform matrix.
179
+ output_state_dict["cls.predictions.transform.dense.weight"] = lm_head["dense.weight"]
180
+ output_state_dict["cls.predictions.transform.dense.bias"] = lm_head["dense.bias"]
181
+
182
+ # The transform LN.
183
+ output_state_dict["cls.predictions.transform.LayerNorm.weight"] = lm_head["layernorm.weight"]
184
+ output_state_dict["cls.predictions.transform.LayerNorm.bias"] = lm_head["layernorm.bias"]
185
+
186
+ # For the decoder, we replicate the weights.
187
+ output_state_dict["cls.predictions.decoder.weight"] = word_embeddings
188
+ output_state_dict["cls.predictions.bias"] = lm_head["bias"]
189
+
190
+ # The classifier from Megatron (for MLNI).
191
+ binary_head = model["binary_head"]
192
+
193
+ # Store the classifier.
194
+ output_state_dict["cls.seq_relationship.weight"] = binary_head["weight"]
195
+ output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]
196
+
197
+ # It should be done!
198
+ return output_state_dict, output_config