File size: 4,141 Bytes
9d7268a
bcbc05a
 
 
31e368b
9d7268a
 
 
24d96ab
31e368b
 
 
bcbc05a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d7268a
bcbc05a
 
9d7268a
 
 
 
 
 
 
 
 
 
 
 
bcbc05a
9d7268a
bcbc05a
 
9d7268a
bcbc05a
 
 
 
 
9d7268a
 
 
 
 
24d96ab
 
9d7268a
24d96ab
9d7268a
bcbc05a
c6fe3c5
 
bcbc05a
 
 
9d7268a
 
bcbc05a
9d7268a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcbc05a
9d7268a
8dc3889
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import PreTrainedModel

from src import loss
from src import vision_model
from src.config import TinyCLIPConfig
from src.config import TinyCLIPTextConfig
from src.config import TinyCLIPVisionConfig


class Projection(nn.Module):
    def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_in, d_out, bias=False)
        self.linear2 = nn.Linear(d_out, d_out, bias=False)
        self.layer_norm = nn.LayerNorm(d_out)
        self.drop = nn.Dropout(p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embed1 = self.linear1(x)
        embed2 = self.drop(self.linear2(F.gelu(embed1)))
        embeds = self.layer_norm(embed1 + embed2)
        return embeds


def projection_layers(d_in: int, d_out: int, num_layers: int) -> nn.Module:
    layers = []
    for _ in range(num_layers - 1):
        layers.extend([Projection(d_in, d_in), nn.GELU()])
    layers += [Projection(d_in, d_out)]
    return nn.Sequential(*layers)


def mean_pooling(
    text_representation: torch.FloatTensor, attention_mask: torch.LongTensor
) -> torch.FloatTensor:
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float()
    return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )  # type: ignore


class TinyCLIPTextEncoder(PreTrainedModel):
    config_class = TinyCLIPTextConfig

    def __init__(self, config: TinyCLIPTextConfig):
        super().__init__(config)
        self.base = transformers.AutoModel.from_pretrained(config.text_model)
        self.cls_type = config.cls_type
        self.projection = projection_layers(
            self.base.config.hidden_size, config.embed_dims, config.projection_layers
        )

    def forward(self, x: dict[str, torch.Tensor]):
        out = self.base(**x).last_hidden_state
        if self.cls_type:
            out = out[:, 0]  # get CLS token output
        else:
            out = mean_pooling(out, x["attention_mask"])  # type: ignore

        projected_vec = self.projection(out)
        return F.normalize(projected_vec, dim=-1)


class TinyCLIPVisionEncoder(PreTrainedModel):
    config_class = TinyCLIPVisionConfig

    def __init__(self, config: TinyCLIPVisionConfig):
        super().__init__(config)
        base, num_features = vision_model.get_vision_base(config)
        self.base = base
        self.projection = projection_layers(
            num_features, config.embed_dims, config.projection_layers
        )

    def forward(self, images: torch.Tensor):
        projected_vec = self.projection(self.base(images))
        return F.normalize(projected_vec, dim=-1)


class TinyCLIP(PreTrainedModel):
    config_class = TinyCLIPConfig

    def __init__(self, config: TinyCLIPConfig):
        super().__init__(config)
        self.text_encoder = TinyCLIPTextEncoder(config.text_config)
        self.vision_encoder = TinyCLIPVisionEncoder(config.vision_config)

        if config.freeze_text_base:
            self.text_encoder.base.eval()
            for param in self.text_encoder.parameters():
                param.requires_grad = False

        if config.freeze_vision_base:
            self.vision_encoder.base.eval()
            for param in self.vision_encoder.parameters():
                param.requires_grad = False

        self.loss_fn = loss.get_loss(config.loss_type)

    def forward(
        self,
        text_input: dict[str, torch.Tensor],
        vision_input: list[Image.Image],
        return_loss: bool = False,
    ) -> dict[str, torch.Tensor]:
        text_output = self.text_encoder(text_input)
        vision_output = self.vision_encoder(vision_input)

        out = {"text_output": text_output, "vision_output": vision_output}

        if return_loss:
            out["loss"] = self.loss_fn(vision_output, text_output)

        return out


if __name__ == "__main__":
    model = TinyCLIP(TinyCLIPConfig())
    print(model)
    print("Done!")