|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as f |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from transformers import BatchEncoding, BatchFeature, PreTrainedModel, logging |
|
from transformers.models.clip.modeling_clip import ( |
|
CLIPOutput, |
|
CLIPTextModelOutput, |
|
CLIPVisionModelOutput, |
|
clip_loss, |
|
) |
|
|
|
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig |
|
from .eva_model import EVAVisionTransformer |
|
from .hf_model import HFTextEncoder |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
""" Jina CLIP model implementation """ |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm (with cast back to input dtype).""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
origtype = x.dtype |
|
x = f.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
return x.to(origtype) |
|
|
|
|
|
def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder: |
|
return HFTextEncoder( |
|
model_name_or_path=config.hf_model_name_or_path, |
|
output_dim=config.embed_dim, |
|
pooler_type=config.pooler_type, |
|
proj_type=config.proj_type, |
|
proj_bias=config.proj_bias, |
|
pretrained=False, |
|
output_tokens=False, |
|
trust_remote_code=True, |
|
revision=None, |
|
model_config_kwargs=config.hf_model_config_kwargs, |
|
) |
|
|
|
|
|
def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer: |
|
norm_layer = partial(LayerNorm, eps=1e-6) |
|
|
|
if config.fused_layer_norm: |
|
try: |
|
from apex.normalization import FusedLayerNorm |
|
|
|
norm_layer = partial(FusedLayerNorm, eps=1e-6) |
|
except (ModuleNotFoundError, ImportError): |
|
logger.warning('Please install apex to use fused layer norm, ignoring') |
|
|
|
return EVAVisionTransformer( |
|
img_size=config.image_size, |
|
patch_size=config.patch_size, |
|
num_classes=config.embed_dim, |
|
use_mean_pooling=False, |
|
init_values=config.ls_init_value, |
|
patch_dropout=config.patch_dropout, |
|
embed_dim=config.width, |
|
depth=config.layers, |
|
num_heads=config.width // config.head_width, |
|
mlp_ratio=config.mlp_ratio, |
|
qkv_bias=config.qkv_bias, |
|
drop_path_rate=config.drop_path_rate, |
|
norm_layer=norm_layer, |
|
xattn=config.x_attention, |
|
rope=config.rope_embeddings, |
|
postnorm=config.post_norm, |
|
pt_hw_seq_len=config.pt_hw_seq_len, |
|
intp_freq=config.intp_freq, |
|
naiveswiglu=config.naive_swiglu, |
|
subln=config.subln, |
|
proj_type=config.proj_type, |
|
) |
|
|
|
|
|
class JinaCLIPPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for |
|
downloading and loading pretrained models. |
|
""" |
|
|
|
config_class = JinaCLIPConfig |
|
base_model_prefix = 'clip' |
|
supports_gradient_checkpointing = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, JinaCLIPModel): |
|
if isinstance(module.text_projection, nn.Linear): |
|
nn.init.normal_( |
|
module.text_projection.weight, |
|
std=module.text_embed_dim**-0.5 * self.config.initializer_factor, |
|
) |
|
if isinstance(module.text_projection, nn.Linear): |
|
nn.init.normal_( |
|
module.visual_projection.weight, |
|
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, |
|
) |
|
if isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class JinaCLIPTextModel(JinaCLIPPreTrainedModel): |
|
config_class = JinaCLIPTextConfig |
|
|
|
def __init__(self, config: JinaCLIPTextConfig): |
|
super().__init__(config) |
|
self.text_model = _build_text_tower(config) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Union[None, torch.Tensor, BatchEncoding] = None, |
|
return_dict: Optional[bool] = None, |
|
*_, |
|
**__, |
|
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids |
|
feats = self.text_model(x=x) |
|
out = CLIPTextModelOutput(text_embeds=feats) |
|
return out if return_dict else out.to_tuple() |
|
|
|
|
|
class JinaCLIPVisionModel(JinaCLIPPreTrainedModel): |
|
config_class = JinaCLIPVisionConfig |
|
main_input_name = 'pixel_values' |
|
|
|
def __init__(self, config: JinaCLIPVisionConfig): |
|
super().__init__(config) |
|
self.vision_model = _build_vision_tower(config) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, |
|
return_dict: Optional[bool] = None, |
|
*_, |
|
**__, |
|
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
x = ( |
|
pixel_values.pixel_values |
|
if isinstance(pixel_values, BatchFeature) |
|
else pixel_values |
|
) |
|
feats = self.vision_model(x=x) |
|
out = CLIPVisionModelOutput(image_embeds=feats) |
|
return out if return_dict else out.to_tuple() |
|
|
|
|
|
class JinaCLIPModel(JinaCLIPPreTrainedModel): |
|
config_class = JinaCLIPConfig |
|
|
|
def __init__(self, config: JinaCLIPConfig): |
|
super().__init__(config) |
|
|
|
if not isinstance(config.text_config, JinaCLIPTextConfig): |
|
raise ValueError( |
|
'Attribute config.text_config is expected to be of type ' |
|
f'JinaCLIPTextConfig but is of type {type(config.text_config)}.' |
|
) |
|
|
|
if not isinstance(config.vision_config, JinaCLIPVisionConfig): |
|
raise ValueError( |
|
'Attribute config.vision_config is expected to be of type ' |
|
f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.' |
|
) |
|
|
|
text_config = config.text_config |
|
vision_config = config.vision_config |
|
|
|
self.add_projections = config.add_projections |
|
self.projection_dim = config.projection_dim |
|
self.text_embed_dim = text_config.embed_dim |
|
self.vision_embed_dim = vision_config.embed_dim |
|
|
|
self.text_model = _build_text_tower(text_config) |
|
self.vision_model = _build_vision_tower(vision_config) |
|
self.logit_scale = nn.Parameter( |
|
torch.tensor(self.config.logit_scale_init_value) |
|
) |
|
|
|
if self.add_projections: |
|
self.visual_projection = nn.Linear( |
|
self.vision_embed_dim, self.projection_dim, bias=False |
|
) |
|
self.text_projection = nn.Linear( |
|
self.text_embed_dim, self.projection_dim, bias=False |
|
) |
|
else: |
|
self.visual_projection = nn.Identity() |
|
self.text_projection = nn.Identity() |
|
|
|
self.post_init() |
|
|
|
def get_text_features( |
|
self, |
|
input_ids: Union[None, torch.Tensor, BatchEncoding] = None, |
|
*_, |
|
**__, |
|
) -> torch.FloatTensor: |
|
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids |
|
return self.text_projection(self.text_model(x=x)) |
|
|
|
def get_image_features( |
|
self, |
|
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, |
|
*_, |
|
**__, |
|
) -> torch.FloatTensor: |
|
x = ( |
|
pixel_values.pixel_values |
|
if isinstance(pixel_values, BatchFeature) |
|
else pixel_values |
|
) |
|
return self.visual_projection(self.vision_model(x=x)) |
|
|
|
def encode_text( |
|
self, |
|
input_ids: Union[None, torch.Tensor, BatchEncoding] = None, |
|
return_dict: Optional[bool] = None, |
|
*_, |
|
**__, |
|
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
feats = self.get_text_features(input_ids=input_ids) |
|
out = CLIPTextModelOutput(text_embeds=feats) |
|
return out if return_dict else out.to_tuple() |
|
|
|
def encode_image( |
|
self, |
|
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, |
|
return_dict: Optional[bool] = None, |
|
*_, |
|
**__, |
|
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
feats = self.get_image_features(pixel_values=pixel_values) |
|
out = CLIPVisionModelOutput(image_embeds=feats) |
|
return out if return_dict else out.to_tuple() |
|
|
|
def forward( |
|
self, |
|
input_ids: Union[None, torch.Tensor, BatchEncoding] = None, |
|
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, |
|
return_dict: Optional[bool] = None, |
|
return_loss: Optional[bool] = None, |
|
*_, |
|
**__, |
|
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
image_embeds = self.get_image_features(pixel_values=pixel_values) |
|
text_embeds = self.get_text_features(input_ids=input_ids) |
|
|
|
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) |
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
logit_scale = self.logit_scale.exp() |
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale |
|
logits_per_image = logits_per_text.t() |
|
|
|
loss = None |
|
if return_loss: |
|
loss = clip_loss(logits_per_text) |
|
|
|
if not return_dict: |
|
output = ( |
|
logits_per_image, |
|
logits_per_text, |
|
text_embeds, |
|
image_embeds, |
|
None, |
|
None, |
|
) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CLIPOutput( |
|
loss=loss, |
|
logits_per_image=logits_per_image, |
|
logits_per_text=logits_per_text, |
|
text_embeds=text_embeds, |
|
image_embeds=image_embeds, |
|
text_model_output=None, |
|
vision_model_output=None, |
|
) |
|
|