File size: 5,473 Bytes
6dc1cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
""" ViTamin

Paper: Designing Scalable Vison Models in the Vision-Language Era

@misc{chen2023designing,
      title={Designing Scalable Vison Models in the Vision-Language Era},
      author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen},
      year={2023},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Based on Apache 2.0 licensed code at https://github.com/Beckschen/ViTamin

by Jieneng Chen 2024
"""

import copy
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union


if TYPE_CHECKING:
    from transformers.processing_utils import ProcessorMixin
    from transformers.utils import TensorType

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

class ViTaminTextConfig(PretrainedConfig):
    model_type = "vitamin_text_model"

    def __init__(
        self,
        context_length = 77,
        vocab_size = 49408,
        width = 1024,
        heads = 16,
        layers = 24,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.vocab_size = vocab_size
        self.context_length = context_length
        self.width = width
        self.heads = heads
        self.layers = layers
        
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        if 'text_config' in config_dict:
            config_dict = config_dict['text_config']

        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        return cls.from_dict(config_dict, **kwargs)


class ViTaminVisionConfig(PretrainedConfig):

    model_type = "vitamin_vision_model"

    def __init__(
        self,
        timm_model_name = "vitamin_large",
        timm_model_pretrained = False,
        timm_pool = "",
        timm_proj = "linear",
        timm_drop = 0.0,
        timm_drop_path = 0.1,
        image_size = 256,
        timm_proj_bias = False,
        patch_dropout = 0.0,
        drop_path = None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.timm_model_name = timm_model_name
        self.timm_model_pretrained = timm_model_pretrained
        self.timm_pool = timm_pool
        self.timm_proj = timm_proj
        self.timm_drop = timm_drop
        self.timm_drop_path = timm_drop_path
        self.timm_proj_bias = timm_proj_bias
        self.patch_dropout = patch_dropout
        self.image_size = image_size


    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        if 'vision_config' in config_dict:
            config_dict = config_dict['vision_config']

        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        return cls.from_dict(config_dict, **kwargs)



class ViTaminConfig(PretrainedConfig):
    model_type = "vitamin"
    is_composition = True

    def __init__(
        self, text_config=None, vision_config=None, embed_dim=512,  **kwargs
    ):
        super().__init__(**kwargs)
        if text_config is None:
            text_config = {}
            logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")

        if vision_config is None:
            vision_config = {}
            logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
        
        self.embed_dim = embed_dim
        self.text_config = ViTaminTextConfig(**text_config)
        self.vision_config = ViTaminVisionConfig(**vision_config)
        
    @classmethod
    def from_text_vision_configs(cls, text_config: ViTaminTextConfig, vision_config: ViTaminVisionConfig, **kwargs):
        r"""
        Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
        configuration.
        Returns:
            [`CLIPConfig`]: An instance of a configuration object
        """

        return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        output["text_config"] = self.text_config.to_dict()
        output["vision_config"] = self.vision_config.to_dict()
        output["model_type"] = self.__class__.model_type
        return output