from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import transformers from transformers import PreTrainedModel from src import loss from src import vision_model from src.config import TinyCLIPConfig from src.config import TinyCLIPTextConfig from src.config import TinyCLIPVisionConfig class Projection(nn.Module): def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None: super().__init__() self.linear1 = nn.Linear(d_in, d_out, bias=False) self.linear2 = nn.Linear(d_out, d_out, bias=False) self.layer_norm = nn.LayerNorm(d_out) self.drop = nn.Dropout(p) def forward(self, x: torch.Tensor) -> torch.Tensor: embed1 = self.linear1(x) embed2 = self.drop(self.linear2(F.gelu(embed1))) embeds = self.layer_norm(embed1 + embed2) return embeds def projection_layers(d_in: int, d_out: int, num_layers: int) -> nn.Module: layers = [] for _ in range(num_layers - 1): layers.extend([Projection(d_in, d_in), nn.GELU()]) layers += [Projection(d_in, d_out)] return nn.Sequential(*layers) def mean_pooling( text_representation: torch.FloatTensor, attention_mask: torch.LongTensor ) -> torch.FloatTensor: input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float() return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) # type: ignore class TinyCLIPTextEncoder(PreTrainedModel): config_class = TinyCLIPTextConfig def __init__(self, config: TinyCLIPTextConfig): super().__init__(config) self.base = transformers.AutoModel.from_pretrained(config.text_model) self.cls_type = config.cls_type self.projection = projection_layers( self.base.config.hidden_size, config.embed_dims, config.projection_layers ) def forward(self, x: dict[str, torch.Tensor]): out = self.base(**x).last_hidden_state if self.cls_type: out = out[:, 0] # get CLS token output else: out = mean_pooling(out, x["attention_mask"]) # type: ignore projected_vec = self.projection(out) return F.normalize(projected_vec, dim=-1) class TinyCLIPVisionEncoder(PreTrainedModel): config_class = TinyCLIPVisionConfig def __init__(self, config: TinyCLIPVisionConfig): super().__init__(config) base, num_features = vision_model.get_vision_base(config) self.base = base self.projection = projection_layers( num_features, config.embed_dims, config.projection_layers ) def forward(self, images: list[Image.Image]): x: torch.Tensor = torch.stack([self.transform(image) for image in images]) # type: ignore projected_vec = self.projection(self.base(x)) return F.normalize(projected_vec, dim=-1) class TinyCLIP(PreTrainedModel): config_class = TinyCLIPConfig def __init__(self, config: TinyCLIPConfig): super().__init__(config) self.text_encoder = TinyCLIPTextEncoder(config.text_config) self.vision_encoder = TinyCLIPVisionEncoder(config.vision_config) if config.freeze_text_base: self.text_encoder.base.eval() for param in self.text_encoder.parameters(): param.requires_grad = False if config.freeze_vision_base: self.vision_encoder.base.eval() for param in self.vision_encoder.parameters(): param.requires_grad = False self.loss_fn = loss.get_loss(config.loss_type) def forward( self, text_input: dict[str, torch.Tensor], vision_input: list[Image.Image], return_loss: bool = False, ) -> dict[str, torch.Tensor]: text_output = self.text_encoder(text_input) vision_output = self.vision_encoder(vision_input) out = {"text_output": text_output, "vision_output": vision_output} if return_loss: out["loss"] = self.loss_fn(vision_output, text_output) return out if __name__ == "__main__": model = TinyCLIP(TinyCLIPConfig()) print(model) print("Done!")