|
from transformers import PretrainedConfig |
|
|
|
"""Spice CNN model configuration""" |
|
|
|
SPICE_CNN_PRETRAINED_CONFIG_ARCHIVE_MAP = { |
|
"spicecloud/spice-cnn-base": "https://huggingface.co/spice-cnn-base/resolve/main/config.json" |
|
} |
|
|
|
|
|
|
|
class SpiceCNNConfig(PretrainedConfig): |
|
""" |
|
This is the configuration class to store the configuration of a [`SpiceCNNModel`]. |
|
It is used to instantiate an SpiceCNN model according to the specified arguments, |
|
defining the model architecture. Instantiating a configuration with the defaults |
|
will yield a similar configuration to that of the SpiceCNN |
|
[spicecloud/spice-cnn-base](https://huggingface.co/spicecloud/spice-cnn-base) |
|
architecture. |
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control |
|
the model outputs. Read the documentation from [`PretrainedConfig`] for more |
|
information. |
|
""" |
|
|
|
model_type = "spicecnn" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int = 3, |
|
num_classes: int = 10, |
|
dropout_rate: float = 0.4, |
|
hidden_size: int = 128, |
|
num_filters: int = 16, |
|
kernel_size: int = 3, |
|
stride: int = 1, |
|
padding: int = 1, |
|
pooling_size: int = 2, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.in_channels = in_channels |
|
self.num_classes = num_classes |
|
self.dropout_rate = dropout_rate |
|
self.hidden_size = hidden_size |
|
self.num_filters = num_filters |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.pooling_size = pooling_size |
|
|