File size: 1,588 Bytes
dc2a7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch.nn as nn

# from torchsummary import summary

from transformers import PreTrainedModel

from .configuration_spice_cnn import SpiceCNNConfig


class SpiceCNNModelForImageClassification(PreTrainedModel):
    config_class = SpiceCNNConfig

    def __init__(self, config: SpiceCNNConfig):
        super().__init__(config)
        layers = [
            nn.Conv2d(
                config.in_channels, 16, kernel_size=config.kernel_size, padding=1
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=config.pooling_size),
            nn.Conv2d(16, 32, kernel_size=config.kernel_size, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=config.pooling_size),
            nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=config.pooling_size),
            nn.Flatten(),
            nn.Linear(64 * 3 * 3, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, config.num_classes),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, tensor, labels=None):
        logits = self.model(tensor)
        if labels is not None:
            loss_fnc = nn.CrossEntropyLoss()
            loss = loss_fnc(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}


# config = SpiceCNNConfig(in_channels=1)
# cnn = SpiceCNNModelForImageClassification(config)
# summary(cnn, (1,28,28))