File size: 4,071 Bytes
5e1c8df
 
35352c6
c7a14ad
35352c6
5e1c8df
 
a8c8fe0
c6fe3c5
69fda24
18cb46c
 
c6fe3c5
 
 
69fda24
18cb46c
 
c6fe3c5
 
18cb46c
c6fe3c5
 
5e1c8df
35352c6
 
 
5e1c8df
 
35352c6
 
 
 
 
c7a14ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d96ab
c7a14ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571c526
 
 
 
 
 
 
 
 
 
c7a14ad
35352c6
 
a8c8fe0
35352c6
bd0d978
35352c6
 
 
 
 
 
 
c6fe3c5
 
35352c6
 
 
24d96ab
5e1c8df
571c526
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pathlib

import pydantic
from transformers import PretrainedConfig

MAX_DOWNLOAD_TIME = 0.2

IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
MODEL_PATH = pathlib.Path("/tmp/models")
VISION_MODEL_PATH = MODEL_PATH / "vision"
TEXT_MODEL_PATH = MODEL_PATH / "text"

IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
MODEL_PATH.mkdir(parents=True, exist_ok=True)
VISION_MODEL_PATH.mkdir(parents=True, exist_ok=True)
TEXT_MODEL_PATH.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "tiny_clip"
REPO_ID = "sachin/clip-model"

WANDB_ENTITY = "sachinruk"


class DataConfig(pydantic.BaseModel):
    buffer_size: int = 1000
    data_len: int = 100
    train_len: int = 90
    small_dataset: str = "laion/220k-gpt4vision-captions-from-livis"
    large_dataset: str = "laion/laion400m"
    dataset: str = small_dataset


class TinyCLIPTextConfig(PretrainedConfig):
    model_type = "text"

    def __init__(
        self,
        text_model: str = "microsoft/xtremedistil-l6-h256-uncased",
        projection_layers: int = 3,
        embed_dims: int = 512,
        max_len: int = 128,
        cls_type: bool = True,
        **kwargs,
    ):
        self.text_model = text_model
        self.projection_layers = projection_layers
        self.embed_dims = embed_dims
        self.max_len = max_len
        self.cls_type = cls_type
        super().__init__(**kwargs)


class TinyCLIPVisionConfig(PretrainedConfig):
    model_type = "vision"

    def __init__(
        self,
        vision_model: str = "edgenext_small",
        projection_layers: int = 3,
        embed_dims: int = 512,
        **kwargs,
    ):
        self.vision_model = vision_model
        self.projection_layers = projection_layers
        self.embed_dims = embed_dims
        super().__init__(**kwargs)


class TinyCLIPConfig(PretrainedConfig):
    model_type = "clip"

    def __init__(
        self,
        text_model: str = "microsoft/xtremedistil-l6-h256-uncased",
        vision_model: str = "edgenext_small",
        projection_layers: int = 3,
        embed_dim: int = 512,
        max_len: int = 128,
        cls_type: bool = True,
        freeze_vision_base: bool = False,
        freeze_text_base: bool = True,
        loss_type: str = "cyclip",
        **kwargs,
    ):
        self.text_config = TinyCLIPTextConfig(
            text_model=text_model,
            projection_layers=projection_layers,
            embed_dims=embed_dim,
            max_len=max_len,
            cls_type=cls_type,
        )
        self.vision_config = TinyCLIPVisionConfig(
            vision_model=vision_model, projection_layers=projection_layers, embed_dims=embed_dim
        )
        self.freeze_vision_base = freeze_vision_base
        self.freeze_text_base = freeze_text_base
        self.loss_type = loss_type
        super().__init__(**kwargs)

    @classmethod
    def from_dict(cls, config_dict, **kwargs):
        text_config_dict = config_dict.pop("text_config", {})
        text_config = TinyCLIPTextConfig.from_dict(text_config_dict)

        vision_config_dict = config_dict.pop("vision_config", {})
        vision_config = TinyCLIPVisionConfig.from_dict(vision_config_dict)

        return cls(text_config=text_config, vision_config=vision_config, **config_dict, **kwargs)


class TrainerConfig(pydantic.BaseModel):
    epochs: int = 20
    batch_size: int = 64
    learning_rate: float = 5e-4
    lr_scheduler: bool = True
    accumulate_grad_batches: int = 1
    temperature: float = 1.0
    vision_freeze_layers: int = 2
    lambda_1: float = 1.0
    lambda_2: float = 1.0

    val_check_interval: int = 1000
    log_every_n_steps: int = 100
    debug: bool = False

    run_openai_clip: bool = False

    _model_config: TinyCLIPConfig = TinyCLIPConfig()
    _data_config: DataConfig = DataConfig()

    def __init__(self, **data):
        super().__init__(**data)
        if "_model_config" in data:
            self._model_config = TinyCLIPConfig.from_dict(data["_model_config"])