Spaces:
Runtime error
Runtime error
File size: 693 Bytes
dcafc9b d008b16 dcafc9b d008b16 dcafc9b d008b16 dcafc9b d008b16 dcafc9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
class VitLaionPreProcess(torch.nn.Module):
def __init__(self, processor):
super().__init__()
self.processor = processor
def forward(self, img):
out = self.processor(images=img, return_tensors="pt")
return out.data['pixel_values'].squeeze()
class VitLaionFeatureExtractor(torch.nn.Module):
def __init__(self, model, processor):
super().__init__()
self.vit_model = model
self.transforms = VitLaionPreProcess(processor)
def forward(self, x):
img_a, img_b = x
return self.vit_model.get_image_features(pixel_values=img_a), self.vit_model.get_image_features(
pixel_values=img_b)
|