from typing import List, Union import torch import torch.nn as nn import torch.nn.functional as F from open_clip import create_model, get_tokenizer from open_clip.transform import PreprocessCfg, image_transform_v2 from PIL import Image from transformers import PretrainedConfig, PreTrainedModel class MatryoshkaNllbClipConfig(PretrainedConfig): def __init__( self, clip_model_name: str = "", target_resolution: int = -1, mrl_resolutions: List[int] = [], **kwargs, ): super().__init__(**kwargs) self.clip_model_name = clip_model_name self.target_resolution = target_resolution self.mrl_resolutions = mrl_resolutions class MatryoshkaLayer(nn.Module): def __init__(self, resolutions: List[int], target_resolution: int = 768): super().__init__() self.resolutions = resolutions self.layers = nn.ModuleDict() for resolution in resolutions: self.layers[str(resolution)] = nn.Linear(target_resolution, resolution) def forward(self, x, resolution: Union[int, None] = None): if resolution is not None: if resolution not in self.resolutions: raise ValueError(f"Resolution {resolution} not in {self.resolutions}") return self.layers[str(resolution)](x) outputs = [] for resolution in self.resolutions: outputs.append(self.layers[str(resolution)](x)) return outputs class MatryoshkaNllbClip(PreTrainedModel): config_class = MatryoshkaNllbClipConfig def __init__(self, config: MatryoshkaNllbClipConfig, device): super().__init__(config) if isinstance(device, str): device = torch.device(device) self.config = config self.model = create_model( config.clip_model_name, output_dict=True ) pp_cfg = PreprocessCfg(**self.model.visual.preprocess_cfg) self.transform = image_transform_v2( pp_cfg, is_train=False, ) self._device = device self.model.to(device) self.matryoshka_layer = MatryoshkaLayer( config.mrl_resolutions, config.target_resolution ) self.matryoshka_layer.to(device) self.tokenizer = get_tokenizer(config.clip_model_name) def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None): image_inputs = image_inputs.to(self._device) input_ids = input_ids.to(self._device) outputs = self.model( image=image_inputs, text=input_ids, ) mrl_image_features = None mrl_text_features = None if resolution is not None: mrl_image_features = self.matryoshka_layer.forward( outputs["image_features"], resolution ) mrl_text_features = self.matryoshka_layer.forward( outputs["text_features"], resolution ) return { "image_features": outputs["image_features"], "text_features": outputs["text_features"], "mrl_image_features": mrl_image_features, "mrl_text_features": mrl_text_features, "logit_scale": outputs["logit_scale"], "logit_bias": outputs["logit_bias"], } def encode_image( self, image, normalize=False, resolution: Union[int, None] = None, ): with torch.inference_mode(): features = self.model.visual(image) if resolution is not None: if resolution not in self.matryoshka_layer.resolutions: raise ValueError( f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}" ) features = self.matryoshka_layer.layers[str(resolution)](features) return F.normalize(features, dim=-1) if normalize else features def encode_text( self, text, normalize=False, resolution: Union[int, None] = None, ): with torch.inference_mode(): features = self.model.text(text) if resolution is not None: if resolution not in self.matryoshka_layer.resolutions: raise ValueError( f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}" ) features = self.matryoshka_layer.layers[str(resolution)](features) return F.normalize(features, dim=-1) if normalize else features def image_features( self, images: List[Image.Image], normalize=False, resolution: Union[int, None] = None, ): image_inputs = [self.transform(image) for image in images] image_inputs = torch.stack(image_inputs, dim=0).to(self._device) with torch.inference_mode(): features = self.model.visual(image_inputs) if resolution is not None: if resolution not in self.matryoshka_layer.resolutions: raise ValueError( f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}" ) features = self.matryoshka_layer.layers[str(resolution)](features) return F.normalize(features, dim=-1) if normalize else features def text_features( self, texts: List[str], langs: Union[List[str], None] = None, normalize=False, resolution: Union[int, None] = None, ): if langs is None: langs = ["eng_Latn"] * len(texts) texts = [f"{lang}{text}" for lang, text in zip(langs, texts)] input_ids = self.tokenizer.tokenizer.batch_encode_plus( texts, return_tensors="pt", padding="longest", add_special_tokens=False )["input_ids"].to(self._device) with torch.inference_mode(): features = self.model.text(input_ids) if resolution is not None: if resolution not in self.matryoshka_layer.resolutions: raise ValueError( f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}" ) features = self.matryoshka_layer.layers[str(resolution)](features) return F.normalize(features, dim=-1) if normalize else features def get_logits( self, images: List[Image.Image], texts: List[str], langs: Union[List[str], None] = None, resolution: Union[int, None] = None, ): image_features = self.image_features( images, normalize=True, resolution=resolution ) text_features = self.text_features( texts, langs, normalize=True, resolution=resolution ) with torch.inference_mode(): image_logits = ( self.model.logit_scale.exp() * image_features @ text_features.T ) if self.model.logit_bias is not None: image_logits += self.model.logit_bias text_logits = image_logits.T return image_logits, text_logits