marqo-fashionCLIP / marqo_fashionCLIP.py
DavidJung's picture
Add support for AutoModel
8478037
raw
history blame
2.31 kB
import torch
from open_clip import create_model
from transformers import PretrainedConfig, PreTrainedModel, CLIPProcessor
from transformers.models.clip.modeling_clip import CLIPOutput
from typing import Optional, Tuple, Union
class MarqoFashionCLIPConfig(PretrainedConfig):
def __init__(
self,
open_clip_model_name: str = "",
**kwargs,
):
super().__init__(**kwargs)
self.open_clip_model_name = open_clip_model_name
class MarqoFashionCLIP(PreTrainedModel):
config_class = MarqoFashionCLIPConfig
def __init__(self, config: MarqoFashionCLIPConfig):
super().__init__(config)
self.config = config
self.model = create_model(config.open_clip_model_name, output_dict=True)
self.model.to(self.device)
self.model.eval()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
normalize: bool = False,
**kwargs
) -> torch.FloatTensor:
with torch.inference_mode():
image_features = self.model.encode_image(pixel_values, normalize=normalize)
return image_features
def get_text_features(
self,
input_ids: torch.Tensor,
normalize: bool = False,
**kwargs
) -> torch.FloatTensor:
with torch.inference_mode():
text_features = self.model.encode_text(input_ids, normalize=normalize)
return text_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPOutput]:
vision_outputs = self.get_image_features(pixel_values=pixel_values, normalize=True)
text_outputs = self.get_text_features(input_ids=input_ids, normalize=True)
logits_per_text = text_outputs @ vision_outputs.T
logits_per_image = logits_per_text.T
if not return_dict:
return logits_per_image, logits_per_text, text_outputs, vision_outputs
return CLIPOutput(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_outputs,
image_embeds=vision_outputs
)