import torch class VitLaionPreProcess(torch.nn.Module): def __init__(self, processor): super().__init__() self.processor = processor def forward(self, img): out = self.processor(images=img, return_tensors="pt") return out.data['pixel_values'].squeeze() class VitLaionFeatureExtractor(torch.nn.Module): def __init__(self, model, processor): super().__init__() self.vit_model = model self.transforms = VitLaionPreProcess(processor) def forward(self, x): img_a, img_b = x return self.vit_model.get_image_features(pixel_values=img_a), self.vit_model.get_image_features( pixel_values=img_b)