import torch from transformers import AutoModel, AutoProcessor class VitLaionPreProcess(torch.nn.Module): def __init__(self): super().__init__() self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") def forward(self, img): out = self.processor(images=img, return_tensors="pt") return out.data['pixel_values'].squeeze() class VitLaionFeatureExtractor(torch.nn.Module): def __init__(self): super().__init__() self.vit_model = AutoModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") self.transforms = VitLaionPreProcess() def forward(self, x): img_a, img_b = x return self.vit_model.get_image_features(pixel_values=img_a), self.vit_model.get_image_features( pixel_values=img_b)