|
''' |
|
Author: Qiguang Chen |
|
LastEditors: Qiguang Chen |
|
Date: 2023-02-13 10:44:39 |
|
LastEditTime: 2023-02-19 15:45:08 |
|
Description: |
|
|
|
''' |
|
|
|
import argparse |
|
import sys |
|
import os |
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
import dill |
|
|
|
from common.config import Config |
|
from common.model_manager import ModelManager |
|
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoTokenizer, PreTrainedTokenizer |
|
|
|
class PretrainedConfigForSLUToSave(PretrainedConfig): |
|
def __init__(self, **kargs) -> None: |
|
cfg = model_manager.config |
|
kargs["name_or_path"] = cfg.base["name"] |
|
kargs["return_dict"] = False |
|
kargs["is_decoder"] = True |
|
kargs["_id2label"] = {"intent": model_manager.intent_list, "slot": model_manager.slot_list} |
|
kargs["_label2id"] = {"intent": model_manager.intent_dict, "slot": model_manager.slot_dict} |
|
kargs["_num_labels"] = {"intent": len(model_manager.intent_list), "slot": len(model_manager.slot_list)} |
|
kargs["tokenizer_class"] = cfg.base["name"] |
|
kargs["vocab_size"] = model_manager.tokenizer.vocab_size |
|
kargs["model"] = cfg.model |
|
kargs["model"]["decoder"]["intent_classifier"]["intent_label_num"] = len(model_manager.intent_list) |
|
kargs["model"]["decoder"]["slot_classifier"]["slot_label_num"] = len(model_manager.slot_list) |
|
kargs["tokenizer"] = cfg.tokenizer |
|
len(model_manager.slot_list) |
|
super().__init__(**kargs) |
|
|
|
class PretrainedModelForSLUToSave(PreTrainedModel): |
|
def __init__(self, config: PretrainedConfig, *inputs, **kwargs) -> None: |
|
super().__init__(config, *inputs, **kwargs) |
|
self.model = model_manager.model |
|
self.config_class = config |
|
|
|
|
|
class PreTrainedTokenizerForSLUToSave(PreTrainedTokenizer): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.tokenizer = model_manager.tokenizer |
|
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix = None): |
|
if filename_prefix is not None: |
|
path = os.path.join(save_directory, filename_prefix+"-tokenizer.pkl") |
|
else: |
|
path = os.path.join(save_directory, "tokenizer.pkl") |
|
|
|
|
|
|
|
|
|
|
|
with open(path,'wb') as f: |
|
dill.dump(self.tokenizer,f) |
|
return (path,) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config_path', '-cp', type=str, required=True) |
|
parser.add_argument('--output_path', '-op', type=str, default="save/temp") |
|
args = parser.parse_args() |
|
config = Config.load_from_yaml(args.config_path) |
|
config.base["train"] = False |
|
config.base["test"] = False |
|
if config.model_manager["load_dir"] is None: |
|
config.model_manager["load_dir"] = config.model_manager["save_dir"] |
|
model_manager = ModelManager(config) |
|
model_manager.load() |
|
model_manager.config.autoload_template() |
|
|
|
pretrained_config = PretrainedConfigForSLUToSave() |
|
pretrained_model= PretrainedModelForSLUToSave(pretrained_config) |
|
pretrained_model.save_pretrained(args.output_path) |
|
|
|
pretrained_tokenizer = PreTrainedTokenizerForSLUToSave() |
|
pretrained_tokenizer.save_pretrained(args.output_path) |
|
|