nllb-siglip-mrl-base / nllb_mrl.py
visheratin's picture
Update nllb_mrl.py
2f07d8b verified
raw
history blame
6.84 kB
from dataclasses import dataclass
from typing import List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from open_clip import create_model_and_transforms, get_tokenizer
from PIL import Image
from transformers import PretrainedConfig
@dataclass
class MatryoshkaNllbClipConfig(PretrainedConfig):
clip_model_name: str
clip_model_version: str
target_resolution: int
mrl_resolutions: List[int]
class MatryoshkaLayer(nn.Module):
def __init__(self, resolutions: List[int], target_resolution: int = 768):
super().__init__()
self.resolutions = resolutions
self.layers = nn.ModuleDict()
for resolution in resolutions:
self.layers[str(resolution)] = nn.Linear(target_resolution, resolution)
def forward(self, x, resolution: Union[int, None] = None):
if resolution is not None:
if resolution not in self.resolutions:
raise ValueError(f"Resolution {resolution} not in {self.resolutions}")
return self.layers[str(resolution)](x)
outputs = []
for resolution in self.resolutions:
outputs.append(self.layers[str(resolution)](x))
return outputs
class MatryoshkaNllbClip(nn.Module, PyTorchModelHubMixin):
def __init__(self, config: MatryoshkaNllbClipConfig, device):
super().__init__()
if isinstance(device, str):
device = torch.device(device)
self.config = config
self.model, _, self.transform = create_model_and_transforms(
config.clip_model_name, config.clip_model_version, output_dict=True
)
self.device = device
self.model.to(device)
self.matryoshka_layer = MatryoshkaLayer(
config.mrl_resolutions, config.target_resolution
)
self.matryoshka_layer.to(device)
self.tokenizer = get_tokenizer(config.clip_model_name)
def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
image_inputs = image_inputs.to(self.device)
input_ids = input_ids.to(self.device)
outputs = self.model(
image=image_inputs,
text=input_ids,
)
mrl_image_features = None
mrl_text_features = None
if resolution is not None:
mrl_image_features = self.matryoshka_layer.forward(
outputs["image_features"], resolution
)
mrl_text_features = self.matryoshka_layer.forward(
outputs["text_features"], resolution
)
return {
"image_features": outputs["image_features"],
"text_features": outputs["text_features"],
"mrl_image_features": mrl_image_features,
"mrl_text_features": mrl_text_features,
"logit_scale": outputs["logit_scale"],
"logit_bias": outputs["logit_bias"],
}
def encode_image(
self,
image,
normalize=False,
resolution: Union[int, None] = None,
):
with torch.inference_mode():
features = self.model.visual(image)
if resolution is not None:
if resolution not in self.matryoshka_layer.resolutions:
raise ValueError(
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
)
features = self.matryoshka_layer.layers[str(resolution)](features)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(
self,
text,
normalize=False,
resolution: Union[int, None] = None,
):
with torch.inference_mode():
features = self.model.text(text)
if resolution is not None:
if resolution not in self.matryoshka_layer.resolutions:
raise ValueError(
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
)
features = self.matryoshka_layer.layers[str(resolution)](features)
return F.normalize(features, dim=-1) if normalize else features
def image_features(
self,
images: List[Image.Image],
normalize=False,
resolution: Union[int, None] = None,
):
image_inputs = [self.transform(image) for image in images]
image_inputs = torch.stack(image_inputs, dim=0).to(self.device)
with torch.inference_mode():
features = self.model.visual(image_inputs)
if resolution is not None:
if resolution not in self.matryoshka_layer.resolutions:
raise ValueError(
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
)
features = self.matryoshka_layer.layers[str(resolution)](features)
return F.normalize(features, dim=-1) if normalize else features
def text_features(
self,
texts: List[str],
langs: Union[List[str], None] = None,
normalize=False,
resolution: Union[int, None] = None,
):
if langs is None:
langs = ["eng_Latn"] * len(texts)
texts = [f"{lang} {text}" for lang, text in zip(langs, texts)]
input_ids = self.tokenizer.tokenizer.batch_encode_plus(
texts, return_tensors="pt", padding="longest", add_special_tokens=False
)["input_ids"].to(self.device)
with torch.inference_mode():
features = self.model.text(input_ids)
if resolution is not None:
if resolution not in self.matryoshka_layer.resolutions:
raise ValueError(
f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
)
features = self.matryoshka_layer.layers[str(resolution)](features)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(
self,
images: List[Image.Image],
texts: List[str],
langs: Union[List[str], None] = None,
resolution: Union[int, None] = None,
):
image_features = self.image_features(
images, normalize=True, resolution=resolution
)
text_features = self.text_features(
texts, langs, normalize=True, resolution=resolution
)
with torch.inference_mode():
image_logits = (
self.model.logit_scale.exp() * image_features @ text_features.T
)
if self.model.logit_bias is not None:
image_logits += self.model.logit_bias
text_logits = image_logits.T
return image_logits, text_logits