gector-xlnet-base-cased-5k / configuration_gector.py
ktzsh's picture
Upload folder using huggingface_hub
010f214
import os
import json
from typing import OrderedDict, Mapping, Union
from transformers import PretrainedConfig, AutoConfig
from transformers.onnx import OnnxConfig
class GectorConfig(PretrainedConfig):
model_type = "gector"
# To add config values from base model config
def __subclassconfig__(self, base_config: AutoConfig):
if base_config:
self.__dict__.update(base_config.__dict__)
def __init__(
self,
model_id: str = None,
id2label: dict = None,
label2id: dict = None,
detect_id2label: dict = None,
detect_label2id: dict = None,
classifier_dropout: float = 0,
label_pad_token: str = "<PAD>",
label_unknown_token: str = "<UNK>",
detect_pad_token_id: int = 3,
correct_pad_token_id: int = 5001,
num_detect_tags: int = 4,
num_correct_tags: int = 5002,
max_length: int = 128,
label_smoothing: float = 0.0,
special_tokens_fix: bool = False,
delete_confidence: float = 0.0,
additional_confidence: float = 0.2,
base_config: AutoConfig = None,
verb_form_vocab: dict = None,
**kwargs,
):
super().__init__(**kwargs)
self.__subclassconfig__(base_config)
self.model_id = model_id
self.label2id = label2id
self.id2label = id2label
self.detect_label2id = detect_label2id
self.detect_id2label = detect_id2label
self.detect_pad_token_id = detect_pad_token_id
self.correct_pad_token_id = correct_pad_token_id
self.num_detect_tags = num_detect_tags
self.num_correct_tags = num_correct_tags
self.classifier_dropout = classifier_dropout
self.max_length = max_length
self.label_smoothing = label_smoothing
self.special_tokens_fix = special_tokens_fix
self.delete_confidence = delete_confidence
self.additional_confidence = additional_confidence
self.verb_form_vocab = verb_form_vocab
# def save_pretrained(
# self,
# save_directory: Union[str, os.PathLike],
# push_to_hub: bool = False,
# **kwargs,
# ):
# if os.path.isfile(save_directory):
# raise AssertionError(
# f"Provided path ({save_directory}) should be a directory, not a file"
# )
# os.makedirs(save_directory, exist_ok=True)
# if self.verb_form_vocab:
# verb_form_vocab_file = os.path.join(save_directory, "verb_form_vocab.json")
# with open(verb_form_vocab_file, "w", encoding="utf-8") as writer:
# writer.write(json.dumps(self.verb_form_vocab, indent=2, sort_keys=True) + "\n")
# super().save_pretrained(save_directory, push_to_hub, **kwargs)
class GectorOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)