File size: 693 Bytes
dcafc9b
 
 
 
 
d008b16
dcafc9b
d008b16
dcafc9b
 
 
 
 
 
 
d008b16
dcafc9b
d008b16
 
dcafc9b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch


class VitLaionPreProcess(torch.nn.Module):

    def __init__(self, processor):
        super().__init__()
        self.processor = processor

    def forward(self, img):
        out = self.processor(images=img, return_tensors="pt")
        return out.data['pixel_values'].squeeze()


class VitLaionFeatureExtractor(torch.nn.Module):
    def __init__(self, model, processor):
        super().__init__()
        self.vit_model = model
        self.transforms = VitLaionPreProcess(processor)

    def forward(self, x):
        img_a, img_b = x
        return self.vit_model.get_image_features(pixel_values=img_a), self.vit_model.get_image_features(
            pixel_values=img_b)