japanese-clip-vit-b-32-roberta-base / configuration_japanese_clip.py
hidehisa-arai's picture
update
f68c08c
from transformers import PretrainedConfig, RobertaConfig
class JapaneseCLIPVisionConfig(PretrainedConfig):
model_type = "vit"
is_composition = True
def __init__(self,
image_size: int,
patch_size: int,
width: int,
layers: int,
head_width: int,
mlp_ratio: float,
ls_init_value: float = None,
attentional_pool: bool = False,
attn_pooler_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.0,
no_ln_pre: bool = False,
pool_type: str = "tok",
final_ln_after_pool: bool = False,
output_tokens: bool = False,
**kwargs
):
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.layers = layers
self.head_width = head_width
self.heads = width // head_width
self.mlp_ratio = mlp_ratio
self.ls_init_value = ls_init_value
self.attentional_pool = attentional_pool
self.attn_pooler_queries = attn_pooler_queries
self.attn_pooler_heads = attn_pooler_heads
self.output_dim = output_dim
self.patch_dropout = patch_dropout
self.no_ln_pre = no_ln_pre
self.pool_type = pool_type
self.final_ln_after_pool = final_ln_after_pool
self.output_tokens = output_tokens
super().__init__(**kwargs)
class JapaneseCLIPConfig(PretrainedConfig):
model_type = "japanese_clip"
is_composition = True
def __init__(
self,
max_length: int = 77,
**kwargs
):
super().__init__(**kwargs)
self.max_length = max_length
if "vision_config" not in kwargs:
raise ValueError("vision_config must be provided")
if "text_config" not in kwargs:
raise ValueError("text_config must be provided")
vision_config = kwargs.pop("vision_config")
text_config = kwargs.pop("text_config")
self.vision_config = JapaneseCLIPVisionConfig(**vision_config)
self.text_config = RobertaConfig(**text_config)
@classmethod
def from_vision_text_configs(
cls,
vision_config: PretrainedConfig,
text_config: PretrainedConfig,
**kwargs
):
r"""
Instantiate a [`VisionTextDualEncoderConfig`] (or a derived class) from text model configuration and vision
model configuration.
Returns:
[`VisionTextDualEncoderConfig`]: An instance of a configuration object
"""
return cls(
vision_config=vision_config.to_dict(),
text_config=text_config.to_dict(),
**kwargs,
)