rhendz commited on
Commit
aa9a09c
1 Parent(s): e6757c8

Upload model

Browse files
config.json CHANGED
@@ -8,6 +8,7 @@
8
  },
9
  "dropout_rate": 0.2,
10
  "hidden_size": 128,
 
11
  "kernel_size": 3,
12
  "model_type": "spicecnn",
13
  "num_classes": 10,
 
8
  },
9
  "dropout_rate": 0.2,
10
  "hidden_size": 128,
11
+ "in_channels": 1,
12
  "kernel_size": 3,
13
  "model_type": "spicecnn",
14
  "num_classes": 10,
configuration_spice_cnn.py CHANGED
@@ -26,6 +26,7 @@ class SpiceCNNConfig(PretrainedConfig):
26
 
27
  def __init__(
28
  self,
 
29
  num_classes: int = 10,
30
  dropout_rate: float = 0.2,
31
  hidden_size: int = 128,
@@ -37,6 +38,7 @@ class SpiceCNNConfig(PretrainedConfig):
37
  **kwargs
38
  ):
39
  super().__init__(**kwargs)
 
40
  self.num_classes = num_classes
41
  self.dropout_rate = dropout_rate
42
  self.hidden_size = hidden_size
 
26
 
27
  def __init__(
28
  self,
29
+ in_channels: int = 3,
30
  num_classes: int = 10,
31
  dropout_rate: float = 0.2,
32
  hidden_size: int = 128,
 
38
  **kwargs
39
  ):
40
  super().__init__(**kwargs)
41
+ self.in_channels = in_channels
42
  self.num_classes = num_classes
43
  self.dropout_rate = dropout_rate
44
  self.hidden_size = hidden_size
modeling_spice_cnn.py CHANGED
@@ -12,7 +12,7 @@ class SpiceCNNModelForImageClassification(PreTrainedModel):
12
  super().__init__(config)
13
  layers = [
14
  nn.Conv2d(
15
- 1,
16
  16,
17
  kernel_size=config.kernel_size,
18
  stride=config.stride,
 
12
  super().__init__(config)
13
  layers = [
14
  nn.Conv2d(
15
+ config.in_channels,
16
  16,
17
  kernel_size=config.kernel_size,
18
  stride=config.stride,
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e07e343bdf0c9f3c218ba10e865eca0d59c3adf4da998294d288160c382657e6
3
  size 830347
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2bd33ae4006b549f8ef4839e107525981d190d4922a7236e5de3a59190450a1
3
  size 830347