File size: 4,478 Bytes
e086ae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# roberta-small预训练模型tensorflow转pytorch
# 源项目:https://github.com/ZhuiyiTechnology/pretrained-models
# roberta-small下载: https://open.zhuiyi.ai/releases/nlp/models/zhuiyi/chinese_roberta_L-6_H-384_A-12.zip
# 注意苏神版本的roberta-small/tiny的ckpt无pooler层, 区别于bert base转换脚本需要删除pooler层
# 使用的时候需要with_pool=False, 否则会有warnings, CLS的输出直接按last_hidden_state[:, 0]取得
import torch
import tensorflow as tf
import json
def convert(tf_dir):
tf_path = tf_dir + 'bert_model.ckpt'
torch_path = tf_dir + 'pytorch_model.bin'
with open(tf_dir + 'bert_config.json', 'r') as f:
config = json.load(f)
num_layers = config['num_hidden_layers']
torch_state_dict = {}
prefix = 'bert'
mapping = {
'bert/embeddings/word_embeddings': f'{prefix}.embeddings.word_embeddings.weight',
'bert/embeddings/position_embeddings': f'{prefix}.embeddings.position_embeddings.weight',
'bert/embeddings/token_type_embeddings': f'{prefix}.embeddings.token_type_embeddings.weight',
'bert/embeddings/LayerNorm/beta': f'{prefix}.embeddings.LayerNorm.bias',
'bert/embeddings/LayerNorm/gamma': f'{prefix}.embeddings.LayerNorm.weight',
'cls/predictions/transform/dense/kernel': 'cls.predictions.transform.dense.weight##',
'cls/predictions/transform/dense/bias': 'cls.predictions.transform.dense.bias',
'cls/predictions/transform/LayerNorm/beta': 'cls.predictions.transform.LayerNorm.bias',
'cls/predictions/transform/LayerNorm/gamma': 'cls.predictions.transform.LayerNorm.weight',
'cls/predictions/output_bias': 'cls.predictions.bias'}
if ('embedding_size' in config) and (config['embedding_size'] != config['hidden_size']):
mapping.update({'bert/encoder/embedding_hidden_mapping_in/kernel': f'{prefix}.encoder.embedding_hidden_mapping_in.weight##',
'bert/encoder/embedding_hidden_mapping_in/bias': f'{prefix}.encoder.embedding_hidden_mapping_in.bias'})
for i in range(num_layers):
prefix_i = f'{prefix}.encoder.layer.%d.' % i
mapping.update({
f'bert/encoder/layer_{i}/attention/self/query/kernel': prefix_i + 'attention.self.query.weight##', # 转置标识
f'bert/encoder/layer_{i}/attention/self/query/bias': prefix_i + 'attention.self.query.bias',
f'bert/encoder/layer_{i}/attention/self/key/kernel': prefix_i + 'attention.self.key.weight##',
f'bert/encoder/layer_{i}/attention/self/key/bias': prefix_i + 'attention.self.key.bias',
f'bert/encoder/layer_{i}/attention/self/value/kernel': prefix_i + 'attention.self.value.weight##',
f'bert/encoder/layer_{i}/attention/self/value/bias': prefix_i + 'attention.self.value.bias',
f'bert/encoder/layer_{i}/attention/output/dense/kernel': prefix_i + 'attention.output.dense.weight##',
f'bert/encoder/layer_{i}/attention/output/dense/bias': prefix_i + 'attention.output.dense.bias',
f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta': prefix_i + 'attention.output.LayerNorm.bias',
f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma': prefix_i + 'attention.output.LayerNorm.weight',
f'bert/encoder/layer_{i}/intermediate/dense/kernel': prefix_i + 'intermediate.dense.weight##',
f'bert/encoder/layer_{i}/intermediate/dense/bias': prefix_i + 'intermediate.dense.bias',
f'bert/encoder/layer_{i}/output/dense/kernel': prefix_i + 'output.dense.weight##',
f'bert/encoder/layer_{i}/output/dense/bias': prefix_i + 'output.dense.bias',
f'bert/encoder/layer_{i}/output/LayerNorm/beta': prefix_i + 'output.LayerNorm.bias',
f'bert/encoder/layer_{i}/output/LayerNorm/gamma': prefix_i + 'output.LayerNorm.weight'
})
for key, value in mapping.items():
ts = tf.train.load_variable(tf_path, key)
if value.endswith('##'):
value = value.replace('##', '')
torch_state_dict[value] = torch.from_numpy(ts).T
else:
torch_state_dict[value] = torch.from_numpy(ts)
torch_state_dict['cls.predictions.decoder.weight'] = torch_state_dict[f'{prefix}.embeddings.word_embeddings.weight']
torch_state_dict['cls.predictions.decoder.bias'] = torch_state_dict['cls.predictions.bias']
torch.save(torch_state_dict, torch_path)
convert('E:/pretrain_ckpt/roberta/sushen@chinese_roberta_L-6_H-384_A-12/') |