Dr. Jorge Abreu Vicente commited on
Commit
9e54323
1 Parent(s): fd77cf7

Create convert_biomegatron_checkpoint.py

Browse files
Files changed (1) hide show
  1. convert_biomegatron_checkpoint.py +197 -0
convert_biomegatron_checkpoint.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import zipfile
6
+
7
+ import torch
8
+
9
+ ####################################################################################################
10
+ # This code is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py
11
+
12
+ def recursive_print(name, val, spaces=0):
13
+ # Format the message.
14
+ if name is None:
15
+ msg = None
16
+ else:
17
+ fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
18
+ msg = fmt.format(name)
19
+
20
+ # Print and recurse (if needed).
21
+ if isinstance(val, dict):
22
+ if msg is not None:
23
+ print(msg)
24
+ for k in val.keys():
25
+ recursive_print(k, val[k], spaces + 2)
26
+ elif isinstance(val, torch.Tensor):
27
+ print(msg, ":", val.size())
28
+ else:
29
+ print(msg, ":", val)
30
+
31
+
32
+ def convert_megatron_checkpoint(input_state_dict, head_model=True):
33
+ # The converted output model.
34
+ output_state_dict = {}
35
+
36
+ # The model.
37
+ model = input_state_dict["model"]
38
+ # The language model.
39
+ lm = model["language_model"]
40
+ # The embeddings.
41
+ embeddings = lm["embedding"]
42
+
43
+ # The word embeddings.
44
+ word_embeddings = embeddings["word_embeddings"]["weight"]
45
+ # Store the word embeddings.
46
+ output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings
47
+
48
+ # The position embeddings.
49
+ pos_embeddings = embeddings["position_embeddings"]["weight"]
50
+ # Trained for 512 x 1024.
51
+ assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024
52
+ # Store the position embeddings.
53
+ output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings
54
+
55
+ # The token-type embeddings.
56
+ tokentype_embeddings = embeddings["tokentype_embeddings"]["weight"]
57
+ # Store the position embeddings.
58
+ output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings
59
+
60
+ # The transformer.
61
+ transformer = lm["transformer"]
62
+
63
+ # The regex to extract layer names.
64
+ layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
65
+
66
+ # The simple map of names for "automated" rules.
67
+ megatron_to_transformers = {
68
+ "attention.dense": ".attention.output.dense.",
69
+ "mlp.dense_h_to_4h": ".intermediate.dense.",
70
+ "mlp.dense_4h_to_h": ".output.dense.",
71
+ }
72
+
73
+ # Keep track of the attention/query/value tensor.
74
+ attention_qkv_weight = None
75
+
76
+ # Extract the layers.
77
+ for key, val in transformer.items():
78
+ # Match the name.
79
+ m = layer_re.match(key)
80
+
81
+ # Stop if that's not a layer
82
+ if m is None:
83
+ break
84
+
85
+ # The index of the layer.
86
+ layer_idx = int(m.group(1))
87
+ # The name of the operation.
88
+ op_name = m.group(2)
89
+ # Is it a weight or a bias?
90
+ weight_or_bias = m.group(3)
91
+
92
+ # The name of the layer.
93
+ layer_name = f"bert.encoder.layer.{layer_idx}"
94
+
95
+ # For layernorm(s), simply store the layer norm.
96
+ if op_name.endswith("layernorm"):
97
+
98
+ ln_name = "attention.ln" if op_name.startswith("input") else "ln"
99
+ output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
100
+
101
+ # Transpose the QKV matrix.
102
+ elif op_name == "attention.query_key_value" and weight_or_bias == "weight":
103
+
104
+ # Make sure the QKV pointer is nil.
105
+ assert attention_qkv_weight is None, ""
106
+
107
+ # Store the tensor as we need the bias as well to interleave QKV and biases.
108
+ attention_qkv_weight = val
109
+
110
+ # Transpose the bias.
111
+ elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
112
+
113
+ # Make sure we read the weight tensor.
114
+ assert attention_qkv_weight is not None, ""
115
+
116
+ # Split the QKV matrix into Q, K and V. Megatron stores Q,K,V interleaved.
117
+ q = attention_qkv_weight[0 * 1024 : 1 * 1024, :]
118
+ k = attention_qkv_weight[1 * 1024 : 2 * 1024, :]
119
+ v = attention_qkv_weight[2 * 1024 : 3 * 1024, :]
120
+
121
+ # Split the bias.
122
+ q_bias = val[0 * 1024 : 1 * 1024]
123
+ k_bias = val[1 * 1024 : 2 * 1024]
124
+ v_bias = val[2 * 1024 : 3 * 1024]
125
+
126
+ # Store.
127
+ output_state_dict[f"{layer_name}.attention.self.query.weight"] = q
128
+ output_state_dict[f"{layer_name}.attention.self.query.bias"] = q_bias
129
+ output_state_dict[f"{layer_name}.attention.self.key.weight"] = k
130
+ output_state_dict[f"{layer_name}.attention.self.key.bias"] = k_bias
131
+ output_state_dict[f"{layer_name}.attention.self.value.weight"] = v
132
+ output_state_dict[f"{layer_name}.attention.self.value.bias"] = v_bias
133
+
134
+ # Clear the stored tensor.
135
+ attention_qkv_weight = None
136
+
137
+ # Copy weights and biases as is.
138
+ elif weight_or_bias in ["weight", "bias"]:
139
+
140
+ out_name = megatron_to_transformers[op_name]
141
+ output_state_dict[layer_name + out_name + weight_or_bias] = val
142
+
143
+ # The final layernorm.
144
+ output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
145
+ output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]
146
+
147
+ # The config.
148
+ output_config = {
149
+ "vocab_size": word_embeddings.size(0),
150
+ "hidden_size": 1024,
151
+ "num_hidden_layers": 24,
152
+ "num_attention_heads": 16,
153
+ "hidden_act": "gelu_new",
154
+ "intermediate_size": 4096,
155
+ "hidden_dropout_prob": 0.1,
156
+ "attention_probs_dropout_prob": 0.1,
157
+ "max_position_embeddings": 512,
158
+ "type_vocab_size": 2,
159
+ "initializer_range": 0.2,
160
+ "layer_norm_eps": 1e-12,
161
+ "position_embedding_type": "absolute",
162
+ "use_cache": False,
163
+ "model_type": "megatron-bert",
164
+ }
165
+
166
+ if head_model:
167
+ # The pooler.
168
+ pooler = lm["pooler"]
169
+
170
+ # Store the matrix and the bias.
171
+ output_state_dict["bert.pooler.dense.weight"] = pooler["dense.weight"]
172
+ output_state_dict["bert.pooler.dense.bias"] = pooler["dense.bias"]
173
+
174
+ # The LM head from Megatron (for RACE).
175
+ lm_head = model["lm_head"]
176
+
177
+ # The transform matrix.
178
+ output_state_dict["cls.predictions.transform.dense.weight"] = lm_head["dense.weight"]
179
+ output_state_dict["cls.predictions.transform.dense.bias"] = lm_head["dense.bias"]
180
+
181
+ # The transform LN.
182
+ output_state_dict["cls.predictions.transform.LayerNorm.weight"] = lm_head["layernorm.weight"]
183
+ output_state_dict["cls.predictions.transform.LayerNorm.bias"] = lm_head["layernorm.bias"]
184
+
185
+ # For the decoder, we replicate the weights.
186
+ output_state_dict["cls.predictions.decoder.weight"] = word_embeddings
187
+ output_state_dict["cls.predictions.bias"] = lm_head["bias"]
188
+
189
+ # The classifier from Megatron (for MLNI).
190
+ binary_head = model["binary_head"]
191
+
192
+ # Store the classifier.
193
+ output_state_dict["cls.seq_relationship.weight"] = binary_head["weight"]
194
+ output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]
195
+
196
+ # It should be done!
197
+ return output_state_dict, output_config