|
import time |
|
import os |
|
import torch |
|
import numpy as np |
|
import torchvision |
|
import torch.nn.functional as F |
|
from torchvision.datasets import ImageFolder |
|
import torchvision.transforms as transforms |
|
from tqdm import tqdm |
|
import pickle |
|
import argparse |
|
from PIL import Image |
|
|
|
concat = lambda x: np.concatenate(x, axis=0) |
|
to_np = lambda x: x.data.to("cpu").numpy() |
|
|
|
|
|
class Wrapper(torch.nn.Module): |
|
def __init__(self, model): |
|
super(Wrapper, self).__init__() |
|
self.model = model |
|
self.avgpool_output = None |
|
self.query = None |
|
self.cossim_value = {} |
|
|
|
def fw_hook(module, input, output): |
|
self.avgpool_output = output.squeeze() |
|
|
|
self.model.avgpool.register_forward_hook(fw_hook) |
|
|
|
def forward(self, input): |
|
_ = self.model(input) |
|
return self.avgpool_output |
|
|
|
def __repr__(self): |
|
return "Wrappper" |
|
|
|
|
|
def QueryToEmbedding(query_pil): |
|
dataset_transform = transforms.Compose( |
|
[ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
model = torchvision.models.resnet50(pretrained=True) |
|
model.eval() |
|
myw = Wrapper(model) |
|
|
|
|
|
query_pt = dataset_transform(query_pil) |
|
|
|
with torch.no_grad(): |
|
embedding = to_np(myw(query_pt.unsqueeze(0))) |
|
|
|
return np.asarray([embedding]) |
|
|