# coding=utf-8 # # Code mainly copied from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # and adjusted for Jina CLIP from functools import partial from typing import Optional, Tuple, Union import torch import torch.nn.functional as f import torch.utils.checkpoint from torch import nn from transformers import BatchEncoding, BatchFeature, PreTrainedModel, logging from transformers.models.clip.modeling_clip import ( CLIPOutput, CLIPTextModelOutput, CLIPVisionModelOutput, clip_loss, ) from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig from .eva_model import EVAVisionTransformer from .hf_model import HFTextEncoder logger = logging.get_logger(__name__) """ Jina CLIP model implementation """ class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm (with cast back to input dtype).""" def forward(self, x: torch.Tensor): origtype = x.dtype x = f.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x.to(origtype) def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder: return HFTextEncoder( model_name_or_path=config.hf_model_name_or_path, output_dim=config.embed_dim, pooler_type=config.pooler_type, proj_type=config.proj_type, proj_bias=config.proj_bias, pretrained=False, output_tokens=False, trust_remote_code=True, revision=None, model_config_kwargs=config.hf_model_config_kwargs, ) def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer: norm_layer = partial(LayerNorm, eps=1e-6) if config.fused_layer_norm: try: from apex.normalization import FusedLayerNorm norm_layer = partial(FusedLayerNorm, eps=1e-6) except (ModuleNotFoundError, ImportError): logger.warning('Please install apex to use fused layer norm, ignoring') return EVAVisionTransformer( img_size=config.image_size, patch_size=config.patch_size, num_classes=config.embed_dim, use_mean_pooling=False, init_values=config.ls_init_value, patch_dropout=config.patch_dropout, embed_dim=config.width, depth=config.layers, num_heads=config.width // config.head_width, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, drop_path_rate=config.drop_path_rate, norm_layer=norm_layer, xattn=config.x_attention, rope=config.rope_embeddings, postnorm=config.post_norm, pt_hw_seq_len=config.pt_hw_seq_len, intp_freq=config.intp_freq, naiveswiglu=config.naive_swiglu, subln=config.subln, proj_type=config.proj_type, ) class JinaCLIPPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = JinaCLIPConfig base_model_prefix = 'clip' supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, JinaCLIPModel): if isinstance(module.text_projection, nn.Linear): nn.init.normal_( module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) if isinstance(module.text_projection, nn.Linear): nn.init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) if isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class JinaCLIPTextModel(JinaCLIPPreTrainedModel): config_class = JinaCLIPTextConfig def __init__(self, config: JinaCLIPTextConfig): super().__init__(config) self.text_model = _build_text_tower(config) self.post_init() def forward( self, input_ids: Union[None, torch.Tensor, BatchEncoding] = None, return_dict: Optional[bool] = None, *_, **__, ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids feats = self.text_model(x=x) out = CLIPTextModelOutput(text_embeds=feats) return out if return_dict else out.to_tuple() class JinaCLIPVisionModel(JinaCLIPPreTrainedModel): config_class = JinaCLIPVisionConfig main_input_name = 'pixel_values' def __init__(self, config: JinaCLIPVisionConfig): super().__init__(config) self.vision_model = _build_vision_tower(config) self.post_init() def forward( self, pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, return_dict: Optional[bool] = None, *_, **__, ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) x = ( pixel_values.pixel_values if isinstance(pixel_values, BatchFeature) else pixel_values ) feats = self.vision_model(x=x) out = CLIPVisionModelOutput(image_embeds=feats) return out if return_dict else out.to_tuple() class JinaCLIPModel(JinaCLIPPreTrainedModel): config_class = JinaCLIPConfig def __init__(self, config: JinaCLIPConfig): super().__init__(config) if not isinstance(config.text_config, JinaCLIPTextConfig): raise ValueError( 'Attribute config.text_config is expected to be of type ' f'JinaCLIPTextConfig but is of type {type(config.text_config)}.' ) if not isinstance(config.vision_config, JinaCLIPVisionConfig): raise ValueError( 'Attribute config.vision_config is expected to be of type ' f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.' ) text_config = config.text_config vision_config = config.vision_config self.add_projections = config.add_projections self.projection_dim = config.projection_dim self.text_embed_dim = text_config.embed_dim self.vision_embed_dim = vision_config.embed_dim self.text_model = _build_text_tower(text_config) self.vision_model = _build_vision_tower(vision_config) self.logit_scale = nn.Parameter( torch.tensor(self.config.logit_scale_init_value) ) if self.add_projections: self.visual_projection = nn.Linear( self.vision_embed_dim, self.projection_dim, bias=False ) self.text_projection = nn.Linear( self.text_embed_dim, self.projection_dim, bias=False ) else: self.visual_projection = nn.Identity() self.text_projection = nn.Identity() self.post_init() def get_text_features( self, input_ids: Union[None, torch.Tensor, BatchEncoding] = None, *_, **__, ) -> torch.FloatTensor: x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids return self.text_projection(self.text_model(x=x)) def get_image_features( self, pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, *_, **__, ) -> torch.FloatTensor: x = ( pixel_values.pixel_values if isinstance(pixel_values, BatchFeature) else pixel_values ) return self.visual_projection(self.vision_model(x=x)) def encode_text( self, input_ids: Union[None, torch.Tensor, BatchEncoding] = None, return_dict: Optional[bool] = None, *_, **__, ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) feats = self.get_text_features(input_ids=input_ids) out = CLIPTextModelOutput(text_embeds=feats) return out if return_dict else out.to_tuple() def encode_image( self, pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, return_dict: Optional[bool] = None, *_, **__, ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) feats = self.get_image_features(pixel_values=pixel_values) out = CLIPVisionModelOutput(image_embeds=feats) return out if return_dict else out.to_tuple() def forward( self, input_ids: Union[None, torch.Tensor, BatchEncoding] = None, pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, return_dict: Optional[bool] = None, return_loss: Optional[bool] = None, *_, **__, ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) image_embeds = self.get_image_features(pixel_values=pixel_values) text_embeds = self.get_text_features(input_ids=input_ids) # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_per_image = logits_per_text.t() loss = None if return_loss: loss = clip_loss(logits_per_text) if not return_dict: output = ( logits_per_image, logits_per_text, text_embeds, image_embeds, None, None, ) return ((loss,) + output) if loss is not None else output return CLIPOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=None, vision_model_output=None, )