leftthomas commited on
Commit
901629b
1 Parent(s): 5e92eb3

Upload configuration_resnet.py

Browse files
Files changed (1) hide show
  1. configuration_resnet.py +13 -0
configuration_resnet.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class ResnetConfig(PretrainedConfig):
6
+ model_type = 'resnet'
7
+ def __init__(self, block_type='bottleneck', layers: List[int] = [3, 4, 6, 3], num_classes: int = 1000, **kwargs):
8
+ if block_type not in ['basic', 'bottleneck']:
9
+ raise ValueError(f"`block` must be 'basic' or bottleneck', got {block_type}.")
10
+ self.block_type = block_type
11
+ self.layers = layers
12
+ self.num_classes = num_classes
13
+ super().__init__(**kwargs)