File size: 520 Bytes
9addc67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import PretrainedConfig

class Resnet50Config(PretrainedConfig):
    # since we have an image classification task
    # we need to put a model type that is close to our task
    # don't worry this will not affect our model
    #model_type = "MobileNetV1"
    def __init__(
        self,
        num_classes=6,
        **kwargs):

        self.num_classes = num_classes
        super().__init__(**kwargs)
class Resnet50Config:
    def __init__(self, num_classes):
        self.num_classes = num_classes