|
from transformers import PretrainedConfig |
|
|
|
|
|
class TotalClassifierConfig(PretrainedConfig): |
|
model_type = "total_classifier" |
|
|
|
def __init__( |
|
self, |
|
backbone: str = "tf_efficientnetv2_b0", |
|
feature_dim: int = 192, |
|
cnn_dropout: float = 0.1, |
|
in_chans: int = 1, |
|
rnn_type: str = "GRU", |
|
rnn_num_layers: int = 1, |
|
rnn_dropout: float = 0.0, |
|
num_classes: int = 117, |
|
seq_len: int = 512, |
|
linear_dropout: float = 0.1, |
|
image_size: tuple[int, int] = (256, 256), |
|
**kwargs, |
|
): |
|
self.backbone = backbone |
|
self.feature_dim = feature_dim |
|
self.cnn_dropout = cnn_dropout |
|
self.in_chans = in_chans |
|
self.rnn_type = rnn_type |
|
self.rnn_num_layers = rnn_num_layers |
|
self.rnn_dropout = rnn_dropout |
|
self.num_classes = num_classes |
|
self.seq_len = seq_len |
|
self.linear_dropout = linear_dropout |
|
self.image_size = image_size |
|
super().__init__(**kwargs) |
|
|