File size: 7,580 Bytes
3323445 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Adapted from https://github.com/huggingface/transformers/issues/9920#issuecomment-770970712
import torch
import os
import tensorflow as tf
from transformers import ConvBertConfig, ConvBertForMaskedLM, ConvBertPreTrainedModel
from transformers.utils import logging
from operator import attrgetter
logger = logging.get_logger(__name__)
config_file = "/researchdisk/convbert-base-generator-finnish/config.json"
tf_path = "/researchdisk/convbert-base-finnish/renamed-model.ckpt"
pytorch_dump_path = "/researchdisk/convbert-base-generator-finnish"
config = ConvBertConfig.from_json_file(config_file)
model = ConvBertForMaskedLM(config)
def load_tf_weights_in_convbert(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_data = {}
for name, shape in init_vars:
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_data[name] = array
param_mapping = {
"convbert.embeddings.word_embeddings.weight": "electra/embeddings/word_embeddings",
"convbert.embeddings.position_embeddings.weight": "electra/embeddings/position_embeddings",
"convbert.embeddings.token_type_embeddings.weight": "electra/embeddings/token_type_embeddings",
"convbert.embeddings.LayerNorm.weight": "electra/embeddings/LayerNorm/gamma",
"convbert.embeddings.LayerNorm.bias": "electra/embeddings/LayerNorm/beta",
"convbert.embeddings_project.weight": "generator/embeddings_project/kernel",
"convbert.embeddings_project.bias": "generator/embeddings_project/bias",
"generator_predictions.LayerNorm.weight": "generator_predictions/LayerNorm/gamma",
"generator_predictions.LayerNorm.bias": "generator_predictions/LayerNorm/beta",
"generator_predictions.dense.weight": "generator_predictions/dense/kernel",
"generator_predictions.dense.bias": "generator_predictions/dense/bias",
"generator_lm_head.bias": "generator_predictions/output_bias"
}
if config.num_groups > 1:
group_dense_name = "g_dense"
else:
group_dense_name = "dense"
for j in range(config.num_hidden_layers):
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.query.weight"
] = f"generator/encoder/layer_{j}/attention/self/query/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.query.bias"
] = f"generator/encoder/layer_{j}/attention/self/query/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key.weight"
] = f"generator/encoder/layer_{j}/attention/self/key/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key.bias"
] = f"generator/encoder/layer_{j}/attention/self/key/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.value.weight"
] = f"generator/encoder/layer_{j}/attention/self/value/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.value.bias"
] = f"generator/encoder/layer_{j}/attention/self/value/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.depthwise.weight"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/depthwise_kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.pointwise.weight"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/pointwise_kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.bias"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.conv_kernel_layer.weight"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_kernel/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.conv_kernel_layer.bias"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_kernel/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.conv_out_layer.weight"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_point/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.self.conv_out_layer.bias"
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_point/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.output.dense.weight"
] = f"generator/encoder/layer_{j}/attention/output/dense/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.attention.output.LayerNorm.weight"
] = f"generator/encoder/layer_{j}/attention/output/LayerNorm/gamma"
param_mapping[
f"convbert.encoder.layer.{j}.attention.output.dense.bias"
] = f"generator/encoder/layer_{j}/attention/output/dense/bias"
param_mapping[
f"convbert.encoder.layer.{j}.attention.output.LayerNorm.bias"
] = f"generator/encoder/layer_{j}/attention/output/LayerNorm/beta"
param_mapping[
f"convbert.encoder.layer.{j}.intermediate.dense.weight"
] = f"generator/encoder/layer_{j}/intermediate/{group_dense_name}/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.intermediate.dense.bias"
] = f"generator/encoder/layer_{j}/intermediate/{group_dense_name}/bias"
param_mapping[
f"convbert.encoder.layer.{j}.output.dense.weight"
] = f"generator/encoder/layer_{j}/output/{group_dense_name}/kernel"
param_mapping[
f"convbert.encoder.layer.{j}.output.dense.bias"
] = f"generator/encoder/layer_{j}/output/{group_dense_name}/bias"
param_mapping[
f"convbert.encoder.layer.{j}.output.LayerNorm.weight"
] = f"generator/encoder/layer_{j}/output/LayerNorm/gamma"
param_mapping[f"convbert.encoder.layer.{j}.output.LayerNorm.bias"] = f"generator/encoder/layer_{j}/output/LayerNorm/beta"
for param in model.named_parameters():
param_name = param[0]
retriever = attrgetter(param_name)
result = retriever(model)
tf_name = param_mapping[param_name]
value = torch.from_numpy(tf_data[tf_name])
logger.info(f"TF: {tf_name}, PT: {param_name} ")
if tf_name.endswith("/kernel"):
if not tf_name.endswith("/intermediate/g_dense/kernel"):
if not tf_name.endswith("/output/g_dense/kernel"):
value = value.T
if tf_name.endswith("/depthwise_kernel"):
value = value.permute(1, 2, 0) # 2, 0, 1
if tf_name.endswith("/pointwise_kernel"):
value = value.permute(2, 1, 0) # 2, 1, 0
if tf_name.endswith("/conv_attn_key/bias"):
value = value.unsqueeze(-1)
result.data = value
return model
model = load_tf_weights_in_convbert(model, config, tf_path)
model.save_pretrained(pytorch_dump_path) |