SeyedAli's picture
Rename src/similarity/model_implements/ViTMS.py to src/similarity/model_implements/ViTMSN.py
58dc921
raw
history blame contribute delete
754 Bytes
from transformers import AutoImageProcessor, ViTMSNModel
from PIL import Image
import numpy as np
import torch
class ViTMS():
def __init__(self):
self.feature_extractor = AutoImageProcessor.from_pretrained('facebook/vit-msn-small')
self.model = ViTMSNModel.from_pretrained('facebook/vit-msn-small')
def extract_feature(self, imgs):
features = []
for img in imgs:
inputs = self.feature_extractor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
last_hidden_states = outputs.last_hidden_state
features.append(np.squeeze(last_hidden_states.numpy()).flatten())
return features