Maikou's picture
related files and example data
b621857
raw
history blame
2.45 kB
# -*- coding: utf-8 -*-
import torch
import numpy as np
from PIL import Image
from dataclasses import dataclass
from torchvision.transforms import Normalize
from transformers import CLIPModel, CLIPTokenizer
from transformers.utils import ModelOutput
from typing import Iterable, Optional, Union, List
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
@dataclass
class CLIPEmbedOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
embeds: torch.FloatTensor = None
class CLIPEncoder(torch.nn.Module):
def __init__(self, model_path="openai/clip-vit-base-patch32"):
super().__init__()
# Load the CLIP model and processor
self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.model.training = False
for p in self.model.parameters():
p.requires_grad = False
@torch.no_grad()
def encode_image(self, images: Iterable[Optional[ImageType]]):
pixel_values = self.image_preprocess(images)
vision_outputs = self.model.vision_model(pixel_values=pixel_values)
pooler_output = vision_outputs[1] # pooled_output
image_features = self.model.visual_projection(pooler_output)
visual_embeds = CLIPEmbedOutput(
last_hidden_state=vision_outputs.last_hidden_state,
pooler_output=pooler_output,
embeds=image_features
)
return visual_embeds
@torch.no_grad()
def encode_text(self, texts: List[str]):
text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
text_outputs = self.model.text_model(input_ids=text_inputs)
pooler_output = text_outputs[1] # pooled_output
text_features = self.model.text_projection(pooler_output)
text_embeds = CLIPEmbedOutput(
last_hidden_state=text_outputs.last_hidden_state,
pooler_output=pooler_output,
embeds=text_features
)
return text_embeds
def forward(self,
images: Iterable[Optional[ImageType]],
texts: List[str]):
visual_embeds = self.encode_image(images)
text_embeds = self.encode_text(texts)
return visual_embeds, text_embeds