Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import torch | |
import numpy as np | |
from PIL import Image | |
from dataclasses import dataclass | |
from torchvision.transforms import Normalize | |
from transformers import CLIPModel, CLIPTokenizer | |
from transformers.utils import ModelOutput | |
from typing import Iterable, Optional, Union, List | |
ImageType = Union[np.ndarray, torch.Tensor, Image.Image] | |
class CLIPEmbedOutput(ModelOutput): | |
last_hidden_state: torch.FloatTensor = None | |
pooler_output: torch.FloatTensor = None | |
embeds: torch.FloatTensor = None | |
class CLIPEncoder(torch.nn.Module): | |
def __init__(self, model_path="openai/clip-vit-base-patch32"): | |
super().__init__() | |
# Load the CLIP model and processor | |
self.model: CLIPModel = CLIPModel.from_pretrained(model_path) | |
self.tokenizer = CLIPTokenizer.from_pretrained(model_path) | |
self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
self.model.training = False | |
for p in self.model.parameters(): | |
p.requires_grad = False | |
def encode_image(self, images: Iterable[Optional[ImageType]]): | |
pixel_values = self.image_preprocess(images) | |
vision_outputs = self.model.vision_model(pixel_values=pixel_values) | |
pooler_output = vision_outputs[1] # pooled_output | |
image_features = self.model.visual_projection(pooler_output) | |
visual_embeds = CLIPEmbedOutput( | |
last_hidden_state=vision_outputs.last_hidden_state, | |
pooler_output=pooler_output, | |
embeds=image_features | |
) | |
return visual_embeds | |
def encode_text(self, texts: List[str]): | |
text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") | |
text_outputs = self.model.text_model(input_ids=text_inputs) | |
pooler_output = text_outputs[1] # pooled_output | |
text_features = self.model.text_projection(pooler_output) | |
text_embeds = CLIPEmbedOutput( | |
last_hidden_state=text_outputs.last_hidden_state, | |
pooler_output=pooler_output, | |
embeds=text_features | |
) | |
return text_embeds | |
def forward(self, | |
images: Iterable[Optional[ImageType]], | |
texts: List[str]): | |
visual_embeds = self.encode_image(images) | |
text_embeds = self.encode_text(texts) | |
return visual_embeds, text_embeds | |