|
|
|
|
|
|
|
|
|
""" Model schema in open_clip format for inference only. """ |
|
import math |
|
from typing import Any, Optional, Dict |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from mobileclip.text_encoder import ( |
|
TextTransformer, |
|
) |
|
|
|
from .image_encoder import MCi |
|
|
|
|
|
class CLIP(nn.Module): |
|
"""Base class for multi-modal image-text data""" |
|
|
|
def __init__(self, cfg: Dict, output_dict: bool = False, *args, **kwargs) -> None: |
|
super().__init__() |
|
self.output_dict = output_dict |
|
self.projection_dim = cfg["embed_dim"] |
|
if self.projection_dim is None: |
|
raise ValueError("Please specify `embed_dim` in model config.") |
|
|
|
self.image_encoder = MCi( |
|
model_name=cfg["image_cfg"]["model_name"], |
|
projection_dim=self.projection_dim, |
|
) |
|
self.text_encoder = TextTransformer( |
|
cfg=cfg["text_cfg"], projection_dim=self.projection_dim |
|
) |
|
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07)) |
|
|
|
def _exponentiate_and_clip_logits(self, max_scale: float = 100.0): |
|
scale = self.logit_scale.exp() |
|
scale = torch.clamp(scale, 0, max_scale) |
|
return scale |
|
|
|
def encode_image(self, image: torch.Tensor, normalize: bool = False): |
|
image_encoder_out = self.image_encoder(image) |
|
if isinstance(image_encoder_out, dict): |
|
features = image_encoder_out["logits"] |
|
else: |
|
features = image_encoder_out |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def encode_text(self, text: torch.Tensor, normalize: bool = False): |
|
text_features = self.text_encoder(text_tokens=text, key_padding_mask=None) |
|
return F.normalize(text_features, dim=-1) if normalize else text_features |
|
|
|
def forward( |
|
self, |
|
image: Optional[torch.Tensor] = None, |
|
text: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs |
|
) -> Any: |
|
|
|
image_embeddings = ( |
|
self.encode_image(image, normalize=True) if image is not None else None |
|
) |
|
text_embeddings = ( |
|
self.encode_text(text, normalize=True) if text is not None else None |
|
) |
|
|
|
if self.output_dict: |
|
return { |
|
"image_features": image_embeddings, |
|
"text_features": text_embeddings, |
|
"logit_scale": self._exponentiate_and_clip_logits(), |
|
} |
|
return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits() |
|
|