File size: 628 Bytes
56dfd9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from transformers import PretrainedConfig
class AutextificationMTLConfig(PretrainedConfig):
model_type = "custom-text-classifier"
def __init__(
self,
transformer_name: str = "xlm-roberta-base",
hidden_nodes: int = 64,
threshold: float = 0.9919,
**kwargs,
):
if hidden_nodes <= 0:
raise ValueError(
f"`hidden_size` must be a positive number, got {hidden_nodes}."
)
self.transformer_name = transformer_name
self.hidden_nodes = hidden_nodes
self.threshold = threshold
super().__init__(**kwargs)
|