File size: 1,656 Bytes
dc2a7ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from transformers import PretrainedConfig
"""Spice CNN model configuration"""
SPICE_CNN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"spicecloud/spice-cnn-base": "https://huggingface.co/spice-cnn-base/resolve/main/config.json"
}
# Define custom convnet configuration
class SpiceCNNConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`SpiceCNNModel`].
It is used to instantiate an SpiceCNN model according to the specified arguments,
defining the model architecture. Instantiating a configuration with the defaults
will yield a similar configuration to that of the SpiceCNN
[spicecloud/spice-cnn-base](https://huggingface.co/spicecloud/spice-cnn-base)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control
the model outputs. Read the documentation from [`PretrainedConfig`] for more
information.
"""
model_type = "spicecnn"
def __init__(
self,
in_channels: int = 3,
num_classes: int = 10,
dropout_rate: float = 0.4,
hidden_size: int = 128,
num_filters: int = 16,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
pooling_size: int = 2,
**kwargs
):
super().__init__(**kwargs)
self.in_channels = in_channels
self.num_classes = num_classes
self.dropout_rate = dropout_rate
self.hidden_size = hidden_size
self.num_filters = num_filters
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.pooling_size = pooling_size
|