total-classifier / configuration.py
ianpan's picture
Upload model
1a29f83 verified
from transformers import PretrainedConfig
class TotalClassifierConfig(PretrainedConfig):
model_type = "total_classifier"
def __init__(
self,
backbone: str = "tf_efficientnetv2_b0",
feature_dim: int = 192,
cnn_dropout: float = 0.1,
in_chans: int = 1,
rnn_type: str = "GRU",
rnn_num_layers: int = 1,
rnn_dropout: float = 0.0,
num_classes: int = 117,
seq_len: int = 512,
linear_dropout: float = 0.1,
image_size: tuple[int, int] = (256, 256),
**kwargs,
):
self.backbone = backbone
self.feature_dim = feature_dim
self.cnn_dropout = cnn_dropout
self.in_chans = in_chans
self.rnn_type = rnn_type
self.rnn_num_layers = rnn_num_layers
self.rnn_dropout = rnn_dropout
self.num_classes = num_classes
self.seq_len = seq_len
self.linear_dropout = linear_dropout
self.image_size = image_size
super().__init__(**kwargs)