lenet_mnist / src /models.py
rzimmerdev's picture
fix: Removed last activation layer from LeNet model
7dc7452
raw
history blame
No virus
1.67 kB
#!/usr/bin/env python
# coding: utf-8
import torch
from torch import nn
class CNN(nn.Module):
def __init__(self, input_channels, num_classes):
super().__init__()
self.feature_layers = [input_channels, 6, 16]
self.kernels = [5, 5]
self.pools = [2, 2]
self.feature_activations = [nn.ReLU for _ in range(len(self.feature_layers) - 1)]
self.classifier_layers = [400, 120, 84, num_classes]
self.classifier_activations = [nn.ReLU for _ in range(len(self.classifier_layers) - 1)]
feature_layers = []
for idx, layer in enumerate(list(zip(self.feature_layers[:-1], self.feature_layers[1:]))):
feature_layers.append(
nn.Conv2d(in_channels=layer[0], out_channels=layer[1],
kernel_size=self.kernels[idx], padding=2 if idx == 0 else 0)
)
feature_layers.append(self.feature_activations[idx]())
feature_layers.append(nn.MaxPool2d(kernel_size=self.pools[idx]))
classifier_layers = []
for idx, layer in enumerate(list(zip(self.classifier_layers[:-1], self.classifier_layers[1:]))):
classifier_layers.append(
nn.Linear(in_features=layer[0], out_features=layer[1])
)
if idx < len(self.classifier_activations) - 1:
classifier_layers.append(self.classifier_activations[idx]())
self.feature_extractor = nn.Sequential(*feature_layers)
self.classifier = nn.Sequential(*classifier_layers)
def forward(self, x):
x = self.feature_extractor(x)
y = self.classifier(torch.flatten(x))
return y