convbert-base-generator-finnish / convert_original_convbert_tf_checkpoint_to_generator_pytorch.py
aapot
Add convbert generator model
3323445
# Adapted from https://github.com/huggingface/transformers/issues/9920#issuecomment-770970712
import torch
import os
import tensorflow as tf
from transformers import ConvBertConfig, ConvBertForMaskedLM, ConvBertPreTrainedModel
from transformers.utils import logging
from operator import attrgetter
logger = logging.get_logger(__name__)
config_file = "/researchdisk/convbert-base-generator-finnish/config.json"
tf_path = "/researchdisk/convbert-base-finnish/renamed-model.ckpt"
pytorch_dump_path = "/researchdisk/convbert-base-generator-finnish"
config = ConvBertConfig.from_json_file(config_file)
model = ConvBertForMaskedLM(config)
def load_tf_weights_in_convbert(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_data = {}
for name, shape in init_vars:
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_data[name] = array
param_mapping = {
"convbert.embeddings.word_embeddings.weight": "electra/embeddings/word_embeddings",
"convbert.embeddings.position_embeddings.weight": "electra/embeddings/position_embeddings",
"convbert.embeddings.token_type_embeddings.weight": "electra/embeddings/token_type_embeddings",
"convbert.embeddings.LayerNorm.weight": "electra/embeddings/LayerNorm/gamma",
"convbert.embeddings.LayerNorm.bias": "electra/embeddings/LayerNorm/beta",
"convbert.embeddings_project.weight": "generator/embeddings_project/kernel",
"convbert.embeddings_project.bias": "generator/embeddings_project/bias",
"generator_predictions.LayerNorm.weight": "generator_predictions/LayerNorm/gamma",
"generator_predictions.LayerNorm.bias": "generator_predictions/LayerNorm/beta",
"generator_predictions.dense.weight": "generator_predictions/dense/kernel",
"generator_predictions.dense.bias": "generator_predictions/dense/bias",
"generator_lm_head.bias": "generator_predictions/output_bias"
}
if config.num_groups > 1:
group_dense_name = "g_dense"
else:
group_dense_name = "dense"
for j in range(config.num_hidden_layers):
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.query.weight"
] = f"generator/encoder/layer_{j}/attention/self/query/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.query.bias"
] = f"generator/encoder/layer_{j}/attention/self/query/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key.weight"
] = f"generator/encoder/layer_{j}/attention/self/key/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key.bias"
] = f"generator/encoder/layer_{j}/attention/self/key/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.value.weight"
] = f"generator/encoder/layer_{j}/attention/self/value/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.value.bias"
] = f"generator/encoder/layer_{j}/attention/self/value/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.depthwise.weight"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/depthwise_kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.pointwise.weight"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/pointwise_kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.bias"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.conv_kernel_layer.weight"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_kernel/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.conv_kernel_layer.bias"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_kernel/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.conv_out_layer.weight"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_point/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.conv_out_layer.bias"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_point/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.output.dense.weight"
] = f"generator/encoder/layer_{j}/attention/output/dense/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.output.LayerNorm.weight"
] = f"generator/encoder/layer_{j}/attention/output/LayerNorm/gamma"
param_mapping[
f"convbert.encoder.layer.{j}.attention.output.dense.bias"
] = f"generator/encoder/layer_{j}/attention/output/dense/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.output.LayerNorm.bias"
] = f"generator/encoder/layer_{j}/attention/output/LayerNorm/beta"
param_mapping[
f"convbert.encoder.layer.{j}.intermediate.dense.weight"
] = f"generator/encoder/layer_{j}/intermediate/{group_dense_name}/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.intermediate.dense.bias"
] = f"generator/encoder/layer_{j}/intermediate/{group_dense_name}/bias"
param_mapping[
f"convbert.encoder.layer.{j}.output.dense.weight"
] = f"generator/encoder/layer_{j}/output/{group_dense_name}/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.output.dense.bias"
] = f"generator/encoder/layer_{j}/output/{group_dense_name}/bias"
param_mapping[
f"convbert.encoder.layer.{j}.output.LayerNorm.weight"
] = f"generator/encoder/layer_{j}/output/LayerNorm/gamma"
param_mapping[f"convbert.encoder.layer.{j}.output.LayerNorm.bias"] = f"generator/encoder/layer_{j}/output/LayerNorm/beta"
for param in model.named_parameters():
param_name = param[0]
retriever = attrgetter(param_name)
result = retriever(model)
tf_name = param_mapping[param_name]
value = torch.from_numpy(tf_data[tf_name])
logger.info(f"TF: {tf_name}, PT: {param_name} ")
if tf_name.endswith("/kernel"):
if not tf_name.endswith("/intermediate/g_dense/kernel"):
if not tf_name.endswith("/output/g_dense/kernel"):
value = value.T
if tf_name.endswith("/depthwise_kernel"):
value = value.permute(1, 2, 0) # 2, 0, 1
if tf_name.endswith("/pointwise_kernel"):
value = value.permute(2, 1, 0) # 2, 1, 0
if tf_name.endswith("/conv_attn_key/bias"):
value = value.unsqueeze(-1)
result.data = value
return model
model = load_tf_weights_in_convbert(model, config, tf_path)
model.save_pretrained(pytorch_dump_path)