File size: 520 Bytes
9addc67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from transformers import PretrainedConfig
class Resnet50Config(PretrainedConfig):
# since we have an image classification task
# we need to put a model type that is close to our task
# don't worry this will not affect our model
#model_type = "MobileNetV1"
def __init__(
self,
num_classes=6,
**kwargs):
self.num_classes = num_classes
super().__init__(**kwargs)
class Resnet50Config:
def __init__(self, num_classes):
self.num_classes = num_classes
|