rhendz commited on
Commit
bfa394a
1 Parent(s): 8488c63

Upload model

Browse files
Files changed (2) hide show
  1. modeling_spice_cnn.py +10 -8
  2. pytorch_model.bin +1 -1
modeling_spice_cnn.py CHANGED
@@ -1,9 +1,11 @@
1
  import torch.nn as nn
 
2
  # from torchsummary import summary
3
 
4
  from transformers import PreTrainedModel
5
 
6
- from hf_models.models.spice_cnn.configuration_spice_cnn import SpiceCNNConfig
 
7
 
8
  class SpiceCNNModelForImageClassification(PreTrainedModel):
9
  config_class = SpiceCNNConfig
@@ -11,26 +13,25 @@ class SpiceCNNModelForImageClassification(PreTrainedModel):
11
  def __init__(self, config: SpiceCNNConfig):
12
  super().__init__(config)
13
  layers = [
14
- nn.Conv2d(config.in_channels, 16, kernel_size=config.kernel_size, padding=1),
 
 
15
  nn.BatchNorm2d(16),
16
  nn.ReLU(),
17
  nn.MaxPool2d(kernel_size=config.pooling_size),
18
-
19
  nn.Conv2d(16, 32, kernel_size=config.kernel_size, padding=1),
20
  nn.BatchNorm2d(32),
21
  nn.ReLU(),
22
  nn.MaxPool2d(kernel_size=config.pooling_size),
23
-
24
  nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1),
25
  nn.BatchNorm2d(64),
26
  nn.ReLU(),
27
  nn.MaxPool2d(kernel_size=config.pooling_size),
28
-
29
  nn.Flatten(),
30
- nn.Linear(64*3*3, 128),
31
  nn.ReLU(),
32
  nn.Dropout(0.5),
33
- nn.Linear(128, config.num_classes)
34
  ]
35
  self.model = nn.Sequential(*layers)
36
 
@@ -41,7 +42,8 @@ class SpiceCNNModelForImageClassification(PreTrainedModel):
41
  loss = loss_fnc(logits, labels)
42
  return {"loss": loss, "logits": logits}
43
  return {"logits": logits}
44
-
 
45
  # config = SpiceCNNConfig(in_channels=1)
46
  # cnn = SpiceCNNModelForImageClassification(config)
47
  # summary(cnn, (1,28,28))
 
1
  import torch.nn as nn
2
+
3
  # from torchsummary import summary
4
 
5
  from transformers import PreTrainedModel
6
 
7
+ from .configuration_spice_cnn import SpiceCNNConfig
8
+
9
 
10
  class SpiceCNNModelForImageClassification(PreTrainedModel):
11
  config_class = SpiceCNNConfig
 
13
  def __init__(self, config: SpiceCNNConfig):
14
  super().__init__(config)
15
  layers = [
16
+ nn.Conv2d(
17
+ config.in_channels, 16, kernel_size=config.kernel_size, padding=1
18
+ ),
19
  nn.BatchNorm2d(16),
20
  nn.ReLU(),
21
  nn.MaxPool2d(kernel_size=config.pooling_size),
 
22
  nn.Conv2d(16, 32, kernel_size=config.kernel_size, padding=1),
23
  nn.BatchNorm2d(32),
24
  nn.ReLU(),
25
  nn.MaxPool2d(kernel_size=config.pooling_size),
 
26
  nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1),
27
  nn.BatchNorm2d(64),
28
  nn.ReLU(),
29
  nn.MaxPool2d(kernel_size=config.pooling_size),
 
30
  nn.Flatten(),
31
+ nn.Linear(64 * 3 * 3, 128),
32
  nn.ReLU(),
33
  nn.Dropout(0.5),
34
+ nn.Linear(128, config.num_classes),
35
  ]
36
  self.model = nn.Sequential(*layers)
37
 
 
42
  loss = loss_fnc(logits, labels)
43
  return {"loss": loss, "logits": logits}
44
  return {"logits": logits}
45
+
46
+
47
  # config = SpiceCNNConfig(in_channels=1)
48
  # cnn = SpiceCNNModelForImageClassification(config)
49
  # summary(cnn, (1,28,28))
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4c6b92faf374dee8dcdbaa0b65dd62f3fde1178284ffe3a8e254b49f3c73e249
3
  size 402812
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9099bac65074f55cf277404cf2daffecc6893d84cdda384e6e95eb6d6a257914
3
  size 402812