bbexx commited on
Commit
15d27c4
·
1 Parent(s): 7a14932
Files changed (5) hide show
  1. README.md +95 -0
  2. configuration_vitamin.py +158 -0
  3. model.py +741 -0
  4. timm_model.py +151 -0
  5. vitamin.py +828 -0
README.md CHANGED
@@ -1,3 +1,98 @@
1
  ---
2
  license: mit
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ datasets:
4
+ - mlfoundations/datacomp_1b
5
+ pipeline_tag: feature-extraction
6
  ---
7
+
8
+ # Model card for ViTamin-XL-336px
9
+
10
+ Official huggingface models of **ViTamin**, from the following CVPR 2024 paper:
11
+
12
+ [ViTamin: Design Scalable Vision Models in the Vision-language Era](https://arxiv.org/pdf/2404.02132.pdf).\
13
+ ✨  [Jieneng Chen](https://beckschen.github.io), [Qihang Yu](https://yucornetto.github.io/), [Xiaohui Shen](https://xiaohuishen.github.io/), [Alan Yuille](https://www.cs.jhu.edu/~ayuille/) and [Liang-Chieh Chen](http://liangchiehchen.com/)\
14
+ 🏠  Johns Hopkins University, Bytedance
15
+
16
+ 🔥 This ViTamin-XL-336px is the pre-trained model transferred to open-vocabulary detection and segmentation, and large multi-modal models in our paper.
17
+
18
+ Load from HuggingFace with transformers.AutoModel:
19
+ ```python
20
+ import torch
21
+ import open_clip
22
+ from PIL import Image
23
+ from transformers import AutoModel, CLIPImageProcessor
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ model = AutoModel.from_pretrained(
27
+ 'jienengchen/ViTamin-L2-336px',
28
+ trust_remote_code=True).to(device).eval()
29
+
30
+ image = Image.open('./image.png').convert('RGB')
31
+ image_processor = CLIPImageProcessor.from_pretrained('jienengchen/ViTamin-L2-336px')
32
+
33
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
34
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
35
+
36
+ tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
37
+ text = tokenizer(["a photo of vitamin", "a dog", "a cat"]).to(device)
38
+
39
+ with torch.no_grad(), torch.cuda.amp.autocast():
40
+ image_features, text_features, logit_scale = model(pixel_values, text)
41
+ text_probs = (100.0 * image_features @ text_features.to(torch.float).T).softmax(dim=-1)
42
+
43
+ print("Label probs:", text_probs)
44
+ ```
45
+
46
+ ## Main Results with CLIP Pre-training on DataComp-1B
47
+
48
+
49
+ | image encoder | image size | num patches | text encoder depth/width | seen samples (B) | trainable params Image+Text (M) | MACs Image+Text (G) | ImageNet Acc. | avg. 38 datasets | ImageNet dist. shift. | VTAB | retrieval |
50
+ |---------------|------------|-------------|--------------------------|-------------------|---------------------------------|----------------------|---------------|------------------|-----------------------|------|-----------|
51
+ | ViTamin-L | 224 | 196 | 12/768 | 12.8 | 333.3+123.7 | 72.6+6.6 | 80.8 | 66.7 | 69.8 | 65.3 | 60.3 |
52
+ | ViTamin-L | 256 | 256 | 12/768 | 12.8+0.2 | 333.4+123.7 | 94.8+6.6 | 81.2 | 67.0 | 71.1 | 65.3 | 61.2 |
53
+ | ViTamin-L | 336 | 441 | 12/768 | 12.8+0.2 | 333.6+123.7 | 163.4+6.6 | 81.6 | 67.0 | 72.1 | 64.4 | 61.6 |
54
+ | ViTamin-L | 384 | 576 | 12/768 | 12.8+0.2 | 333.7+123.7 | 213.4+6.6 | 81.8 | 67.2 | 72.4 | 64.7 | 61.8 |
55
+ | ViTamin-L2 | 224 | 196 | 24/1024 | 12.8 | 333.6+354.0 | 72.6+23.3 | 80.9 | 66.4 | 70.6 | 63.4 | 61.5 |
56
+ | ViTamin-L2 | 256 | 256 | 24/1024 | 12.8+0.5 | 333.6+354.0 | 94.8+23.3 | 81.5 | 67.4 | 71.9 | 64.1 | 63.1 |
57
+ | ViTamin-L2 | 336 | 441 | 24/1024 | 12.8+0.5 | 333.8+354.0 | 163.4+23.3 | 81.8 | 67.8 | 73.0 | 64.5 | 63.6 |
58
+ | ViTamin-L2 | 384 | 576 | 24/1024 | 12.8+0.5 | 334.0+354.0 | 213.4+23.3 | 82.1 | 68.1 | 73.4 | 64.8 | 63.7 |
59
+ | ViTamin-XL | 256 | 256 | 27/1152 | 12.8+0.5 | 436.1+488.7 | 125.3+33.1 | 82.1 | 67.6 | 72.3 | 65.4 | 62.7 |
60
+ | ViTamin-XL | 384 | 576 | 27/1152 | 12.8+0.5 | 436.1+488.7 | 281.9+33.1 | 82.6 | 68.1 | 73.6 | 65.6 | 63.8 |
61
+ | ViTamin-XL | 256 | 256 | 27/1152 | 40 | 436.1+488.7 | 125.3+33.1 | 82.3 | 67.5 | 72.8 | 64.0 | 62.1 |
62
+ | ViTamin-XL | 336 | 441 | 27/1152 | 40+1 | 436.1+488.7 | 215.9+33.1 | 82.7 | 68.0 | 73.9 | 64.1 | 62.6 |
63
+ | ViTamin-XL | 384 | 576 | 27/1152 | 40+1 | 436.1+488.7 | 281.9+33.1 | 82.9 | 68.1 | 74.1 | 64.0 | 62.5 |
64
+
65
+ ## Main Results on Downstream tasks
66
+ **Open-Vocab Detection**
67
+ | image encoder | detector | OV-COCO (AP<sub>50</sub><sup>novel</sup>) | OV-LVIS (AP<sub>r</sub>) |
68
+ |---------------|----------|---------------------------------------|-----------------------|
69
+ | ViT-L/14 | Sliding F-ViT | 36.1 | 32.5 |
70
+ | ViTamin-L | Sliding F-ViT | 37.5 | 35.6 |
71
+
72
+ **Open-Vocab Segmentation**
73
+
74
+ | image encoder | segmentor | ADE | Cityscapes | MV | A-150 | A-847 | PC-459 | PC-59 | PAS-21 |
75
+ |---------------|-------------|----------------|--------------|------|-------|-------|--------|-------|--------------------|
76
+ | ViT-L/14 | Sliding FC-CLIP | 24.6 | 40.7 | 16.5 | 31.8 | 14.3 | 18.3 | 55.1 | 81.5 |
77
+ | ViTamin-L | Sliding FC-CLIP | 27.3 | 44.0 | 18.2 | 35.6 | 16.1 | 20.4 | 58.4 | 83.4 |
78
+
79
+ Note: Panoptic dataset (ADE, CityScapes, MV) are with the metric of PQ. Semantic dataset (A-150, A-847, PC-459, PC-59, PAS-21) are with the metric of mIoU.
80
+
81
+ **Large Multi-modal Models**
82
+
83
+ | image encoder | image size | VQAv2 | GQA | VizWiz | SQA | T-VQA | POPE | MME | MM-Bench | MM-B-CN | SEED | LLaVA-Wild | MM-Vet |
84
+ |---------------|----------|-------|------|--------|------|-------|------|------|----------|---------|------|------------|--------|
85
+ | ViTamin-L | 224 | 78.4 | 61.6 | 51.1 | 66.9 | 58.7 | 84.6 | 1421 | 65.4 | 58.4 | 57.7 | 64.5 | 33.6 |
86
+ | ViTamin-L | 384 | 78.9 | 61.6 | 55.4 | 67.6 | 59.8 | 85.5 | 1447 | 64.5 | 58.3 | 57.9 | 66.1 | 33.6 |
87
+
88
+
89
+ ## Citing ViTamin
90
+
91
+ ```
92
+ @inproceedings{chen2024vitamin,
93
+ title={ViTamin: Design Scalable Vision Models in the Vision-language Era},
94
+ author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, ALan and Chen, Liang-Chieh},
95
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
96
+ year={2024}
97
+ }
98
+ ```
configuration_vitamin.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ViTamin
2
+
3
+ Paper: Designing Scalable Vison Models in the Vision-Language Era
4
+
5
+ @misc{chen2023designing,
6
+ title={Designing Scalable Vison Models in the Vision-Language Era},
7
+ author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen},
8
+ year={2023},
9
+ archivePrefix={arXiv},
10
+ primaryClass={cs.CV}
11
+ }
12
+
13
+ Based on Apache 2.0 licensed code at https://github.com/Beckschen/ViTamin
14
+
15
+ by Jieneng Chen 2024
16
+ """
17
+
18
+ import copy
19
+ import os
20
+ from collections import OrderedDict
21
+ from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from transformers.processing_utils import ProcessorMixin
26
+ from transformers.utils import TensorType
27
+
28
+ from transformers.configuration_utils import PretrainedConfig
29
+ from transformers.utils import logging
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ class ViTaminTextConfig(PretrainedConfig):
34
+ model_type = "vitamin_text_model"
35
+
36
+ def __init__(
37
+ self,
38
+ context_length = 77,
39
+ vocab_size = 49408,
40
+ width = 1024,
41
+ heads = 16,
42
+ layers = 24,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(**kwargs)
46
+
47
+ self.vocab_size = vocab_size
48
+ self.context_length = context_length
49
+ self.width = width
50
+ self.heads = heads
51
+ self.layers = layers
52
+
53
+ @classmethod
54
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
55
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
56
+
57
+ if 'text_config' in config_dict:
58
+ config_dict = config_dict['text_config']
59
+
60
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
61
+ logger.warning(
62
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
63
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
64
+ )
65
+
66
+ return cls.from_dict(config_dict, **kwargs)
67
+
68
+
69
+ class ViTaminVisionConfig(PretrainedConfig):
70
+
71
+ model_type = "vitamin_vision_model"
72
+
73
+ def __init__(
74
+ self,
75
+ timm_model_name = "vitamin_large",
76
+ timm_model_pretrained = False,
77
+ timm_pool = "",
78
+ timm_proj = "linear",
79
+ timm_drop = 0.0,
80
+ timm_drop_path = 0.1,
81
+ image_size = 256,
82
+ timm_proj_bias = False,
83
+ patch_dropout = 0.0,
84
+ drop_path = None,
85
+ **kwargs,
86
+ ):
87
+ super().__init__(**kwargs)
88
+
89
+ self.timm_model_name = timm_model_name
90
+ self.timm_model_pretrained = timm_model_pretrained
91
+ self.timm_pool = timm_pool
92
+ self.timm_proj = timm_proj
93
+ self.timm_drop = timm_drop
94
+ self.timm_drop_path = timm_drop_path
95
+ self.timm_proj_bias = timm_proj_bias
96
+ self.patch_dropout = patch_dropout
97
+ self.image_size = image_size
98
+
99
+
100
+ @classmethod
101
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
102
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
103
+
104
+ if 'vision_config' in config_dict:
105
+ config_dict = config_dict['vision_config']
106
+
107
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
108
+ logger.warning(
109
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
110
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
111
+ )
112
+
113
+ return cls.from_dict(config_dict, **kwargs)
114
+
115
+
116
+
117
+ class ViTaminConfig(PretrainedConfig):
118
+ model_type = "vitamin"
119
+ is_composition = True
120
+
121
+ def __init__(
122
+ self, text_config=None, vision_config=None, embed_dim=512, **kwargs
123
+ ):
124
+ super().__init__(**kwargs)
125
+ if text_config is None:
126
+ text_config = {}
127
+ logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
128
+
129
+ if vision_config is None:
130
+ vision_config = {}
131
+ logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
132
+
133
+ self.embed_dim = embed_dim
134
+ self.text_config = ViTaminTextConfig(**text_config)
135
+ self.vision_config = ViTaminVisionConfig(**vision_config)
136
+
137
+ @classmethod
138
+ def from_text_vision_configs(cls, text_config: ViTaminTextConfig, vision_config: ViTaminVisionConfig, **kwargs):
139
+ r"""
140
+ Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
141
+ configuration.
142
+ Returns:
143
+ [`CLIPConfig`]: An instance of a configuration object
144
+ """
145
+
146
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
147
+
148
+ def to_dict(self):
149
+ """
150
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
151
+ Returns:
152
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
153
+ """
154
+ output = copy.deepcopy(self.__dict__)
155
+ output["text_config"] = self.text_config.to_dict()
156
+ output["vision_config"] = self.vision_config.to_dict()
157
+ output["model_type"] = self.__class__.model_type
158
+ return output
model.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ViTamin
2
+
3
+ Paper: Designing Scalable Vison Models in the Vision-Language Era
4
+
5
+ @misc{chen2023designing,
6
+ title={Designing Scalable Vison Models in the Vision-Language Era},
7
+ author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen},
8
+ year={2023},
9
+ archivePrefix={arXiv},
10
+ primaryClass={cs.CV}
11
+ }
12
+
13
+ Based on Apache 2.0 licensed code at https://github.com/Beckschen/ViTamin
14
+
15
+ by Jieneng Chen 2024
16
+
17
+ Reference: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
18
+ """
19
+
20
+ from dataclasses import dataclass
21
+ import logging
22
+ import math
23
+ from typing import Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn
29
+ from torch.utils.checkpoint import checkpoint
30
+ from functools import partial
31
+ from open_clip.hf_model import HFTextEncoder
32
+ from open_clip.modified_resnet import ModifiedResNet
33
+ from open_clip.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
34
+ from open_clip.utils import to_2tuple
35
+ import time
36
+ import timm
37
+ from timm.models.vision_transformer import _create_vision_transformer
38
+ from .timm_model import TimmModel
39
+ from .vitamin import *
40
+ # from .vitamin import HybridEmbed, MbConvStages, VitCfg, VitConvCfg
41
+ from .vitamin import GeGluMlp, ViTamin, HybridEmbed, MbConvStages, VitCfg, VitConvCfg
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from .configuration_vitamin import ViTaminConfig, ViTaminVisionConfig
44
+
45
+ @dataclass
46
+ class CLIPVisionCfg:
47
+ layers: Union[Tuple[int, int, int, int], int] = 12
48
+ width: int = 768
49
+ head_width: int = 64
50
+ mlp_ratio: float = 4.0
51
+ patch_size: int = 16
52
+ image_size: Union[Tuple[int, int], int] = 224
53
+
54
+ ls_init_value: Optional[float] = None
55
+ patch_dropout: float = 0.
56
+ input_patchnorm: bool = False
57
+ global_average_pool: bool = False
58
+ attentional_pool: bool = False
59
+ n_queries: int = 256
60
+ attn_pooler_heads: int = 8
61
+ output_tokens: bool = False
62
+
63
+ timm_model_name: str = None
64
+ timm_model_pretrained: bool = False
65
+ timm_pool: str = 'avg'
66
+ timm_proj: str = 'linear'
67
+ timm_proj_bias: bool = False
68
+ timm_drop: float = 0.
69
+ timm_drop_path: Optional[float] = None
70
+
71
+
72
+ @dataclass
73
+ class CLIPTextCfg:
74
+ context_length: int = 77
75
+ vocab_size: int = 49408
76
+ width: int = 512
77
+ heads: int = 8
78
+ layers: int = 12
79
+ ls_init_value: Optional[float] = None # layer scale initial value
80
+ hf_model_name: str = None
81
+ hf_tokenizer_name: str = None
82
+ hf_model_pretrained: bool = True
83
+ proj: str = 'mlp'
84
+ pooler_type: str = 'mean_pooler'
85
+ embed_cls: bool = False
86
+ pad_id: int = 0
87
+ output_tokens: bool = False
88
+ text_mask: str = 'first' # default first truncate in bpe_tokenizer
89
+
90
+
91
+ def get_cast_dtype(precision: str):
92
+ cast_dtype = None
93
+ if precision == 'bf16':
94
+ cast_dtype = torch.bfloat16
95
+ elif precision == 'fp16':
96
+ cast_dtype = torch.float16
97
+ return cast_dtype
98
+
99
+
100
+ def get_input_dtype(precision: str):
101
+ input_dtype = None
102
+ if precision in ('bf16', 'pure_bf16'):
103
+ input_dtype = torch.bfloat16
104
+ elif precision in ('fp16', 'pure_fp16'):
105
+ input_dtype = torch.float16
106
+ return input_dtype
107
+
108
+
109
+ def _build_vision_tower(
110
+ embed_dim: int,
111
+ vision_cfg: CLIPVisionCfg,
112
+ quick_gelu: bool = False,
113
+ cast_dtype: Optional[torch.dtype] = None
114
+ ):
115
+ if isinstance(vision_cfg, dict):
116
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
117
+
118
+ act_layer = QuickGELU if quick_gelu else nn.GELU
119
+
120
+ if vision_cfg.timm_model_name:
121
+ visual = TimmModel(
122
+ vision_cfg.timm_model_name,
123
+ pretrained=vision_cfg.timm_model_pretrained,
124
+ pool=vision_cfg.timm_pool,
125
+ proj=vision_cfg.timm_proj,
126
+ proj_bias=vision_cfg.timm_proj_bias,
127
+ drop=vision_cfg.timm_drop,
128
+ drop_path=vision_cfg.timm_drop_path,
129
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
130
+ embed_dim=embed_dim,
131
+ image_size=vision_cfg.image_size,
132
+ )
133
+ elif isinstance(vision_cfg.layers, (tuple, list)):
134
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
135
+ visual = ModifiedResNet(
136
+ layers=vision_cfg.layers,
137
+ output_dim=embed_dim,
138
+ heads=vision_heads,
139
+ image_size=vision_cfg.image_size,
140
+ width=vision_cfg.width,
141
+ )
142
+ else:
143
+ vision_heads = vision_cfg.width // vision_cfg.head_width
144
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
145
+ visual = VisionTransformer(
146
+ image_size=vision_cfg.image_size,
147
+ patch_size=vision_cfg.patch_size,
148
+ width=vision_cfg.width,
149
+ layers=vision_cfg.layers,
150
+ heads=vision_heads,
151
+ mlp_ratio=vision_cfg.mlp_ratio,
152
+ ls_init_value=vision_cfg.ls_init_value,
153
+ patch_dropout=vision_cfg.patch_dropout,
154
+ input_patchnorm=vision_cfg.input_patchnorm,
155
+ global_average_pool=vision_cfg.global_average_pool,
156
+ attentional_pool=vision_cfg.attentional_pool,
157
+ n_queries=vision_cfg.n_queries,
158
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
159
+ output_tokens=vision_cfg.output_tokens,
160
+ output_dim=embed_dim,
161
+ act_layer=act_layer,
162
+ norm_layer=norm_layer,
163
+ )
164
+
165
+ return visual
166
+
167
+
168
+ def _build_text_tower(
169
+ embed_dim: int,
170
+ text_cfg: CLIPTextCfg,
171
+ quick_gelu: bool = False,
172
+ cast_dtype: Optional[torch.dtype] = None,
173
+ ):
174
+ if isinstance(text_cfg, dict):
175
+ text_cfg = CLIPTextCfg(**text_cfg)
176
+
177
+ if text_cfg.hf_model_name:
178
+ text = HFTextEncoder(
179
+ text_cfg.hf_model_name,
180
+ output_dim=embed_dim,
181
+ proj=text_cfg.proj,
182
+ pooler_type=text_cfg.pooler_type,
183
+ pretrained=text_cfg.hf_model_pretrained,
184
+ output_tokens=text_cfg.output_tokens,
185
+ )
186
+ else:
187
+ act_layer = QuickGELU if quick_gelu else nn.GELU
188
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
189
+
190
+ text = TextTransformer(
191
+ context_length=text_cfg.context_length,
192
+ vocab_size=text_cfg.vocab_size,
193
+ width=text_cfg.width,
194
+ heads=text_cfg.heads,
195
+ layers=text_cfg.layers,
196
+ ls_init_value=text_cfg.ls_init_value,
197
+ output_dim=embed_dim,
198
+ embed_cls=text_cfg.embed_cls,
199
+ output_tokens=text_cfg.output_tokens,
200
+ pad_id=text_cfg.pad_id,
201
+ act_layer=act_layer,
202
+ norm_layer=norm_layer,
203
+ )
204
+ return text
205
+
206
+
207
+ class CLIP(nn.Module):
208
+ output_dict: torch.jit.Final[bool]
209
+
210
+ def __init__(
211
+ self,
212
+ embed_dim: int,
213
+ vision_cfg: CLIPVisionCfg,
214
+ text_cfg: CLIPTextCfg,
215
+ quick_gelu: bool = False,
216
+ cast_dtype: Optional[torch.dtype] = None,
217
+ output_dict: bool = False,
218
+ ):
219
+ super().__init__()
220
+ self.output_dict = output_dict
221
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
222
+
223
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
224
+ self.transformer = text.transformer
225
+ self.context_length = text.context_length
226
+ self.vocab_size = text.vocab_size
227
+ self.token_embedding = text.token_embedding
228
+ self.positional_embedding = text.positional_embedding
229
+
230
+ self.ln_final = text.ln_final
231
+ self.text_projection = text.text_projection
232
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
233
+
234
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
235
+
236
+ self.method_lock_text_tower = text.lock
237
+ self.text_no_grad = False
238
+
239
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
240
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
241
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
242
+
243
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj=False):
244
+ # added by jieneng
245
+ self.method_lock_text_tower(unlocked_layers, freeze_layer_norm)
246
+ self.text_no_grad = True
247
+
248
+ @torch.jit.ignore
249
+ def set_grad_checkpointing(self, enable=True, enable_text=True):
250
+ self.visual.set_grad_checkpointing(enable)
251
+ self.transformer.grad_checkpointing = enable_text
252
+
253
+ def encode_image(self, image, normalize: bool = False):
254
+ features = self.visual(image)
255
+ return F.normalize(features, dim=-1) if normalize else features
256
+
257
+ def encode_text(self, text, normalize: bool = False):
258
+ cast_dtype = self.transformer.get_cast_dtype()
259
+
260
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
261
+
262
+ x = x + self.positional_embedding.to(cast_dtype)
263
+ x = x.permute(1, 0, 2) # NLD -> LND
264
+ x = self.transformer(x, attn_mask=self.attn_mask)
265
+ x = x.permute(1, 0, 2) # LND -> NLD
266
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
267
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
268
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
269
+ return F.normalize(x, dim=-1) if normalize else x
270
+
271
+ def forward(
272
+ self,
273
+ image: Optional[torch.Tensor] = None,
274
+ text: Optional[torch.Tensor] = None,
275
+ ):
276
+ # torch.cuda.synchronize()
277
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
278
+
279
+ if self.text_no_grad:
280
+ with torch.no_grad():
281
+ text_features = self.encode_text(text, normalize=True).detach() if text is not None else None
282
+ else:
283
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
284
+
285
+
286
+ if self.output_dict:
287
+ return {
288
+ "image_features": image_features,
289
+ "text_features": text_features,
290
+ "logit_scale": self.logit_scale.exp()
291
+ }
292
+ return image_features, text_features, self.logit_scale.exp()
293
+
294
+
295
+ # class CustomTextCLIP(nn.Module):
296
+
297
+
298
+ class CustomTextCLIP(nn.Module):
299
+ output_dict: torch.jit.Final[bool]
300
+
301
+ def __init__(
302
+ self,
303
+ embed_dim: int,
304
+ vision_cfg: CLIPVisionCfg,
305
+ text_cfg: CLIPTextCfg,
306
+ quick_gelu: bool = False,
307
+ cast_dtype: Optional[torch.dtype] = None,
308
+ output_dict: bool = False,
309
+ ):
310
+ super().__init__()
311
+ self.output_dict = output_dict
312
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
313
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
314
+ self.context_length = self.text.context_length
315
+ self.vocab_size = self.text.vocab_size
316
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
317
+ self.text_no_grad = False
318
+
319
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
320
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
321
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
322
+
323
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj = False):
324
+ self.text.lock(unlocked_layers, freeze_layer_norm, unlock_text_proj)
325
+ self.text_no_grad = True
326
+
327
+
328
+ @torch.jit.ignore
329
+ def set_grad_checkpointing(self, enable=True, enable_text=True):
330
+ self.visual.set_grad_checkpointing(enable)
331
+ self.text.set_grad_checkpointing(enable_text)
332
+
333
+
334
+ def encode_image(self, image, normalize: bool = False):
335
+ features = self.visual(image)
336
+ return F.normalize(features, dim=-1) if normalize else features
337
+
338
+ def encode_text(self, text, normalize: bool = False):
339
+ features = self.text(text)
340
+ return F.normalize(features, dim=-1) if normalize else features
341
+
342
+ def forward(
343
+ self,
344
+ image: Optional[torch.Tensor] = None,
345
+ text: Optional[torch.Tensor] = None,
346
+ ):
347
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
348
+ # if self.text_no_grad:
349
+ # with torch.no_grad():
350
+ # text_features = self.encode_text(text, normalize=True).detach() if text is not None else None
351
+ # else:
352
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
353
+
354
+ if self.output_dict:
355
+ return {
356
+ "image_features": image_features,
357
+ "text_features": text_features,
358
+ "logit_scale": self.logit_scale.exp()
359
+ }
360
+ return image_features, text_features, self.logit_scale.exp()
361
+
362
+
363
+ class ViTaminPreTrainedModel(PreTrainedModel):
364
+ """
365
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
366
+ models.
367
+ """
368
+
369
+ config_class = ViTaminConfig
370
+ base_model_prefix = 'vitamin'
371
+
372
+
373
+ # hack CLIPVisionModel for llava: https://github.com/huggingface/transformers/blob/9acce7de1cb8229304a467938ebb47727d60cdb2/src/transformers/models/clip/modeling_clip.py#L878
374
+ class ViTaminVisionModel(PreTrainedModel):
375
+ config_class = ViTaminVisionConfig
376
+ main_input_name = 'pixel_values'
377
+
378
+ def __init__(self, config: ViTaminVisionConfig):
379
+ super().__init__(config)
380
+
381
+ self.visual = _build_vision_tower(config.embed_dim, config)
382
+
383
+ def forward(
384
+ self,
385
+ pixel_values: Optional[torch.FloatTensor] = None,
386
+ select_layer = -2,
387
+ ):
388
+ assert len(pixel_values.shape) == 4, f'wrong pixel_values size: {pixel_values.shape}'
389
+ x = self.visual.trunk.patch_embed.backbone.stem(pixel_values)
390
+ x = self.visual.trunk.patch_embed.backbone.stages[0](x)
391
+ x = self.visual.trunk.patch_embed.backbone.stages[1](x)
392
+ x = self.visual.trunk.patch_embed.backbone.pool(x)
393
+ x = self.visual.trunk.patch_embed.proj(x)
394
+ x = x.flatten(2).transpose(1, 2)
395
+ x = self.visual.trunk.patch_drop(x)
396
+ x = self.visual.trunk.norm_pre(x)
397
+ x = self.visual.trunk.blocks[:select_layer+1](x)
398
+ return x
399
+
400
+
401
+ class ViTaminCLIP(ViTaminPreTrainedModel):
402
+ output_dict: torch.jit.Final[bool]
403
+ config_class: ViTaminConfig
404
+
405
+ def __init__(
406
+ self,
407
+ config: ViTaminConfig
408
+ ):
409
+ super().__init__(config)
410
+
411
+ embed_dim=config.embed_dim #: int,
412
+ vision_cfg=config.vision_cfg #: CLIPVisionCfg,
413
+ text_cfg=config.text_cfg #: CLIPTextCfg,
414
+ quick_gelu=False
415
+ cast_dtype=None
416
+ output_dict=False
417
+
418
+ self.config = config
419
+ self.output_dict = output_dict
420
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
421
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
422
+ self.context_length = self.text.context_length
423
+ self.vocab_size = self.text.vocab_size
424
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
425
+ self.text_no_grad = False
426
+
427
+ def forward_visual4llava(
428
+ self,
429
+ pixel_values: Optional[torch.FloatTensor] = None,
430
+ select_layer = -2,
431
+ ):
432
+ assert len(pixel_values.shape) == 4, f'wrong pixel_values size: {pixel_values.shape}'
433
+ x = self.visual.trunk.patch_embed.backbone.stem(pixel_values)
434
+ x = self.visual.trunk.patch_embed.backbone.stages[0](x)
435
+ x = self.visual.trunk.patch_embed.backbone.stages[1](x)
436
+ x = self.visual.trunk.patch_embed.backbone.pool(x)
437
+ x = self.visual.trunk.patch_embed.proj(x)
438
+ x = x.flatten(2).transpose(1, 2)
439
+ x = self.visual.trunk.patch_drop(x)
440
+ x = self.visual.trunk.norm_pre(x)
441
+ x = self.visual.trunk.blocks[:select_layer+1](x)
442
+ return x
443
+
444
+ def encode_image(self, image, normalize: bool = False):
445
+ features = self.visual(image)
446
+ return F.normalize(features, dim=-1) if normalize else features
447
+
448
+ def encode_text(self, text, normalize: bool = False):
449
+ features = self.text(text)
450
+ return F.normalize(features, dim=-1) if normalize else features
451
+
452
+ def forward_pixel(
453
+ self,
454
+ image: Optional[torch.Tensor] = None,
455
+ text: Optional[torch.Tensor] = None,
456
+ ):
457
+
458
+ x = self.visual.trunk.patch_embed.backbone.stem(image)
459
+ x = self.visual.trunk.patch_embed.backbone.stages[0](x)
460
+ x = self.visual.trunk.patch_embed.backbone.stages[1](x)
461
+ x = self.visual.trunk.patch_embed.backbone.pool(x)
462
+ x = self.visual.trunk.patch_embed.proj(x)
463
+ x = x.flatten(2).transpose(1, 2)
464
+ x = self.visual.trunk.patch_drop(x)
465
+ x = self.visual.trunk.norm_pre(x)
466
+ x = self.visual.trunk.blocks(x)
467
+ x = self.visual.trunk.fc_norm(x)
468
+ x = self.visual.head.proj(x)
469
+ image_features = F.normalize(x, dim=-1)
470
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
471
+
472
+ if self.output_dict:
473
+ return {
474
+ "image_features": image_features,
475
+ "text_features": text_features,
476
+ "logit_scale": self.logit_scale.exp()
477
+ }
478
+ return image_features, text_features, self.logit_scale.exp()
479
+
480
+ def forward(
481
+ self,
482
+ image: Optional[torch.Tensor] = None,
483
+ text: Optional[torch.Tensor] = None,
484
+ ):
485
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
486
+ # if self.text_no_grad:
487
+ # with torch.no_grad():
488
+ # text_features = self.encode_text(text, normalize=True).detach() if text is not None else None
489
+ # else:
490
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
491
+
492
+ if self.output_dict:
493
+ return {
494
+ "image_features": image_features,
495
+ "text_features": text_features,
496
+ "logit_scale": self.logit_scale.exp()
497
+ }
498
+ return image_features, text_features, self.logit_scale.exp()
499
+
500
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
501
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
502
+
503
+ def _convert_weights(l):
504
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
505
+ l.weight.data = l.weight.data.to(dtype)
506
+ if l.bias is not None:
507
+ l.bias.data = l.bias.data.to(dtype)
508
+
509
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
510
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
511
+ tensor = getattr(l, attr)
512
+ if tensor is not None:
513
+ tensor.data = tensor.data.to(dtype)
514
+
515
+ if isinstance(l, (CLIP, TextTransformer)):
516
+ # convert text nn.Parameter projections
517
+ attr = getattr(l, "text_projection", None)
518
+ if attr is not None:
519
+ attr.data = attr.data.to(dtype)
520
+
521
+ if isinstance(l, VisionTransformer):
522
+ # convert vision nn.Parameter projections
523
+ attr = getattr(l, "proj", None)
524
+ if attr is not None:
525
+ attr.data = attr.data.to(dtype)
526
+
527
+ model.apply(_convert_weights)
528
+
529
+
530
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
531
+
532
+
533
+ # used to maintain checkpoint compatibility
534
+ def convert_to_custom_text_state_dict(state_dict: dict):
535
+ if 'text_projection' in state_dict:
536
+ # old format state_dict, move text tower -> .text
537
+ new_state_dict = {}
538
+ for k, v in state_dict.items():
539
+ if any(k.startswith(p) for p in (
540
+ 'text_projection',
541
+ 'positional_embedding',
542
+ 'token_embedding',
543
+ 'transformer',
544
+ 'ln_final',
545
+ )):
546
+ k = 'text.' + k
547
+ new_state_dict[k] = v
548
+ return new_state_dict
549
+ return state_dict
550
+
551
+
552
+ def build_model_from_openai_state_dict(
553
+ state_dict: dict,
554
+ quick_gelu=True,
555
+ cast_dtype=torch.float16,
556
+ ):
557
+ vit = "visual.proj" in state_dict
558
+
559
+ if vit:
560
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
561
+ vision_layers = len(
562
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
563
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
564
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
565
+ image_size = vision_patch_size * grid_size
566
+ else:
567
+ counts: list = [
568
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
569
+ vision_layers = tuple(counts)
570
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
571
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
572
+ vision_patch_size = None
573
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
574
+ image_size = output_width * 32
575
+
576
+ embed_dim = state_dict["text_projection"].shape[1]
577
+ context_length = state_dict["positional_embedding"].shape[0]
578
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
579
+ transformer_width = state_dict["ln_final.weight"].shape[0]
580
+ transformer_heads = transformer_width // 64
581
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
582
+
583
+ vision_cfg = CLIPVisionCfg(
584
+ layers=vision_layers,
585
+ width=vision_width,
586
+ patch_size=vision_patch_size,
587
+ image_size=image_size,
588
+ )
589
+ text_cfg = CLIPTextCfg(
590
+ context_length=context_length,
591
+ vocab_size=vocab_size,
592
+ width=transformer_width,
593
+ heads=transformer_heads,
594
+ layers=transformer_layers,
595
+ )
596
+ model = CLIP(
597
+ embed_dim,
598
+ vision_cfg=vision_cfg,
599
+ text_cfg=text_cfg,
600
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
601
+ cast_dtype=cast_dtype,
602
+ )
603
+
604
+ for key in ["input_resolution", "context_length", "vocab_size"]:
605
+ state_dict.pop(key, None)
606
+
607
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
608
+ model.load_state_dict(state_dict)
609
+ return model.eval()
610
+
611
+
612
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
613
+ model.eval()
614
+ image_size = model.visual.image_size
615
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
616
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
617
+ model = torch.jit.trace_module(
618
+ model,
619
+ inputs=dict(
620
+ forward=(example_images, example_text),
621
+ encode_text=(example_text,),
622
+ encode_image=(example_images,)
623
+ ))
624
+ model.visual.image_size = image_size
625
+ return model
626
+
627
+ def resize_pos_embed_timm(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
628
+ # Rescale the grid of position embeddings when loading from state_dict
629
+ old_pos_embed = state_dict.get('visual.trunk.pos_embed', None) # 1, 196, 1024]
630
+ if old_pos_embed is None:
631
+ return
632
+
633
+ grid_size = to_2tuple(model.visual.trunk.patch_embed.grid_size)
634
+
635
+
636
+ if hasattr(model.visual.trunk, 'cls_token') and model.visual.trunk.cls_token is not None:
637
+ return
638
+ # extra_tokens?
639
+ raise NotImplementedError
640
+
641
+ new_seq_len = grid_size[0] * grid_size[1]
642
+ if new_seq_len == old_pos_embed.shape[0]:
643
+ return
644
+
645
+ pos_emb_img = old_pos_embed
646
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img[0]))))
647
+ old_pos_emb_img = pos_emb_img
648
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) # Resizing position embedding grid-size from (1, 1) to (21, 21)
649
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
650
+
651
+ pos_emb_img = F.interpolate(
652
+ pos_emb_img,
653
+ size=grid_size,
654
+ mode=interpolation,
655
+ antialias=antialias,
656
+ align_corners=False,
657
+ )
658
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)
659
+ state_dict['visual.trunk.pos_embed'] = pos_emb_img
660
+
661
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
662
+ # Rescale the grid of position embeddings when loading from state_dict
663
+ pe_key_name = 'visual.positional_embedding'
664
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
665
+ if old_pos_embed is None:
666
+ pe_key_name = 'visual.trunk.pos_embed'
667
+ old_pos_embed = state_dict.get('visual.trunk.pos_embed', None) # 1, 196, 1024]
668
+
669
+ if old_pos_embed is None:
670
+ return
671
+
672
+ if hasattr(model.visual, 'grid_size'):
673
+ grid_size = to_2tuple(model.visual.grid_size)
674
+ elif hasattr(model.visual.trunk.patch_embed, 'grid_size'):
675
+ grid_size = to_2tuple(model.visual.trunk.patch_embed.grid_size)
676
+ else:
677
+ return
678
+
679
+ if hasattr(model.visual.trunk, 'cls_token') and model.visual.trunk.cls_token is not None:
680
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
681
+ else:
682
+ extra_tokens = 0
683
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
684
+
685
+ if new_seq_len == old_pos_embed.shape[0]:
686
+ return
687
+
688
+ if extra_tokens:
689
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
690
+ else:
691
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
692
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
693
+ old_pos_emb_img = pos_emb_img
694
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) # Resizing position embedding grid-size from (1, 1) to (21, 21)
695
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
696
+
697
+
698
+ pos_emb_img = F.interpolate(
699
+ pos_emb_img,
700
+ size=grid_size,
701
+ mode=interpolation,
702
+ antialias=antialias,
703
+ align_corners=False,
704
+ )
705
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
706
+ if pos_emb_tok is not None:
707
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
708
+ else:
709
+ new_pos_embed = pos_emb_img
710
+ state_dict[pe_key_name] = new_pos_embed
711
+
712
+ def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
713
+ old_pos_embed = state_dict.get('positional_embedding', None)
714
+ if old_pos_embed is None:
715
+ return
716
+ # FIXME add support for text cls_token
717
+ model_pos_embed = getattr(model, 'positional_embedding', None)
718
+ if model_pos_embed is None:
719
+ model_pos_embed = getattr(model.text, 'positional_embedding', None)
720
+
721
+ old_num_pos = old_pos_embed.shape[0]
722
+ old_width = old_pos_embed.shape[1]
723
+ num_pos = model_pos_embed.shape[0]
724
+ width = model_pos_embed.shape[1]
725
+ assert old_width == width, 'text pos_embed width changed!'
726
+ if old_num_pos == num_pos:
727
+ return
728
+
729
+ logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
730
+ old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
731
+ old_pos_embed = F.interpolate(
732
+ old_pos_embed,
733
+ size=num_pos,
734
+ mode=interpolation,
735
+ antialias=antialias,
736
+ align_corners=False,
737
+ )
738
+ old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
739
+ new_pos_embed = old_pos_embed
740
+
741
+ state_dict['positional_embedding'] = new_pos_embed
timm_model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model (OpenCLIP).
4
+ """
5
+ import logging
6
+ from collections import OrderedDict
7
+
8
+ import torch, sys
9
+ import torch.nn as nn
10
+ import timm
11
+
12
+ try:
13
+ import timm
14
+ from timm.models.layers import Mlp, to_2tuple
15
+ try:
16
+ # old timm imports < 0.8.1
17
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
18
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
19
+ except ImportError:
20
+ # new timm imports >= 0.8.1
21
+ from timm.layers import RotAttentionPool2d
22
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
23
+ except ImportError:
24
+ timm = None
25
+ from timm.models import create_model
26
+ from open_clip.utils import freeze_batch_norm_2d
27
+
28
+ from .vitamin import *
29
+
30
+ class TimmModel(nn.Module):
31
+ """ timm model adapter
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model_name,
37
+ embed_dim,
38
+ image_size=224,
39
+ pool='avg',
40
+ proj='linear',
41
+ proj_bias=False,
42
+ drop=0.,
43
+ drop_path=None,
44
+ patch_drop=None,
45
+ pretrained=False,
46
+ ):
47
+ super().__init__()
48
+ if timm is None:
49
+ raise RuntimeError("Please `pip install timm` to use timm models.")
50
+ self.image_size = to_2tuple(image_size)
51
+
52
+ # setup kwargs that may not be common across all models
53
+ timm_kwargs = {}
54
+ if drop_path is not None:
55
+ timm_kwargs['drop_path_rate'] = drop_path
56
+ if patch_drop is not None:
57
+ timm_kwargs['patch_drop_rate'] = patch_drop
58
+
59
+ custom_pool = pool in ('abs_attn', 'rot_attn')
60
+ if not proj and not custom_pool:
61
+ # use network classifier head as projection if no proj specified and no custom pooling used
62
+ self.trunk = timm.create_model(
63
+ model_name,
64
+ num_classes=embed_dim,
65
+ global_pool=pool,
66
+ pretrained=pretrained,
67
+ **timm_kwargs,
68
+ )
69
+ prev_chs = embed_dim
70
+ else:
71
+ self.trunk = timm.create_model(
72
+ model_name,
73
+ pretrained=pretrained,
74
+ **timm_kwargs,
75
+ )
76
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
77
+ feature_ndim = 1 if not feat_size else 2
78
+ if custom_pool:
79
+ assert feature_ndim == 2
80
+ # if attn pooling used, remove both classifier and default pool
81
+ self.trunk.reset_classifier(0, global_pool='')
82
+ else:
83
+ # reset global pool if pool config set, otherwise leave as network default
84
+ reset_kwargs = dict(global_pool=pool) if pool else {}
85
+ self.trunk.reset_classifier(0, **reset_kwargs)
86
+ prev_chs = self.trunk.num_features
87
+
88
+ head_layers = OrderedDict()
89
+
90
+ # Add custom pooling to head
91
+ if pool == 'abs_attn':
92
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
93
+ prev_chs = embed_dim
94
+ elif pool == 'rot_attn':
95
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
96
+ prev_chs = embed_dim
97
+
98
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
99
+ if proj == 'linear':
100
+ head_layers['drop'] = nn.Dropout(drop)
101
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
102
+ elif proj == 'mlp':
103
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
104
+ else:
105
+ assert not proj, f'Unknown projection type {proj}.'
106
+
107
+ self.head = nn.Sequential(head_layers)
108
+
109
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
110
+ """ lock modules
111
+ Args:
112
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
113
+ """
114
+ if not unlocked_groups:
115
+ # lock full model
116
+ for param in self.trunk.parameters():
117
+ param.requires_grad = False
118
+ if freeze_bn_stats:
119
+ freeze_batch_norm_2d(self.trunk)
120
+ else:
121
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
122
+ try:
123
+ # FIXME import here until API stable and in an official release
124
+ from timm.models.helpers import group_parameters, group_modules
125
+ except ImportError:
126
+ raise RuntimeError(
127
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
128
+ matcher = self.trunk.group_matcher()
129
+ gparams = group_parameters(self.trunk, matcher)
130
+ max_layer_id = max(gparams.keys())
131
+ max_layer_id = max_layer_id - unlocked_groups
132
+ for group_idx in range(max_layer_id + 1):
133
+ group = gparams[group_idx]
134
+ for param in group:
135
+ self.trunk.get_parameter(param).requires_grad = False
136
+ if freeze_bn_stats:
137
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
138
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
139
+ freeze_batch_norm_2d(self.trunk, gmodules)
140
+
141
+ @torch.jit.ignore
142
+ def set_grad_checkpointing(self, enable=True):
143
+ try:
144
+ self.trunk.set_grad_checkpointing(enable)
145
+ except Exception as e:
146
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
147
+
148
+ def forward(self, x):
149
+ x = self.trunk(x)
150
+ x = self.head(x)
151
+ return x
vitamin.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ViTamin
2
+
3
+ Paper: Designing Scalable Vison Models in the Vision-Language Era
4
+
5
+ @misc{chen2023designing,
6
+ title={Designing Scalable Vison Models in the Vision-Language Era},
7
+ author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen},
8
+ year={2023},
9
+ archivePrefix={arXiv},
10
+ primaryClass={cs.CV}
11
+ }
12
+
13
+ Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin
14
+
15
+ Modifications and timm support by Jieneng Chen 2023
16
+
17
+ Adapted from timm codebase, thanks!
18
+ """
19
+
20
+ from functools import partial
21
+ from typing import List, Tuple
22
+ from dataclasses import dataclass, replace
23
+ from typing import Callable, Optional, Union, Tuple, List, Sequence
24
+ import math, time
25
+ from torch.jit import Final
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import timm
30
+ from torch.utils.checkpoint import checkpoint
31
+ from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
32
+
33
+
34
+ from timm.layers import to_2tuple, DropPath, Format # , trunc_normal_
35
+ from timm.layers.norm_act import _create_act
36
+ from timm.models._registry import register_model
37
+ from timm.models._manipulate import named_apply, checkpoint_seq
38
+ from timm.models._builder import build_model_with_cfg
39
+ from timm.models.vision_transformer import get_act_layer, Type, LayerType, Mlp, Block, PatchEmbed, VisionTransformer, checkpoint_filter_fn, get_init_weights_vit, init_weights_vit_timm, _load_weights
40
+ import logging
41
+ from collections import OrderedDict
42
+
43
+
44
+
45
+ @dataclass
46
+ class VitConvCfg:
47
+ expand_ratio: float = 4.0
48
+ expand_output: bool = True # calculate expansion channels from output (vs input chs)
49
+ kernel_size: int = 3
50
+ group_size: int = 1 # 1 == depthwise
51
+ pre_norm_act: bool = False # activation after pre-norm
52
+ stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
53
+ pool_type: str = 'avg2'
54
+ downsample_pool_type: str = 'avg2'
55
+ act_layer: str = 'gelu' # stem & stage 1234
56
+ norm_layer: str = ''
57
+ norm_layer_cl: str = ''
58
+ norm_eps: Optional[float] = None
59
+ down_shortcut: Optional[bool] = True
60
+ mlp: str = 'mlp'
61
+
62
+ def __post_init__(self):
63
+ use_mbconv = True
64
+ if not self.norm_layer:
65
+ self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
66
+ if not self.norm_layer_cl and not use_mbconv:
67
+ self.norm_layer_cl = 'layernorm'
68
+ if self.norm_eps is None:
69
+ self.norm_eps = 1e-5 if use_mbconv else 1e-6
70
+ self.downsample_pool_type = self.downsample_pool_type or self.pool_type
71
+
72
+ @dataclass
73
+ class VitCfg:
74
+ embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
75
+ depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
76
+ stem_width: int = 64
77
+ conv_cfg: VitConvCfg = VitConvCfg()
78
+ weight_init: str = 'vit_eff'
79
+ head_type: str = ""
80
+ stem_type: str = "stem"
81
+
82
+ def _init_conv(module, name, scheme=''):
83
+ if isinstance(module, nn.Conv2d):
84
+ fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
85
+ fan_out //= module.groups
86
+ nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
87
+ if module.bias is not None:
88
+ nn.init.zeros_(module.bias)
89
+
90
+ class Stem(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_chs: int,
94
+ out_chs: int,
95
+ act_layer: str = 'gelu',
96
+ norm_layer: str = 'layernorm2d',
97
+ norm_eps: float = 1e-6,
98
+ bias: bool = True,
99
+ ):
100
+ super().__init__()
101
+ self.grad_checkpointing=False
102
+ norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
103
+ self.out_chs = out_chs
104
+ self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
105
+ self.norm1 = norm_act_layer(out_chs)
106
+ self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
107
+ named_apply(_init_conv, self)
108
+
109
+ def forward(self, x):
110
+ if self.grad_checkpointing:
111
+ x = checkpoint(self.conv1, x)
112
+ x = self.norm1(x)
113
+ x = checkpoint(self.conv2, x)
114
+ else:
115
+ x = self.conv1(x)
116
+ x = self.norm1(x)
117
+ x = self.conv2(x)
118
+
119
+ return x
120
+
121
+ class Downsample2d(nn.Module):
122
+ def __init__(
123
+ self,
124
+ dim: int,
125
+ dim_out: int,
126
+ bias: bool = True,
127
+ ):
128
+ super().__init__()
129
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
130
+ if dim != dim_out:
131
+ self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv
132
+ else:
133
+ self.expand = nn.Identity()
134
+
135
+ def forward(self, x):
136
+ x = self.pool(x)
137
+ x = self.expand(x)
138
+ return x
139
+
140
+
141
+ class StridedConv(nn.Module):
142
+ """ downsample 2d as well
143
+ """
144
+ def __init__(
145
+ self,
146
+ kernel_size=3,
147
+ stride=2,
148
+ padding=1,
149
+ in_chans=3,
150
+ embed_dim=768,
151
+ ):
152
+ super().__init__()
153
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
154
+ norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
155
+ self.norm = norm_layer(in_chans)
156
+
157
+ def forward(self, x):
158
+ x = self.norm(x)
159
+ x = self.proj(x)
160
+ return x
161
+
162
+
163
+ class MbConvLNBlock(nn.Module):
164
+ def __init__(
165
+ self,
166
+ in_chs: int,
167
+ out_chs: int,
168
+ stride: int = 1,
169
+ drop_path: float = 0.,
170
+ kernel_size: int = 3,
171
+ norm_layer: str = 'layernorm2d',
172
+ norm_eps: float = 1e-6,
173
+ act_layer: str = 'gelu',
174
+ expand_ratio: float = 4.0,
175
+ ):
176
+ super(MbConvLNBlock, self).__init__()
177
+ self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
178
+ mid_chs = make_divisible(out_chs * expand_ratio)
179
+ prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
180
+
181
+ if stride == 2:
182
+ self.shortcut = Downsample2d(in_chs, out_chs, bias=True)
183
+ elif in_chs != out_chs:
184
+ self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True)
185
+ else:
186
+ self.shortcut = nn.Identity()
187
+
188
+ self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
189
+ self.down = nn.Identity()
190
+ self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
191
+ self.act1 = _create_act(act_layer, inplace=True)
192
+ self.act2 = _create_act(act_layer, inplace=True)
193
+
194
+ self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
195
+ self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
196
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
197
+
198
+
199
+ def init_weights(self, scheme=''):
200
+ named_apply(partial(_init_conv, scheme=scheme), self)
201
+
202
+ def forward(self, x):
203
+ shortcut = self.shortcut(x)
204
+
205
+ x = self.pre_norm(x)
206
+ x = self.down(x) # nn.Identity()
207
+
208
+ # 1x1 expansion conv & act
209
+ x = self.conv1_1x1(x)
210
+ x = self.act1(x)
211
+
212
+ # (strided) depthwise 3x3 conv & act
213
+ x = self.conv2_kxk(x)
214
+ x = self.act2(x)
215
+
216
+ # 1x1 linear projection to output width
217
+ x = self.conv3_1x1(x)
218
+ x = self.drop_path(x) + shortcut
219
+
220
+ return x
221
+
222
+
223
+ class MbConvStages(nn.Module):
224
+ """ stage 1 and stage 2 of ViTamin: MBConv-LN blocks
225
+ """
226
+ def __init__(
227
+ self,
228
+ cfg: VitCfg,
229
+ img_size: Union[int, Tuple[int, int]] = 224, # place holder
230
+ in_chans: int = 3,
231
+ ):
232
+ super().__init__()
233
+ self.grad_checkpointing = False
234
+ self.stem = Stem(
235
+ in_chs=in_chans,
236
+ out_chs=cfg.stem_width,
237
+ )
238
+ stages = []
239
+ self.num_stages = len(cfg.embed_dim)
240
+ for s, dim in enumerate(cfg.embed_dim[:2]):
241
+ blocks = []
242
+ stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
243
+ for d in range(cfg.depths[s]):
244
+ blocks += [MbConvLNBlock(
245
+ in_chs = stage_in_chs if d==0 else dim,
246
+ out_chs = dim,
247
+ stride = 2 if d == 0 else 1,
248
+ )]
249
+ blocks = nn.Sequential(*blocks)
250
+ stages += [blocks]
251
+
252
+ self.stages = nn.ModuleList(stages)
253
+ self.pool = StridedConv(
254
+ stride=2,
255
+ in_chans=cfg.embed_dim[1],
256
+ embed_dim=cfg.embed_dim[2]
257
+ )
258
+
259
+ def forward(self, x):
260
+ x = self.stem(x)
261
+ if self.grad_checkpointing and not torch.jit.is_scripting():
262
+ for stage in self.stages:
263
+ x = checkpoint_seq(stage, x)
264
+ x = checkpoint(self.pool, x)
265
+ else:
266
+ for stage in self.stages:
267
+ x = stage(x)
268
+ x = self.pool(x)
269
+
270
+ return x
271
+
272
+ class GeGluMlp(nn.Module):
273
+ def __init__(
274
+ self,
275
+ in_features,
276
+ hidden_features,
277
+ act_layer = None,
278
+ drop = 0.0,
279
+ ):
280
+ super().__init__()
281
+ norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
282
+ self.norm = norm_layer(in_features)
283
+ self.act = nn.GELU()
284
+ self.w0 = nn.Linear(in_features, hidden_features)
285
+ self.w1 = nn.Linear(in_features, hidden_features)
286
+ self.w2 = nn.Linear(hidden_features, in_features)
287
+
288
+ def forward(self, x):
289
+ x = self.norm(x)
290
+ x = self.act(self.w0(x)) * self.w1(x)
291
+ x = self.w2(x)
292
+ return x
293
+
294
+ class HybridEmbed(nn.Module):
295
+ """
296
+ Extract feature map from stage 1-2, flatten, project to embedding dim.
297
+ """
298
+ def __init__(
299
+ self,
300
+ backbone,
301
+ img_size=224,
302
+ patch_size=1,
303
+ feature_size=None,
304
+ in_chans=3,
305
+ embed_dim=1024,
306
+ bias=True,
307
+ dynamic_img_pad=False,
308
+ ):
309
+ super().__init__()
310
+ assert isinstance(backbone, nn.Module)
311
+ img_size = to_2tuple(img_size)
312
+ patch_size = to_2tuple(patch_size)
313
+ self.img_size = img_size
314
+ self.patch_size = patch_size
315
+ self.backbone = backbone
316
+ if feature_size is None:
317
+ feature_size = img_size[0] // 16
318
+ feature_size = to_2tuple(feature_size)
319
+ if hasattr(self.backbone, 'feature_info'):
320
+ feature_dim = self.backbone.feature_info.channels()[-1]
321
+ elif hasattr(self.backbone, 'num_features'):
322
+ feature_dim = self.backbone.num_features
323
+ else:
324
+ feature_dim = embed_dim
325
+ assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
326
+ self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
327
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
328
+ self.proj = nn.Identity()
329
+
330
+ def forward(self, x):
331
+ x = self.backbone(x)
332
+ if isinstance(x, (list, tuple)):
333
+ x = x[-1] # last feature if backbone outputs list/tuple of features
334
+ x = self.proj(x)
335
+ x = x.flatten(2).transpose(1, 2)
336
+ return x
337
+
338
+ def _trunc_normal_(tensor, mean, std, a, b):
339
+ # rewrite timm trunc normal
340
+ def norm_cdf(x):
341
+ # Computes standard normal cumulative distribution function
342
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
343
+
344
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
345
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
346
+ "The distribution of values may be incorrect.",
347
+ stacklevel=2)
348
+
349
+ l = norm_cdf((a - mean) / std)
350
+ u = norm_cdf((b - mean) / std)
351
+
352
+ # Uniformly fill tensor with values from [l, u], then translate to
353
+ # [2l-1, 2u-1].
354
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
355
+
356
+ # Use inverse cdf transform for normal distribution to get truncated standard normal
357
+ # tensor.erfinv_() # NOTE: deleted as "erfinv_cuda" not implemented for 'BFloat16'
358
+
359
+ # Transform to proper mean, std
360
+ tensor.mul_(std * math.sqrt(2.))
361
+ tensor.add_(mean)
362
+
363
+ # Clamp to ensure it's in the proper range
364
+ tensor.clamp_(min=a, max=b)
365
+ return tensor
366
+
367
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
368
+ with torch.no_grad():
369
+ return _trunc_normal_(tensor, mean, std, a, b)
370
+
371
+ class ViTamin(nn.Module):
372
+ """ hack timm VisionTransformer
373
+ """
374
+ dynamic_img_size: Final[bool]
375
+
376
+ def __init__(
377
+ self,
378
+ img_size: Union[int, Tuple[int, int]] = 224,
379
+ patch_size: Union[int, Tuple[int, int]] = 16,
380
+ in_chans: int = 3,
381
+ num_classes: int = 1000,
382
+ global_pool = 'token',
383
+ embed_dim: int = 768,
384
+ depth: int = 12,
385
+ num_heads: int = 12,
386
+ mlp_ratio: float = 4.,
387
+ qkv_bias: bool = True,
388
+ qk_norm: bool = False,
389
+ init_values: Optional[float] = None,
390
+ class_token: bool = True,
391
+ no_embed_class: bool = False,
392
+ reg_tokens: int = 0,
393
+ pre_norm: bool = False,
394
+ fc_norm: Optional[bool] = None,
395
+ dynamic_img_size: bool = False,
396
+ dynamic_img_pad: bool = False,
397
+ drop_rate: float = 0.,
398
+ pos_drop_rate: float = 0.,
399
+ patch_drop_rate: float = 0.,
400
+ proj_drop_rate: float = 0.,
401
+ attn_drop_rate: float = 0.,
402
+ drop_path_rate: float = 0.,
403
+ weight_init = '',
404
+ fix_init: bool = False,
405
+ embed_layer: Callable = PatchEmbed,
406
+ norm_layer: Optional[LayerType] = None,
407
+ act_layer: Optional[LayerType] = None,
408
+ block_fn: Type[nn.Module] = Block,
409
+ mlp_layer: Type[nn.Module] = Mlp,
410
+ is_pos_embed: bool = True
411
+ ) -> None:
412
+ """
413
+ Args:
414
+ img_size: Input image size.
415
+ patch_size: Patch size.
416
+ in_chans: Number of image input channels.
417
+ num_classes: Mumber of classes for classification head.
418
+ global_pool: Type of global pooling for final sequence (default: 'token').
419
+ embed_dim: Transformer embedding dimension.
420
+ depth: Depth of transformer.
421
+ num_heads: Number of attention heads.
422
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
423
+ qkv_bias: Enable bias for qkv projections if True.
424
+ init_values: Layer-scale init values (layer-scale enabled if not None).
425
+ class_token: Use class token.
426
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
427
+ reg_tokens: Number of register tokens.
428
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
429
+ drop_rate: Head dropout rate.
430
+ pos_drop_rate: Position embedding dropout rate.
431
+ attn_drop_rate: Attention dropout rate.
432
+ drop_path_rate: Stochastic depth rate.
433
+ weight_init: Weight initialization scheme.
434
+ fix_init: Apply weight initialization fix (scaling w/ layer index).
435
+ embed_layer: Patch embedding layer.
436
+ norm_layer: Normalization layer.
437
+ act_layer: MLP activation layer.
438
+ block_fn: Transformer block layer.
439
+ """
440
+ super().__init__()
441
+ assert global_pool in ('', 'avg', 'token', 'map')
442
+ assert class_token or global_pool != 'token'
443
+ use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
444
+ norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
445
+ act_layer = get_act_layer(act_layer) or nn.GELU
446
+
447
+ self.num_classes = num_classes
448
+ self.global_pool = global_pool
449
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
450
+ self.num_prefix_tokens = 1 if class_token else 0
451
+ self.num_prefix_tokens += reg_tokens
452
+ self.num_reg_tokens = reg_tokens
453
+ self.has_class_token = class_token
454
+ self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
455
+ self.dynamic_img_size = dynamic_img_size
456
+ self.grad_checkpointing = False
457
+ self.is_pos_embed = is_pos_embed
458
+ embed_args = {}
459
+ if dynamic_img_size:
460
+ # flatten deferred until after pos embed
461
+ embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
462
+
463
+ # stage_1_2 = MbConvStages(cfg=VitCfg(
464
+ # embed_dim=(160, 320, 1024),
465
+ # depths=(2, 4, 1),
466
+ # stem_width=160,
467
+ # conv_cfg = VitConvCfg(
468
+ # norm_layer='layernorm2d',
469
+ # norm_eps=1e-6,
470
+ # ),
471
+ # head_type='1d',
472
+ # ),
473
+ # )
474
+ # self.patch_embed = HybridEmbed(
475
+ # stage_1_2,
476
+ # img_size=img_size,
477
+ # patch_size=1,
478
+ # in_chans=in_chans,
479
+ # embed_dim=embed_dim,
480
+ # bias=not pre_norm,
481
+ # dynamic_img_pad=dynamic_img_pad,
482
+ # **embed_args,)
483
+ self.patch_embed = embed_layer(
484
+ img_size=img_size,
485
+ patch_size=patch_size,
486
+ in_chans=in_chans,
487
+ embed_dim=embed_dim,
488
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
489
+ )
490
+
491
+ num_patches = self.patch_embed.num_patches
492
+
493
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
494
+ self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
495
+
496
+ if self.is_pos_embed:
497
+ embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
498
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
499
+ else:
500
+ self.pos_embed = None
501
+
502
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
503
+ if patch_drop_rate > 0:
504
+ self.patch_drop = PatchDropout(
505
+ patch_drop_rate,
506
+ num_prefix_tokens=self.num_prefix_tokens,
507
+ )
508
+ else:
509
+ self.patch_drop = nn.Identity()
510
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
511
+
512
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
513
+ self.blocks = nn.Sequential(*[
514
+ block_fn(
515
+ dim=embed_dim,
516
+ num_heads=num_heads,
517
+ mlp_ratio=mlp_ratio,
518
+ qkv_bias=qkv_bias,
519
+ qk_norm=qk_norm,
520
+ init_values=init_values,
521
+ proj_drop=proj_drop_rate,
522
+ attn_drop=attn_drop_rate,
523
+ drop_path=dpr[i],
524
+ norm_layer=norm_layer,
525
+ act_layer=act_layer,
526
+ mlp_layer=mlp_layer,
527
+ )
528
+ for i in range(depth)])
529
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
530
+
531
+ # Classifier Head
532
+ if global_pool == 'map':
533
+ self.attn_pool = AttentionPoolLatent(
534
+ self.embed_dim,
535
+ num_heads=num_heads,
536
+ mlp_ratio=mlp_ratio,
537
+ norm_layer=norm_layer,
538
+ )
539
+ else:
540
+ self.attn_pool = None
541
+
542
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
543
+ self.head_drop = nn.Dropout(drop_rate)
544
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
545
+
546
+ if weight_init != 'skip':
547
+ self.init_weights(weight_init)
548
+ if fix_init:
549
+ self.fix_init_weight()
550
+
551
+ def init_weights(self, mode=''):
552
+ assert mode in ('jax', 'jax_nlhb', 'moco', '')
553
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
554
+ if self.is_pos_embed:
555
+ trunc_normal_(self.pos_embed, std=.02)
556
+ if self.cls_token is not None:
557
+ nn.init.normal_(self.cls_token, std=1e-6)
558
+ named_apply(get_init_weights_vit(mode, head_bias), self)
559
+
560
+ def _init_weights(self, m):
561
+ # this fn left here for compat with downstream users
562
+ init_weights_vit_timm(m)
563
+
564
+ @torch.jit.ignore()
565
+ def load_pretrained(self, checkpoint_path, prefix=''):
566
+ _load_weights(self, checkpoint_path, prefix)
567
+
568
+ @torch.jit.ignore
569
+ def no_weight_decay(self):
570
+ if self.is_pos_embed:
571
+ return {'pos_embed', 'cls_token', 'dist_token'}
572
+ else:
573
+ return {'cls_token', 'dist_token'}
574
+
575
+ @torch.jit.ignore
576
+ def group_matcher(self, coarse=False):
577
+ return dict(
578
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
579
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
580
+ )
581
+
582
+ @torch.jit.ignore
583
+ def set_grad_checkpointing(self, enable=True):
584
+ self.grad_checkpointing = enable
585
+ self.patch_embed.backbone.stem.grad_checkpointing = enable # disable https://blog.csdn.net/lhx526080338/article/details/127894671?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_baidulandingword~default-1-127894671-blog-125562110.235^v38^pc_relevant_anti_t3_base&spm=1001.2101.3001.4242.2&utm_relevant_index=4
586
+ self.patch_embed.backbone.grad_checkpointing = enable
587
+
588
+ @torch.jit.ignore
589
+ def get_classifier(self):
590
+ return self.head
591
+
592
+ def reset_classifier(self, num_classes: int, global_pool=None):
593
+ self.num_classes = num_classes
594
+ if global_pool is not None:
595
+ assert global_pool in ('', 'avg', 'token')
596
+ self.global_pool = global_pool
597
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
598
+
599
+ def _pos_embed(self, x):
600
+ if self.no_embed_class:
601
+ # deit-3, updated JAX (big vision)
602
+ # position embedding does not overlap with class token, add then concat
603
+ x = x + self.pos_embed
604
+ if self.cls_token is not None:
605
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
606
+ else:
607
+ # original timm, JAX, and deit vit impl
608
+ # pos_embed has entry for class token, concat then add
609
+ if self.cls_token is not None:
610
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
611
+ x = x + self.pos_embed
612
+ return self.pos_drop(x)
613
+
614
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
615
+ x = self.patch_embed(x)
616
+ if self.is_pos_embed:
617
+ x = self._pos_embed(x)
618
+ x = self.patch_drop(x)
619
+ x = self.norm_pre(x)
620
+ if self.grad_checkpointing and not torch.jit.is_scripting():
621
+ x = checkpoint_seq(self.blocks, x)
622
+ else:
623
+ x = self.blocks(x)
624
+ x = self.norm(x)
625
+ return x
626
+
627
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
628
+ if self.attn_pool is not None:
629
+ x = self.attn_pool(x)
630
+ elif self.global_pool == 'avg':
631
+ x = x[:, self.num_prefix_tokens:].mean(dim=1)
632
+ elif self.global_pool:
633
+ x = x[:, 0] # class token
634
+ x = self.fc_norm(x)
635
+ x = self.head_drop(x)
636
+ return x if pre_logits else self.head(x)
637
+
638
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
639
+ x = self.forward_features(x)
640
+ x = self.forward_head(x)
641
+ return x
642
+
643
+ def _create_vision_transformer(variant, pretrained=False, **kwargs):
644
+ if kwargs.get('features_only', None):
645
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
646
+
647
+ return build_model_with_cfg(
648
+ ViTamin, # ViTamin
649
+ variant,
650
+ pretrained,
651
+ pretrained_filter_fn=checkpoint_filter_fn,
652
+ **kwargs,
653
+ )
654
+
655
+
656
+ def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
657
+ embed_layer = partial(HybridEmbed, backbone=backbone)
658
+ kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
659
+ return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
660
+
661
+
662
+ @register_model
663
+ def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer:
664
+ stage_1_2 = MbConvStages(cfg=VitCfg(
665
+ embed_dim=(64, 128, 384),
666
+ depths=(2, 4, 1),
667
+ stem_width=64,
668
+ conv_cfg = VitConvCfg(
669
+ norm_layer='layernorm2d',
670
+ norm_eps=1e-6,
671
+ ),
672
+ head_type='1d',
673
+ ),
674
+ )
675
+ stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
676
+ model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
677
+ return model
678
+
679
+
680
+ @register_model
681
+ def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer:
682
+ stage_1_2 = MbConvStages(cfg=VitCfg(
683
+ embed_dim=(128, 256, 768),
684
+ depths=(2, 4, 1),
685
+ stem_width=128,
686
+ conv_cfg = VitConvCfg(
687
+ norm_layer='layernorm2d',
688
+ norm_eps=1e-6,
689
+ ),
690
+ head_type='1d',
691
+ ),
692
+ )
693
+ stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
694
+ model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
695
+ return model
696
+
697
+
698
+ @register_model
699
+ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
700
+ stage_1_2 = MbConvStages(cfg=VitCfg(
701
+ embed_dim=(160, 320, 1024),
702
+ depths=(2, 4, 1),
703
+ stem_width=160,
704
+ conv_cfg = VitConvCfg(
705
+ norm_layer='layernorm2d',
706
+ norm_eps=1e-6,
707
+ ),
708
+ head_type='1d',
709
+ ),
710
+ )
711
+ stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
712
+ model = _create_vision_transformer_hybrid(
713
+ 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
714
+ return model
715
+
716
+
717
+ @register_model
718
+ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
719
+ backbone = MbConvStages(cfg=VitCfg(
720
+ embed_dim=(160, 320, 1024),
721
+ depths=(2, 4, 1),
722
+ stem_width=160,
723
+ conv_cfg = VitConvCfg(
724
+ norm_layer='layernorm2d',
725
+ norm_eps=1e-6,
726
+ ),
727
+ head_type='1d',
728
+ ),
729
+ )
730
+ model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
731
+ model = _create_vision_transformer_hybrid(
732
+ 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
733
+ return model
734
+
735
+
736
+ @register_model
737
+ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
738
+ backbone = MbConvStages(cfg=VitCfg(
739
+ embed_dim=(160, 320, 1024),
740
+ depths=(2, 4, 1),
741
+ stem_width=160,
742
+ conv_cfg = VitConvCfg(
743
+ norm_layer='layernorm2d',
744
+ norm_eps=1e-6,
745
+ ),
746
+ head_type='1d',
747
+ ),
748
+ )
749
+ model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
750
+ model = _create_vision_transformer_hybrid(
751
+ 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
752
+ return model
753
+
754
+
755
+ @register_model
756
+ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
757
+ backbone = MbConvStages(cfg=VitCfg(
758
+ embed_dim=(160, 320, 1024),
759
+ depths=(2, 4, 1),
760
+ stem_width=160,
761
+ conv_cfg = VitConvCfg(
762
+ norm_layer='layernorm2d',
763
+ norm_eps=1e-6,
764
+ ),
765
+ head_type='1d',
766
+ ),
767
+ )
768
+ model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, is_pos_embed=False, global_pool='avg')
769
+ model = _create_vision_transformer_hybrid(
770
+ 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
771
+ return model
772
+
773
+
774
+ @register_model
775
+ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
776
+ backbone = MbConvStages(cfg=VitCfg(
777
+ embed_dim=(192, 384, 1152),
778
+ depths=(2, 4, 1),
779
+ stem_width=192,
780
+ conv_cfg = VitConvCfg(
781
+ norm_layer='layernorm2d',
782
+ norm_eps=1e-6,
783
+ ),
784
+ head_type='1d',
785
+ ),
786
+ )
787
+ model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, is_pos_embed=False, global_pool='avg')
788
+ model = _create_vision_transformer_hybrid(
789
+ 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
790
+ return model
791
+
792
+
793
+ @register_model
794
+ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
795
+ backbone = MbConvStages(cfg=VitCfg(
796
+ embed_dim=(192, 384, 1152),
797
+ depths=(2, 4, 1),
798
+ stem_width=192,
799
+ conv_cfg = VitConvCfg(
800
+ norm_layer='layernorm2d',
801
+ norm_eps=1e-6,
802
+ ),
803
+ head_type='1d',
804
+ ),
805
+ )
806
+ model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, is_pos_embed=False, global_pool='avg')
807
+ model = _create_vision_transformer_hybrid(
808
+ 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
809
+ return model
810
+
811
+
812
+ def count_params(model: nn.Module):
813
+ return sum([m.numel() for m in model.parameters()])
814
+
815
+
816
+ def count_stage_params(model: nn.Module, prefix='none'):
817
+ collections = []
818
+ for name, m in model.named_parameters():
819
+ print(name)
820
+ if name.startswith(prefix):
821
+ collections.append(m.numel())
822
+ return sum(collections)
823
+
824
+
825
+ if __name__ == "__main__":
826
+ model = timm.create_model('vitamin_large', num_classes=10).cuda()
827
+ # x = torch.rand([2,3,224,224]).cuda()
828
+ check_keys(model)