# t5_pegasus从tf转为bert4torch适配的pytorch版本 # 权重链接:https://github.com/ZhuiyiTechnology/t5-pegasus import torch import tensorflow as tf import json choice = 'small' if choice == 'small': ckpt_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_small_torch/' tf_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_small_tf/' torch_path = ckpt_dir + 'pytorch_model.bin' elif choice == 'base': ckpt_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_base_torch/' tf_dir = 'E:/pretrain_ckpt/t5/sushen@chinese_t5_pegasus_base_tf/' torch_path = ckpt_dir + 'pytorch_model.bin' else: raise ValueError(f'{choice} not in pre maintained choices') tf_path = tf_dir + 'model.ckpt' with open(tf_dir + 'config.json', 'r', encoding='utf-8') as f: config = json.load(f) num_layers = config['num_hidden_layers'] torch_state_dict = {} mapping = { 'shared/embedding': 'shared.weight', 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight##T', # 自定义标记,##T结尾表示要转置 'encoder/rms_norm/scale': 'encoder.final_layer_norm.weight', 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight##T', 'decoder/rms_norm/scale': 'decoder.final_layer_norm.weight', 'decoder/logits/kernel': 'lm_head.weight##T' } for i in range(num_layers): i1 = str(i).rjust(3, '0') mapping.update({ f'encoder/block_{i1}/layer_000/SelfAttention/q': f'encoder.block.{i}.layer.0.SelfAttention.q.weight##T', f'encoder/block_{i1}/layer_000/SelfAttention/k': f'encoder.block.{i}.layer.0.SelfAttention.k.weight##T', f'encoder/block_{i1}/layer_000/SelfAttention/v': f'encoder.block.{i}.layer.0.SelfAttention.v.weight##T', f'encoder/block_{i1}/layer_000/SelfAttention/o': f'encoder.block.{i}.layer.0.SelfAttention.o.weight##T', f'encoder/block_{i1}/layer_000/rms_norm/scale': f'encoder.block.{i}.layer.0.layer_norm.weight', f'encoder/block_{i1}/layer_001/DenseReluDense/wi_0/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight##T', f'encoder/block_{i1}/layer_001/DenseReluDense/wi_1/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight##T', f'encoder/block_{i1}/layer_001/DenseReluDense/wo/kernel': f'encoder.block.{i}.layer.1.DenseReluDense.wo.weight##T', f'encoder/block_{i1}/layer_001/rms_norm/scale': f'encoder.block.{i}.layer.1.layer_norm.weight', f'decoder/block_{i1}/layer_000/SelfAttention/q': f'decoder.block.{i}.layer.0.SelfAttention.q.weight##T', f'decoder/block_{i1}/layer_000/SelfAttention/k': f'decoder.block.{i}.layer.0.SelfAttention.k.weight##T', f'decoder/block_{i1}/layer_000/SelfAttention/v': f'decoder.block.{i}.layer.0.SelfAttention.v.weight##T', f'decoder/block_{i1}/layer_000/SelfAttention/o': f'decoder.block.{i}.layer.0.SelfAttention.o.weight##T', f'decoder/block_{i1}/layer_000/rms_norm/scale': f'decoder.block.{i}.layer.0.layer_norm.weight', f'decoder/block_{i1}/layer_001/EncDecAttention/q': f'decoder.block.{i}.layer.1.EncDecAttention.q.weight##T', f'decoder/block_{i1}/layer_001/EncDecAttention/k': f'decoder.block.{i}.layer.1.EncDecAttention.k.weight##T', f'decoder/block_{i1}/layer_001/EncDecAttention/v': f'decoder.block.{i}.layer.1.EncDecAttention.v.weight##T', f'decoder/block_{i1}/layer_001/EncDecAttention/o': f'decoder.block.{i}.layer.1.EncDecAttention.o.weight##T', f'decoder/block_{i1}/layer_001/rms_norm/scale': f'decoder.block.{i}.layer.1.layer_norm.weight', f'decoder/block_{i1}/layer_002/DenseReluDense/wi_0/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight##T', f'decoder/block_{i1}/layer_002/DenseReluDense/wi_1/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight##T', f'decoder/block_{i1}/layer_002/DenseReluDense/wo/kernel': f'decoder.block.{i}.layer.2.DenseReluDense.wo.weight##T', f'decoder/block_{i1}/layer_002/rms_norm/scale': f'decoder.block.{i}.layer.2.layer_norm.weight', }) transpose_layers = [''] for k, v in mapping.items(): ts = torch.from_numpy(tf.train.load_variable(tf_path, k)) # if len(ts.shape)==2 and ts.shape[0] == ts.shape[1]: # print(k, v) if v.endswith('##T'): torch_state_dict[v.rstrip('##T')] = ts.T else: torch_state_dict[v] = ts torch.save(torch_state_dict, torch_path) if choice == 'base': config = \ { "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 2048, "num_attention_heads": 12, "attention_head_size": 64, "num_hidden_layers": 12, "vocab_size": 50000, "relative_attention_num_buckets": 32, "attention_scale": False, "is_dropout": True } elif choice == 'small': config = \ { "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 512, "initializer_range": 0.02, "intermediate_size": 1024, "num_attention_heads": 6, "attention_head_size": 64, "num_hidden_layers": 8, "vocab_size": 50000, "relative_attention_num_buckets": 32, "attention_scale": False, "is_dropout": True } with open(ckpt_dir+'/bert4torch_config.json', 'w') as f: f.write(json.dumps(config, indent=4))