theimageconvert2
/
transformers_4_35_0
/models
/data2vec
/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
# coding=utf-8 | |
# Copyright 2022 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Convert data2vec checkpoint.""" | |
import argparse | |
import os | |
import pathlib | |
import fairseq | |
import torch | |
from fairseq.modules import TransformerSentenceEncoderLayer | |
from packaging import version | |
from transformers import ( | |
Data2VecTextConfig, | |
Data2VecTextForMaskedLM, | |
Data2VecTextForSequenceClassification, | |
Data2VecTextModel, | |
) | |
from transformers.models.bert.modeling_bert import ( | |
BertIntermediate, | |
BertLayer, | |
BertOutput, | |
BertSelfAttention, | |
BertSelfOutput, | |
) | |
# IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz | |
# File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py | |
from transformers.utils import logging | |
if version.parse(fairseq.__version__) < version.parse("0.9.0"): | |
raise Exception("requires fairseq >= 0.9.0") | |
logging.set_verbosity_info() | |
logger = logging.get_logger(__name__) | |
SAMPLE_TEXT = "Hello world! cécé herlolip" | |
def convert_data2vec_checkpoint_to_pytorch( | |
data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool | |
): | |
""" | |
Copy/paste/tweak data2vec's weights to our BERT structure. | |
""" | |
data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path) | |
data2vec = Data2VecTextModel.from_pretrained( | |
data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name | |
) | |
data2vec.eval() # disable dropout | |
data2vec_model = data2vec.models[0] | |
data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder | |
config = Data2VecTextConfig( | |
vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings, | |
hidden_size=data2vec_model.args.encoder_embed_dim, | |
num_hidden_layers=data2vec_model.args.encoder_layers, | |
num_attention_heads=data2vec_model.args.encoder_attention_heads, | |
intermediate_size=data2vec_model.args.encoder_ffn_embed_dim, | |
max_position_embeddings=514, | |
type_vocab_size=1, | |
layer_norm_eps=1e-5, # PyTorch default used in fairseq | |
) | |
if classification_head: | |
config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0] | |
print("Our BERT config:", config) | |
model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config) | |
model.eval() | |
# Now let's copy all the weights. | |
# Embeddings | |
model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight | |
model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight | |
model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like( | |
model.data2vec_text.embeddings.token_type_embeddings.weight | |
) # just zero them out b/c data2vec doesn't use them. | |
model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight | |
model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias | |
for i in range(config.num_hidden_layers): | |
# Encoder: start of layer | |
layer: BertLayer = model.data2vec_text.encoder.layer[i] | |
data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i] | |
# self attention | |
self_attn: BertSelfAttention = layer.attention.self | |
assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size( | |
(config.hidden_size, config.hidden_size) | |
), ( | |
"Shape for data2vec_layer.self_attn.k_proj.weight.data should be" | |
f" {torch.Size((config.hidden_size, config.hidden_size))}" | |
) | |
assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size( | |
(config.hidden_size, config.hidden_size) | |
), ( | |
"Shape for data2vec_layer.self_attn.q_proj.weight.data should be" | |
f" {torch.Size((config.hidden_size, config.hidden_size))}" | |
) | |
assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size( | |
(config.hidden_size, config.hidden_size) | |
), ( | |
"Shape for data2vec_layer.self_attn.v_proj.weight.data should be" | |
f" {torch.Size((config.hidden_size, config.hidden_size))}" | |
) | |
self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight | |
self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias | |
self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight | |
self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias | |
self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight | |
self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias | |
# self-attention output | |
self_output: BertSelfOutput = layer.attention.output | |
assert ( | |
self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape | |
), f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}" | |
self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight | |
self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias | |
self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight | |
self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias | |
# intermediate | |
intermediate: BertIntermediate = layer.intermediate | |
assert ( | |
intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape | |
), f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}" | |
intermediate.dense.weight = data2vec_layer.fc1.weight | |
intermediate.dense.bias = data2vec_layer.fc1.bias | |
# output | |
bert_output: BertOutput = layer.output | |
assert ( | |
bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape | |
), f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}" | |
bert_output.dense.weight = data2vec_layer.fc2.weight | |
bert_output.dense.bias = data2vec_layer.fc2.bias | |
bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight | |
bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias | |
# end of layer | |
if classification_head: | |
model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight | |
model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias | |
model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight | |
model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias | |
else: | |
# LM Head | |
model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight | |
model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias | |
model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight | |
model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias | |
model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight | |
model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias | |
# Let's check that we get the same results. | |
input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 | |
our_output = model(input_ids)[0] | |
if classification_head: | |
their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids)) | |
else: | |
their_output = data2vec_model(input_ids)[0] | |
print(our_output.shape, their_output.shape) | |
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() | |
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7 | |
success = torch.allclose(our_output, their_output, atol=1e-3) | |
print("Do both models output the same tensors?", "🔥" if success else "💩") | |
if not success: | |
raise Exception("Something went wRoNg") | |
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True) | |
print(f"Saving model to {pytorch_dump_folder_path}") | |
model.save_pretrained(pytorch_dump_folder_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument( | |
"--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." | |
) | |
parser.add_argument( | |
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." | |
) | |
parser.add_argument( | |
"--classification_head", action="store_true", help="Whether to convert a final classification head." | |
) | |
args = parser.parse_args() | |
convert_data2vec_checkpoint_to_pytorch( | |
args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head | |
) | |