File size: 1,029 Bytes
dba8d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

from transformers import PreTrainedModel
from .MyConfig import MnistConfig
from torch import nn
import torch.nn.functional as F

class MnistModel(PreTrainedModel):
    config_class = MnistConfig
    def __init__(self, config):
        super().__init__(config)
        # use the config to instantiate our model
        self.conv1 = nn.Conv2d(1, config.conv1, kernel_size=5)
        self.conv2 = nn.Conv2d(config.conv1, config.conv2, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x,labels=None):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        output = self.softmax(x)
        if labels != None :
          print("continue training script here")

        return output