Tongjilibo commited on
Commit
2fc5f2f
1 Parent(s): aeed0cb

init commit

Browse files
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # 使用说明
4
+
5
+ - 推荐使用本权重,此权重为使用convert.py自动下载paddle权重并转为pytorch权重
6
+ - [源项目](https://github.com/universal-ie/UIE)
7
+ - [uie_pytorch](https://github.com/HUSTAI/uie_pytorch)
8
+ - 用户也可以使用convert.py自行下载和转换
bert4torch_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_probs_dropout_prob": 0.1,
3
+ "hidden_act": "gelu",
4
+ "hidden_dropout_prob": 0.1,
5
+ "hidden_size": 768,
6
+ "initializer_range": 0.02,
7
+ "max_position_embeddings": 2048,
8
+ "num_attention_heads": 12,
9
+ "num_hidden_layers": 12,
10
+ "task_type_vocab_size": 3,
11
+ "type_vocab_size": 4,
12
+ "use_task_id": true,
13
+ "vocab_size": 40000,
14
+ "layer_norm_eps": 1e-12,
15
+ "intermediate_size": 3072,
16
+ "model": "uie"
17
+ }
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_probs_dropout_prob": 0.1,
3
+ "hidden_act": "gelu",
4
+ "hidden_dropout_prob": 0.1,
5
+ "hidden_size": 768,
6
+ "initializer_range": 0.02,
7
+ "max_position_embeddings": 2048,
8
+ "num_attention_heads": 12,
9
+ "num_hidden_layers": 12,
10
+ "task_type_vocab_size": 3,
11
+ "type_vocab_size": 4,
12
+ "use_task_id": true,
13
+ "vocab_size": 40000,
14
+ "architectures": [
15
+ "UIE"
16
+ ],
17
+ "layer_norm_eps": 1e-12,
18
+ "intermediate_size": 3072
19
+ }
convert.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''下载预训练模型并且转了pytorch格式
2
+ '''
3
+ import argparse
4
+ import collections
5
+ import json
6
+ import os
7
+ import pickle
8
+ import torch
9
+ import logging
10
+ import shutil
11
+ from tqdm import tqdm
12
+ import time
13
+
14
+ logger = logging.Logger('log')
15
+
16
+
17
+ def get_path_from_url(url, root_dir, check_exist=True, decompress=True):
18
+ """ Download from given url to root_dir.
19
+ if file or directory specified by url is exists under
20
+ root_dir, return the path directly, otherwise download
21
+ from url and decompress it, return the path.
22
+
23
+ Args:
24
+ url (str): download url
25
+ root_dir (str): root dir for downloading, it should be
26
+ WEIGHTS_HOME or DATASET_HOME
27
+ decompress (bool): decompress zip or tar file. Default is `True`
28
+
29
+ Returns:
30
+ str: a local path to save downloaded models & weights & datasets.
31
+ """
32
+
33
+ import os.path
34
+ import os
35
+ import tarfile
36
+ import zipfile
37
+
38
+ def is_url(path):
39
+ """
40
+ Whether path is URL.
41
+ Args:
42
+ path (string): URL string or not.
43
+ """
44
+ return path.startswith('http://') or path.startswith('https://')
45
+
46
+ def _map_path(url, root_dir):
47
+ # parse path after download under root_dir
48
+ fname = os.path.split(url)[-1]
49
+ fpath = fname
50
+ return os.path.join(root_dir, fpath)
51
+
52
+ def _get_download(url, fullname):
53
+ import requests
54
+ # using requests.get method
55
+ fname = os.path.basename(fullname)
56
+ try:
57
+ req = requests.get(url, stream=True)
58
+ except Exception as e: # requests.exceptions.ConnectionError
59
+ logger.info("Downloading {} from {} failed with exception {}".format(
60
+ fname, url, str(e)))
61
+ return False
62
+
63
+ if req.status_code != 200:
64
+ raise RuntimeError("Downloading from {} failed with code "
65
+ "{}!".format(url, req.status_code))
66
+
67
+ # For protecting download interupted, download to
68
+ # tmp_fullname firstly, move tmp_fullname to fullname
69
+ # after download finished
70
+ tmp_fullname = fullname + "_tmp"
71
+ total_size = req.headers.get('content-length')
72
+ with open(tmp_fullname, 'wb') as f:
73
+ if total_size:
74
+ with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar:
75
+ for chunk in req.iter_content(chunk_size=1024):
76
+ f.write(chunk)
77
+ pbar.update(1)
78
+ else:
79
+ for chunk in req.iter_content(chunk_size=1024):
80
+ if chunk:
81
+ f.write(chunk)
82
+ shutil.move(tmp_fullname, fullname)
83
+
84
+ return fullname
85
+
86
+ def _download(url, path):
87
+ """
88
+ Download from url, save to path.
89
+
90
+ url (str): download url
91
+ path (str): download to given path
92
+ """
93
+
94
+ if not os.path.exists(path):
95
+ os.makedirs(path)
96
+
97
+ fname = os.path.split(url)[-1]
98
+ fullname = os.path.join(path, fname)
99
+ retry_cnt = 0
100
+
101
+ logger.info("Downloading {} from {}".format(fname, url))
102
+ DOWNLOAD_RETRY_LIMIT = 3
103
+ while not os.path.exists(fullname):
104
+ if retry_cnt < DOWNLOAD_RETRY_LIMIT:
105
+ retry_cnt += 1
106
+ else:
107
+ raise RuntimeError("Download from {} failed. "
108
+ "Retry limit reached".format(url))
109
+
110
+ if not _get_download(url, fullname):
111
+ time.sleep(1)
112
+ continue
113
+
114
+ return fullname
115
+
116
+ def _uncompress_file_zip(filepath):
117
+ with zipfile.ZipFile(filepath, 'r') as files:
118
+ file_list = files.namelist()
119
+
120
+ file_dir = os.path.dirname(filepath)
121
+
122
+ if _is_a_single_file(file_list):
123
+ rootpath = file_list[0]
124
+ uncompressed_path = os.path.join(file_dir, rootpath)
125
+ files.extractall(file_dir)
126
+
127
+ elif _is_a_single_dir(file_list):
128
+ # `strip(os.sep)` to remove `os.sep` in the tail of path
129
+ rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
130
+ os.sep)[-1]
131
+ uncompressed_path = os.path.join(file_dir, rootpath)
132
+
133
+ files.extractall(file_dir)
134
+ else:
135
+ rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
136
+ uncompressed_path = os.path.join(file_dir, rootpath)
137
+ if not os.path.exists(uncompressed_path):
138
+ os.makedirs(uncompressed_path)
139
+ files.extractall(os.path.join(file_dir, rootpath))
140
+
141
+ return uncompressed_path
142
+
143
+ def _is_a_single_file(file_list):
144
+ if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
145
+ return True
146
+ return False
147
+
148
+ def _is_a_single_dir(file_list):
149
+ new_file_list = []
150
+ for file_path in file_list:
151
+ if '/' in file_path:
152
+ file_path = file_path.replace('/', os.sep)
153
+ elif '\\' in file_path:
154
+ file_path = file_path.replace('\\', os.sep)
155
+ new_file_list.append(file_path)
156
+
157
+ file_name = new_file_list[0].split(os.sep)[0]
158
+ for i in range(1, len(new_file_list)):
159
+ if file_name != new_file_list[i].split(os.sep)[0]:
160
+ return False
161
+ return True
162
+
163
+ def _uncompress_file_tar(filepath, mode="r:*"):
164
+ with tarfile.open(filepath, mode) as files:
165
+ file_list = files.getnames()
166
+
167
+ file_dir = os.path.dirname(filepath)
168
+
169
+ if _is_a_single_file(file_list):
170
+ rootpath = file_list[0]
171
+ uncompressed_path = os.path.join(file_dir, rootpath)
172
+ files.extractall(file_dir)
173
+ elif _is_a_single_dir(file_list):
174
+ rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
175
+ os.sep)[-1]
176
+ uncompressed_path = os.path.join(file_dir, rootpath)
177
+ files.extractall(file_dir)
178
+ else:
179
+ rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
180
+ uncompressed_path = os.path.join(file_dir, rootpath)
181
+ if not os.path.exists(uncompressed_path):
182
+ os.makedirs(uncompressed_path)
183
+
184
+ files.extractall(os.path.join(file_dir, rootpath))
185
+
186
+ return uncompressed_path
187
+
188
+ def _decompress(fname):
189
+ """
190
+ Decompress for zip and tar file
191
+ """
192
+ logger.info("Decompressing {}...".format(fname))
193
+
194
+ # For protecting decompressing interupted,
195
+ # decompress to fpath_tmp directory firstly, if decompress
196
+ # successed, move decompress files to fpath and delete
197
+ # fpath_tmp and remove download compress file.
198
+
199
+ if tarfile.is_tarfile(fname):
200
+ uncompressed_path = _uncompress_file_tar(fname)
201
+ elif zipfile.is_zipfile(fname):
202
+ uncompressed_path = _uncompress_file_zip(fname)
203
+ else:
204
+ raise TypeError("Unsupport compress file type {}".format(fname))
205
+
206
+ return uncompressed_path
207
+
208
+ assert is_url(url), "downloading from {} not a url".format(url)
209
+ fullpath = _map_path(url, root_dir)
210
+ if os.path.exists(fullpath) and check_exist:
211
+ logger.info("Found {}".format(fullpath))
212
+ else:
213
+ fullpath = _download(url, root_dir)
214
+
215
+ if decompress and (tarfile.is_tarfile(fullpath) or
216
+ zipfile.is_zipfile(fullpath)):
217
+ fullpath = _decompress(fullpath)
218
+
219
+ return fullpath
220
+
221
+
222
+ MODEL_MAP = {
223
+ "uie-base": {
224
+ "resource_file_urls": {
225
+ "model_state.pdparams":
226
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v0.1/model_state.pdparams",
227
+ "model_config.json":
228
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
229
+ "vocab_file":
230
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
231
+ "special_tokens_map":
232
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
233
+ "tokenizer_config":
234
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
235
+ }
236
+ },
237
+ "uie-medium": {
238
+ "resource_file_urls": {
239
+ "model_state.pdparams":
240
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams",
241
+ "model_config.json":
242
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json",
243
+ "vocab_file":
244
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
245
+ "special_tokens_map":
246
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
247
+ "tokenizer_config":
248
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
249
+ }
250
+ },
251
+ "uie-mini": {
252
+ "resource_file_urls": {
253
+ "model_state.pdparams":
254
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams",
255
+ "model_config.json":
256
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json",
257
+ "vocab_file":
258
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
259
+ "special_tokens_map":
260
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
261
+ "tokenizer_config":
262
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
263
+ }
264
+ },
265
+ "uie-micro": {
266
+ "resource_file_urls": {
267
+ "model_state.pdparams":
268
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams",
269
+ "model_config.json":
270
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json",
271
+ "vocab_file":
272
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
273
+ "special_tokens_map":
274
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
275
+ "tokenizer_config":
276
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
277
+ }
278
+ },
279
+ "uie-nano": {
280
+ "resource_file_urls": {
281
+ "model_state.pdparams":
282
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams",
283
+ "model_config.json":
284
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json",
285
+ "vocab_file":
286
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
287
+ "special_tokens_map":
288
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
289
+ "tokenizer_config":
290
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
291
+ }
292
+ },
293
+ "uie-medical-base": {
294
+ "resource_file_urls": {
295
+ "model_state.pdparams":
296
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams",
297
+ "model_config.json":
298
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
299
+ "vocab_file":
300
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
301
+ "special_tokens_map":
302
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
303
+ "tokenizer_config":
304
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
305
+ }
306
+ },
307
+ "uie-tiny": {
308
+ "resource_file_urls": {
309
+ "model_state.pdparams":
310
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams",
311
+ "model_config.json":
312
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
313
+ "vocab_file":
314
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt",
315
+ "special_tokens_map":
316
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json",
317
+ "tokenizer_config":
318
+ "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json"
319
+ }
320
+ }
321
+ }
322
+
323
+
324
+ def build_params_map(attention_num=12):
325
+ """
326
+ build params map from paddle-paddle's ERNIE to transformer's BERT
327
+ :return:
328
+ """
329
+ weight_map = collections.OrderedDict({
330
+ 'encoder.embeddings.word_embeddings.weight': "bert.embeddings.word_embeddings.weight",
331
+ 'encoder.embeddings.position_embeddings.weight': "bert.embeddings.position_embeddings.weight",
332
+ 'encoder.embeddings.token_type_embeddings.weight': "bert.embeddings.token_type_embeddings.weight",
333
+ 'encoder.embeddings.task_type_embeddings.weight': "embeddings.task_type_embeddings.weight", # 这里没有前缀bert,直接映射到bert4torch结构
334
+ 'encoder.embeddings.layer_norm.weight': 'bert.embeddings.LayerNorm.weight',
335
+ 'encoder.embeddings.layer_norm.bias': 'bert.embeddings.LayerNorm.bias',
336
+ })
337
+ # add attention layers
338
+ for i in range(attention_num):
339
+ weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.query.weight'
340
+ weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.query.bias'
341
+ weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.key.weight'
342
+ weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.key.bias'
343
+ weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.value.weight'
344
+ weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.value.bias'
345
+ weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'bert.encoder.layer.{i}.attention.output.dense.weight'
346
+ weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'bert.encoder.layer.{i}.attention.output.dense.bias'
347
+ weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.weight'
348
+ weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.bias'
349
+ weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'bert.encoder.layer.{i}.intermediate.dense.weight'
350
+ weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'bert.encoder.layer.{i}.intermediate.dense.bias'
351
+ weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'bert.encoder.layer.{i}.output.dense.weight'
352
+ weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'bert.encoder.layer.{i}.output.dense.bias'
353
+ weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'bert.encoder.layer.{i}.output.LayerNorm.weight'
354
+ weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'bert.encoder.layer.{i}.output.LayerNorm.bias'
355
+ # add pooler
356
+ weight_map.update(
357
+ {
358
+ 'encoder.pooler.dense.weight': 'bert.pooler.dense.weight',
359
+ 'encoder.pooler.dense.bias': 'bert.pooler.dense.bias',
360
+ 'linear_start.weight': 'linear_start.weight',
361
+ 'linear_start.bias': 'linear_start.bias',
362
+ 'linear_end.weight': 'linear_end.weight',
363
+ 'linear_end.bias': 'linear_end.bias',
364
+ }
365
+ )
366
+ return weight_map
367
+
368
+
369
+ def extract_and_convert(input_dir, output_dir):
370
+ if not os.path.exists(output_dir):
371
+ os.makedirs(output_dir)
372
+ logger.info('=' * 20 + 'save config file' + '=' * 20)
373
+ config = json.load(open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8'))
374
+ config = config['init_args'][0]
375
+ config["architectures"] = ["UIE"]
376
+ config['layer_norm_eps'] = 1e-12
377
+ del config['init_class']
378
+ if 'sent_type_vocab_size' in config:
379
+ config['type_vocab_size'] = config['sent_type_vocab_size']
380
+ config['intermediate_size'] = 4 * config['hidden_size']
381
+ json.dump(config, open(os.path.join(output_dir, 'config.json'),
382
+ 'wt', encoding='utf-8'), indent=4)
383
+ logger.info('=' * 20 + 'save vocab file' + '=' * 20)
384
+ with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
385
+ words = f.read().splitlines()
386
+ words_set = set()
387
+ words_duplicate_indices = []
388
+ for i in range(len(words)-1, -1, -1):
389
+ word = words[i]
390
+ if word in words_set:
391
+ words_duplicate_indices.append(i)
392
+ words_set.add(word)
393
+ for i, idx in enumerate(words_duplicate_indices):
394
+ words[idx] = chr(0x1F6A9+i) # Change duplicated word to 🚩 LOL
395
+ with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
396
+ for word in words:
397
+ f.write(word+'\n')
398
+ special_tokens_map = {
399
+ "unk_token": "[UNK]",
400
+ "sep_token": "[SEP]",
401
+ "pad_token": "[PAD]",
402
+ "cls_token": "[CLS]",
403
+ "mask_token": "[MASK]"
404
+ }
405
+ json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'),
406
+ 'wt', encoding='utf-8'))
407
+ tokenizer_config = {
408
+ "do_lower_case": True,
409
+ "unk_token": "[UNK]",
410
+ "sep_token": "[SEP]",
411
+ "pad_token": "[PAD]",
412
+ "cls_token": "[CLS]",
413
+ "mask_token": "[MASK]",
414
+ "tokenizer_class": "BertTokenizer"
415
+ }
416
+ json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'),
417
+ 'wt', encoding='utf-8'))
418
+ logger.info('=' * 20 + 'extract weights' + '=' * 20)
419
+ state_dict = collections.OrderedDict()
420
+ weight_map = build_params_map(attention_num=config['num_hidden_layers'])
421
+ paddle_paddle_params = pickle.load(
422
+ open(os.path.join(input_dir, 'model_state.pdparams'), 'rb'))
423
+ del paddle_paddle_params['StructuredToParameterName@@']
424
+ for weight_name, weight_value in paddle_paddle_params.items():
425
+ if 'weight' in weight_name:
426
+ if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
427
+ weight_value = weight_value.transpose()
428
+ # Fix: embedding error
429
+ if 'word_embeddings.weight' in weight_name:
430
+ weight_value[0, :] = 0
431
+ if weight_name not in weight_map:
432
+ logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}")
433
+ continue
434
+ state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
435
+ logger.info(f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}")
436
+ torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
437
+
438
+
439
+ def check_model(input_model):
440
+ if not os.path.exists(input_model):
441
+ if input_model not in MODEL_MAP:
442
+ raise ValueError('input_model not exists!')
443
+
444
+ resource_file_urls = MODEL_MAP[input_model]['resource_file_urls']
445
+ logger.info("Downloading resource files...")
446
+
447
+ for key, val in resource_file_urls.items():
448
+ file_path = os.path.join(input_model, key)
449
+ if not os.path.exists(file_path):
450
+ get_path_from_url(val, input_model)
451
+
452
+
453
+ def do_main():
454
+ check_model(args.input_model)
455
+ extract_and_convert(args.input_model, args.output_model)
456
+
457
+ if __name__ == '__main__':
458
+ parser = argparse.ArgumentParser()
459
+ parser.add_argument("-i", "--input_model", default="uie-base", type=str,
460
+ help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]")
461
+ parser.add_argument("-o", "--output_model", default="uie_base_pytorch", type=str,
462
+ help="Directory of output pytorch model")
463
+ args = parser.parse_args()
464
+
465
+ do_main()
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcc8d6a47e8ae6377bb00d4762e58a89291ff45fe8790956238713128b748d7d
3
+ size 471850153
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenizer_class": "BertTokenizer"}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff