|
|
|
|
|
import torch |
|
import tensorflow as tf |
|
import json |
|
|
|
choice = 'small' |
|
|
|
if choice == 'small': |
|
ckpt_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_small_torch/' |
|
tf_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_small_tf/' |
|
torch_path = ckpt_dir + 'pytorch_model.bin' |
|
elif choice == 'base': |
|
ckpt_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_base_torch/' |
|
tf_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_base_tf/' |
|
torch_path = ckpt_dir + 'pytorch_model.bin' |
|
else: |
|
raise ValueError(f'{choice} not in pre maintained choices') |
|
|
|
|
|
tf_path = tf_dir + 'model.ckpt' |
|
with open(tf_dir + 'config.json', 'r', encoding='utf-8') as f: |
|
config = json.load(f) |
|
num_layers = config['num_hidden_layers'] |
|
torch_state_dict = {} |
|
|
|
mapping = { |
|
'shared/embedding': 'shared.weight', |
|
'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight##T', |
|
'encoder/rms_norm/scale': 'encoder.final_layer_norm.weight', |
|
'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight##T', |
|
'decoder/rms_norm/scale': 'decoder.final_layer_norm.weight', |
|
'decoder/logits/kernel': 'lm_head.weight##T' |
|
} |
|
|
|
|
|
for i in range(num_layers): |
|
i1 = str(i).rjust(3, '0') |
|
mapping.update({ |
|
f'encoder/block_{i1}/layer_000/SelfAttention/q': f'encoder.block.{i}.layer.0.SelfAttention.q.weight##T', |
|
f'encoder/block_{i1}/layer_000/SelfAttention/k': f'encoder.block.{i}.layer.0.SelfAttention.k.weight##T', |
|
f'encoder/block_{i1}/layer_000/SelfAttention/v': f'encoder.block.{i}.layer.0.SelfAttention.v.weight##T', |
|
f'encoder/block_{i1}/layer_000/SelfAttention/o': f'encoder.block.{i}.layer.0.SelfAttention.o.weight##T', |
|
f'encoder/block_{i1}/layer_000/rms_norm/scale': f'encoder.block.{i}.layer.0.layer_norm.weight', |
|
f'encoder/block_{i1}/layer_001/DenseReluDense/wi_0/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight##T', |
|
f'encoder/block_{i1}/layer_001/DenseReluDense/wi_1/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight##T', |
|
f'encoder/block_{i1}/layer_001/DenseReluDense/wo/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wo.weight##T', |
|
f'encoder/block_{i1}/layer_001/rms_norm/scale': f'encoder.block.{i}.layer.1.layer_norm.weight', |
|
f'decoder/block_{i1}/layer_000/SelfAttention/q': f'decoder.block.{i}.layer.0.SelfAttention.q.weight##T', |
|
f'decoder/block_{i1}/layer_000/SelfAttention/k': f'decoder.block.{i}.layer.0.SelfAttention.k.weight##T', |
|
f'decoder/block_{i1}/layer_000/SelfAttention/v': f'decoder.block.{i}.layer.0.SelfAttention.v.weight##T', |
|
f'decoder/block_{i1}/layer_000/SelfAttention/o': f'decoder.block.{i}.layer.0.SelfAttention.o.weight##T', |
|
f'decoder/block_{i1}/layer_000/rms_norm/scale': f'decoder.block.{i}.layer.0.layer_norm.weight', |
|
f'decoder/block_{i1}/layer_001/EncDecAttention/q': f'decoder.block.{i}.layer.1.EncDecAttention.q.weight##T', |
|
f'decoder/block_{i1}/layer_001/EncDecAttention/k': f'decoder.block.{i}.layer.1.EncDecAttention.k.weight##T', |
|
f'decoder/block_{i1}/layer_001/EncDecAttention/v': f'decoder.block.{i}.layer.1.EncDecAttention.v.weight##T', |
|
f'decoder/block_{i1}/layer_001/EncDecAttention/o': f'decoder.block.{i}.layer.1.EncDecAttention.o.weight##T', |
|
f'decoder/block_{i1}/layer_001/rms_norm/scale': f'decoder.block.{i}.layer.1.layer_norm.weight', |
|
f'decoder/block_{i1}/layer_002/DenseReluDense/wi_0/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight##T', |
|
f'decoder/block_{i1}/layer_002/DenseReluDense/wi_1/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight##T', |
|
f'decoder/block_{i1}/layer_002/DenseReluDense/wo/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wo.weight##T', |
|
f'decoder/block_{i1}/layer_002/rms_norm/scale': f'decoder.block.{i}.layer.2.layer_norm.weight', |
|
}) |
|
|
|
transpose_layers = [''] |
|
for k, v in mapping.items(): |
|
ts = torch.from_numpy(tf.train.load_variable(tf_path, k)) |
|
|
|
|
|
|
|
if v.endswith('##T'): |
|
torch_state_dict[v.rstrip('##T')] = ts.T |
|
else: |
|
torch_state_dict[v] = ts |
|
|
|
torch.save(torch_state_dict, torch_path) |
|
|
|
if choice == 'base': |
|
config = \ |
|
{ |
|
"hidden_act": "gelu", |
|
"hidden_dropout_prob": 0.1, |
|
"hidden_size": 768, |
|
"initializer_range": 0.02, |
|
"intermediate_size": 2048, |
|
"num_attention_heads": 12, |
|
"attention_head_size": 64, |
|
"num_hidden_layers": 12, |
|
"vocab_size": 50000, |
|
"relative_attention_num_buckets": 32, |
|
"attention_scale": False, |
|
"is_dropout": True |
|
} |
|
|
|
elif choice == 'small': |
|
config = \ |
|
{ |
|
"hidden_act": "gelu", |
|
"hidden_dropout_prob": 0.1, |
|
"hidden_size": 512, |
|
"initializer_range": 0.02, |
|
"intermediate_size": 1024, |
|
"num_attention_heads": 6, |
|
"attention_head_size": 64, |
|
"num_hidden_layers": 8, |
|
"vocab_size": 50000, |
|
"relative_attention_num_buckets": 32, |
|
"attention_scale": False, |
|
"is_dropout": True |
|
} |
|
|
|
with open(ckpt_dir+'/bert4torch_config.json', 'w') as f: |
|
f.write(json.dumps(config, indent=4)) |
|
|