File size: 1,424 Bytes
dba8d15
 
827dc1a
dba8d15
 
 
 
827dc1a
dba8d15
827dc1a
dba8d15
827dc1a
dba8d15
 
 
 
 
 
 
 
827dc1a
dba8d15
827dc1a
 
dba8d15
 
 
 
 
 
827dc1a
dba8d15
827dc1a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

from transformers import PreTrainedModel
from .MyConfig import MnistConfig # local import
from torch import nn
import torch.nn.functional as F

class MnistModel(PreTrainedModel):
    # pass the previously defined config class to the model
    config_class = MnistConfig

    def __init__(self, config):
        # instantiate the model using the configuration
        super().__init__(config)
        # use the config to instantiate our model
        self.conv1 = nn.Conv2d(1, config.conv1, kernel_size=5)
        self.conv2 = nn.Conv2d(config.conv1, config.conv2, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.softmax = nn.Softmax(dim=-1)
        self.criterion = nn.CrossEntropyLoss()
    def forward(self, x,labels=None):
        # the labels parameter allows us to finetune our model
        # with the Trainer API easily
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        logits = self.softmax(x)
        if labels != None :
          # this will make your AI compatible with the trainer API
          loss = self.criterion(logits, labels)
          return {"loss": loss, "logits": logits}
        return logits