import os import json from typing import OrderedDict, Mapping, Union from transformers import PretrainedConfig, AutoConfig from transformers.onnx import OnnxConfig class GectorConfig(PretrainedConfig): model_type = "gector" # To add config values from base model config def __subclassconfig__(self, base_config: AutoConfig): if base_config: self.__dict__.update(base_config.__dict__) def __init__( self, model_id: str = None, id2label: dict = None, label2id: dict = None, detect_id2label: dict = None, detect_label2id: dict = None, classifier_dropout: float = 0, label_pad_token: str = "", label_unknown_token: str = "", detect_pad_token_id: int = 3, correct_pad_token_id: int = 5001, num_detect_tags: int = 4, num_correct_tags: int = 5002, max_length: int = 128, label_smoothing: float = 0.0, special_tokens_fix: bool = False, delete_confidence: float = 0.0, additional_confidence: float = 0.2, base_config: AutoConfig = None, verb_form_vocab: dict = None, **kwargs, ): super().__init__(**kwargs) self.__subclassconfig__(base_config) self.model_id = model_id self.label2id = label2id self.id2label = id2label self.detect_label2id = detect_label2id self.detect_id2label = detect_id2label self.detect_pad_token_id = detect_pad_token_id self.correct_pad_token_id = correct_pad_token_id self.num_detect_tags = num_detect_tags self.num_correct_tags = num_correct_tags self.classifier_dropout = classifier_dropout self.max_length = max_length self.label_smoothing = label_smoothing self.special_tokens_fix = special_tokens_fix self.delete_confidence = delete_confidence self.additional_confidence = additional_confidence self.verb_form_vocab = verb_form_vocab # def save_pretrained( # self, # save_directory: Union[str, os.PathLike], # push_to_hub: bool = False, # **kwargs, # ): # if os.path.isfile(save_directory): # raise AssertionError( # f"Provided path ({save_directory}) should be a directory, not a file" # ) # os.makedirs(save_directory, exist_ok=True) # if self.verb_form_vocab: # verb_form_vocab_file = os.path.join(save_directory, "verb_form_vocab.json") # with open(verb_form_vocab_file, "w", encoding="utf-8") as writer: # writer.write(json.dumps(self.verb_form_vocab, indent=2, sort_keys=True) + "\n") # super().save_pretrained(save_directory, push_to_hub, **kwargs) class GectorOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ ("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ] )