Tongjilibo
commited on
Commit
•
2fc5f2f
1
Parent(s):
aeed0cb
init commit
Browse files- README.md +8 -0
- bert4torch_config.json +17 -0
- config.json +19 -0
- convert.py +465 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer_config.json +1 -0
- vocab.txt +0 -0
README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
# 使用说明
|
4 |
+
|
5 |
+
- 推荐使用本权重,此权重为使用convert.py自动下载paddle权重并转为pytorch权重
|
6 |
+
- [源项目](https://github.com/universal-ie/UIE)
|
7 |
+
- [uie_pytorch](https://github.com/HUSTAI/uie_pytorch)
|
8 |
+
- 用户也可以使用convert.py自行下载和转换
|
bert4torch_config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"attention_probs_dropout_prob": 0.1,
|
3 |
+
"hidden_act": "gelu",
|
4 |
+
"hidden_dropout_prob": 0.1,
|
5 |
+
"hidden_size": 768,
|
6 |
+
"initializer_range": 0.02,
|
7 |
+
"max_position_embeddings": 2048,
|
8 |
+
"num_attention_heads": 12,
|
9 |
+
"num_hidden_layers": 12,
|
10 |
+
"task_type_vocab_size": 3,
|
11 |
+
"type_vocab_size": 4,
|
12 |
+
"use_task_id": true,
|
13 |
+
"vocab_size": 40000,
|
14 |
+
"layer_norm_eps": 1e-12,
|
15 |
+
"intermediate_size": 3072,
|
16 |
+
"model": "uie"
|
17 |
+
}
|
config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"attention_probs_dropout_prob": 0.1,
|
3 |
+
"hidden_act": "gelu",
|
4 |
+
"hidden_dropout_prob": 0.1,
|
5 |
+
"hidden_size": 768,
|
6 |
+
"initializer_range": 0.02,
|
7 |
+
"max_position_embeddings": 2048,
|
8 |
+
"num_attention_heads": 12,
|
9 |
+
"num_hidden_layers": 12,
|
10 |
+
"task_type_vocab_size": 3,
|
11 |
+
"type_vocab_size": 4,
|
12 |
+
"use_task_id": true,
|
13 |
+
"vocab_size": 40000,
|
14 |
+
"architectures": [
|
15 |
+
"UIE"
|
16 |
+
],
|
17 |
+
"layer_norm_eps": 1e-12,
|
18 |
+
"intermediate_size": 3072
|
19 |
+
}
|
convert.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''下载预训练模型并且转了pytorch格式
|
2 |
+
'''
|
3 |
+
import argparse
|
4 |
+
import collections
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
import torch
|
9 |
+
import logging
|
10 |
+
import shutil
|
11 |
+
from tqdm import tqdm
|
12 |
+
import time
|
13 |
+
|
14 |
+
logger = logging.Logger('log')
|
15 |
+
|
16 |
+
|
17 |
+
def get_path_from_url(url, root_dir, check_exist=True, decompress=True):
|
18 |
+
""" Download from given url to root_dir.
|
19 |
+
if file or directory specified by url is exists under
|
20 |
+
root_dir, return the path directly, otherwise download
|
21 |
+
from url and decompress it, return the path.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
url (str): download url
|
25 |
+
root_dir (str): root dir for downloading, it should be
|
26 |
+
WEIGHTS_HOME or DATASET_HOME
|
27 |
+
decompress (bool): decompress zip or tar file. Default is `True`
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
str: a local path to save downloaded models & weights & datasets.
|
31 |
+
"""
|
32 |
+
|
33 |
+
import os.path
|
34 |
+
import os
|
35 |
+
import tarfile
|
36 |
+
import zipfile
|
37 |
+
|
38 |
+
def is_url(path):
|
39 |
+
"""
|
40 |
+
Whether path is URL.
|
41 |
+
Args:
|
42 |
+
path (string): URL string or not.
|
43 |
+
"""
|
44 |
+
return path.startswith('http://') or path.startswith('https://')
|
45 |
+
|
46 |
+
def _map_path(url, root_dir):
|
47 |
+
# parse path after download under root_dir
|
48 |
+
fname = os.path.split(url)[-1]
|
49 |
+
fpath = fname
|
50 |
+
return os.path.join(root_dir, fpath)
|
51 |
+
|
52 |
+
def _get_download(url, fullname):
|
53 |
+
import requests
|
54 |
+
# using requests.get method
|
55 |
+
fname = os.path.basename(fullname)
|
56 |
+
try:
|
57 |
+
req = requests.get(url, stream=True)
|
58 |
+
except Exception as e: # requests.exceptions.ConnectionError
|
59 |
+
logger.info("Downloading {} from {} failed with exception {}".format(
|
60 |
+
fname, url, str(e)))
|
61 |
+
return False
|
62 |
+
|
63 |
+
if req.status_code != 200:
|
64 |
+
raise RuntimeError("Downloading from {} failed with code "
|
65 |
+
"{}!".format(url, req.status_code))
|
66 |
+
|
67 |
+
# For protecting download interupted, download to
|
68 |
+
# tmp_fullname firstly, move tmp_fullname to fullname
|
69 |
+
# after download finished
|
70 |
+
tmp_fullname = fullname + "_tmp"
|
71 |
+
total_size = req.headers.get('content-length')
|
72 |
+
with open(tmp_fullname, 'wb') as f:
|
73 |
+
if total_size:
|
74 |
+
with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar:
|
75 |
+
for chunk in req.iter_content(chunk_size=1024):
|
76 |
+
f.write(chunk)
|
77 |
+
pbar.update(1)
|
78 |
+
else:
|
79 |
+
for chunk in req.iter_content(chunk_size=1024):
|
80 |
+
if chunk:
|
81 |
+
f.write(chunk)
|
82 |
+
shutil.move(tmp_fullname, fullname)
|
83 |
+
|
84 |
+
return fullname
|
85 |
+
|
86 |
+
def _download(url, path):
|
87 |
+
"""
|
88 |
+
Download from url, save to path.
|
89 |
+
|
90 |
+
url (str): download url
|
91 |
+
path (str): download to given path
|
92 |
+
"""
|
93 |
+
|
94 |
+
if not os.path.exists(path):
|
95 |
+
os.makedirs(path)
|
96 |
+
|
97 |
+
fname = os.path.split(url)[-1]
|
98 |
+
fullname = os.path.join(path, fname)
|
99 |
+
retry_cnt = 0
|
100 |
+
|
101 |
+
logger.info("Downloading {} from {}".format(fname, url))
|
102 |
+
DOWNLOAD_RETRY_LIMIT = 3
|
103 |
+
while not os.path.exists(fullname):
|
104 |
+
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
|
105 |
+
retry_cnt += 1
|
106 |
+
else:
|
107 |
+
raise RuntimeError("Download from {} failed. "
|
108 |
+
"Retry limit reached".format(url))
|
109 |
+
|
110 |
+
if not _get_download(url, fullname):
|
111 |
+
time.sleep(1)
|
112 |
+
continue
|
113 |
+
|
114 |
+
return fullname
|
115 |
+
|
116 |
+
def _uncompress_file_zip(filepath):
|
117 |
+
with zipfile.ZipFile(filepath, 'r') as files:
|
118 |
+
file_list = files.namelist()
|
119 |
+
|
120 |
+
file_dir = os.path.dirname(filepath)
|
121 |
+
|
122 |
+
if _is_a_single_file(file_list):
|
123 |
+
rootpath = file_list[0]
|
124 |
+
uncompressed_path = os.path.join(file_dir, rootpath)
|
125 |
+
files.extractall(file_dir)
|
126 |
+
|
127 |
+
elif _is_a_single_dir(file_list):
|
128 |
+
# `strip(os.sep)` to remove `os.sep` in the tail of path
|
129 |
+
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
|
130 |
+
os.sep)[-1]
|
131 |
+
uncompressed_path = os.path.join(file_dir, rootpath)
|
132 |
+
|
133 |
+
files.extractall(file_dir)
|
134 |
+
else:
|
135 |
+
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
|
136 |
+
uncompressed_path = os.path.join(file_dir, rootpath)
|
137 |
+
if not os.path.exists(uncompressed_path):
|
138 |
+
os.makedirs(uncompressed_path)
|
139 |
+
files.extractall(os.path.join(file_dir, rootpath))
|
140 |
+
|
141 |
+
return uncompressed_path
|
142 |
+
|
143 |
+
def _is_a_single_file(file_list):
|
144 |
+
if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
|
145 |
+
return True
|
146 |
+
return False
|
147 |
+
|
148 |
+
def _is_a_single_dir(file_list):
|
149 |
+
new_file_list = []
|
150 |
+
for file_path in file_list:
|
151 |
+
if '/' in file_path:
|
152 |
+
file_path = file_path.replace('/', os.sep)
|
153 |
+
elif '\\' in file_path:
|
154 |
+
file_path = file_path.replace('\\', os.sep)
|
155 |
+
new_file_list.append(file_path)
|
156 |
+
|
157 |
+
file_name = new_file_list[0].split(os.sep)[0]
|
158 |
+
for i in range(1, len(new_file_list)):
|
159 |
+
if file_name != new_file_list[i].split(os.sep)[0]:
|
160 |
+
return False
|
161 |
+
return True
|
162 |
+
|
163 |
+
def _uncompress_file_tar(filepath, mode="r:*"):
|
164 |
+
with tarfile.open(filepath, mode) as files:
|
165 |
+
file_list = files.getnames()
|
166 |
+
|
167 |
+
file_dir = os.path.dirname(filepath)
|
168 |
+
|
169 |
+
if _is_a_single_file(file_list):
|
170 |
+
rootpath = file_list[0]
|
171 |
+
uncompressed_path = os.path.join(file_dir, rootpath)
|
172 |
+
files.extractall(file_dir)
|
173 |
+
elif _is_a_single_dir(file_list):
|
174 |
+
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
|
175 |
+
os.sep)[-1]
|
176 |
+
uncompressed_path = os.path.join(file_dir, rootpath)
|
177 |
+
files.extractall(file_dir)
|
178 |
+
else:
|
179 |
+
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
|
180 |
+
uncompressed_path = os.path.join(file_dir, rootpath)
|
181 |
+
if not os.path.exists(uncompressed_path):
|
182 |
+
os.makedirs(uncompressed_path)
|
183 |
+
|
184 |
+
files.extractall(os.path.join(file_dir, rootpath))
|
185 |
+
|
186 |
+
return uncompressed_path
|
187 |
+
|
188 |
+
def _decompress(fname):
|
189 |
+
"""
|
190 |
+
Decompress for zip and tar file
|
191 |
+
"""
|
192 |
+
logger.info("Decompressing {}...".format(fname))
|
193 |
+
|
194 |
+
# For protecting decompressing interupted,
|
195 |
+
# decompress to fpath_tmp directory firstly, if decompress
|
196 |
+
# successed, move decompress files to fpath and delete
|
197 |
+
# fpath_tmp and remove download compress file.
|
198 |
+
|
199 |
+
if tarfile.is_tarfile(fname):
|
200 |
+
uncompressed_path = _uncompress_file_tar(fname)
|
201 |
+
elif zipfile.is_zipfile(fname):
|
202 |
+
uncompressed_path = _uncompress_file_zip(fname)
|
203 |
+
else:
|
204 |
+
raise TypeError("Unsupport compress file type {}".format(fname))
|
205 |
+
|
206 |
+
return uncompressed_path
|
207 |
+
|
208 |
+
assert is_url(url), "downloading from {} not a url".format(url)
|
209 |
+
fullpath = _map_path(url, root_dir)
|
210 |
+
if os.path.exists(fullpath) and check_exist:
|
211 |
+
logger.info("Found {}".format(fullpath))
|
212 |
+
else:
|
213 |
+
fullpath = _download(url, root_dir)
|
214 |
+
|
215 |
+
if decompress and (tarfile.is_tarfile(fullpath) or
|
216 |
+
zipfile.is_zipfile(fullpath)):
|
217 |
+
fullpath = _decompress(fullpath)
|
218 |
+
|
219 |
+
return fullpath
|
220 |
+
|
221 |
+
|
222 |
+
MODEL_MAP = {
|
223 |
+
"uie-base": {
|
224 |
+
"resource_file_urls": {
|
225 |
+
"model_state.pdparams":
|
226 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v0.1/model_state.pdparams",
|
227 |
+
"model_config.json":
|
228 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
|
229 |
+
"vocab_file":
|
230 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
|
231 |
+
"special_tokens_map":
|
232 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
|
233 |
+
"tokenizer_config":
|
234 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
|
235 |
+
}
|
236 |
+
},
|
237 |
+
"uie-medium": {
|
238 |
+
"resource_file_urls": {
|
239 |
+
"model_state.pdparams":
|
240 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams",
|
241 |
+
"model_config.json":
|
242 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json",
|
243 |
+
"vocab_file":
|
244 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
|
245 |
+
"special_tokens_map":
|
246 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
|
247 |
+
"tokenizer_config":
|
248 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
|
249 |
+
}
|
250 |
+
},
|
251 |
+
"uie-mini": {
|
252 |
+
"resource_file_urls": {
|
253 |
+
"model_state.pdparams":
|
254 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams",
|
255 |
+
"model_config.json":
|
256 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json",
|
257 |
+
"vocab_file":
|
258 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
|
259 |
+
"special_tokens_map":
|
260 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
|
261 |
+
"tokenizer_config":
|
262 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
|
263 |
+
}
|
264 |
+
},
|
265 |
+
"uie-micro": {
|
266 |
+
"resource_file_urls": {
|
267 |
+
"model_state.pdparams":
|
268 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams",
|
269 |
+
"model_config.json":
|
270 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json",
|
271 |
+
"vocab_file":
|
272 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
|
273 |
+
"special_tokens_map":
|
274 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
|
275 |
+
"tokenizer_config":
|
276 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
|
277 |
+
}
|
278 |
+
},
|
279 |
+
"uie-nano": {
|
280 |
+
"resource_file_urls": {
|
281 |
+
"model_state.pdparams":
|
282 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams",
|
283 |
+
"model_config.json":
|
284 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json",
|
285 |
+
"vocab_file":
|
286 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
|
287 |
+
"special_tokens_map":
|
288 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
|
289 |
+
"tokenizer_config":
|
290 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
|
291 |
+
}
|
292 |
+
},
|
293 |
+
"uie-medical-base": {
|
294 |
+
"resource_file_urls": {
|
295 |
+
"model_state.pdparams":
|
296 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams",
|
297 |
+
"model_config.json":
|
298 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
|
299 |
+
"vocab_file":
|
300 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
|
301 |
+
"special_tokens_map":
|
302 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
|
303 |
+
"tokenizer_config":
|
304 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
|
305 |
+
}
|
306 |
+
},
|
307 |
+
"uie-tiny": {
|
308 |
+
"resource_file_urls": {
|
309 |
+
"model_state.pdparams":
|
310 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams",
|
311 |
+
"model_config.json":
|
312 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
|
313 |
+
"vocab_file":
|
314 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt",
|
315 |
+
"special_tokens_map":
|
316 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json",
|
317 |
+
"tokenizer_config":
|
318 |
+
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json"
|
319 |
+
}
|
320 |
+
}
|
321 |
+
}
|
322 |
+
|
323 |
+
|
324 |
+
def build_params_map(attention_num=12):
|
325 |
+
"""
|
326 |
+
build params map from paddle-paddle's ERNIE to transformer's BERT
|
327 |
+
:return:
|
328 |
+
"""
|
329 |
+
weight_map = collections.OrderedDict({
|
330 |
+
'encoder.embeddings.word_embeddings.weight': "bert.embeddings.word_embeddings.weight",
|
331 |
+
'encoder.embeddings.position_embeddings.weight': "bert.embeddings.position_embeddings.weight",
|
332 |
+
'encoder.embeddings.token_type_embeddings.weight': "bert.embeddings.token_type_embeddings.weight",
|
333 |
+
'encoder.embeddings.task_type_embeddings.weight': "embeddings.task_type_embeddings.weight", # 这里没有前缀bert,直接映射到bert4torch结构
|
334 |
+
'encoder.embeddings.layer_norm.weight': 'bert.embeddings.LayerNorm.weight',
|
335 |
+
'encoder.embeddings.layer_norm.bias': 'bert.embeddings.LayerNorm.bias',
|
336 |
+
})
|
337 |
+
# add attention layers
|
338 |
+
for i in range(attention_num):
|
339 |
+
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.query.weight'
|
340 |
+
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.query.bias'
|
341 |
+
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.key.weight'
|
342 |
+
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.key.bias'
|
343 |
+
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.value.weight'
|
344 |
+
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.value.bias'
|
345 |
+
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'bert.encoder.layer.{i}.attention.output.dense.weight'
|
346 |
+
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'bert.encoder.layer.{i}.attention.output.dense.bias'
|
347 |
+
weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.weight'
|
348 |
+
weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.bias'
|
349 |
+
weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'bert.encoder.layer.{i}.intermediate.dense.weight'
|
350 |
+
weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'bert.encoder.layer.{i}.intermediate.dense.bias'
|
351 |
+
weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'bert.encoder.layer.{i}.output.dense.weight'
|
352 |
+
weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'bert.encoder.layer.{i}.output.dense.bias'
|
353 |
+
weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'bert.encoder.layer.{i}.output.LayerNorm.weight'
|
354 |
+
weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'bert.encoder.layer.{i}.output.LayerNorm.bias'
|
355 |
+
# add pooler
|
356 |
+
weight_map.update(
|
357 |
+
{
|
358 |
+
'encoder.pooler.dense.weight': 'bert.pooler.dense.weight',
|
359 |
+
'encoder.pooler.dense.bias': 'bert.pooler.dense.bias',
|
360 |
+
'linear_start.weight': 'linear_start.weight',
|
361 |
+
'linear_start.bias': 'linear_start.bias',
|
362 |
+
'linear_end.weight': 'linear_end.weight',
|
363 |
+
'linear_end.bias': 'linear_end.bias',
|
364 |
+
}
|
365 |
+
)
|
366 |
+
return weight_map
|
367 |
+
|
368 |
+
|
369 |
+
def extract_and_convert(input_dir, output_dir):
|
370 |
+
if not os.path.exists(output_dir):
|
371 |
+
os.makedirs(output_dir)
|
372 |
+
logger.info('=' * 20 + 'save config file' + '=' * 20)
|
373 |
+
config = json.load(open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8'))
|
374 |
+
config = config['init_args'][0]
|
375 |
+
config["architectures"] = ["UIE"]
|
376 |
+
config['layer_norm_eps'] = 1e-12
|
377 |
+
del config['init_class']
|
378 |
+
if 'sent_type_vocab_size' in config:
|
379 |
+
config['type_vocab_size'] = config['sent_type_vocab_size']
|
380 |
+
config['intermediate_size'] = 4 * config['hidden_size']
|
381 |
+
json.dump(config, open(os.path.join(output_dir, 'config.json'),
|
382 |
+
'wt', encoding='utf-8'), indent=4)
|
383 |
+
logger.info('=' * 20 + 'save vocab file' + '=' * 20)
|
384 |
+
with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
|
385 |
+
words = f.read().splitlines()
|
386 |
+
words_set = set()
|
387 |
+
words_duplicate_indices = []
|
388 |
+
for i in range(len(words)-1, -1, -1):
|
389 |
+
word = words[i]
|
390 |
+
if word in words_set:
|
391 |
+
words_duplicate_indices.append(i)
|
392 |
+
words_set.add(word)
|
393 |
+
for i, idx in enumerate(words_duplicate_indices):
|
394 |
+
words[idx] = chr(0x1F6A9+i) # Change duplicated word to 🚩 LOL
|
395 |
+
with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
|
396 |
+
for word in words:
|
397 |
+
f.write(word+'\n')
|
398 |
+
special_tokens_map = {
|
399 |
+
"unk_token": "[UNK]",
|
400 |
+
"sep_token": "[SEP]",
|
401 |
+
"pad_token": "[PAD]",
|
402 |
+
"cls_token": "[CLS]",
|
403 |
+
"mask_token": "[MASK]"
|
404 |
+
}
|
405 |
+
json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'),
|
406 |
+
'wt', encoding='utf-8'))
|
407 |
+
tokenizer_config = {
|
408 |
+
"do_lower_case": True,
|
409 |
+
"unk_token": "[UNK]",
|
410 |
+
"sep_token": "[SEP]",
|
411 |
+
"pad_token": "[PAD]",
|
412 |
+
"cls_token": "[CLS]",
|
413 |
+
"mask_token": "[MASK]",
|
414 |
+
"tokenizer_class": "BertTokenizer"
|
415 |
+
}
|
416 |
+
json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'),
|
417 |
+
'wt', encoding='utf-8'))
|
418 |
+
logger.info('=' * 20 + 'extract weights' + '=' * 20)
|
419 |
+
state_dict = collections.OrderedDict()
|
420 |
+
weight_map = build_params_map(attention_num=config['num_hidden_layers'])
|
421 |
+
paddle_paddle_params = pickle.load(
|
422 |
+
open(os.path.join(input_dir, 'model_state.pdparams'), 'rb'))
|
423 |
+
del paddle_paddle_params['StructuredToParameterName@@']
|
424 |
+
for weight_name, weight_value in paddle_paddle_params.items():
|
425 |
+
if 'weight' in weight_name:
|
426 |
+
if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
|
427 |
+
weight_value = weight_value.transpose()
|
428 |
+
# Fix: embedding error
|
429 |
+
if 'word_embeddings.weight' in weight_name:
|
430 |
+
weight_value[0, :] = 0
|
431 |
+
if weight_name not in weight_map:
|
432 |
+
logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}")
|
433 |
+
continue
|
434 |
+
state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
|
435 |
+
logger.info(f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}")
|
436 |
+
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
|
437 |
+
|
438 |
+
|
439 |
+
def check_model(input_model):
|
440 |
+
if not os.path.exists(input_model):
|
441 |
+
if input_model not in MODEL_MAP:
|
442 |
+
raise ValueError('input_model not exists!')
|
443 |
+
|
444 |
+
resource_file_urls = MODEL_MAP[input_model]['resource_file_urls']
|
445 |
+
logger.info("Downloading resource files...")
|
446 |
+
|
447 |
+
for key, val in resource_file_urls.items():
|
448 |
+
file_path = os.path.join(input_model, key)
|
449 |
+
if not os.path.exists(file_path):
|
450 |
+
get_path_from_url(val, input_model)
|
451 |
+
|
452 |
+
|
453 |
+
def do_main():
|
454 |
+
check_model(args.input_model)
|
455 |
+
extract_and_convert(args.input_model, args.output_model)
|
456 |
+
|
457 |
+
if __name__ == '__main__':
|
458 |
+
parser = argparse.ArgumentParser()
|
459 |
+
parser.add_argument("-i", "--input_model", default="uie-base", type=str,
|
460 |
+
help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]")
|
461 |
+
parser.add_argument("-o", "--output_model", default="uie_base_pytorch", type=str,
|
462 |
+
help="Directory of output pytorch model")
|
463 |
+
args = parser.parse_args()
|
464 |
+
|
465 |
+
do_main()
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dcc8d6a47e8ae6377bb00d4762e58a89291ff45fe8790956238713128b748d7d
|
3 |
+
size 471850153
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenizer_class": "BertTokenizer"}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|