Added convertion files and README
Browse files- README.md +150 -0
- convert.py +125 -0
- pytorch_weights_postprocess.py +67 -0
README.md
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: paddlenlp
|
3 |
+
license: apache-2.0
|
4 |
+
datasets:
|
5 |
+
- xnli
|
6 |
+
- mlqa
|
7 |
+
- paws-x
|
8 |
+
language:
|
9 |
+
- fr
|
10 |
+
- es
|
11 |
+
- en
|
12 |
+
- de
|
13 |
+
- sw
|
14 |
+
- ru
|
15 |
+
- zh
|
16 |
+
- el
|
17 |
+
- bg
|
18 |
+
- ar
|
19 |
+
- vi
|
20 |
+
- th
|
21 |
+
- hi
|
22 |
+
- ur
|
23 |
+
---
|
24 |
+
|
25 |
+
### Disclaimer :- I don't own the weights of `ernie-m-base` neither did I trained the model. I only converted the model weights from paddle to pytorch(using the scripts listed in files).
|
26 |
+
The real(paddle) weights can be found [here](https://huggingface.co/PaddlePaddle/ernie-m-base).
|
27 |
+
|
28 |
+
The rest of the README is copied from the same page listed above,
|
29 |
+
|
30 |
+
[![paddlenlp-banner](https://user-images.githubusercontent.com/1371212/175816733-8ec25eb0-9af3-4380-9218-27c154518258.png)](https://github.com/PaddlePaddle/PaddleNLP)
|
31 |
+
|
32 |
+
# PaddlePaddle/ernie-m-base
|
33 |
+
|
34 |
+
## Ernie-M
|
35 |
+
|
36 |
+
ERNIE-M, proposed by Baidu, is a new training method that encourages the model to align the representation of multiple languages with monolingual corpora,
|
37 |
+
to overcome the constraint that the parallel corpus size places on the model performance. The insight is to integrate back-translation into the pre-training
|
38 |
+
process by generating pseudo-parallel sentence pairs on a monolingual corpus to enable the learning of semantic alignments between different languages,
|
39 |
+
thereby enhancing the semantic modeling of cross-lingual models. Experimental results show that ERNIE-M outperforms existing cross-lingual models and
|
40 |
+
delivers new state-of-the-art results in various cross-lingual downstream tasks.
|
41 |
+
|
42 |
+
We proposed two novel methods to align the representation of multiple languages:
|
43 |
+
|
44 |
+
Cross-Attention Masked Language Modeling(CAMLM): In CAMLM, we learn the multilingual semantic representation by restoring the MASK tokens in the input sentences.
|
45 |
+
Back-Translation masked language modeling(BTMLM): We use BTMLM to train our model to generate pseudo-parallel sentences from the monolingual sentences. The generated pairs are then used as the input of the model to further align the cross-lingual semantics, thus enhancing the multilingual representation.
|
46 |
+
|
47 |
+
![ernie-m](ernie_m.png)
|
48 |
+
|
49 |
+
## Benchmark
|
50 |
+
|
51 |
+
### XNLI
|
52 |
+
|
53 |
+
XNLI is a subset of MNLI and has been translated into 14 different kinds of languages including some low-resource languages. The goal of the task is to predict testual entailment (whether sentence A implies / contradicts / neither sentence B).
|
54 |
+
|
55 |
+
| Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur | Avg |
|
56 |
+
| ---------------------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- |
|
57 |
+
| Cross-lingual Transfer | | | | | | | | | | | | | | | | |
|
58 |
+
| XLM | 85.0 | 78.7 | 78.9 | 77.8 | 76.6 | 77.4 | 75.3 | 72.5 | 73.1 | 76.1 | 73.2 | 76.5 | 69.6 | 68.4 | 67.3 | 75.1 |
|
59 |
+
| Unicoder | 85.1 | 79.0 | 79.4 | 77.8 | 77.2 | 77.2 | 76.3 | 72.8 | 73.5 | 76.4 | 73.6 | 76.2 | 69.4 | 69.7 | 66.7 | 75.4 |
|
60 |
+
| XLM-R | 85.8 | 79.7 | 80.7 | 78.7 | 77.5 | 79.6 | 78.1 | 74.2 | 73.8 | 76.5 | 74.6 | 76.7 | 72.4 | 66.5 | 68.3 | 76.2 |
|
61 |
+
| INFOXLM | **86.4** | **80.6** | 80.8 | 78.9 | 77.8 | 78.9 | 77.6 | 75.6 | 74.0 | 77.0 | 73.7 | 76.7 | 72.0 | 66.4 | 67.1 | 76.2 |
|
62 |
+
| **ERNIE-M** | 85.5 | 80.1 | **81.2** | **79.2** | **79.1** | **80.4** | **78.1** | **76.8** | **76.3** | **78.3** | **75.8** | **77.4** | **72.9** | **69.5** | **68.8** | **77.3** |
|
63 |
+
| XLM-R Large | 89.1 | 84.1 | 85.1 | 83.9 | 82.9 | 84.0 | 81.2 | 79.6 | 79.8 | 80.8 | 78.1 | 80.2 | 76.9 | 73.9 | 73.8 | 80.9 |
|
64 |
+
| INFOXLM Large | **89.7** | 84.5 | 85.5 | 84.1 | 83.4 | 84.2 | 81.3 | 80.9 | 80.4 | 80.8 | 78.9 | 80.9 | 77.9 | 74.8 | 73.7 | 81.4 |
|
65 |
+
| VECO Large | 88.2 | 79.2 | 83.1 | 82.9 | 81.2 | 84.2 | 82.8 | 76.2 | 80.3 | 74.3 | 77.0 | 78.4 | 71.3 | **80.4** | **79.1** | 79.9 |
|
66 |
+
| **ERNIR-M Large** | 89.3 | **85.1** | **85.7** | **84.4** | **83.7** | **84.5** | 82.0 | **81.2** | **81.2** | **81.9** | **79.2** | **81.0** | **78.6** | 76.2 | 75.4 | **82.0** |
|
67 |
+
| Translate-Train-All | | | | | | | | | | | | | | | | |
|
68 |
+
| XLM | 85.0 | 80.8 | 81.3 | 80.3 | 79.1 | 80.9 | 78.3 | 75.6 | 77.6 | 78.5 | 76.0 | 79.5 | 72.9 | 72.8 | 68.5 | 77.8 |
|
69 |
+
| Unicoder | 85.6 | 81.1 | 82.3 | 80.9 | 79.5 | 81.4 | 79.7 | 76.8 | 78.2 | 77.9 | 77.1 | 80.5 | 73.4 | 73.8 | 69.6 | 78.5 |
|
70 |
+
| XLM-R | 85.4 | 81.4 | 82.2 | 80.3 | 80.4 | 81.3 | 79.7 | 78.6 | 77.3 | 79.7 | 77.9 | 80.2 | 76.1 | 73.1 | 73.0 | 79.1 |
|
71 |
+
| INFOXLM | 86.1 | 82.0 | 82.8 | 81.8 | 80.9 | 82.0 | 80.2 | 79.0 | 78.8 | 80.5 | 78.3 | 80.5 | 77.4 | 73.0 | 71.6 | 79.7 |
|
72 |
+
| **ERNIE-M** | **86.2** | **82.5** | **83.8** | **82.6** | **82.4** | **83.4** | **80.2** | **80.6** | **80.5** | **81.1** | **79.2** | **80.5** | **77.7** | **75.0** | **73.3** | **80.6** |
|
73 |
+
| XLM-R Large | 89.1 | 85.1 | 86.6 | 85.7 | 85.3 | 85.9 | 83.5 | 83.2 | 83.1 | 83.7 | 81.5 | **83.7** | **81.6** | 78.0 | 78.1 | 83.6 |
|
74 |
+
| VECO Large | 88.9 | 82.4 | 86.0 | 84.7 | 85.3 | 86.2 | **85.8** | 80.1 | 83.0 | 77.2 | 80.9 | 82.8 | 75.3 | **83.1** | **83.0** | 83.0 |
|
75 |
+
| **ERNIE-M Large** | **89.5** | **86.5** | **86.9** | **86.1** | **86.0** | **86.8** | 84.1 | **83.8** | **84.1** | **84.5** | **82.1** | 83.5 | 81.1 | 79.4 | 77.9 | **84.2** |
|
76 |
+
|
77 |
+
### Cross-lingual Named Entity Recognition
|
78 |
+
|
79 |
+
* datasets:CoNLI
|
80 |
+
|
81 |
+
| Model | en | nl | es | de | Avg |
|
82 |
+
| ------------------------------ | --------- | --------- | --------- | --------- | --------- |
|
83 |
+
| *Fine-tune on English dataset* | | | | | |
|
84 |
+
| mBERT | 91.97 | 77.57 | 74.96 | 69.56 | 78.52 |
|
85 |
+
| XLM-R | 92.25 | **78.08** | 76.53 | **69.60** | 79.11 |
|
86 |
+
| **ERNIE-M** | **92.78** | 78.01 | **79.37** | 68.08 | **79.56** |
|
87 |
+
| XLM-R LARGE | 92.92 | 80.80 | 78.64 | 71.40 | 80.94 |
|
88 |
+
| **ERNIE-M LARGE** | **93.28** | **81.45** | **78.83** | **72.99** | **81.64** |
|
89 |
+
| *Fine-tune on all dataset* | | | | | |
|
90 |
+
| XLM-R | 91.08 | 89.09 | 87.28 | 83.17 | 87.66 |
|
91 |
+
| **ERNIE-M** | **93.04** | **91.73** | **88.33** | **84.20** | **89.32** |
|
92 |
+
| XLM-R LARGE | 92.00 | 91.60 | **89.52** | 84.60 | 89.43 |
|
93 |
+
| **ERNIE-M LARGE** | **94.01** | **93.81** | 89.23 | **86.20** | **90.81** |
|
94 |
+
|
95 |
+
### Cross-lingual Question Answering
|
96 |
+
|
97 |
+
* datasets:MLQA
|
98 |
+
|
99 |
+
| Model | en | es | de | ar | hi | vi | zh | Avg |
|
100 |
+
| ----------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
|
101 |
+
| mBERT | 77.7 / 65.2 | 64.3 / 46.6 | 57.9 / 44.3 | 45.7 / 29.8 | 43.8 / 29.7 | 57.1 / 38.6 | 57.5 / 37.3 | 57.7 / 41.6 |
|
102 |
+
| XLM | 74.9 / 62.4 | 68.0 / 49.8 | 62.2 / 47.6 | 54.8 / 36.3 | 48.8 / 27.3 | 61.4 / 41.8 | 61.1 / 39.6 | 61.6 / 43.5 |
|
103 |
+
| XLM-R | 77.1 / 64.6 | 67.4 / 49.6 | 60.9 / 46.7 | 54.9 / 36.6 | 59.4 / 42.9 | 64.5 / 44.7 | 61.8 / 39.3 | 63.7 / 46.3 |
|
104 |
+
| INFOXLM | 81.3 / 68.2 | 69.9 / 51.9 | 64.2 / 49.6 | 60.1 / 40.9 | 65.0 / 47.5 | 70.0 / 48.6 | 64.7 / **41.2** | 67.9 / 49.7 |
|
105 |
+
| **ERNIE-M** | **81.6 / 68.5** | **70.9 / 52.6** | **65.8 / 50.7** | **61.8 / 41.9** | **65.4 / 47.5** | **70.0 / 49.2** | **65.6** / 41.0 | **68.7 / 50.2** |
|
106 |
+
| XLM-R LARGE | 80.6 / 67.8 | 74.1 / 56.0 | 68.5 / 53.6 | 63.1 / 43.5 | 62.9 / 51.6 | 71.3 / 50.9 | 68.0 / 45.4 | 70.7 / 52.7 |
|
107 |
+
| INFOXLM LARGE | **84.5 / 71.6** | **75.1 / 57.3** | **71.2 / 56.2** | **67.6 / 47.6** | 72.5 / 54.2 | **75.2 / 54.1** | 69.2 / 45.4 | 73.6 / 55.2 |
|
108 |
+
| **ERNIE-M LARGE** | 84.4 / 71.5 | 74.8 / 56.6 | 70.8 / 55.9 | 67.4 / 47.2 | **72.6 / 54.7** | 75.0 / 53.7 | **71.1 / 47.5** | **73.7 / 55.3** |
|
109 |
+
|
110 |
+
### Cross-lingual Paraphrase Identification
|
111 |
+
|
112 |
+
* datasets:PAWS-X
|
113 |
+
|
114 |
+
| Model | en | de | es | fr | ja | ko | zh | Avg |
|
115 |
+
| ---------------------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- |
|
116 |
+
| Cross-lingual Transfer | | | | | | | | |
|
117 |
+
| mBERT | 94.0 | 85.7 | 87.4 | 87.0 | 73.0 | 69.6 | 77.0 | 81.9 |
|
118 |
+
| XLM | 94.0 | 85.9 | 88.3 | 87.4 | 69.3 | 64.8 | 76.5 | 80.9 |
|
119 |
+
| MMTE | 93.1 | 85.1 | 87.2 | 86.9 | 72.0 | 69.2 | 75.9 | 81.3 |
|
120 |
+
| XLM-R LARGE | 94.7 | 89.7 | 90.1 | 90.4 | 78.7 | 79.0 | 82.3 | 86.4 |
|
121 |
+
| VECO LARGE | **96.2** | 91.3 | 91.4 | 92.0 | 81.8 | 82.9 | 85.1 | 88.7 |
|
122 |
+
| **ERNIE-M LARGE** | 96.0 | **91.9** | **91.4** | **92.2** | **83.9** | **84.5** | **86.9** | **89.5** |
|
123 |
+
| Translate-Train-All | | | | | | | | |
|
124 |
+
| VECO LARGE | 96.4 | 93.0 | 93.0 | 93.5 | 87.2 | 86.8 | 87.9 | 91.1 |
|
125 |
+
| **ERNIE-M LARGE** | **96.5** | **93.5** | **93.3** | **93.8** | **87.9** | **88.4** | **89.2** | **91.8** |
|
126 |
+
|
127 |
+
|
128 |
+
### Cross-lingual Sentence Retrieval
|
129 |
+
|
130 |
+
* dataset:Tatoeba
|
131 |
+
|
132 |
+
| Model | Avg |
|
133 |
+
| --------------------------------------- | -------- |
|
134 |
+
| XLM-R LARGE | 75.2 |
|
135 |
+
| VECO LARGE | 86.9 |
|
136 |
+
| **ERNIE-M LARGE** | **87.9** |
|
137 |
+
| **ERNIE-M LARGE( after fine-tuning)** | **93.3** |
|
138 |
+
|
139 |
+
|
140 |
+
## Citation Info
|
141 |
+
|
142 |
+
```text
|
143 |
+
@article{Ouyang2021ERNIEMEM,
|
144 |
+
title={ERNIE-M: Enhanced Multilingual Representation by Aligning Cross-lingual Semantics with Monolingual Corpora},
|
145 |
+
author={Xuan Ouyang and Shuohuan Wang and Chao Pang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
|
146 |
+
journal={ArXiv},
|
147 |
+
year={2021},
|
148 |
+
volume={abs/2012.15674}
|
149 |
+
}
|
150 |
+
```
|
convert.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copied from https://github.com/nghuyong/ERNIE-Pytorch/blob/master/convert.py
|
2 |
+
|
3 |
+
|
4 |
+
#!/usr/bin/env python
|
5 |
+
# encoding: utf-8
|
6 |
+
"""
|
7 |
+
File Description:
|
8 |
+
ernie3.0 series model conversion based on paddlenlp repository
|
9 |
+
ernie2.0 series model conversion based on paddlenlp repository
|
10 |
+
official repo: https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo
|
11 |
+
Author: nghuyong liushu
|
12 |
+
Mail: nghuyong@163.com 1554987494@qq.com
|
13 |
+
Created Time: 2022/8/17
|
14 |
+
"""
|
15 |
+
import collections
|
16 |
+
import os
|
17 |
+
import json
|
18 |
+
import paddle.fluid.dygraph as D
|
19 |
+
import torch
|
20 |
+
from paddle import fluid
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
def build_params_map(attention_num=12):
|
24 |
+
"""
|
25 |
+
build params map from paddle-paddle's ERNIE to transformer's BERT
|
26 |
+
:return:
|
27 |
+
"""
|
28 |
+
weight_map = collections.OrderedDict({
|
29 |
+
'embeddings.word_embeddings.weight': "embeddings.word_embeddings.weight",
|
30 |
+
'embeddings.position_embeddings.weight': "embeddings.position_embeddings.weight",
|
31 |
+
# 'ernie.embeddings.token_type_embeddings.weight': "ernie.embeddings.token_type_embeddings.weight",
|
32 |
+
# 'ernie.embeddings.task_type_embeddings.weight': "ernie.embeddings.task_type_embeddings.weight",
|
33 |
+
'embeddings.layer_norm.weight': 'embeddings.layer_norm.weight',
|
34 |
+
'embeddings.layer_norm.bias': 'embeddings.layer_norm.bias',
|
35 |
+
})
|
36 |
+
# add attention layers
|
37 |
+
for i in range(attention_num):
|
38 |
+
weight_map[f'encoder.layers.{i}.self_attn.q_proj.weight'] = f'encoder.layers.{i}.self_attn.q_proj.weight'
|
39 |
+
weight_map[f'encoder.layers.{i}.self_attn.q_proj.bias'] = f'encoder.layers.{i}.self_attn.q_proj.bias'
|
40 |
+
weight_map[f'encoder.layers.{i}.self_attn.k_proj.weight'] = f'encoder.layers.{i}.self_attn.k_proj.weight'
|
41 |
+
weight_map[f'encoder.layers.{i}.self_attn.k_proj.bias'] = f'encoder.layers.{i}.self_attn.k_proj.bias'
|
42 |
+
weight_map[f'encoder.layers.{i}.self_attn.v_proj.weight'] = f'encoder.layers.{i}.self_attn.v_proj.weight'
|
43 |
+
weight_map[f'encoder.layers.{i}.self_attn.v_proj.bias'] = f'encoder.layers.{i}.self_attn.v_proj.bias'
|
44 |
+
weight_map[f'encoder.layers.{i}.self_attn.out_proj.weight'] = f'encoder.layers.{i}.self_attn.out_proj.weight'
|
45 |
+
weight_map[f'encoder.layers.{i}.self_attn.out_proj.bias'] = f'encoder.layers.{i}.self_attn.out_proj.bias'
|
46 |
+
weight_map[f'encoder.layers.{i}.norm1.weight'] = f'encoder.layers.{i}.norm1.weight'
|
47 |
+
weight_map[f'encoder.layers.{i}.norm1.bias'] = f'encoder.layers.{i}.norm1.bias'
|
48 |
+
weight_map[f'encoder.layers.{i}.linear1.weight'] = f'encoder.layers.{i}.linear1.weight'
|
49 |
+
weight_map[f'encoder.layers.{i}.linear1.bias'] = f'encoder.layers.{i}.linear1.bias'
|
50 |
+
weight_map[f'encoder.layers.{i}.linear2.weight'] = f'encoder.layers.{i}.linear2.weight'
|
51 |
+
weight_map[f'encoder.layers.{i}.linear2.bias'] = f'encoder.layers.{i}.linear2.bias'
|
52 |
+
weight_map[f'encoder.layers.{i}.norm2.weight'] = f'encoder.layers.{i}.norm2.weight'
|
53 |
+
weight_map[f'encoder.layers.{i}.norm2.bias'] = f'encoder.layers.{i}.norm2.bias'
|
54 |
+
#
|
55 |
+
weight_map.update(
|
56 |
+
{
|
57 |
+
'pooler.dense.weight': 'pooler.dense.weight',
|
58 |
+
'pooler.dense.bias': 'pooler.dense.bias',
|
59 |
+
# 'cls.predictions.transform.weight': 'cls.predictions.transform.dense.weight',
|
60 |
+
# 'cls.predictions.transform.bias': 'cls.predictions.transform.dense.bias',
|
61 |
+
# 'cls.predictions.layer_norm.weight': 'cls.predictions.transform.LayerNorm.gamma',
|
62 |
+
# 'cls.predictions.layer_norm.bias': 'cls.predictions.transform.LayerNorm.beta',
|
63 |
+
# 'cls.predictions.decoder_bias': 'cls.predictions.bias'
|
64 |
+
}
|
65 |
+
)
|
66 |
+
return weight_map
|
67 |
+
|
68 |
+
|
69 |
+
def extract_and_convert(input_dir, output_dir):
|
70 |
+
"""
|
71 |
+
抽取并转换
|
72 |
+
:param input_dir:
|
73 |
+
:param output_dir:
|
74 |
+
:return:
|
75 |
+
"""
|
76 |
+
if not os.path.exists(output_dir):
|
77 |
+
os.makedirs(output_dir)
|
78 |
+
print('=' * 20 + 'save config file' + '=' * 20)
|
79 |
+
config = json.load(open(os.path.join(input_dir, 'config.json'), 'rt', encoding='utf-8'))
|
80 |
+
# if 'init_args' in config:
|
81 |
+
# config = config['init_args'][0]
|
82 |
+
# del config['init_class']
|
83 |
+
config['layer_norm_eps'] = 1e-5
|
84 |
+
# config['model_type'] = 'ernie'
|
85 |
+
# config['architectures'] = ["ErnieForMaskedLM"] # or 'BertModel'
|
86 |
+
# config['intermediate_size'] = 4 * config['hidden_size']
|
87 |
+
json.dump(config, open(os.path.join(output_dir, 'config.json'), 'wt', encoding='utf-8'), indent=4)
|
88 |
+
print('=' * 20 + 'save vocab file' + '=' * 20)
|
89 |
+
with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
|
90 |
+
words = f.read().splitlines()
|
91 |
+
words = [word.split('\t')[0] for word in words]
|
92 |
+
with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
|
93 |
+
for word in words:
|
94 |
+
f.write(word + "\n")
|
95 |
+
print('=' * 20 + 'extract weights' + '=' * 20)
|
96 |
+
state_dict = collections.OrderedDict()
|
97 |
+
weight_map = build_params_map(attention_num=config['num_hidden_layers'])
|
98 |
+
with fluid.dygraph.guard():
|
99 |
+
paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
|
100 |
+
for weight_name, weight_value in paddle_paddle_params.items():
|
101 |
+
if 'weight' in weight_name:
|
102 |
+
# if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name:
|
103 |
+
# weight_value = weight_value.transpose()
|
104 |
+
|
105 |
+
# if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name and \
|
106 |
+
# "k_proj" not in weight_name and "v_proj" not in weight_name and \
|
107 |
+
# "out_proj" not in weight_name and "linear1" not in weight_name and \
|
108 |
+
# "linear2" not in weight_name:
|
109 |
+
# weight_value = weight_value.transpose()
|
110 |
+
if "encoder" in weight_name:
|
111 |
+
if "linear1" in weight_name or "linear2" in weight_name:
|
112 |
+
weight_value = weight_value.transpose()
|
113 |
+
else:
|
114 |
+
weight_value = weight_value.transpose()
|
115 |
+
|
116 |
+
if weight_name not in weight_map:
|
117 |
+
print('=' * 20, '[SKIP]', weight_name, '=' * 20)
|
118 |
+
continue
|
119 |
+
state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
|
120 |
+
print(weight_name, '->', weight_map[weight_name], weight_value.shape)
|
121 |
+
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
extract_and_convert("./ernie_m_large_paddle/", "./ernie_m_large_torch/")
|
pytorch_weights_postprocess.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code takes the pytorch weights generated using paddle2torch_weights script and then stacks
|
2 |
+
# Queries, Keys and Values for Attention(self_attn) Layer in Encoder Layers(to make it more like torch.nn.MultiheadAttention).
|
3 |
+
|
4 |
+
import torch
|
5 |
+
full_state_dict = torch.load("./pytorch_model.bin")
|
6 |
+
full_state_dict = dict((".".join(k.split(".")[1:]), v) \
|
7 |
+
for k, v in full_state_dict.items())
|
8 |
+
|
9 |
+
def con_cat(kqv_dict):
|
10 |
+
kqv_dict_keys = list(kqv_dict.keys())
|
11 |
+
if "weight" in kqv_dict_keys[0]:
|
12 |
+
tmp = kqv_dict_keys[0].split(".")[3]
|
13 |
+
c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")],
|
14 |
+
kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")],
|
15 |
+
kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")]
|
16 |
+
])
|
17 |
+
c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_weight"])
|
18 |
+
# return {c_dict_key:c_dict_value}
|
19 |
+
return {f"encoder.{c_dict_key}":c_dict_value}
|
20 |
+
|
21 |
+
#(k,q,v), (k,v,q), (q, k, v), (q, v, k), (v, k, q), (v, q, k)
|
22 |
+
if "bias" in kqv_dict_keys[0]:
|
23 |
+
tmp = kqv_dict_keys[0].split(".")[3]
|
24 |
+
c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")],
|
25 |
+
kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")],
|
26 |
+
kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")]
|
27 |
+
])
|
28 |
+
c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_bias"])
|
29 |
+
# return {c_dict_key:c_dict_value}
|
30 |
+
return {f"encoder.{c_dict_key}":c_dict_value}
|
31 |
+
|
32 |
+
|
33 |
+
mod_dict = {}
|
34 |
+
#Embedding weights
|
35 |
+
for k, v in full_state_dict.items():
|
36 |
+
if "embedding" in k or "layer_norm" in k:
|
37 |
+
mod_dict.update({f"embeddings.{k}": v})
|
38 |
+
|
39 |
+
#Encoder weights
|
40 |
+
for i in range(12):
|
41 |
+
sd = dict((k, v) for k, v in full_state_dict.items() if f"layers.{i}" in k)
|
42 |
+
kvq_weight = {}
|
43 |
+
kvq_bias = {}
|
44 |
+
for k, v in sd.items():
|
45 |
+
if "self_attn" in k and "out_proj" not in k:
|
46 |
+
if "weight" in k:
|
47 |
+
kvq_weight[k] = v
|
48 |
+
if "bias" in k:
|
49 |
+
kvq_bias[k] = v
|
50 |
+
else:
|
51 |
+
mod_dict[f"encoder.{k}"] = v
|
52 |
+
|
53 |
+
mod_dict.update(con_cat(kvq_weight))
|
54 |
+
mod_dict.update(con_cat(kvq_bias))
|
55 |
+
|
56 |
+
#Pooler
|
57 |
+
for k, v in full_state_dict.items():
|
58 |
+
if "pooler" in k:
|
59 |
+
mod_dict.update({k:v})
|
60 |
+
|
61 |
+
|
62 |
+
for k, v in mod_dict.items():
|
63 |
+
print(k, v.size())
|
64 |
+
|
65 |
+
model_name = "ernie-m-base_pytorch"
|
66 |
+
PATH = f"./{model_name}/pytorch_model.bin"
|
67 |
+
torch.save(mod_dict, PATH)
|