dog_recognition / feat_ext.py
Daniel Bustamante Ospina
Changes in the model loading
d008b16
import torch
class VitLaionPreProcess(torch.nn.Module):
def __init__(self, processor):
super().__init__()
self.processor = processor
def forward(self, img):
out = self.processor(images=img, return_tensors="pt")
return out.data['pixel_values'].squeeze()
class VitLaionFeatureExtractor(torch.nn.Module):
def __init__(self, model, processor):
super().__init__()
self.vit_model = model
self.transforms = VitLaionPreProcess(processor)
def forward(self, x):
img_a, img_b = x
return self.vit_model.get_image_features(pixel_values=img_a), self.vit_model.get_image_features(
pixel_values=img_b)