|
from typing import Optional |
|
|
|
from transformers import AutoConfig |
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
class RelikReaderConfig(PretrainedConfig): |
|
model_type = "relik-reader" |
|
|
|
def __init__( |
|
self, |
|
transformer_model: str = "microsoft/deberta-v3-base", |
|
additional_special_symbols: int = 101, |
|
additional_special_symbols_types: Optional[int] = 0, |
|
num_layers: Optional[int] = None, |
|
activation: str = "gelu", |
|
linears_hidden_size: Optional[int] = 512, |
|
use_last_k_layers: int = 1, |
|
threshold: Optional[float] = 0.5, |
|
entity_type_loss: bool = False, |
|
add_entity_embedding: bool = None, |
|
training: bool = False, |
|
default_reader_class: Optional[str] = None, |
|
**kwargs |
|
) -> None: |
|
self.transformer_model = transformer_model |
|
self.additional_special_symbols = additional_special_symbols |
|
self.additional_special_symbols_types = additional_special_symbols_types |
|
self.num_layers = num_layers |
|
self.activation = activation |
|
self.linears_hidden_size = linears_hidden_size |
|
self.use_last_k_layers = use_last_k_layers |
|
self.entity_type_loss = entity_type_loss |
|
self.add_entity_embedding = ( |
|
True |
|
if add_entity_embedding is None and entity_type_loss |
|
else add_entity_embedding |
|
) |
|
self.training = training |
|
self.threshold = threshold |
|
self.default_reader_class = default_reader_class |
|
super().__init__(**kwargs) |
|
|
|
|
|
AutoConfig.register("relik-reader", RelikReaderConfig) |
|
|