# Adapted from https://github.com/huggingface/transformers/issues/9920#issuecomment-770970712 import torch import os import tensorflow as tf from transformers import ConvBertConfig, ConvBertForMaskedLM, ConvBertPreTrainedModel from transformers.utils import logging from operator import attrgetter logger = logging.get_logger(__name__) config_file = "/researchdisk/convbert-base-generator-finnish/config.json" tf_path = "/researchdisk/convbert-base-finnish/renamed-model.ckpt" pytorch_dump_path = "/researchdisk/convbert-base-generator-finnish" config = ConvBertConfig.from_json_file(config_file) model = ConvBertForMaskedLM(config) def load_tf_weights_in_convbert(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" try: import tensorflow as tf except ImportError: logger.error( "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " "https://www.tensorflow.org/install/ for installation instructions." ) raise tf_path = os.path.abspath(tf_checkpoint_path) logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) # Load weights from TF model init_vars = tf.train.list_variables(tf_path) tf_data = {} for name, shape in init_vars: logger.info("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(tf_path, name) tf_data[name] = array param_mapping = { "convbert.embeddings.word_embeddings.weight": "electra/embeddings/word_embeddings", "convbert.embeddings.position_embeddings.weight": "electra/embeddings/position_embeddings", "convbert.embeddings.token_type_embeddings.weight": "electra/embeddings/token_type_embeddings", "convbert.embeddings.LayerNorm.weight": "electra/embeddings/LayerNorm/gamma", "convbert.embeddings.LayerNorm.bias": "electra/embeddings/LayerNorm/beta", "convbert.embeddings_project.weight": "generator/embeddings_project/kernel", "convbert.embeddings_project.bias": "generator/embeddings_project/bias", "generator_predictions.LayerNorm.weight": "generator_predictions/LayerNorm/gamma", "generator_predictions.LayerNorm.bias": "generator_predictions/LayerNorm/beta", "generator_predictions.dense.weight": "generator_predictions/dense/kernel", "generator_predictions.dense.bias": "generator_predictions/dense/bias", "generator_lm_head.bias": "generator_predictions/output_bias" } if config.num_groups > 1: group_dense_name = "g_dense" else: group_dense_name = "dense" for j in range(config.num_hidden_layers): param_mapping[ f"convbert.encoder.layer.{j}.attention.self.query.weight" ] = f"generator/encoder/layer_{j}/attention/self/query/kernel" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.query.bias" ] = f"generator/encoder/layer_{j}/attention/self/query/bias" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.key.weight" ] = f"generator/encoder/layer_{j}/attention/self/key/kernel" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.key.bias" ] = f"generator/encoder/layer_{j}/attention/self/key/bias" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.value.weight" ] = f"generator/encoder/layer_{j}/attention/self/value/kernel" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.value.bias" ] = f"generator/encoder/layer_{j}/attention/self/value/bias" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.depthwise.weight" ] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/depthwise_kernel" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.pointwise.weight" ] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/pointwise_kernel" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.bias" ] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/bias" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.conv_kernel_layer.weight" ] = f"generator/encoder/layer_{j}/attention/self/conv_attn_kernel/kernel" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.conv_kernel_layer.bias" ] = f"generator/encoder/layer_{j}/attention/self/conv_attn_kernel/bias" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.conv_out_layer.weight" ] = f"generator/encoder/layer_{j}/attention/self/conv_attn_point/kernel" param_mapping[ f"convbert.encoder.layer.{j}.attention.self.conv_out_layer.bias" ] = f"generator/encoder/layer_{j}/attention/self/conv_attn_point/bias" param_mapping[ f"convbert.encoder.layer.{j}.attention.output.dense.weight" ] = f"generator/encoder/layer_{j}/attention/output/dense/kernel" param_mapping[ f"convbert.encoder.layer.{j}.attention.output.LayerNorm.weight" ] = f"generator/encoder/layer_{j}/attention/output/LayerNorm/gamma" param_mapping[ f"convbert.encoder.layer.{j}.attention.output.dense.bias" ] = f"generator/encoder/layer_{j}/attention/output/dense/bias" param_mapping[ f"convbert.encoder.layer.{j}.attention.output.LayerNorm.bias" ] = f"generator/encoder/layer_{j}/attention/output/LayerNorm/beta" param_mapping[ f"convbert.encoder.layer.{j}.intermediate.dense.weight" ] = f"generator/encoder/layer_{j}/intermediate/{group_dense_name}/kernel" param_mapping[ f"convbert.encoder.layer.{j}.intermediate.dense.bias" ] = f"generator/encoder/layer_{j}/intermediate/{group_dense_name}/bias" param_mapping[ f"convbert.encoder.layer.{j}.output.dense.weight" ] = f"generator/encoder/layer_{j}/output/{group_dense_name}/kernel" param_mapping[ f"convbert.encoder.layer.{j}.output.dense.bias" ] = f"generator/encoder/layer_{j}/output/{group_dense_name}/bias" param_mapping[ f"convbert.encoder.layer.{j}.output.LayerNorm.weight" ] = f"generator/encoder/layer_{j}/output/LayerNorm/gamma" param_mapping[f"convbert.encoder.layer.{j}.output.LayerNorm.bias"] = f"generator/encoder/layer_{j}/output/LayerNorm/beta" for param in model.named_parameters(): param_name = param[0] retriever = attrgetter(param_name) result = retriever(model) tf_name = param_mapping[param_name] value = torch.from_numpy(tf_data[tf_name]) logger.info(f"TF: {tf_name}, PT: {param_name} ") if tf_name.endswith("/kernel"): if not tf_name.endswith("/intermediate/g_dense/kernel"): if not tf_name.endswith("/output/g_dense/kernel"): value = value.T if tf_name.endswith("/depthwise_kernel"): value = value.permute(1, 2, 0) # 2, 0, 1 if tf_name.endswith("/pointwise_kernel"): value = value.permute(2, 1, 0) # 2, 1, 0 if tf_name.endswith("/conv_attn_key/bias"): value = value.unsqueeze(-1) result.data = value return model model = load_tf_weights_in_convbert(model, config, tf_path) model.save_pretrained(pytorch_dump_path)