File size: 5,441 Bytes
3e5558b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# t5_pegasus从tf转为bert4torch适配的pytorch版本
# 权重链接:https://github.com/ZhuiyiTechnology/t5-pegasus
import torch
import tensorflow as tf
import json

choice = 'small'

if choice == 'small':
    ckpt_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_small_torch/'
    tf_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_small_tf/'
    torch_path = ckpt_dir + 'pytorch_model.bin'
elif choice == 'base':
    ckpt_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_base_torch/'
    tf_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_base_tf/'
    torch_path = ckpt_dir + 'pytorch_model.bin'
else:
    raise ValueError(f'{choice} not in pre maintained choices')
    

tf_path = tf_dir + 'model.ckpt'
with open(tf_dir + 'config.json', 'r', encoding='utf-8') as f:
    config = json.load(f)
num_layers = config['num_hidden_layers']
torch_state_dict = {}

mapping = {
'shared/embedding': 'shared.weight',
'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight##T', # 自定义标记,##T结尾表示要转置
'encoder/rms_norm/scale': 'encoder.final_layer_norm.weight',
'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight##T',
'decoder/rms_norm/scale': 'decoder.final_layer_norm.weight',
'decoder/logits/kernel': 'lm_head.weight##T'
}


for i in range(num_layers):
    i1 = str(i).rjust(3, '0')
    mapping.update({
        f'encoder/block_{i1}/layer_000/SelfAttention/q': f'encoder.block.{i}.layer.0.SelfAttention.q.weight##T',
        f'encoder/block_{i1}/layer_000/SelfAttention/k': f'encoder.block.{i}.layer.0.SelfAttention.k.weight##T',
        f'encoder/block_{i1}/layer_000/SelfAttention/v': f'encoder.block.{i}.layer.0.SelfAttention.v.weight##T',
        f'encoder/block_{i1}/layer_000/SelfAttention/o': f'encoder.block.{i}.layer.0.SelfAttention.o.weight##T',
        f'encoder/block_{i1}/layer_000/rms_norm/scale': f'encoder.block.{i}.layer.0.layer_norm.weight',
        f'encoder/block_{i1}/layer_001/DenseReluDense/wi_0/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight##T',
        f'encoder/block_{i1}/layer_001/DenseReluDense/wi_1/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight##T',
        f'encoder/block_{i1}/layer_001/DenseReluDense/wo/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wo.weight##T',
        f'encoder/block_{i1}/layer_001/rms_norm/scale': f'encoder.block.{i}.layer.1.layer_norm.weight',
        f'decoder/block_{i1}/layer_000/SelfAttention/q': f'decoder.block.{i}.layer.0.SelfAttention.q.weight##T',
        f'decoder/block_{i1}/layer_000/SelfAttention/k': f'decoder.block.{i}.layer.0.SelfAttention.k.weight##T',
        f'decoder/block_{i1}/layer_000/SelfAttention/v': f'decoder.block.{i}.layer.0.SelfAttention.v.weight##T',
        f'decoder/block_{i1}/layer_000/SelfAttention/o': f'decoder.block.{i}.layer.0.SelfAttention.o.weight##T',
        f'decoder/block_{i1}/layer_000/rms_norm/scale': f'decoder.block.{i}.layer.0.layer_norm.weight',
        f'decoder/block_{i1}/layer_001/EncDecAttention/q': f'decoder.block.{i}.layer.1.EncDecAttention.q.weight##T',
        f'decoder/block_{i1}/layer_001/EncDecAttention/k': f'decoder.block.{i}.layer.1.EncDecAttention.k.weight##T',
        f'decoder/block_{i1}/layer_001/EncDecAttention/v': f'decoder.block.{i}.layer.1.EncDecAttention.v.weight##T',
        f'decoder/block_{i1}/layer_001/EncDecAttention/o': f'decoder.block.{i}.layer.1.EncDecAttention.o.weight##T',
        f'decoder/block_{i1}/layer_001/rms_norm/scale': f'decoder.block.{i}.layer.1.layer_norm.weight',
        f'decoder/block_{i1}/layer_002/DenseReluDense/wi_0/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight##T',
        f'decoder/block_{i1}/layer_002/DenseReluDense/wi_1/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight##T',
        f'decoder/block_{i1}/layer_002/DenseReluDense/wo/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wo.weight##T',
        f'decoder/block_{i1}/layer_002/rms_norm/scale': f'decoder.block.{i}.layer.2.layer_norm.weight',
    })

transpose_layers = ['']
for k, v in mapping.items():
    ts = torch.from_numpy(tf.train.load_variable(tf_path, k))
    # if len(ts.shape)==2 and ts.shape[0] == ts.shape[1]:
    #     print(k, v)

    if v.endswith('##T'):
        torch_state_dict[v.rstrip('##T')] = ts.T
    else:
        torch_state_dict[v] = ts

torch.save(torch_state_dict, torch_path)

if choice == 'base':
    config = \
    {
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 768,
    "initializer_range": 0.02,
    "intermediate_size": 2048,
    "num_attention_heads": 12,
    "attention_head_size": 64,
    "num_hidden_layers": 12,
    "vocab_size": 50000,
    "relative_attention_num_buckets": 32,
    "attention_scale":  False,
    "is_dropout": True
    }

elif choice == 'small':
    config = \
    {
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 512,
    "initializer_range": 0.02,
    "intermediate_size": 1024,
    "num_attention_heads": 6,
    "attention_head_size": 64,
    "num_hidden_layers": 8,
    "vocab_size": 50000,
    "relative_attention_num_buckets": 32,
    "attention_scale":  False,
    "is_dropout": True
    }

with open(ckpt_dir+'/bert4torch_config.json', 'w') as f:
    f.write(json.dumps(config, indent=4))