nllb-siglip-mrl-base / nllb_mrl.py
visheratin's picture
Update nllb_mrl.py
1fe0679 verified
from typing import List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from open_clip.transform import PreprocessCfg, image_transform_v2
from PIL import Image
from transformers import PretrainedConfig, PreTrainedModel
class MatryoshkaNllbClipConfig(PretrainedConfig):
def __init__(
self,
clip_model_name: str = "",
target_resolution: int = -1,
mrl_resolutions: List[int] = [],
preprocess_cfg: Union[dict, None] = None,
**kwargs,
):
super().__init__(**kwargs)
self.clip_model_name = clip_model_name
self.target_resolution = target_resolution
self.mrl_resolutions = mrl_resolutions
self.preprocess_cfg = preprocess_cfg
class MatryoshkaLayer(nn.Module):
def __init__(self, resolutions: List[int], target_resolution: int = 768):
super().__init__()
self.resolutions = resolutions
self.layers = nn.ModuleDict()
for resolution in resolutions:
self.layers[str(resolution)] = nn.Linear(target_resolution, resolution)
def forward(self, x, resolution: Union[int, None] = None):
if resolution is not None:
if resolution not in self.resolutions:
raise ValueError(f"Resolution {resolution} not in {self.resolutions}")
return self.layers[str(resolution)](x)
outputs = []
for resolution in self.resolutions:
outputs.append(self.layers[str(resolution)](x))
return outputs
class MatryoshkaNllbClip(PreTrainedModel):
config_class = MatryoshkaNllbClipConfig
def __init__(self, config: MatryoshkaNllbClipConfig, device):
super().__init__(config)
if isinstance(device, str):
device = torch.device(device)
self.config = config
self.model = create_model(config.clip_model_name, output_dict=True)
pp_cfg = PreprocessCfg(
size=config.preprocess_cfg["size"],
mean=config.preprocess_cfg["mean"],
std=config.preprocess_cfg["std"],
interpolation=config.preprocess_cfg["interpolation"],
resize_mode=config.preprocess_cfg["resize_mode"],
)
self.transform = image_transform_v2(
pp_cfg,
is_train=False,
)
self._device = device
self.model.to(device)
self.matryoshka_layer = MatryoshkaLayer(
config.mrl_resolutions, config.target_resolution
)
self.matryoshka_layer.to(device)
self.tokenizer = get_tokenizer(config.clip_model_name)
def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
image_inputs = image_inputs.to(self._device)
input_ids = input_ids.to(self._device)
outputs = self.model(
image=image_inputs,
text=input_ids,
)
mrl_image_features = None
mrl_text_features = None
if resolution is not None:
mrl_image_features = self.matryoshka_layer.forward(
outputs["image_features"], resolution
)
mrl_text_features = self.matryoshka_layer.forward(
outputs["text_features"], resolution
)
return {
"image_features": outputs["image_features"],
"text_features": outputs["text_features"],
"mrl_image_features": mrl_image_features,
"mrl_text_features": mrl_text_features,
"logit_scale": outputs["logit_scale"],
"logit_bias": outputs["logit_bias"],
}
def encode_image(
self,
image,
normalize=False,
resolution: Union[int, None] = None,
):
with torch.inference_mode():
features = self.model.visual(image)
if resolution is not None:
if resolution not in self.matryoshka_layer.resolutions:
raise ValueError(
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
)
features = self.matryoshka_layer.layers[str(resolution)](features)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(
self,
text,
normalize=False,
resolution: Union[int, None] = None,
):
with torch.inference_mode():
features = self.model.text(text)
if resolution is not None:
if resolution not in self.matryoshka_layer.resolutions:
raise ValueError(
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
)
features = self.matryoshka_layer.layers[str(resolution)](features)
return F.normalize(features, dim=-1) if normalize else features
def image_features(
self,
images: List[Image.Image],
normalize=False,
resolution: Union[int, None] = None,
):
image_inputs = [self.transform(image) for image in images]
image_inputs = torch.stack(image_inputs, dim=0).to(self._device)
with torch.inference_mode():
features = self.model.visual(image_inputs)
if resolution is not None:
if resolution not in self.matryoshka_layer.resolutions:
raise ValueError(
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
)
features = self.matryoshka_layer.layers[str(resolution)](features)
return F.normalize(features, dim=-1) if normalize else features
def text_features(
self,
texts: List[str],
langs: Union[List[str], None] = None,
normalize=False,
resolution: Union[int, None] = None,
):
if langs is None:
langs = ["eng_Latn"] * len(texts)
texts = [f"{lang}{text}" for lang, text in zip(langs, texts)]
input_ids = self.tokenizer.tokenizer.batch_encode_plus(
texts, return_tensors="pt", padding="longest", add_special_tokens=False
)["input_ids"].to(self._device)
with torch.inference_mode():
features = self.model.text(input_ids)
if resolution is not None:
if resolution not in self.matryoshka_layer.resolutions:
raise ValueError(
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
)
features = self.matryoshka_layer.layers[str(resolution)](features)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(
self,
images: List[Image.Image],
texts: List[str],
langs: Union[List[str], None] = None,
resolution: Union[int, None] = None,
):
image_features = self.image_features(
images, normalize=True, resolution=resolution
)
text_features = self.text_features(
texts, langs, normalize=True, resolution=resolution
)
with torch.inference_mode():
image_logits = (
self.model.logit_scale.exp() * image_features @ text_features.T
)
if self.model.logit_bias is not None:
image_logits += self.model.logit_bias
text_logits = image_logits.T
return image_logits, text_logits