|
import os |
|
import json |
|
|
|
from typing import OrderedDict, Mapping, Union |
|
|
|
from transformers import PretrainedConfig, AutoConfig |
|
from transformers.onnx import OnnxConfig |
|
|
|
|
|
class GectorConfig(PretrainedConfig): |
|
model_type = "gector" |
|
|
|
|
|
def __subclassconfig__(self, base_config: AutoConfig): |
|
if base_config: |
|
self.__dict__.update(base_config.__dict__) |
|
|
|
def __init__( |
|
self, |
|
model_id: str = None, |
|
id2label: dict = None, |
|
label2id: dict = None, |
|
detect_id2label: dict = None, |
|
detect_label2id: dict = None, |
|
classifier_dropout: float = 0, |
|
label_pad_token: str = "<PAD>", |
|
label_unknown_token: str = "<UNK>", |
|
detect_pad_token_id: int = 3, |
|
correct_pad_token_id: int = 5001, |
|
num_detect_tags: int = 4, |
|
num_correct_tags: int = 5002, |
|
max_length: int = 128, |
|
label_smoothing: float = 0.0, |
|
special_tokens_fix: bool = False, |
|
delete_confidence: float = 0.0, |
|
additional_confidence: float = 0.2, |
|
base_config: AutoConfig = None, |
|
verb_form_vocab: dict = None, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.__subclassconfig__(base_config) |
|
|
|
self.model_id = model_id |
|
self.label2id = label2id |
|
self.id2label = id2label |
|
self.detect_label2id = detect_label2id |
|
self.detect_id2label = detect_id2label |
|
self.detect_pad_token_id = detect_pad_token_id |
|
self.correct_pad_token_id = correct_pad_token_id |
|
self.num_detect_tags = num_detect_tags |
|
self.num_correct_tags = num_correct_tags |
|
self.classifier_dropout = classifier_dropout |
|
self.max_length = max_length |
|
self.label_smoothing = label_smoothing |
|
self.special_tokens_fix = special_tokens_fix |
|
self.delete_confidence = delete_confidence |
|
self.additional_confidence = additional_confidence |
|
self.verb_form_vocab = verb_form_vocab |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GectorOnnxConfig(OnnxConfig): |
|
@property |
|
def inputs(self) -> Mapping[str, Mapping[int, str]]: |
|
dynamic_axis = {0: "batch", 1: "sequence"} |
|
return OrderedDict( |
|
[ |
|
("input_ids", dynamic_axis), |
|
("attention_mask", dynamic_axis), |
|
] |
|
) |
|
|