theimageconvert2 / transformers_4_35_0 /models /data2vec /convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
mart9992's picture
m
4c65bff
raw
history blame contribute delete
No virus
9.58 kB
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert data2vec checkpoint."""
import argparse
import os
import pathlib
import fairseq
import torch
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version
from transformers import (
Data2VecTextConfig,
Data2VecTextForMaskedLM,
Data2VecTextForSequenceClassification,
Data2VecTextModel,
)
from transformers.models.bert.modeling_bert import (
BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
# IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
# File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py
from transformers.utils import logging
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SAMPLE_TEXT = "Hello world! cécé herlolip"
def convert_data2vec_checkpoint_to_pytorch(
data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
):
"""
Copy/paste/tweak data2vec's weights to our BERT structure.
"""
data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path)
data2vec = Data2VecTextModel.from_pretrained(
data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name
)
data2vec.eval() # disable dropout
data2vec_model = data2vec.models[0]
data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder
config = Data2VecTextConfig(
vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings,
hidden_size=data2vec_model.args.encoder_embed_dim,
num_hidden_layers=data2vec_model.args.encoder_layers,
num_attention_heads=data2vec_model.args.encoder_attention_heads,
intermediate_size=data2vec_model.args.encoder_ffn_embed_dim,
max_position_embeddings=514,
type_vocab_size=1,
layer_norm_eps=1e-5, # PyTorch default used in fairseq
)
if classification_head:
config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0]
print("Our BERT config:", config)
model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config)
model.eval()
# Now let's copy all the weights.
# Embeddings
model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight
model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight
model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
model.data2vec_text.embeddings.token_type_embeddings.weight
) # just zero them out b/c data2vec doesn't use them.
model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight
model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias
for i in range(config.num_hidden_layers):
# Encoder: start of layer
layer: BertLayer = model.data2vec_text.encoder.layer[i]
data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i]
# self attention
self_attn: BertSelfAttention = layer.attention.self
assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
), (
"Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
f" {torch.Size((config.hidden_size, config.hidden_size))}"
)
assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
), (
"Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
f" {torch.Size((config.hidden_size, config.hidden_size))}"
)
assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
), (
"Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
f" {torch.Size((config.hidden_size, config.hidden_size))}"
)
self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight
self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias
self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight
self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias
# self-attention output
self_output: BertSelfOutput = layer.attention.output
assert (
self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape
), f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias
# intermediate
intermediate: BertIntermediate = layer.intermediate
assert (
intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape
), f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
intermediate.dense.weight = data2vec_layer.fc1.weight
intermediate.dense.bias = data2vec_layer.fc1.bias
# output
bert_output: BertOutput = layer.output
assert (
bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape
), f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
bert_output.dense.weight = data2vec_layer.fc2.weight
bert_output.dense.bias = data2vec_layer.fc2.bias
bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias
# end of layer
if classification_head:
model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight
model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias
else:
# LM Head
model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight
model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias
model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight
model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0]
if classification_head:
their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids))
else:
their_output = data2vec_model(input_ids)[0]
print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
success = torch.allclose(our_output, their_output, atol=1e-3)
print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--classification_head", action="store_true", help="Whether to convert a final classification head."
)
args = parser.parse_args()
convert_data2vec_checkpoint_to_pytorch(
args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
)