import torch from torch import nn from torchvision import transforms class MnistModel(nn.Module): classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.conv1 = nn.Conv2d(1, 3, 3) self.conv2 = nn.Conv2d(3, 6, 3) self.maxpool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(150, 32) self.fc2 = nn.Linear(32, 10) #self.fc3 = nn.Linear(32, 10) self.dropout = nn.Dropout(0.3) def forward(self, x): l1 = nn.ReLU()(self.conv1(x)) l1 = self.maxpool(l1) l2 = nn.ReLU()(self.conv2(l1)) l2 = self.maxpool(l2) fc = torch.flatten(l2, 1) fc1 = nn.ReLU()(self.fc1(fc)) fc1 = self.dropout(fc1) #fc2 = nn.ReLU()(self.fc2(fc1)) out = self.fc2(fc1) return out def load_model(): model = MnistModel() transforming = transforms.Compose([ transforms.Resize((28,28)), transforms.ToTensor(), transforms.Grayscale(num_output_channels=1) ]) model.load_state_dict(torch.load('best_model.pth',map_location='cpu')) return model,transforming,model.classes if __name__=='__main__': pass