File size: 7,187 Bytes
9e54323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import argparse
import json
import os
import re
import zipfile

import torch

####################################################################################################
# This code is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py

def recursive_print(name, val, spaces=0):
    # Format the message.
    if name is None:
        msg = None
    else:
        fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
        msg = fmt.format(name)

    # Print and recurse (if needed).
    if isinstance(val, dict):
        if msg is not None:
            print(msg)
        for k in val.keys():
            recursive_print(k, val[k], spaces + 2)
    elif isinstance(val, torch.Tensor):
        print(msg, ":", val.size())
    else:
        print(msg, ":", val)

        
def convert_megatron_checkpoint(input_state_dict, head_model=True):
    # The converted output model.
    output_state_dict = {}

    # The model.
    model = input_state_dict["model"]
    # The language model.
    lm = model["language_model"]
    # The embeddings.
    embeddings = lm["embedding"]

    # The word embeddings.
    word_embeddings = embeddings["word_embeddings"]["weight"]
    # Store the word embeddings.
    output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings

    # The position embeddings.
    pos_embeddings = embeddings["position_embeddings"]["weight"]
    # Trained for 512 x 1024.
    assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024
    # Store the position embeddings.
    output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings

    # The token-type embeddings.
    tokentype_embeddings = embeddings["tokentype_embeddings"]["weight"]
    # Store the position embeddings.
    output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings

    # The transformer.
    transformer = lm["transformer"]

    # The regex to extract layer names.
    layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # The simple map of names for "automated" rules.
    megatron_to_transformers = {
        "attention.dense": ".attention.output.dense.",
        "mlp.dense_h_to_4h": ".intermediate.dense.",
        "mlp.dense_4h_to_h": ".output.dense.",
    }

    # Keep track of the attention/query/value tensor.
    attention_qkv_weight = None

    # Extract the layers.
    for key, val in transformer.items():
        # Match the name.
        m = layer_re.match(key)

        # Stop if that's not a layer
        if m is None:
            break

        # The index of the layer.
        layer_idx = int(m.group(1))
        # The name of the operation.
        op_name = m.group(2)
        # Is it a weight or a bias?
        weight_or_bias = m.group(3)

        # The name of the layer.
        layer_name = f"bert.encoder.layer.{layer_idx}"

        # For layernorm(s), simply store the layer norm.
        if op_name.endswith("layernorm"):

            ln_name = "attention.ln" if op_name.startswith("input") else "ln"
            output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val

        # Transpose the QKV matrix.
        elif op_name == "attention.query_key_value" and weight_or_bias == "weight":

            # Make sure the QKV pointer is nil.
            assert attention_qkv_weight is None, ""

            # Store the tensor as we need the bias as well to interleave QKV and biases.
            attention_qkv_weight = val

        # Transpose the bias.
        elif op_name == "attention.query_key_value" and weight_or_bias == "bias":

            # Make sure we read the weight tensor.
            assert attention_qkv_weight is not None, ""

            # Split the QKV matrix into Q, K and V. Megatron stores Q,K,V interleaved.
            q = attention_qkv_weight[0 * 1024 : 1 * 1024, :]
            k = attention_qkv_weight[1 * 1024 : 2 * 1024, :]
            v = attention_qkv_weight[2 * 1024 : 3 * 1024, :]

            # Split the bias.
            q_bias = val[0 * 1024 : 1 * 1024]
            k_bias = val[1 * 1024 : 2 * 1024]
            v_bias = val[2 * 1024 : 3 * 1024]

            # Store.
            output_state_dict[f"{layer_name}.attention.self.query.weight"] = q
            output_state_dict[f"{layer_name}.attention.self.query.bias"] = q_bias
            output_state_dict[f"{layer_name}.attention.self.key.weight"] = k
            output_state_dict[f"{layer_name}.attention.self.key.bias"] = k_bias
            output_state_dict[f"{layer_name}.attention.self.value.weight"] = v
            output_state_dict[f"{layer_name}.attention.self.value.bias"] = v_bias

            # Clear the stored tensor.
            attention_qkv_weight = None

        # Copy weights and biases as is.
        elif weight_or_bias in ["weight", "bias"]:

            out_name = megatron_to_transformers[op_name]
            output_state_dict[layer_name + out_name + weight_or_bias] = val

    # The final layernorm.
    output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
    output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]

    # The config.
    output_config = {
        "vocab_size": word_embeddings.size(0),
        "hidden_size": 1024,
        "num_hidden_layers": 24,
        "num_attention_heads": 16,
        "hidden_act": "gelu_new",
        "intermediate_size": 4096,
        "hidden_dropout_prob": 0.1,
        "attention_probs_dropout_prob": 0.1,
        "max_position_embeddings": 512,
        "type_vocab_size": 2,
        "initializer_range": 0.2,
        "layer_norm_eps": 1e-12,
        "position_embedding_type": "absolute",
        "use_cache": False,
        "model_type": "megatron-bert",
    }

    if head_model:
        # The pooler.
        pooler = lm["pooler"]

        # Store the matrix and the bias.
        output_state_dict["bert.pooler.dense.weight"] = pooler["dense.weight"]
        output_state_dict["bert.pooler.dense.bias"] = pooler["dense.bias"]

        # The LM head from Megatron (for RACE).
        lm_head = model["lm_head"]

        # The transform matrix.
        output_state_dict["cls.predictions.transform.dense.weight"] = lm_head["dense.weight"]
        output_state_dict["cls.predictions.transform.dense.bias"] = lm_head["dense.bias"]

        # The transform LN.
        output_state_dict["cls.predictions.transform.LayerNorm.weight"] = lm_head["layernorm.weight"]
        output_state_dict["cls.predictions.transform.LayerNorm.bias"] = lm_head["layernorm.bias"]

        # For the decoder, we replicate the weights.
        output_state_dict["cls.predictions.decoder.weight"] = word_embeddings
        output_state_dict["cls.predictions.bias"] = lm_head["bias"]

        # The classifier from Megatron (for MLNI).
        binary_head = model["binary_head"]

        # Store the classifier.
        output_state_dict["cls.seq_relationship.weight"] = binary_head["weight"]
        output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]

    # It should be done!
    return output_state_dict, output_config