File size: 890 Bytes
e8aa256 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from PIL import Image
import cv2
import torchvision.transforms as T
# dinov2_vitl14
# dinov2_vitg14
def load_dinov2():
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14').cuda()
dinov2_vitl14.eval()
return dinov2_vitl14
def infer_model(model, image):
transform = T.Compose([
T.Resize((196, 196)),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
image = transform(image).unsqueeze(0).cuda()
cls_token = model.forward_features(image)
return cls_token
dinov2 = load_dinov2()
dinov2.requires_grad_(False)
image = "./validation_demo/3373891cdc_Image/1704429543488.jpg"
image = Image.open(image).convert('RGB')
# image = image.resize((64,64))
img_embedding = infer_model(dinov2, image)
print(img_embedding["x_norm_patchtokens"].shape,img_embedding["x_norm_clstoken"].shape) |