from transformers import PreTrainedModel from .MyConfig import MnistConfig # local import from torch import nn import torch.nn.functional as F class MnistModel(PreTrainedModel): # pass the previously defined config class to the model config_class = MnistConfig def __init__(self, config): # instantiate the model using the configuration super().__init__(config) # use the config to instantiate our model self.conv1 = nn.Conv2d(1, config.conv1, kernel_size=5) self.conv2 = nn.Conv2d(config.conv1, config.conv2, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) self.softmax = nn.Softmax(dim=-1) self.criterion = nn.CrossEntropyLoss() def forward(self, x,labels=None): # the labels parameter allows us to finetune our model # with the Trainer API easily x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) logits = self.softmax(x) if labels != None : # this will make your AI compatible with the trainer API loss = self.criterion(logits, labels) return {"loss": loss, "logits": logits} return logits