Tongjilibo commited on
Commit
f7f0a9c
1 Parent(s): 5855858

修改readme

Browse files
Files changed (2) hide show
  1. README.md +7 -2
  2. convert_simbert.py +80 -0
README.md CHANGED
@@ -2,6 +2,11 @@
2
  license: apache-2.0
3
  ---
4
 
5
-
6
  - config.json用于transformers
7
- - bert4torch_config.json用于bert4torch
 
 
 
 
 
 
2
  license: apache-2.0
3
  ---
4
 
5
+ ## 说明
6
  - config.json用于transformers
7
+ - bert4torch_config.json用于bert4torch
8
+
9
+ ## 权重转换
10
+ - 此项目是从tf权重转换而来,可直接使用该权重,或下载下述原始tf权重并使用convert.py进行转换
11
+ - 源项目:https://github.com/ZhuiyiTechnology/simbert
12
+ - 转换脚本: `convert.py`
convert_simbert.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # simbert预训练模型tensorflow转pytorch
2
+ # 源项目:https://github.com/ZhuiyiTechnology/simbert
3
+
4
+ import torch
5
+ import tensorflow as tf
6
+ import json
7
+
8
+ # base
9
+ tf_dir = 'E:/pretrain_ckpt/simbert/sushen@chinese_simbert_L-12_H-768_A-12/'
10
+ tf_path = tf_dir + 'bert_model.ckpt'
11
+ torch_path = 'E:/pretrain_ckpt/simbert/sushen@simbert_chinese_base/pytorch_model.bin'
12
+
13
+ # small
14
+ tf_dir = 'E:/pretrain_ckpt/simbert/sushen@chinese_simbert_L-6_H-384_A-12/'
15
+ tf_path = tf_dir + 'bert_model.ckpt'
16
+ torch_path = 'E:/pretrain_ckpt/simbert/sushen@simbert_chinese_small/pytorch_model.bin'
17
+
18
+ # tiny
19
+ tf_dir = 'E:/pretrain_ckpt/simbert/sushen@chinese_simbert_L-4_H-312_A-12/'
20
+ tf_path = tf_dir + 'bert_model.ckpt'
21
+ torch_path = 'E:/pretrain_ckpt/simbert/sushen@simbert_chinese_tiny/pytorch_model.bin'
22
+
23
+ with open(tf_dir + 'bert_config.json', 'r') as f:
24
+ config = json.load(f)
25
+ num_layers = config['num_hidden_layers']
26
+
27
+ torch_state_dict = {}
28
+
29
+ prefix = 'bert'
30
+ mapping = {
31
+ 'bert/embeddings/word_embeddings': f'{prefix}.embeddings.word_embeddings.weight',
32
+ 'bert/embeddings/position_embeddings': f'{prefix}.embeddings.position_embeddings.weight',
33
+ 'bert/embeddings/token_type_embeddings': f'{prefix}.embeddings.token_type_embeddings.weight',
34
+ 'bert/embeddings/LayerNorm/beta': f'{prefix}.embeddings.LayerNorm.bias',
35
+ 'bert/embeddings/LayerNorm/gamma': f'{prefix}.embeddings.LayerNorm.weight',
36
+ 'cls/predictions/transform/dense/kernel': 'cls.predictions.transform.dense.weight##',
37
+ 'cls/predictions/transform/dense/bias': 'cls.predictions.transform.dense.bias',
38
+ 'cls/predictions/transform/LayerNorm/beta': 'cls.predictions.transform.LayerNorm.bias',
39
+ 'cls/predictions/transform/LayerNorm/gamma': 'cls.predictions.transform.LayerNorm.weight',
40
+ 'cls/predictions/output_bias': 'cls.predictions.bias',
41
+ 'bert/pooler/dense/kernel': f'{prefix}.pooler.dense.weight##',
42
+ 'bert/pooler/dense/bias': f'{prefix}.pooler.dense.bias'}
43
+
44
+ if ('embedding_size' in config) and (config['embedding_size'] != config['hidden_size']):
45
+ mapping.update({'bert/encoder/embedding_hidden_mapping_in/kernel': f'{prefix}.encoder.embedding_hidden_mapping_in.weight##',
46
+ 'bert/encoder/embedding_hidden_mapping_in/bias': f'{prefix}.encoder.embedding_hidden_mapping_in.bias'})
47
+
48
+ for i in range(num_layers):
49
+ prefix_i = f'{prefix}.encoder.layer.%d.' % i
50
+ mapping.update({
51
+ f'bert/encoder/layer_{i}/attention/self/query/kernel': prefix_i + 'attention.self.query.weight##', # 转置标识
52
+ f'bert/encoder/layer_{i}/attention/self/query/bias': prefix_i + 'attention.self.query.bias',
53
+ f'bert/encoder/layer_{i}/attention/self/key/kernel': prefix_i + 'attention.self.key.weight##',
54
+ f'bert/encoder/layer_{i}/attention/self/key/bias': prefix_i + 'attention.self.key.bias',
55
+ f'bert/encoder/layer_{i}/attention/self/value/kernel': prefix_i + 'attention.self.value.weight##',
56
+ f'bert/encoder/layer_{i}/attention/self/value/bias': prefix_i + 'attention.self.value.bias',
57
+ f'bert/encoder/layer_{i}/attention/output/dense/kernel': prefix_i + 'attention.output.dense.weight##',
58
+ f'bert/encoder/layer_{i}/attention/output/dense/bias': prefix_i + 'attention.output.dense.bias',
59
+ f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta': prefix_i + 'attention.output.LayerNorm.bias',
60
+ f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma': prefix_i + 'attention.output.LayerNorm.weight',
61
+ f'bert/encoder/layer_{i}/intermediate/dense/kernel': prefix_i + 'intermediate.dense.weight##',
62
+ f'bert/encoder/layer_{i}/intermediate/dense/bias': prefix_i + 'intermediate.dense.bias',
63
+ f'bert/encoder/layer_{i}/output/dense/kernel': prefix_i + 'output.dense.weight##',
64
+ f'bert/encoder/layer_{i}/output/dense/bias': prefix_i + 'output.dense.bias',
65
+ f'bert/encoder/layer_{i}/output/LayerNorm/beta': prefix_i + 'output.LayerNorm.bias',
66
+ f'bert/encoder/layer_{i}/output/LayerNorm/gamma': prefix_i + 'output.LayerNorm.weight'
67
+ })
68
+
69
+
70
+ for key, value in mapping.items():
71
+ ts = tf.train.load_variable(tf_path, key)
72
+ if value.endswith('##'):
73
+ value = value.replace('##', '')
74
+ torch_state_dict[value] = torch.from_numpy(ts).T
75
+ else:
76
+ torch_state_dict[value] = torch.from_numpy(ts)
77
+ torch_state_dict['cls.predictions.decoder.weight'] = torch_state_dict[f'{prefix}.embeddings.word_embeddings.weight']
78
+ torch_state_dict['cls.predictions.decoder.bias'] = torch_state_dict['cls.predictions.bias']
79
+
80
+ torch.save(torch_state_dict, torch_path)