Spaces:
Runtime error
Runtime error
import torch | |
class VitLaionPreProcess(torch.nn.Module): | |
def __init__(self, processor): | |
super().__init__() | |
self.processor = processor | |
def forward(self, img): | |
out = self.processor(images=img, return_tensors="pt") | |
return out.data['pixel_values'].squeeze() | |
class VitLaionFeatureExtractor(torch.nn.Module): | |
def __init__(self, model, processor): | |
super().__init__() | |
self.vit_model = model | |
self.transforms = VitLaionPreProcess(processor) | |
def forward(self, x): | |
img_a, img_b = x | |
return self.vit_model.get_image_features(pixel_values=img_a), self.vit_model.get_image_features( | |
pixel_values=img_b) | |