import os import tensorflow as tf import torch from collections import OrderedDict tf_checkpoint_path = "chinese_GAU-alpha-char_L-24_H-768-tf/bert_model.ckpt" tf_path = os.path.abspath(tf_checkpoint_path) init_vars = tf.train.list_variables(tf_path) arrays = [] pytorch_state_dict = OrderedDict() for name, shape in init_vars: array = tf.train.load_variable(tf_path, name) new_name = ( name.replace("GAU_alpha", "gau_alpha") .replace("bert", "gau_alpha") .replace("/", ".") .replace("layer_", "layer.") .replace("kernel", "weight") .replace("gamma", "weight") ) if "embeddings" in new_name: new_name = new_name + ".weight" if "_dense" in new_name: array = array.T pytorch_state_dict[new_name] = torch.from_numpy(array) torch.save(pytorch_state_dict, "pytorch_model.bin")