File size: 9,580 Bytes
4c65bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert data2vec checkpoint."""


import argparse
import os
import pathlib

import fairseq
import torch
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version

from transformers import (
    Data2VecTextConfig,
    Data2VecTextForMaskedLM,
    Data2VecTextForSequenceClassification,
    Data2VecTextModel,
)
from transformers.models.bert.modeling_bert import (
    BertIntermediate,
    BertLayer,
    BertOutput,
    BertSelfAttention,
    BertSelfOutput,
)

# IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
# File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py
from transformers.utils import logging


if version.parse(fairseq.__version__) < version.parse("0.9.0"):
    raise Exception("requires fairseq >= 0.9.0")


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

SAMPLE_TEXT = "Hello world! cécé herlolip"


def convert_data2vec_checkpoint_to_pytorch(
    data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
):
    """
    Copy/paste/tweak data2vec's weights to our BERT structure.
    """
    data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path)
    data2vec = Data2VecTextModel.from_pretrained(
        data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name
    )
    data2vec.eval()  # disable dropout
    data2vec_model = data2vec.models[0]
    data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder
    config = Data2VecTextConfig(
        vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings,
        hidden_size=data2vec_model.args.encoder_embed_dim,
        num_hidden_layers=data2vec_model.args.encoder_layers,
        num_attention_heads=data2vec_model.args.encoder_attention_heads,
        intermediate_size=data2vec_model.args.encoder_ffn_embed_dim,
        max_position_embeddings=514,
        type_vocab_size=1,
        layer_norm_eps=1e-5,  # PyTorch default used in fairseq
    )
    if classification_head:
        config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0]
    print("Our BERT config:", config)

    model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config)
    model.eval()

    # Now let's copy all the weights.
    # Embeddings
    model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight
    model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight
    model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
        model.data2vec_text.embeddings.token_type_embeddings.weight
    )  # just zero them out b/c data2vec doesn't use them.
    model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight
    model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias

    for i in range(config.num_hidden_layers):
        # Encoder: start of layer
        layer: BertLayer = model.data2vec_text.encoder.layer[i]
        data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i]

        # self attention
        self_attn: BertSelfAttention = layer.attention.self
        assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )
        assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )
        assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )

        self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
        self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
        self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight
        self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias
        self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight
        self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias

        # self-attention output
        self_output: BertSelfOutput = layer.attention.output
        assert (
            self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape
        ), f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
        self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
        self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
        self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
        self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias

        # intermediate
        intermediate: BertIntermediate = layer.intermediate
        assert (
            intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape
        ), f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
        intermediate.dense.weight = data2vec_layer.fc1.weight
        intermediate.dense.bias = data2vec_layer.fc1.bias

        # output
        bert_output: BertOutput = layer.output
        assert (
            bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape
        ), f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
        bert_output.dense.weight = data2vec_layer.fc2.weight
        bert_output.dense.bias = data2vec_layer.fc2.bias
        bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight
        bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias
        # end of layer

    if classification_head:
        model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight
        model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias
        model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight
        model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias
    else:
        # LM Head
        model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight
        model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias
        model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight
        model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias
        model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight
        model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias

    # Let's check that we get the same results.
    input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0)  # batch of size 1

    our_output = model(input_ids)[0]
    if classification_head:
        their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids))
    else:
        their_output = data2vec_model(input_ids)[0]
    print(our_output.shape, their_output.shape)
    max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
    print(f"max_absolute_diff = {max_absolute_diff}")  # ~ 1e-7
    success = torch.allclose(our_output, their_output, atol=1e-3)
    print("Do both models output the same tensors?", "🔥" if success else "💩")
    if not success:
        raise Exception("Something went wRoNg")

    pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
    print(f"Saving model to {pytorch_dump_folder_path}")
    model.save_pretrained(pytorch_dump_folder_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
    )
    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    parser.add_argument(
        "--classification_head", action="store_true", help="Whether to convert a final classification head."
    )
    args = parser.parse_args()
    convert_data2vec_checkpoint_to_pytorch(
        args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
    )