File size: 2,730 Bytes
ac9a398
f68c08c
ac9a398
 
 
 
f68c08c
ac9a398
 
 
 
 
 
fd3a88e
ac9a398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd3a88e
 
ac9a398
 
 
 
 
 
 
 
 
 
 
e1a2c6a
ac9a398
 
f68c08c
ac9a398
6f5ff9a
ac9a398
 
 
 
 
 
711b47e
ac9a398
 
6f5ff9a
 
 
 
 
711b47e
 
ac9a398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

from transformers import PretrainedConfig, RobertaConfig


class JapaneseCLIPVisionConfig(PretrainedConfig):
    model_type = "vit"
    is_composition = True

    def __init__(self, 
        image_size: int, 
        patch_size: int, 
        width: int,
        layers: int,
        head_width: int,
        mlp_ratio: float,
        ls_init_value: float = None,
        attentional_pool: bool = False,
        attn_pooler_queries: int = 256,
        attn_pooler_heads: int = 8,
        output_dim: int = 512,
        patch_dropout: float = 0.0,
        no_ln_pre: bool = False,
        pool_type: str = "tok",
        final_ln_after_pool: bool = False,
        output_tokens: bool = False,
        **kwargs
    ):
        self.image_size = image_size
        self.patch_size = patch_size
        self.width = width
        self.layers = layers
        self.head_width = head_width
        self.heads = width // head_width
        self.mlp_ratio = mlp_ratio
        self.ls_init_value = ls_init_value
        self.attentional_pool = attentional_pool
        self.attn_pooler_queries = attn_pooler_queries
        self.attn_pooler_heads = attn_pooler_heads
        self.output_dim = output_dim
        self.patch_dropout = patch_dropout
        self.no_ln_pre = no_ln_pre
        self.pool_type = pool_type
        self.final_ln_after_pool = final_ln_after_pool
        self.output_tokens = output_tokens
        super().__init__(**kwargs)


class JapaneseCLIPConfig(PretrainedConfig):
    model_type = "japanese_clip"
    is_composition = True

    def __init__(
        self,
        max_length: int = 77,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.max_length = max_length

        if "vision_config" not in kwargs:
            raise ValueError("vision_config must be provided")
        if "text_config" not in kwargs:
            raise ValueError("text_config must be provided")

        vision_config = kwargs.pop("vision_config")
        text_config = kwargs.pop("text_config")

        self.vision_config = JapaneseCLIPVisionConfig(**vision_config)
        self.text_config = RobertaConfig(**text_config)

    @classmethod
    def from_vision_text_configs(
        cls,
        vision_config: PretrainedConfig,
        text_config: PretrainedConfig,
        **kwargs
    ):
        r"""
        Instantiate a [`VisionTextDualEncoderConfig`] (or a derived class) from text model configuration and vision
        model configuration.
        Returns:
            [`VisionTextDualEncoderConfig`]: An instance of a configuration object
        """

        return cls(
            vision_config=vision_config.to_dict(),
            text_config=text_config.to_dict(),
            **kwargs,
        )