import torch from diffusers import ConfigMixin, Mel, ModelMixin class ImageEncoder(ModelMixin, ConfigMixin): def __init__(self, image_processor, encoder_model): super().__init__() self.processor = image_processor self.encoder = encoder_model self.eval() def forward(self, x): x = self.encoder(x) return x @torch.no_grad() def encode(self, image): x = self.processor(image, return_tensors="pt")['pixel_values'] y = self(x) y = y.last_hidden_state embedings = y[:,0,:] return embedings