from transformers import PretrainedConfig | |
class ResnetFeatureExtractorConfig(PretrainedConfig): | |
model_type = "resnet" | |
def __init__(self, name = 'resnet152', **kwargs): | |
if name != 'resnet152': | |
raise ValueError(f"`name` must be 'resnet152', got {name}.") | |
self.name = name | |
super().__init__(**kwargs) |