from transformers import PreTrainedModel from .MyConfig import MnistConfig from torch import nn import torch.nn.functional as F class MnistModel(PreTrainedModel): config_class = MnistConfig def __init__(self, config): super().__init__(config) # use the config to instantiate our model self.conv1 = nn.Conv2d(1, config.conv1, kernel_size=5) self.conv2 = nn.Conv2d(config.conv1, config.conv2, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) self.softmax = nn.Softmax(dim=-1) def forward(self, x,labels=None): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) output = self.softmax(x) if labels != None : print("continue training script here") return output