File size: 2,542 Bytes
da332f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
'''
Author: Qiguang Chen
LastEditors: Qiguang Chen
Date: 2023-02-13 10:44:39
LastEditTime: 2023-02-14 10:28:43
Description: 

'''

import os
import dill
from common import utils
from common.utils import InputData, download
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer


# parser = argparse.ArgumentParser()
# parser.add_argument('--config_path', '-cp', type=str, default="config/reproduction/atis/joint_bert.yaml")
# args = parser.parse_args()
# config = Config.load_from_yaml(args.config_path)
# config.base["train"] = False
# config.base["test"] = False

# model_manager = ModelManager(config)
# model_manager.load()


class PretrainedConfigForSLU(PretrainedConfig):
    def __init__(self, **kargs) -> None:
        super().__init__(**kargs)

# pretrained_config = PretrainedConfigForSLU()
# # pretrained_config.push_to_hub("xxxx")


class PretrainedModelForSLU(PreTrainedModel):
    def __init__(self, config: PretrainedConfig, *inputs, **kwargs) -> None:
        super().__init__(config, *inputs, **kwargs)
        self.config_class = config
        self.model = utils.instantiate(config.model)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        cls.config_class = PretrainedConfigForSLU
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)


class PreTrainedTokenizerForSLU(PreTrainedTokenizer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        dir_names = f"save/{pretrained_model_name_or_path}".split("/")
        dir_name = ""
        for name in dir_names:
            dir_name +=  name+"/"
            if not os.path.exists(dir_name):
                os.mkdir(dir_name)
        cache_path = f"./save/{pretrained_model_name_or_path}/tokenizer.pkl"
        if not os.path.exists(cache_path):
            download(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/tokenizer.pkl", cache_path)
        with open(cache_path, "rb") as f:
            tokenizer = dill.load(f)
        return tokenizer


# pretrained_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# pretrained_tokenizer = PreTrainedTokenizerForSLU.from_pretrained("LightChen2333/joint-bert-slu-atis")
# test_model = PretrainedModelForSLU.from_pretrained("LightChen2333/joint-bert-slu-atis")
# print(test_model(InputData([pretrained_tokenizer("I want to go to Beijing !")])))