jina-clip-implementation / modeling_clip.py
bwang0911's picture
refactor: refine encode_text
136fb28
raw
history blame
13.7 kB
# coding=utf-8
#
# Code mainly copied from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
# and adjusted for Jina CLIP
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as f
import torch.utils.checkpoint
from torch import nn
from transformers import BatchEncoding, BatchFeature, PreTrainedModel, logging
from transformers.models.clip.modeling_clip import (
CLIPOutput,
CLIPTextModelOutput,
CLIPVisionModelOutput,
clip_loss,
)
try:
from tqdm.autonotebook import trange
has_tqdm = True
except ImportError:
has_tqdm = False
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
from .eva_model import EVAVisionTransformer
from .hf_model import HFTextEncoder
logger = logging.get_logger(__name__)
""" Jina CLIP model implementation """
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
origtype = x.dtype
x = f.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(origtype)
def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder:
return HFTextEncoder(
model_name_or_path=config.hf_model_name_or_path,
output_dim=config.embed_dim,
pooler_type=config.pooler_type,
proj_type=config.proj_type,
proj_bias=config.proj_bias,
pretrained=False,
output_tokens=False,
trust_remote_code=True,
revision=None,
model_config_kwargs=config.hf_model_config_kwargs,
)
def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer:
norm_layer = partial(LayerNorm, eps=1e-6)
if config.fused_layer_norm:
try:
from apex.normalization import FusedLayerNorm
norm_layer = partial(FusedLayerNorm, eps=1e-6)
except (ModuleNotFoundError, ImportError):
logger.warning('Please install apex to use fused layer norm, ignoring')
return EVAVisionTransformer(
img_size=config.image_size,
patch_size=config.patch_size,
num_classes=config.embed_dim,
use_mean_pooling=False,
init_values=config.ls_init_value,
patch_dropout=config.patch_dropout,
embed_dim=config.width,
depth=config.layers,
num_heads=config.width // config.head_width,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
drop_path_rate=config.drop_path_rate,
norm_layer=norm_layer,
xattn=config.x_attention,
rope=config.rope_embeddings,
postnorm=config.post_norm,
pt_hw_seq_len=config.pt_hw_seq_len,
intp_freq=config.intp_freq,
naiveswiglu=config.naive_swiglu,
subln=config.subln,
proj_type=config.proj_type,
)
class JinaCLIPPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for
downloading and loading pretrained models.
"""
config_class = JinaCLIPConfig
base_model_prefix = 'clip'
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, JinaCLIPModel):
if isinstance(module.text_projection, nn.Linear):
nn.init.normal_(
module.text_projection.weight,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
)
if isinstance(module.text_projection, nn.Linear):
nn.init.normal_(
module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
)
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
config_class = JinaCLIPTextConfig
def __init__(self, config: JinaCLIPTextConfig):
super().__init__(config)
self.text_model = _build_text_tower(config)
self.post_init()
def forward(
self,
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
return_dict: Optional[bool] = None,
*_,
**__,
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
feats = self.text_model(x=x)
out = CLIPTextModelOutput(text_embeds=feats)
return out if return_dict else out.to_tuple()
class JinaCLIPVisionModel(JinaCLIPPreTrainedModel):
config_class = JinaCLIPVisionConfig
main_input_name = 'pixel_values'
def __init__(self, config: JinaCLIPVisionConfig):
super().__init__(config)
self.vision_model = _build_vision_tower(config)
self.post_init()
def forward(
self,
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
return_dict: Optional[bool] = None,
*_,
**__,
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
x = (
pixel_values.pixel_values
if isinstance(pixel_values, BatchFeature)
else pixel_values
)
feats = self.vision_model(x=x)
out = CLIPVisionModelOutput(image_embeds=feats)
return out if return_dict else out.to_tuple()
class JinaCLIPModel(JinaCLIPPreTrainedModel):
config_class = JinaCLIPConfig
def __init__(self, config: JinaCLIPConfig):
super().__init__(config)
if not isinstance(config.text_config, JinaCLIPTextConfig):
raise ValueError(
'Attribute config.text_config is expected to be of type '
f'JinaCLIPTextConfig but is of type {type(config.text_config)}.'
)
if not isinstance(config.vision_config, JinaCLIPVisionConfig):
raise ValueError(
'Attribute config.vision_config is expected to be of type '
f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
)
text_config = config.text_config
vision_config = config.vision_config
self.add_projections = config.add_projections
self.projection_dim = config.projection_dim
self.text_embed_dim = text_config.embed_dim
self.vision_embed_dim = vision_config.embed_dim
self.text_model = _build_text_tower(text_config)
self.vision_model = _build_vision_tower(vision_config)
self.logit_scale = nn.Parameter(
torch.tensor(self.config.logit_scale_init_value)
)
if self.add_projections:
self.visual_projection = nn.Linear(
self.vision_embed_dim, self.projection_dim, bias=False
)
self.text_projection = nn.Linear(
self.text_embed_dim, self.projection_dim, bias=False
)
else:
self.visual_projection = nn.Identity()
self.text_projection = nn.Identity()
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
self.post_init()
def get_text_features(
self,
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
*_,
**__,
) -> torch.FloatTensor:
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
return self.text_projection(self.text_model(x=x))
def get_image_features(
self,
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
*_,
**__,
) -> torch.FloatTensor:
x = (
pixel_values.pixel_values
if isinstance(pixel_values, BatchFeature)
else pixel_values
)
return self.visual_projection(self.vision_model(x=x))
@torch.inference_mode()
def encode_text(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
show_progress_bar: Optional[bool] = None,
output_value: str = 'sentence_embedding',
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: Optional[torch.device] = None,
normalize_embeddings: bool = False,
**tokenizer_kwargs,
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
self.eval()
if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO
or logger.getEffectiveLevel() == logging.DEBUG
)
if convert_to_tensor:
convert_to_numpy = False
if output_value != 'sentence_embedding':
convert_to_tensor = False
convert_to_numpy = False
input_was_string = False
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
sentences = [sentences]
input_was_string = True
if device is not None:
self.to(device)
permutation = np.argsort([-len(i) for i in sentences])
inverse_permutation = np.argsort(permutation)
sentences = [sentences[idx] for idx in permutation]
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
if has_tqdm:
range_iter = trange(
0,
len(sentences),
batch_size,
desc="Encoding",
disable=not show_progress_bar,
)
else:
range_iter = range(0, len(sentences), batch_size)
for i in range_iter:
encoded_input = self.tokenizer(
sentences[i : i + batch_size],
return_tensors='pt',
**tokenizer_kwargs,
).to(self.device)
if output_value == 'token_embeddings':
raise NotImplementedError
elif output_value is None:
raise NotImplementedError
else:
embeddings = self.get_text_features(input_ids=encoded_input)
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
if convert_to_numpy:
embeddings = embeddings.cpu()
all_embeddings.extend(embeddings)
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
if convert_to_tensor:
all_embeddings = torch.stack(all_embeddings)
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
if input_was_string:
all_embeddings = all_embeddings[0]
return all_embeddings
def encode_image(
self,
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
return_dict: Optional[bool] = None,
*_,
**__,
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
feats = self.get_image_features(pixel_values=pixel_values)
out = CLIPVisionModelOutput(image_embeds=feats)
return out if return_dict else out.to_tuple()
def forward(
self,
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
return_dict: Optional[bool] = None,
return_loss: Optional[bool] = None,
*_,
**__,
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
image_embeds = self.get_image_features(pixel_values=pixel_values)
text_embeds = self.get_text_features(input_ids=input_ids)
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
loss = clip_loss(logits_per_text)
if not return_dict:
output = (
logits_per_image,
logits_per_text,
text_embeds,
image_embeds,
None,
None,
)
return ((loss,) + output) if loss is not None else output
return CLIPOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=None,
vision_model_output=None,
)