|
''' |
|
Author: Qiguang Chen |
|
LastEditors: Qiguang Chen |
|
Date: 2023-02-13 10:44:39 |
|
LastEditTime: 2023-02-14 10:28:43 |
|
Description: |
|
|
|
''' |
|
|
|
import os |
|
import dill |
|
from common import utils |
|
from common.utils import InputData, download |
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PretrainedConfigForSLU(PretrainedConfig): |
|
def __init__(self, **kargs) -> None: |
|
super().__init__(**kargs) |
|
|
|
|
|
|
|
|
|
|
|
class PretrainedModelForSLU(PreTrainedModel): |
|
def __init__(self, config: PretrainedConfig, *inputs, **kwargs) -> None: |
|
super().__init__(config, *inputs, **kwargs) |
|
self.config_class = config |
|
self.model = utils.instantiate(config.model) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
cls.config_class = PretrainedConfigForSLU |
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|
|
|
|
class PreTrainedTokenizerForSLU(PreTrainedTokenizer): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
dir_names = f"save/{pretrained_model_name_or_path}".split("/") |
|
dir_name = "" |
|
for name in dir_names: |
|
dir_name += name+"/" |
|
if not os.path.exists(dir_name): |
|
os.mkdir(dir_name) |
|
cache_path = f"./save/{pretrained_model_name_or_path}/tokenizer.pkl" |
|
if not os.path.exists(cache_path): |
|
download(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/tokenizer.pkl", cache_path) |
|
with open(cache_path, "rb") as f: |
|
tokenizer = dill.load(f) |
|
return tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|