spice-mnist / configuration_spice_cnn.py
rhendz's picture
Upload folder using huggingface_hub (#1)
dc2a7ae
raw
history blame
1.66 kB
from transformers import PretrainedConfig
"""Spice CNN model configuration"""
SPICE_CNN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"spicecloud/spice-cnn-base": "https://huggingface.co/spice-cnn-base/resolve/main/config.json"
}
# Define custom convnet configuration
class SpiceCNNConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`SpiceCNNModel`].
It is used to instantiate an SpiceCNN model according to the specified arguments,
defining the model architecture. Instantiating a configuration with the defaults
will yield a similar configuration to that of the SpiceCNN
[spicecloud/spice-cnn-base](https://huggingface.co/spicecloud/spice-cnn-base)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control
the model outputs. Read the documentation from [`PretrainedConfig`] for more
information.
"""
model_type = "spicecnn"
def __init__(
self,
in_channels: int = 3,
num_classes: int = 10,
dropout_rate: float = 0.4,
hidden_size: int = 128,
num_filters: int = 16,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
pooling_size: int = 2,
**kwargs
):
super().__init__(**kwargs)
self.in_channels = in_channels
self.num_classes = num_classes
self.dropout_rate = dropout_rate
self.hidden_size = hidden_size
self.num_filters = num_filters
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.pooling_size = pooling_size