|
from typing import List, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from open_clip import create_model, get_tokenizer |
|
from open_clip.transform import PreprocessCfg, image_transform_v2 |
|
from PIL import Image |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
class MatryoshkaNllbClipConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
clip_model_name: str = "", |
|
target_resolution: int = -1, |
|
mrl_resolutions: List[int] = [], |
|
preprocess_cfg: Union[dict, None] = None, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.clip_model_name = clip_model_name |
|
self.target_resolution = target_resolution |
|
self.mrl_resolutions = mrl_resolutions |
|
self.preprocess_cfg = preprocess_cfg |
|
|
|
|
|
class MatryoshkaLayer(nn.Module): |
|
def __init__(self, resolutions: List[int], target_resolution: int = 768): |
|
super().__init__() |
|
self.resolutions = resolutions |
|
self.layers = nn.ModuleDict() |
|
for resolution in resolutions: |
|
self.layers[str(resolution)] = nn.Linear(target_resolution, resolution) |
|
|
|
def forward(self, x, resolution: Union[int, None] = None): |
|
if resolution is not None: |
|
if resolution not in self.resolutions: |
|
raise ValueError(f"Resolution {resolution} not in {self.resolutions}") |
|
return self.layers[str(resolution)](x) |
|
outputs = [] |
|
for resolution in self.resolutions: |
|
outputs.append(self.layers[str(resolution)](x)) |
|
return outputs |
|
|
|
|
|
class MatryoshkaNllbClip(PreTrainedModel): |
|
config_class = MatryoshkaNllbClipConfig |
|
|
|
def __init__(self, config: MatryoshkaNllbClipConfig, device): |
|
super().__init__(config) |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
self.config = config |
|
self.model = create_model(config.clip_model_name, output_dict=True) |
|
pp_cfg = PreprocessCfg( |
|
size=config.preprocess_cfg["size"], |
|
mean=config.preprocess_cfg["mean"], |
|
std=config.preprocess_cfg["std"], |
|
interpolation=config.preprocess_cfg["interpolation"], |
|
resize_mode=config.preprocess_cfg["resize_mode"], |
|
) |
|
self.transform = image_transform_v2( |
|
pp_cfg, |
|
is_train=False, |
|
) |
|
self._device = device |
|
self.model.to(device) |
|
self.matryoshka_layer = MatryoshkaLayer( |
|
config.mrl_resolutions, config.target_resolution |
|
) |
|
self.matryoshka_layer.to(device) |
|
self.tokenizer = get_tokenizer(config.clip_model_name) |
|
|
|
def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None): |
|
image_inputs = image_inputs.to(self._device) |
|
input_ids = input_ids.to(self._device) |
|
outputs = self.model( |
|
image=image_inputs, |
|
text=input_ids, |
|
) |
|
mrl_image_features = None |
|
mrl_text_features = None |
|
if resolution is not None: |
|
mrl_image_features = self.matryoshka_layer.forward( |
|
outputs["image_features"], resolution |
|
) |
|
mrl_text_features = self.matryoshka_layer.forward( |
|
outputs["text_features"], resolution |
|
) |
|
return { |
|
"image_features": outputs["image_features"], |
|
"text_features": outputs["text_features"], |
|
"mrl_image_features": mrl_image_features, |
|
"mrl_text_features": mrl_text_features, |
|
"logit_scale": outputs["logit_scale"], |
|
"logit_bias": outputs["logit_bias"], |
|
} |
|
|
|
def encode_image( |
|
self, |
|
image, |
|
normalize=False, |
|
resolution: Union[int, None] = None, |
|
): |
|
with torch.inference_mode(): |
|
features = self.model.visual(image) |
|
if resolution is not None: |
|
if resolution not in self.matryoshka_layer.resolutions: |
|
raise ValueError( |
|
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}" |
|
) |
|
features = self.matryoshka_layer.layers[str(resolution)](features) |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def encode_text( |
|
self, |
|
text, |
|
normalize=False, |
|
resolution: Union[int, None] = None, |
|
): |
|
with torch.inference_mode(): |
|
features = self.model.text(text) |
|
if resolution is not None: |
|
if resolution not in self.matryoshka_layer.resolutions: |
|
raise ValueError( |
|
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}" |
|
) |
|
features = self.matryoshka_layer.layers[str(resolution)](features) |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def image_features( |
|
self, |
|
images: List[Image.Image], |
|
normalize=False, |
|
resolution: Union[int, None] = None, |
|
): |
|
image_inputs = [self.transform(image) for image in images] |
|
image_inputs = torch.stack(image_inputs, dim=0).to(self._device) |
|
with torch.inference_mode(): |
|
features = self.model.visual(image_inputs) |
|
if resolution is not None: |
|
if resolution not in self.matryoshka_layer.resolutions: |
|
raise ValueError( |
|
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}" |
|
) |
|
features = self.matryoshka_layer.layers[str(resolution)](features) |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def text_features( |
|
self, |
|
texts: List[str], |
|
langs: Union[List[str], None] = None, |
|
normalize=False, |
|
resolution: Union[int, None] = None, |
|
): |
|
if langs is None: |
|
langs = ["eng_Latn"] * len(texts) |
|
texts = [f"{lang}{text}" for lang, text in zip(langs, texts)] |
|
input_ids = self.tokenizer.tokenizer.batch_encode_plus( |
|
texts, return_tensors="pt", padding="longest", add_special_tokens=False |
|
)["input_ids"].to(self._device) |
|
with torch.inference_mode(): |
|
features = self.model.text(input_ids) |
|
if resolution is not None: |
|
if resolution not in self.matryoshka_layer.resolutions: |
|
raise ValueError( |
|
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}" |
|
) |
|
features = self.matryoshka_layer.layers[str(resolution)](features) |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def get_logits( |
|
self, |
|
images: List[Image.Image], |
|
texts: List[str], |
|
langs: Union[List[str], None] = None, |
|
resolution: Union[int, None] = None, |
|
): |
|
image_features = self.image_features( |
|
images, normalize=True, resolution=resolution |
|
) |
|
text_features = self.text_features( |
|
texts, langs, normalize=True, resolution=resolution |
|
) |
|
with torch.inference_mode(): |
|
image_logits = ( |
|
self.model.logit_scale.exp() * image_features @ text_features.T |
|
) |
|
if self.model.logit_bias is not None: |
|
image_logits += self.model.logit_bias |
|
text_logits = image_logits.T |
|
return image_logits, text_logits |