|
import torch.nn as nn |
|
|
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
|
from .configuration_spice_cnn import SpiceCNNConfig |
|
|
|
|
|
class SpiceCNNModelForImageClassification(PreTrainedModel): |
|
config_class = SpiceCNNConfig |
|
|
|
def __init__(self, config: SpiceCNNConfig): |
|
super().__init__(config) |
|
layers = [ |
|
nn.Conv2d( |
|
config.in_channels, 16, kernel_size=config.kernel_size, padding=1 |
|
), |
|
nn.BatchNorm2d(16), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=config.pooling_size), |
|
nn.Conv2d(16, 32, kernel_size=config.kernel_size, padding=1), |
|
nn.BatchNorm2d(32), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=config.pooling_size), |
|
nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=config.pooling_size), |
|
nn.Flatten(), |
|
nn.Linear(64 * 3 * 3, 128), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(128, config.num_classes), |
|
] |
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, tensor, labels=None): |
|
logits = self.model(tensor) |
|
if labels is not None: |
|
loss_fnc = nn.CrossEntropyLoss() |
|
loss = loss_fnc(logits, labels) |
|
return {"loss": loss, "logits": logits} |
|
return {"logits": logits} |
|
|
|
|
|
|
|
|
|
|
|
|