|
'''下载预训练模型并且转了pytorch格式 |
|
''' |
|
import argparse |
|
import collections |
|
import json |
|
import os |
|
import pickle |
|
import torch |
|
import logging |
|
import shutil |
|
from tqdm import tqdm |
|
import time |
|
|
|
logger = logging.Logger('log') |
|
|
|
|
|
def get_path_from_url(url, root_dir, check_exist=True, decompress=True): |
|
""" Download from given url to root_dir. |
|
if file or directory specified by url is exists under |
|
root_dir, return the path directly, otherwise download |
|
from url and decompress it, return the path. |
|
|
|
Args: |
|
url (str): download url |
|
root_dir (str): root dir for downloading, it should be |
|
WEIGHTS_HOME or DATASET_HOME |
|
decompress (bool): decompress zip or tar file. Default is `True` |
|
|
|
Returns: |
|
str: a local path to save downloaded models & weights & datasets. |
|
""" |
|
|
|
import os.path |
|
import os |
|
import tarfile |
|
import zipfile |
|
|
|
def is_url(path): |
|
""" |
|
Whether path is URL. |
|
Args: |
|
path (string): URL string or not. |
|
""" |
|
return path.startswith('http://') or path.startswith('https://') |
|
|
|
def _map_path(url, root_dir): |
|
|
|
fname = os.path.split(url)[-1] |
|
fpath = fname |
|
return os.path.join(root_dir, fpath) |
|
|
|
def _get_download(url, fullname): |
|
import requests |
|
|
|
fname = os.path.basename(fullname) |
|
try: |
|
req = requests.get(url, stream=True) |
|
except Exception as e: |
|
logger.info("Downloading {} from {} failed with exception {}".format( |
|
fname, url, str(e))) |
|
return False |
|
|
|
if req.status_code != 200: |
|
raise RuntimeError("Downloading from {} failed with code " |
|
"{}!".format(url, req.status_code)) |
|
|
|
|
|
|
|
|
|
tmp_fullname = fullname + "_tmp" |
|
total_size = req.headers.get('content-length') |
|
with open(tmp_fullname, 'wb') as f: |
|
if total_size: |
|
with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar: |
|
for chunk in req.iter_content(chunk_size=1024): |
|
f.write(chunk) |
|
pbar.update(1) |
|
else: |
|
for chunk in req.iter_content(chunk_size=1024): |
|
if chunk: |
|
f.write(chunk) |
|
shutil.move(tmp_fullname, fullname) |
|
|
|
return fullname |
|
|
|
def _download(url, path): |
|
""" |
|
Download from url, save to path. |
|
|
|
url (str): download url |
|
path (str): download to given path |
|
""" |
|
|
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
fname = os.path.split(url)[-1] |
|
fullname = os.path.join(path, fname) |
|
retry_cnt = 0 |
|
|
|
logger.info("Downloading {} from {}".format(fname, url)) |
|
DOWNLOAD_RETRY_LIMIT = 3 |
|
while not os.path.exists(fullname): |
|
if retry_cnt < DOWNLOAD_RETRY_LIMIT: |
|
retry_cnt += 1 |
|
else: |
|
raise RuntimeError("Download from {} failed. " |
|
"Retry limit reached".format(url)) |
|
|
|
if not _get_download(url, fullname): |
|
time.sleep(1) |
|
continue |
|
|
|
return fullname |
|
|
|
def _uncompress_file_zip(filepath): |
|
with zipfile.ZipFile(filepath, 'r') as files: |
|
file_list = files.namelist() |
|
|
|
file_dir = os.path.dirname(filepath) |
|
|
|
if _is_a_single_file(file_list): |
|
rootpath = file_list[0] |
|
uncompressed_path = os.path.join(file_dir, rootpath) |
|
files.extractall(file_dir) |
|
|
|
elif _is_a_single_dir(file_list): |
|
|
|
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split( |
|
os.sep)[-1] |
|
uncompressed_path = os.path.join(file_dir, rootpath) |
|
|
|
files.extractall(file_dir) |
|
else: |
|
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] |
|
uncompressed_path = os.path.join(file_dir, rootpath) |
|
if not os.path.exists(uncompressed_path): |
|
os.makedirs(uncompressed_path) |
|
files.extractall(os.path.join(file_dir, rootpath)) |
|
|
|
return uncompressed_path |
|
|
|
def _is_a_single_file(file_list): |
|
if len(file_list) == 1 and file_list[0].find(os.sep) < 0: |
|
return True |
|
return False |
|
|
|
def _is_a_single_dir(file_list): |
|
new_file_list = [] |
|
for file_path in file_list: |
|
if '/' in file_path: |
|
file_path = file_path.replace('/', os.sep) |
|
elif '\\' in file_path: |
|
file_path = file_path.replace('\\', os.sep) |
|
new_file_list.append(file_path) |
|
|
|
file_name = new_file_list[0].split(os.sep)[0] |
|
for i in range(1, len(new_file_list)): |
|
if file_name != new_file_list[i].split(os.sep)[0]: |
|
return False |
|
return True |
|
|
|
def _uncompress_file_tar(filepath, mode="r:*"): |
|
with tarfile.open(filepath, mode) as files: |
|
file_list = files.getnames() |
|
|
|
file_dir = os.path.dirname(filepath) |
|
|
|
if _is_a_single_file(file_list): |
|
rootpath = file_list[0] |
|
uncompressed_path = os.path.join(file_dir, rootpath) |
|
files.extractall(file_dir) |
|
elif _is_a_single_dir(file_list): |
|
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split( |
|
os.sep)[-1] |
|
uncompressed_path = os.path.join(file_dir, rootpath) |
|
files.extractall(file_dir) |
|
else: |
|
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] |
|
uncompressed_path = os.path.join(file_dir, rootpath) |
|
if not os.path.exists(uncompressed_path): |
|
os.makedirs(uncompressed_path) |
|
|
|
files.extractall(os.path.join(file_dir, rootpath)) |
|
|
|
return uncompressed_path |
|
|
|
def _decompress(fname): |
|
""" |
|
Decompress for zip and tar file |
|
""" |
|
logger.info("Decompressing {}...".format(fname)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if tarfile.is_tarfile(fname): |
|
uncompressed_path = _uncompress_file_tar(fname) |
|
elif zipfile.is_zipfile(fname): |
|
uncompressed_path = _uncompress_file_zip(fname) |
|
else: |
|
raise TypeError("Unsupport compress file type {}".format(fname)) |
|
|
|
return uncompressed_path |
|
|
|
assert is_url(url), "downloading from {} not a url".format(url) |
|
fullpath = _map_path(url, root_dir) |
|
if os.path.exists(fullpath) and check_exist: |
|
logger.info("Found {}".format(fullpath)) |
|
else: |
|
fullpath = _download(url, root_dir) |
|
|
|
if decompress and (tarfile.is_tarfile(fullpath) or |
|
zipfile.is_zipfile(fullpath)): |
|
fullpath = _decompress(fullpath) |
|
|
|
return fullpath |
|
|
|
|
|
MODEL_MAP = { |
|
"uie-base": { |
|
"resource_file_urls": { |
|
"model_state.pdparams": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v0.1/model_state.pdparams", |
|
"model_config.json": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json", |
|
"vocab_file": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", |
|
"special_tokens_map": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", |
|
"tokenizer_config": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json" |
|
} |
|
}, |
|
"uie-medium": { |
|
"resource_file_urls": { |
|
"model_state.pdparams": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams", |
|
"model_config.json": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json", |
|
"vocab_file": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", |
|
"special_tokens_map": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", |
|
"tokenizer_config": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", |
|
} |
|
}, |
|
"uie-mini": { |
|
"resource_file_urls": { |
|
"model_state.pdparams": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams", |
|
"model_config.json": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json", |
|
"vocab_file": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", |
|
"special_tokens_map": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", |
|
"tokenizer_config": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", |
|
} |
|
}, |
|
"uie-micro": { |
|
"resource_file_urls": { |
|
"model_state.pdparams": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams", |
|
"model_config.json": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json", |
|
"vocab_file": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", |
|
"special_tokens_map": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", |
|
"tokenizer_config": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", |
|
} |
|
}, |
|
"uie-nano": { |
|
"resource_file_urls": { |
|
"model_state.pdparams": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams", |
|
"model_config.json": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json", |
|
"vocab_file": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", |
|
"special_tokens_map": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", |
|
"tokenizer_config": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", |
|
} |
|
}, |
|
"uie-medical-base": { |
|
"resource_file_urls": { |
|
"model_state.pdparams": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams", |
|
"model_config.json": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json", |
|
"vocab_file": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", |
|
"special_tokens_map": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", |
|
"tokenizer_config": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", |
|
} |
|
}, |
|
"uie-tiny": { |
|
"resource_file_urls": { |
|
"model_state.pdparams": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams", |
|
"model_config.json": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json", |
|
"vocab_file": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt", |
|
"special_tokens_map": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json", |
|
"tokenizer_config": |
|
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json" |
|
} |
|
} |
|
} |
|
|
|
|
|
def build_params_map(attention_num=12): |
|
""" |
|
build params map from paddle-paddle's ERNIE to transformer's BERT |
|
:return: |
|
""" |
|
weight_map = collections.OrderedDict({ |
|
'encoder.embeddings.word_embeddings.weight': "bert.embeddings.word_embeddings.weight", |
|
'encoder.embeddings.position_embeddings.weight': "bert.embeddings.position_embeddings.weight", |
|
'encoder.embeddings.token_type_embeddings.weight': "bert.embeddings.token_type_embeddings.weight", |
|
'encoder.embeddings.task_type_embeddings.weight': "embeddings.task_type_embeddings.weight", |
|
'encoder.embeddings.layer_norm.weight': 'bert.embeddings.LayerNorm.weight', |
|
'encoder.embeddings.layer_norm.bias': 'bert.embeddings.LayerNorm.bias', |
|
}) |
|
|
|
for i in range(attention_num): |
|
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.query.weight' |
|
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.query.bias' |
|
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.key.weight' |
|
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.key.bias' |
|
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.value.weight' |
|
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.value.bias' |
|
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'bert.encoder.layer.{i}.attention.output.dense.weight' |
|
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'bert.encoder.layer.{i}.attention.output.dense.bias' |
|
weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.weight' |
|
weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.bias' |
|
weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'bert.encoder.layer.{i}.intermediate.dense.weight' |
|
weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'bert.encoder.layer.{i}.intermediate.dense.bias' |
|
weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'bert.encoder.layer.{i}.output.dense.weight' |
|
weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'bert.encoder.layer.{i}.output.dense.bias' |
|
weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'bert.encoder.layer.{i}.output.LayerNorm.weight' |
|
weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'bert.encoder.layer.{i}.output.LayerNorm.bias' |
|
|
|
weight_map.update( |
|
{ |
|
'encoder.pooler.dense.weight': 'bert.pooler.dense.weight', |
|
'encoder.pooler.dense.bias': 'bert.pooler.dense.bias', |
|
'linear_start.weight': 'linear_start.weight', |
|
'linear_start.bias': 'linear_start.bias', |
|
'linear_end.weight': 'linear_end.weight', |
|
'linear_end.bias': 'linear_end.bias', |
|
} |
|
) |
|
return weight_map |
|
|
|
|
|
def extract_and_convert(input_dir, output_dir): |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
logger.info('=' * 20 + 'save config file' + '=' * 20) |
|
config = json.load(open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8')) |
|
config = config['init_args'][0] |
|
config["architectures"] = ["UIE"] |
|
config['layer_norm_eps'] = 1e-12 |
|
del config['init_class'] |
|
if 'sent_type_vocab_size' in config: |
|
config['type_vocab_size'] = config['sent_type_vocab_size'] |
|
config['intermediate_size'] = 4 * config['hidden_size'] |
|
json.dump(config, open(os.path.join(output_dir, 'config.json'), |
|
'wt', encoding='utf-8'), indent=4) |
|
logger.info('=' * 20 + 'save vocab file' + '=' * 20) |
|
with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f: |
|
words = f.read().splitlines() |
|
words_set = set() |
|
words_duplicate_indices = [] |
|
for i in range(len(words)-1, -1, -1): |
|
word = words[i] |
|
if word in words_set: |
|
words_duplicate_indices.append(i) |
|
words_set.add(word) |
|
for i, idx in enumerate(words_duplicate_indices): |
|
words[idx] = chr(0x1F6A9+i) |
|
with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f: |
|
for word in words: |
|
f.write(word+'\n') |
|
special_tokens_map = { |
|
"unk_token": "[UNK]", |
|
"sep_token": "[SEP]", |
|
"pad_token": "[PAD]", |
|
"cls_token": "[CLS]", |
|
"mask_token": "[MASK]" |
|
} |
|
json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'), |
|
'wt', encoding='utf-8')) |
|
tokenizer_config = { |
|
"do_lower_case": True, |
|
"unk_token": "[UNK]", |
|
"sep_token": "[SEP]", |
|
"pad_token": "[PAD]", |
|
"cls_token": "[CLS]", |
|
"mask_token": "[MASK]", |
|
"tokenizer_class": "BertTokenizer" |
|
} |
|
json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'), |
|
'wt', encoding='utf-8')) |
|
logger.info('=' * 20 + 'extract weights' + '=' * 20) |
|
state_dict = collections.OrderedDict() |
|
weight_map = build_params_map(attention_num=config['num_hidden_layers']) |
|
paddle_paddle_params = pickle.load( |
|
open(os.path.join(input_dir, 'model_state.pdparams'), 'rb')) |
|
del paddle_paddle_params['StructuredToParameterName@@'] |
|
for weight_name, weight_value in paddle_paddle_params.items(): |
|
if 'weight' in weight_name: |
|
if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name: |
|
weight_value = weight_value.transpose() |
|
|
|
if 'word_embeddings.weight' in weight_name: |
|
weight_value[0, :] = 0 |
|
if weight_name not in weight_map: |
|
logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}") |
|
continue |
|
state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value) |
|
logger.info(f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}") |
|
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin")) |
|
|
|
|
|
def check_model(input_model): |
|
if not os.path.exists(input_model): |
|
if input_model not in MODEL_MAP: |
|
raise ValueError('input_model not exists!') |
|
|
|
resource_file_urls = MODEL_MAP[input_model]['resource_file_urls'] |
|
logger.info("Downloading resource files...") |
|
|
|
for key, val in resource_file_urls.items(): |
|
file_path = os.path.join(input_model, key) |
|
if not os.path.exists(file_path): |
|
get_path_from_url(val, input_model) |
|
|
|
|
|
def do_main(): |
|
check_model(args.input_model) |
|
extract_and_convert(args.input_model, args.output_model) |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-i", "--input_model", default="uie-base", type=str, |
|
help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]") |
|
parser.add_argument("-o", "--output_model", default="uie_base_pytorch", type=str, |
|
help="Directory of output pytorch model") |
|
args = parser.parse_args() |
|
|
|
do_main() |
|
|