|
|
|
|
|
import torch |
|
import os |
|
|
|
import tensorflow as tf |
|
|
|
from transformers import ConvBertConfig, ConvBertForMaskedLM, ConvBertPreTrainedModel |
|
from transformers.utils import logging |
|
from operator import attrgetter |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
config_file = "/researchdisk/convbert-base-generator-finnish/config.json" |
|
tf_path = "/researchdisk/convbert-base-finnish/renamed-model.ckpt" |
|
pytorch_dump_path = "/researchdisk/convbert-base-generator-finnish" |
|
config = ConvBertConfig.from_json_file(config_file) |
|
|
|
model = ConvBertForMaskedLM(config) |
|
|
|
def load_tf_weights_in_convbert(model, config, tf_checkpoint_path): |
|
"""Load tf checkpoints in a pytorch model.""" |
|
try: |
|
import tensorflow as tf |
|
except ImportError: |
|
logger.error( |
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " |
|
"https://www.tensorflow.org/install/ for installation instructions." |
|
) |
|
raise |
|
tf_path = os.path.abspath(tf_checkpoint_path) |
|
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) |
|
|
|
init_vars = tf.train.list_variables(tf_path) |
|
tf_data = {} |
|
for name, shape in init_vars: |
|
logger.info("Loading TF weight {} with shape {}".format(name, shape)) |
|
array = tf.train.load_variable(tf_path, name) |
|
tf_data[name] = array |
|
|
|
param_mapping = { |
|
"convbert.embeddings.word_embeddings.weight": "electra/embeddings/word_embeddings", |
|
"convbert.embeddings.position_embeddings.weight": "electra/embeddings/position_embeddings", |
|
"convbert.embeddings.token_type_embeddings.weight": "electra/embeddings/token_type_embeddings", |
|
"convbert.embeddings.LayerNorm.weight": "electra/embeddings/LayerNorm/gamma", |
|
"convbert.embeddings.LayerNorm.bias": "electra/embeddings/LayerNorm/beta", |
|
"convbert.embeddings_project.weight": "generator/embeddings_project/kernel", |
|
"convbert.embeddings_project.bias": "generator/embeddings_project/bias", |
|
"generator_predictions.LayerNorm.weight": "generator_predictions/LayerNorm/gamma", |
|
"generator_predictions.LayerNorm.bias": "generator_predictions/LayerNorm/beta", |
|
"generator_predictions.dense.weight": "generator_predictions/dense/kernel", |
|
"generator_predictions.dense.bias": "generator_predictions/dense/bias", |
|
"generator_lm_head.bias": "generator_predictions/output_bias" |
|
} |
|
if config.num_groups > 1: |
|
group_dense_name = "g_dense" |
|
else: |
|
group_dense_name = "dense" |
|
|
|
for j in range(config.num_hidden_layers): |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.query.weight" |
|
] = f"generator/encoder/layer_{j}/attention/self/query/kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.query.bias" |
|
] = f"generator/encoder/layer_{j}/attention/self/query/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.key.weight" |
|
] = f"generator/encoder/layer_{j}/attention/self/key/kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.key.bias" |
|
] = f"generator/encoder/layer_{j}/attention/self/key/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.value.weight" |
|
] = f"generator/encoder/layer_{j}/attention/self/value/kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.value.bias" |
|
] = f"generator/encoder/layer_{j}/attention/self/value/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.depthwise.weight" |
|
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/depthwise_kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.pointwise.weight" |
|
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/pointwise_kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.key_conv_attn_layer.bias" |
|
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_key/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.conv_kernel_layer.weight" |
|
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_kernel/kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.conv_kernel_layer.bias" |
|
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_kernel/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.conv_out_layer.weight" |
|
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_point/kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.self.conv_out_layer.bias" |
|
] = f"generator/encoder/layer_{j}/attention/self/conv_attn_point/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.output.dense.weight" |
|
] = f"generator/encoder/layer_{j}/attention/output/dense/kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.output.LayerNorm.weight" |
|
] = f"generator/encoder/layer_{j}/attention/output/LayerNorm/gamma" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.output.dense.bias" |
|
] = f"generator/encoder/layer_{j}/attention/output/dense/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.attention.output.LayerNorm.bias" |
|
] = f"generator/encoder/layer_{j}/attention/output/LayerNorm/beta" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.intermediate.dense.weight" |
|
] = f"generator/encoder/layer_{j}/intermediate/{group_dense_name}/kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.intermediate.dense.bias" |
|
] = f"generator/encoder/layer_{j}/intermediate/{group_dense_name}/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.output.dense.weight" |
|
] = f"generator/encoder/layer_{j}/output/{group_dense_name}/kernel" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.output.dense.bias" |
|
] = f"generator/encoder/layer_{j}/output/{group_dense_name}/bias" |
|
param_mapping[ |
|
f"convbert.encoder.layer.{j}.output.LayerNorm.weight" |
|
] = f"generator/encoder/layer_{j}/output/LayerNorm/gamma" |
|
param_mapping[f"convbert.encoder.layer.{j}.output.LayerNorm.bias"] = f"generator/encoder/layer_{j}/output/LayerNorm/beta" |
|
|
|
for param in model.named_parameters(): |
|
param_name = param[0] |
|
retriever = attrgetter(param_name) |
|
result = retriever(model) |
|
tf_name = param_mapping[param_name] |
|
value = torch.from_numpy(tf_data[tf_name]) |
|
logger.info(f"TF: {tf_name}, PT: {param_name} ") |
|
if tf_name.endswith("/kernel"): |
|
if not tf_name.endswith("/intermediate/g_dense/kernel"): |
|
if not tf_name.endswith("/output/g_dense/kernel"): |
|
value = value.T |
|
if tf_name.endswith("/depthwise_kernel"): |
|
value = value.permute(1, 2, 0) |
|
if tf_name.endswith("/pointwise_kernel"): |
|
value = value.permute(2, 1, 0) |
|
if tf_name.endswith("/conv_attn_key/bias"): |
|
value = value.unsqueeze(-1) |
|
result.data = value |
|
return model |
|
|
|
model = load_tf_weights_in_convbert(model, config, tf_path) |
|
model.save_pretrained(pytorch_dump_path) |