File size: 464 Bytes
70e7cb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

from transformers import PretrainedConfig
from typing import List

class MnistConfig(PretrainedConfig):
    # since we have an image classification task
    # we need to put a model type that is close to our task
    # don't worry this will not affect our model
    model_type = "MobileNetV1"
    def __init__(
        self,
        conv1=10,
        conv2=20,
        **kwargs):
      self.conv1 = conv1
      self.conv2 = conv2
      super().__init__(**kwargs)