Woleek's picture
Init
c4e7950
raw
history blame contribute delete
607 Bytes
import torch
from diffusers import ConfigMixin, Mel, ModelMixin
class ImageEncoder(ModelMixin, ConfigMixin):
def __init__(self, image_processor, encoder_model):
super().__init__()
self.processor = image_processor
self.encoder = encoder_model
self.eval()
def forward(self, x):
x = self.encoder(x)
return x
@torch.no_grad()
def encode(self, image):
x = self.processor(image, return_tensors="pt")['pixel_values']
y = self(x)
y = y.last_hidden_state
embedings = y[:,0,:]
return embedings