File size: 9,580 Bytes
4c65bff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert data2vec checkpoint."""
import argparse
import os
import pathlib
import fairseq
import torch
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version
from transformers import (
Data2VecTextConfig,
Data2VecTextForMaskedLM,
Data2VecTextForSequenceClassification,
Data2VecTextModel,
)
from transformers.models.bert.modeling_bert import (
BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
# IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
# File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py
from transformers.utils import logging
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SAMPLE_TEXT = "Hello world! cécé herlolip"
def convert_data2vec_checkpoint_to_pytorch(
data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
):
"""
Copy/paste/tweak data2vec's weights to our BERT structure.
"""
data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path)
data2vec = Data2VecTextModel.from_pretrained(
data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name
)
data2vec.eval() # disable dropout
data2vec_model = data2vec.models[0]
data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder
config = Data2VecTextConfig(
vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings,
hidden_size=data2vec_model.args.encoder_embed_dim,
num_hidden_layers=data2vec_model.args.encoder_layers,
num_attention_heads=data2vec_model.args.encoder_attention_heads,
intermediate_size=data2vec_model.args.encoder_ffn_embed_dim,
max_position_embeddings=514,
type_vocab_size=1,
layer_norm_eps=1e-5, # PyTorch default used in fairseq
)
if classification_head:
config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0]
print("Our BERT config:", config)
model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config)
model.eval()
# Now let's copy all the weights.
# Embeddings
model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight
model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight
model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
model.data2vec_text.embeddings.token_type_embeddings.weight
) # just zero them out b/c data2vec doesn't use them.
model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight
model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias
for i in range(config.num_hidden_layers):
# Encoder: start of layer
layer: BertLayer = model.data2vec_text.encoder.layer[i]
data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i]
# self attention
self_attn: BertSelfAttention = layer.attention.self
assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
), (
"Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
f" {torch.Size((config.hidden_size, config.hidden_size))}"
)
assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
), (
"Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
f" {torch.Size((config.hidden_size, config.hidden_size))}"
)
assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
), (
"Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
f" {torch.Size((config.hidden_size, config.hidden_size))}"
)
self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight
self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias
self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight
self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias
# self-attention output
self_output: BertSelfOutput = layer.attention.output
assert (
self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape
), f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias
# intermediate
intermediate: BertIntermediate = layer.intermediate
assert (
intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape
), f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
intermediate.dense.weight = data2vec_layer.fc1.weight
intermediate.dense.bias = data2vec_layer.fc1.bias
# output
bert_output: BertOutput = layer.output
assert (
bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape
), f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
bert_output.dense.weight = data2vec_layer.fc2.weight
bert_output.dense.bias = data2vec_layer.fc2.bias
bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias
# end of layer
if classification_head:
model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight
model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias
else:
# LM Head
model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight
model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias
model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight
model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0]
if classification_head:
their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids))
else:
their_output = data2vec_model(input_ids)[0]
print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
success = torch.allclose(our_output, their_output, atol=1e-3)
print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--classification_head", action="store_true", help="Whether to convert a final classification head."
)
args = parser.parse_args()
convert_data2vec_checkpoint_to_pytorch(
args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
)
|