from transformers import AutoImageProcessor, ViTMSNModel from PIL import Image import numpy as np import torch class ViTMS(): def __init__(self): self.feature_extractor = AutoImageProcessor.from_pretrained('facebook/vit-msn-small') self.model = ViTMSNModel.from_pretrained('facebook/vit-msn-small') def extract_feature(self, imgs): features = [] for img in imgs: inputs = self.feature_extractor(images=img, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) last_hidden_states = outputs.last_hidden_state features.append(np.squeeze(last_hidden_states.numpy()).flatten()) return features