susnato commited on
Commit
f3581c2
1 Parent(s): 885227c

Added convertion files and README

Browse files
Files changed (3) hide show
  1. README.md +150 -0
  2. convert.py +125 -0
  3. pytorch_weights_postprocess.py +67 -0
README.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: paddlenlp
3
+ license: apache-2.0
4
+ datasets:
5
+ - xnli
6
+ - mlqa
7
+ - paws-x
8
+ language:
9
+ - fr
10
+ - es
11
+ - en
12
+ - de
13
+ - sw
14
+ - ru
15
+ - zh
16
+ - el
17
+ - bg
18
+ - ar
19
+ - vi
20
+ - th
21
+ - hi
22
+ - ur
23
+ ---
24
+
25
+ ### Disclaimer :- I don't own the weights of `ernie-m-base` neither did I trained the model. I only converted the model weights from paddle to pytorch(using the scripts listed in files).
26
+ The real(paddle) weights can be found [here](https://huggingface.co/PaddlePaddle/ernie-m-base).
27
+
28
+ The rest of the README is copied from the same page listed above,
29
+
30
+ [![paddlenlp-banner](https://user-images.githubusercontent.com/1371212/175816733-8ec25eb0-9af3-4380-9218-27c154518258.png)](https://github.com/PaddlePaddle/PaddleNLP)
31
+
32
+ # PaddlePaddle/ernie-m-base
33
+
34
+ ## Ernie-M
35
+
36
+ ERNIE-M, proposed by Baidu, is a new training method that encourages the model to align the representation of multiple languages with monolingual corpora,
37
+ to overcome the constraint that the parallel corpus size places on the model performance. The insight is to integrate back-translation into the pre-training
38
+ process by generating pseudo-parallel sentence pairs on a monolingual corpus to enable the learning of semantic alignments between different languages,
39
+ thereby enhancing the semantic modeling of cross-lingual models. Experimental results show that ERNIE-M outperforms existing cross-lingual models and
40
+ delivers new state-of-the-art results in various cross-lingual downstream tasks.
41
+
42
+ We proposed two novel methods to align the representation of multiple languages:
43
+
44
+ Cross-Attention Masked Language Modeling(CAMLM): In CAMLM, we learn the multilingual semantic representation by restoring the MASK tokens in the input sentences.
45
+ Back-Translation masked language modeling(BTMLM): We use BTMLM to train our model to generate pseudo-parallel sentences from the monolingual sentences. The generated pairs are then used as the input of the model to further align the cross-lingual semantics, thus enhancing the multilingual representation.
46
+
47
+ ![ernie-m](ernie_m.png)
48
+
49
+ ## Benchmark
50
+
51
+ ### XNLI
52
+
53
+ XNLI is a subset of MNLI and has been translated into 14 different kinds of languages including some low-resource languages. The goal of the task is to predict testual entailment (whether sentence A implies / contradicts / neither sentence B).
54
+
55
+ | Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur | Avg |
56
+ | ---------------------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- |
57
+ | Cross-lingual Transfer | | | | | | | | | | | | | | | | |
58
+ | XLM | 85.0 | 78.7 | 78.9 | 77.8 | 76.6 | 77.4 | 75.3 | 72.5 | 73.1 | 76.1 | 73.2 | 76.5 | 69.6 | 68.4 | 67.3 | 75.1 |
59
+ | Unicoder | 85.1 | 79.0 | 79.4 | 77.8 | 77.2 | 77.2 | 76.3 | 72.8 | 73.5 | 76.4 | 73.6 | 76.2 | 69.4 | 69.7 | 66.7 | 75.4 |
60
+ | XLM-R | 85.8 | 79.7 | 80.7 | 78.7 | 77.5 | 79.6 | 78.1 | 74.2 | 73.8 | 76.5 | 74.6 | 76.7 | 72.4 | 66.5 | 68.3 | 76.2 |
61
+ | INFOXLM | **86.4** | **80.6** | 80.8 | 78.9 | 77.8 | 78.9 | 77.6 | 75.6 | 74.0 | 77.0 | 73.7 | 76.7 | 72.0 | 66.4 | 67.1 | 76.2 |
62
+ | **ERNIE-M** | 85.5 | 80.1 | **81.2** | **79.2** | **79.1** | **80.4** | **78.1** | **76.8** | **76.3** | **78.3** | **75.8** | **77.4** | **72.9** | **69.5** | **68.8** | **77.3** |
63
+ | XLM-R Large | 89.1 | 84.1 | 85.1 | 83.9 | 82.9 | 84.0 | 81.2 | 79.6 | 79.8 | 80.8 | 78.1 | 80.2 | 76.9 | 73.9 | 73.8 | 80.9 |
64
+ | INFOXLM Large | **89.7** | 84.5 | 85.5 | 84.1 | 83.4 | 84.2 | 81.3 | 80.9 | 80.4 | 80.8 | 78.9 | 80.9 | 77.9 | 74.8 | 73.7 | 81.4 |
65
+ | VECO Large | 88.2 | 79.2 | 83.1 | 82.9 | 81.2 | 84.2 | 82.8 | 76.2 | 80.3 | 74.3 | 77.0 | 78.4 | 71.3 | **80.4** | **79.1** | 79.9 |
66
+ | **ERNIR-M Large** | 89.3 | **85.1** | **85.7** | **84.4** | **83.7** | **84.5** | 82.0 | **81.2** | **81.2** | **81.9** | **79.2** | **81.0** | **78.6** | 76.2 | 75.4 | **82.0** |
67
+ | Translate-Train-All | | | | | | | | | | | | | | | | |
68
+ | XLM | 85.0 | 80.8 | 81.3 | 80.3 | 79.1 | 80.9 | 78.3 | 75.6 | 77.6 | 78.5 | 76.0 | 79.5 | 72.9 | 72.8 | 68.5 | 77.8 |
69
+ | Unicoder | 85.6 | 81.1 | 82.3 | 80.9 | 79.5 | 81.4 | 79.7 | 76.8 | 78.2 | 77.9 | 77.1 | 80.5 | 73.4 | 73.8 | 69.6 | 78.5 |
70
+ | XLM-R | 85.4 | 81.4 | 82.2 | 80.3 | 80.4 | 81.3 | 79.7 | 78.6 | 77.3 | 79.7 | 77.9 | 80.2 | 76.1 | 73.1 | 73.0 | 79.1 |
71
+ | INFOXLM | 86.1 | 82.0 | 82.8 | 81.8 | 80.9 | 82.0 | 80.2 | 79.0 | 78.8 | 80.5 | 78.3 | 80.5 | 77.4 | 73.0 | 71.6 | 79.7 |
72
+ | **ERNIE-M** | **86.2** | **82.5** | **83.8** | **82.6** | **82.4** | **83.4** | **80.2** | **80.6** | **80.5** | **81.1** | **79.2** | **80.5** | **77.7** | **75.0** | **73.3** | **80.6** |
73
+ | XLM-R Large | 89.1 | 85.1 | 86.6 | 85.7 | 85.3 | 85.9 | 83.5 | 83.2 | 83.1 | 83.7 | 81.5 | **83.7** | **81.6** | 78.0 | 78.1 | 83.6 |
74
+ | VECO Large | 88.9 | 82.4 | 86.0 | 84.7 | 85.3 | 86.2 | **85.8** | 80.1 | 83.0 | 77.2 | 80.9 | 82.8 | 75.3 | **83.1** | **83.0** | 83.0 |
75
+ | **ERNIE-M Large** | **89.5** | **86.5** | **86.9** | **86.1** | **86.0** | **86.8** | 84.1 | **83.8** | **84.1** | **84.5** | **82.1** | 83.5 | 81.1 | 79.4 | 77.9 | **84.2** |
76
+
77
+ ### Cross-lingual Named Entity Recognition
78
+
79
+ * datasets:CoNLI
80
+
81
+ | Model | en | nl | es | de | Avg |
82
+ | ------------------------------ | --------- | --------- | --------- | --------- | --------- |
83
+ | *Fine-tune on English dataset* | | | | | |
84
+ | mBERT | 91.97 | 77.57 | 74.96 | 69.56 | 78.52 |
85
+ | XLM-R | 92.25 | **78.08** | 76.53 | **69.60** | 79.11 |
86
+ | **ERNIE-M** | **92.78** | 78.01 | **79.37** | 68.08 | **79.56** |
87
+ | XLM-R LARGE | 92.92 | 80.80 | 78.64 | 71.40 | 80.94 |
88
+ | **ERNIE-M LARGE** | **93.28** | **81.45** | **78.83** | **72.99** | **81.64** |
89
+ | *Fine-tune on all dataset* | | | | | |
90
+ | XLM-R | 91.08 | 89.09 | 87.28 | 83.17 | 87.66 |
91
+ | **ERNIE-M** | **93.04** | **91.73** | **88.33** | **84.20** | **89.32** |
92
+ | XLM-R LARGE | 92.00 | 91.60 | **89.52** | 84.60 | 89.43 |
93
+ | **ERNIE-M LARGE** | **94.01** | **93.81** | 89.23 | **86.20** | **90.81** |
94
+
95
+ ### Cross-lingual Question Answering
96
+
97
+ * datasets:MLQA
98
+
99
+ | Model | en | es | de | ar | hi | vi | zh | Avg |
100
+ | ----------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
101
+ | mBERT | 77.7 / 65.2 | 64.3 / 46.6 | 57.9 / 44.3 | 45.7 / 29.8 | 43.8 / 29.7 | 57.1 / 38.6 | 57.5 / 37.3 | 57.7 / 41.6 |
102
+ | XLM | 74.9 / 62.4 | 68.0 / 49.8 | 62.2 / 47.6 | 54.8 / 36.3 | 48.8 / 27.3 | 61.4 / 41.8 | 61.1 / 39.6 | 61.6 / 43.5 |
103
+ | XLM-R | 77.1 / 64.6 | 67.4 / 49.6 | 60.9 / 46.7 | 54.9 / 36.6 | 59.4 / 42.9 | 64.5 / 44.7 | 61.8 / 39.3 | 63.7 / 46.3 |
104
+ | INFOXLM | 81.3 / 68.2 | 69.9 / 51.9 | 64.2 / 49.6 | 60.1 / 40.9 | 65.0 / 47.5 | 70.0 / 48.6 | 64.7 / **41.2** | 67.9 / 49.7 |
105
+ | **ERNIE-M** | **81.6 / 68.5** | **70.9 / 52.6** | **65.8 / 50.7** | **61.8 / 41.9** | **65.4 / 47.5** | **70.0 / 49.2** | **65.6** / 41.0 | **68.7 / 50.2** |
106
+ | XLM-R LARGE | 80.6 / 67.8 | 74.1 / 56.0 | 68.5 / 53.6 | 63.1 / 43.5 | 62.9 / 51.6 | 71.3 / 50.9 | 68.0 / 45.4 | 70.7 / 52.7 |
107
+ | INFOXLM LARGE | **84.5 / 71.6** | **75.1 / 57.3** | **71.2 / 56.2** | **67.6 / 47.6** | 72.5 / 54.2 | **75.2 / 54.1** | 69.2 / 45.4 | 73.6 / 55.2 |
108
+ | **ERNIE-M LARGE** | 84.4 / 71.5 | 74.8 / 56.6 | 70.8 / 55.9 | 67.4 / 47.2 | **72.6 / 54.7** | 75.0 / 53.7 | **71.1 / 47.5** | **73.7 / 55.3** |
109
+
110
+ ### Cross-lingual Paraphrase Identification
111
+
112
+ * datasets:PAWS-X
113
+
114
+ | Model | en | de | es | fr | ja | ko | zh | Avg |
115
+ | ---------------------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- |
116
+ | Cross-lingual Transfer | | | | | | | | |
117
+ | mBERT | 94.0 | 85.7 | 87.4 | 87.0 | 73.0 | 69.6 | 77.0 | 81.9 |
118
+ | XLM | 94.0 | 85.9 | 88.3 | 87.4 | 69.3 | 64.8 | 76.5 | 80.9 |
119
+ | MMTE | 93.1 | 85.1 | 87.2 | 86.9 | 72.0 | 69.2 | 75.9 | 81.3 |
120
+ | XLM-R LARGE | 94.7 | 89.7 | 90.1 | 90.4 | 78.7 | 79.0 | 82.3 | 86.4 |
121
+ | VECO LARGE | **96.2** | 91.3 | 91.4 | 92.0 | 81.8 | 82.9 | 85.1 | 88.7 |
122
+ | **ERNIE-M LARGE** | 96.0 | **91.9** | **91.4** | **92.2** | **83.9** | **84.5** | **86.9** | **89.5** |
123
+ | Translate-Train-All | | | | | | | | |
124
+ | VECO LARGE | 96.4 | 93.0 | 93.0 | 93.5 | 87.2 | 86.8 | 87.9 | 91.1 |
125
+ | **ERNIE-M LARGE** | **96.5** | **93.5** | **93.3** | **93.8** | **87.9** | **88.4** | **89.2** | **91.8** |
126
+
127
+
128
+ ### Cross-lingual Sentence Retrieval
129
+
130
+ * dataset:Tatoeba
131
+
132
+ | Model | Avg |
133
+ | --------------------------------------- | -------- |
134
+ | XLM-R LARGE | 75.2 |
135
+ | VECO LARGE | 86.9 |
136
+ | **ERNIE-M LARGE** | **87.9** |
137
+ | **ERNIE-M LARGE( after fine-tuning)** | **93.3** |
138
+
139
+
140
+ ## Citation Info
141
+
142
+ ```text
143
+ @article{Ouyang2021ERNIEMEM,
144
+ title={ERNIE-M: Enhanced Multilingual Representation by Aligning Cross-lingual Semantics with Monolingual Corpora},
145
+ author={Xuan Ouyang and Shuohuan Wang and Chao Pang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
146
+ journal={ArXiv},
147
+ year={2021},
148
+ volume={abs/2012.15674}
149
+ }
150
+ ```
convert.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/nghuyong/ERNIE-Pytorch/blob/master/convert.py
2
+
3
+
4
+ #!/usr/bin/env python
5
+ # encoding: utf-8
6
+ """
7
+ File Description:
8
+ ernie3.0 series model conversion based on paddlenlp repository
9
+ ernie2.0 series model conversion based on paddlenlp repository
10
+ official repo: https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo
11
+ Author: nghuyong liushu
12
+ Mail: nghuyong@163.com 1554987494@qq.com
13
+ Created Time: 2022/8/17
14
+ """
15
+ import collections
16
+ import os
17
+ import json
18
+ import paddle.fluid.dygraph as D
19
+ import torch
20
+ from paddle import fluid
21
+ import numpy as np
22
+
23
+ def build_params_map(attention_num=12):
24
+ """
25
+ build params map from paddle-paddle's ERNIE to transformer's BERT
26
+ :return:
27
+ """
28
+ weight_map = collections.OrderedDict({
29
+ 'embeddings.word_embeddings.weight': "embeddings.word_embeddings.weight",
30
+ 'embeddings.position_embeddings.weight': "embeddings.position_embeddings.weight",
31
+ # 'ernie.embeddings.token_type_embeddings.weight': "ernie.embeddings.token_type_embeddings.weight",
32
+ # 'ernie.embeddings.task_type_embeddings.weight': "ernie.embeddings.task_type_embeddings.weight",
33
+ 'embeddings.layer_norm.weight': 'embeddings.layer_norm.weight',
34
+ 'embeddings.layer_norm.bias': 'embeddings.layer_norm.bias',
35
+ })
36
+ # add attention layers
37
+ for i in range(attention_num):
38
+ weight_map[f'encoder.layers.{i}.self_attn.q_proj.weight'] = f'encoder.layers.{i}.self_attn.q_proj.weight'
39
+ weight_map[f'encoder.layers.{i}.self_attn.q_proj.bias'] = f'encoder.layers.{i}.self_attn.q_proj.bias'
40
+ weight_map[f'encoder.layers.{i}.self_attn.k_proj.weight'] = f'encoder.layers.{i}.self_attn.k_proj.weight'
41
+ weight_map[f'encoder.layers.{i}.self_attn.k_proj.bias'] = f'encoder.layers.{i}.self_attn.k_proj.bias'
42
+ weight_map[f'encoder.layers.{i}.self_attn.v_proj.weight'] = f'encoder.layers.{i}.self_attn.v_proj.weight'
43
+ weight_map[f'encoder.layers.{i}.self_attn.v_proj.bias'] = f'encoder.layers.{i}.self_attn.v_proj.bias'
44
+ weight_map[f'encoder.layers.{i}.self_attn.out_proj.weight'] = f'encoder.layers.{i}.self_attn.out_proj.weight'
45
+ weight_map[f'encoder.layers.{i}.self_attn.out_proj.bias'] = f'encoder.layers.{i}.self_attn.out_proj.bias'
46
+ weight_map[f'encoder.layers.{i}.norm1.weight'] = f'encoder.layers.{i}.norm1.weight'
47
+ weight_map[f'encoder.layers.{i}.norm1.bias'] = f'encoder.layers.{i}.norm1.bias'
48
+ weight_map[f'encoder.layers.{i}.linear1.weight'] = f'encoder.layers.{i}.linear1.weight'
49
+ weight_map[f'encoder.layers.{i}.linear1.bias'] = f'encoder.layers.{i}.linear1.bias'
50
+ weight_map[f'encoder.layers.{i}.linear2.weight'] = f'encoder.layers.{i}.linear2.weight'
51
+ weight_map[f'encoder.layers.{i}.linear2.bias'] = f'encoder.layers.{i}.linear2.bias'
52
+ weight_map[f'encoder.layers.{i}.norm2.weight'] = f'encoder.layers.{i}.norm2.weight'
53
+ weight_map[f'encoder.layers.{i}.norm2.bias'] = f'encoder.layers.{i}.norm2.bias'
54
+ #
55
+ weight_map.update(
56
+ {
57
+ 'pooler.dense.weight': 'pooler.dense.weight',
58
+ 'pooler.dense.bias': 'pooler.dense.bias',
59
+ # 'cls.predictions.transform.weight': 'cls.predictions.transform.dense.weight',
60
+ # 'cls.predictions.transform.bias': 'cls.predictions.transform.dense.bias',
61
+ # 'cls.predictions.layer_norm.weight': 'cls.predictions.transform.LayerNorm.gamma',
62
+ # 'cls.predictions.layer_norm.bias': 'cls.predictions.transform.LayerNorm.beta',
63
+ # 'cls.predictions.decoder_bias': 'cls.predictions.bias'
64
+ }
65
+ )
66
+ return weight_map
67
+
68
+
69
+ def extract_and_convert(input_dir, output_dir):
70
+ """
71
+ 抽取并转换
72
+ :param input_dir:
73
+ :param output_dir:
74
+ :return:
75
+ """
76
+ if not os.path.exists(output_dir):
77
+ os.makedirs(output_dir)
78
+ print('=' * 20 + 'save config file' + '=' * 20)
79
+ config = json.load(open(os.path.join(input_dir, 'config.json'), 'rt', encoding='utf-8'))
80
+ # if 'init_args' in config:
81
+ # config = config['init_args'][0]
82
+ # del config['init_class']
83
+ config['layer_norm_eps'] = 1e-5
84
+ # config['model_type'] = 'ernie'
85
+ # config['architectures'] = ["ErnieForMaskedLM"] # or 'BertModel'
86
+ # config['intermediate_size'] = 4 * config['hidden_size']
87
+ json.dump(config, open(os.path.join(output_dir, 'config.json'), 'wt', encoding='utf-8'), indent=4)
88
+ print('=' * 20 + 'save vocab file' + '=' * 20)
89
+ with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
90
+ words = f.read().splitlines()
91
+ words = [word.split('\t')[0] for word in words]
92
+ with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
93
+ for word in words:
94
+ f.write(word + "\n")
95
+ print('=' * 20 + 'extract weights' + '=' * 20)
96
+ state_dict = collections.OrderedDict()
97
+ weight_map = build_params_map(attention_num=config['num_hidden_layers'])
98
+ with fluid.dygraph.guard():
99
+ paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
100
+ for weight_name, weight_value in paddle_paddle_params.items():
101
+ if 'weight' in weight_name:
102
+ # if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name:
103
+ # weight_value = weight_value.transpose()
104
+
105
+ # if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name and \
106
+ # "k_proj" not in weight_name and "v_proj" not in weight_name and \
107
+ # "out_proj" not in weight_name and "linear1" not in weight_name and \
108
+ # "linear2" not in weight_name:
109
+ # weight_value = weight_value.transpose()
110
+ if "encoder" in weight_name:
111
+ if "linear1" in weight_name or "linear2" in weight_name:
112
+ weight_value = weight_value.transpose()
113
+ else:
114
+ weight_value = weight_value.transpose()
115
+
116
+ if weight_name not in weight_map:
117
+ print('=' * 20, '[SKIP]', weight_name, '=' * 20)
118
+ continue
119
+ state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
120
+ print(weight_name, '->', weight_map[weight_name], weight_value.shape)
121
+ torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
122
+
123
+
124
+ if __name__ == '__main__':
125
+ extract_and_convert("./ernie_m_large_paddle/", "./ernie_m_large_torch/")
pytorch_weights_postprocess.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code takes the pytorch weights generated using paddle2torch_weights script and then stacks
2
+ # Queries, Keys and Values for Attention(self_attn) Layer in Encoder Layers(to make it more like torch.nn.MultiheadAttention).
3
+
4
+ import torch
5
+ full_state_dict = torch.load("./pytorch_model.bin")
6
+ full_state_dict = dict((".".join(k.split(".")[1:]), v) \
7
+ for k, v in full_state_dict.items())
8
+
9
+ def con_cat(kqv_dict):
10
+ kqv_dict_keys = list(kqv_dict.keys())
11
+ if "weight" in kqv_dict_keys[0]:
12
+ tmp = kqv_dict_keys[0].split(".")[3]
13
+ c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")],
14
+ kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")],
15
+ kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")]
16
+ ])
17
+ c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_weight"])
18
+ # return {c_dict_key:c_dict_value}
19
+ return {f"encoder.{c_dict_key}":c_dict_value}
20
+
21
+ #(k,q,v), (k,v,q), (q, k, v), (q, v, k), (v, k, q), (v, q, k)
22
+ if "bias" in kqv_dict_keys[0]:
23
+ tmp = kqv_dict_keys[0].split(".")[3]
24
+ c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")],
25
+ kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")],
26
+ kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")]
27
+ ])
28
+ c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_bias"])
29
+ # return {c_dict_key:c_dict_value}
30
+ return {f"encoder.{c_dict_key}":c_dict_value}
31
+
32
+
33
+ mod_dict = {}
34
+ #Embedding weights
35
+ for k, v in full_state_dict.items():
36
+ if "embedding" in k or "layer_norm" in k:
37
+ mod_dict.update({f"embeddings.{k}": v})
38
+
39
+ #Encoder weights
40
+ for i in range(12):
41
+ sd = dict((k, v) for k, v in full_state_dict.items() if f"layers.{i}" in k)
42
+ kvq_weight = {}
43
+ kvq_bias = {}
44
+ for k, v in sd.items():
45
+ if "self_attn" in k and "out_proj" not in k:
46
+ if "weight" in k:
47
+ kvq_weight[k] = v
48
+ if "bias" in k:
49
+ kvq_bias[k] = v
50
+ else:
51
+ mod_dict[f"encoder.{k}"] = v
52
+
53
+ mod_dict.update(con_cat(kvq_weight))
54
+ mod_dict.update(con_cat(kvq_bias))
55
+
56
+ #Pooler
57
+ for k, v in full_state_dict.items():
58
+ if "pooler" in k:
59
+ mod_dict.update({k:v})
60
+
61
+
62
+ for k, v in mod_dict.items():
63
+ print(k, v.size())
64
+
65
+ model_name = "ernie-m-base_pytorch"
66
+ PATH = f"./{model_name}/pytorch_model.bin"
67
+ torch.save(mod_dict, PATH)