""" ViTamin Paper: Designing Scalable Vison Models in the Vision-Language Era @misc{chen2023designing, title={Designing Scalable Vison Models in the Vision-Language Era}, author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, year={2023}, archivePrefix={arXiv}, primaryClass={cs.CV} } Based on Apache 2.0 licensed code at https://github.com/Beckschen/ViTamin by Jieneng Chen 2024 """ import copy import os from collections import OrderedDict from typing import TYPE_CHECKING, Any, Mapping, Optional, Union if TYPE_CHECKING: from transformers.processing_utils import ProcessorMixin from transformers.utils import TensorType from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class ViTaminTextConfig(PretrainedConfig): model_type = "vitamin_text_model" def __init__( self, context_length = 77, vocab_size = 49408, width = 1024, heads = 16, layers = 24, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.context_length = context_length self.width = width self.heads = heads self.layers = layers @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) if 'text_config' in config_dict: config_dict = config_dict['text_config'] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class ViTaminVisionConfig(PretrainedConfig): model_type = "vitamin_vision_model" def __init__( self, timm_model_name = "vitamin_large", timm_model_pretrained = False, timm_pool = "", timm_proj = "linear", timm_drop = 0.0, timm_drop_path = 0.1, image_size = 256, timm_proj_bias = False, patch_dropout = 0.0, drop_path = None, **kwargs, ): super().__init__(**kwargs) self.timm_model_name = timm_model_name self.timm_model_pretrained = timm_model_pretrained self.timm_pool = timm_pool self.timm_proj = timm_proj self.timm_drop = timm_drop self.timm_drop_path = timm_drop_path self.timm_proj_bias = timm_proj_bias self.patch_dropout = patch_dropout self.image_size = image_size @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) if 'vision_config' in config_dict: config_dict = config_dict['vision_config'] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class ViTaminConfig(PretrainedConfig): model_type = "vitamin" is_composition = True def __init__( self, text_config=None, vision_config=None, embed_dim=512, **kwargs ): super().__init__(**kwargs) if text_config is None: text_config = {} logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.") if vision_config is None: vision_config = {} logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.") self.embed_dim = embed_dim self.text_config = ViTaminTextConfig(**text_config) self.vision_config = ViTaminVisionConfig(**vision_config) @classmethod def from_text_vision_configs(cls, text_config: ViTaminTextConfig, vision_config: ViTaminVisionConfig, **kwargs): r""" Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model configuration. Returns: [`CLIPConfig`]: An instance of a configuration object """ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) def to_dict(self): """ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, """ output = copy.deepcopy(self.__dict__) output["text_config"] = self.text_config.to_dict() output["vision_config"] = self.vision_config.to_dict() output["model_type"] = self.__class__.model_type return output