--- tags: - PyTorch - CNN datasets: - sklearn-digits --- Basic TinyCNN PyTorch model trained on Sklearn Digits dataset. ```python """ Credits to Zama.ai - https://github.com/zama-ai/concrete-ml/blob/main/docs/user/advanced_examples/ConvolutionalNeuralNetwork.ipynb """ import numpy as np import torch from torch import nn from torch.nn.utils import prune class TinyCNN(nn.Module): """A very small CNN to classify the sklearn digits dataset. This class also allows pruning to a maximum of 10 active neurons, which should help keep the accumulator bit width low. """ def __init__(self, n_classes) -> None: """Construct the CNN with a configurable number of classes.""" super().__init__() # This network has a total complexity of 1216 MAC self.conv1 = nn.Conv2d(1, 2, 3, stride=1, padding=0) self.conv2 = nn.Conv2d(2, 3, 3, stride=2, padding=0) self.conv3 = nn.Conv2d(3, 16, 2, stride=1, padding=0) self.fc1 = nn.Linear(16, n_classes) # Enable pruning, prepared for training self.toggle_pruning(True) def toggle_pruning(self, enable): """Enables or removes pruning.""" # Maximum number of active neurons (i.e. corresponding weight != 0) n_active = 10 # Go through all the convolution layers for layer in (self.conv1, self.conv2, self.conv3): s = layer.weight.shape # Compute fan-in (number of inputs to a neuron) # and fan-out (number of neurons in the layer) st = [s[0], np.prod(s[1:])] # The number of input neurons (fan-in) is the product of # the kernel width x height x inChannels. if st[1] > n_active: if enable: # This will create a forward hook to create a mask tensor that is multiplied # with the weights during forward. The mask will contain 0s or 1s prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0]) else: # When disabling pruning, the mask is multiplied with the weights # and the result is stored in the weights member prune.remove(layer, "weight") def forward(self, x): """Run inference on the tiny CNN, apply the decision layer on the reshaped conv output.""" x = self.conv1(x) x = torch.relu(x) x = self.conv2(x) x = torch.relu(x) x = self.conv3(x) x = torch.relu(x) x = x.view(-1, 16) x = self.fc1(x) return x ```