File size: 838 Bytes
f203678 1947a8e f203678 37c96c3 f203678 75c0722 f203678 75c0722 f203678 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from transformers import PretrainedConfig
from typing import List
class AugViTConfig(PretrainedConfig):
model_type = "augvit"
def __init__(
self,
image_size: int = 224,
patch_size: int = 32,
num_classes: int = 1000,
dim: int = 128,
depth: int = 2,
heads: int = 16,
mlp_dim: int = 256,
dropout: int = 0.1,
emb_dropout: int = 0.1,
num_channels:int=3,
**kwargs,
):
self.image_size = image_size
self.patch_size = patch_size
self.num_classes = num_classes
self.dim = dim
self.depth = depth
self.heads = heads
self.mlp_dim = mlp_dim
self.dropout = dropout
self.emb_dropout = emb_dropout
self.num_channels=num_channels
super().__init__(**kwargs) |