import argparse, os, json import numpy as np from imageio import imread from PIL import Image import torch import torchvision import ssl ssl._create_default_https_context = ssl._create_unverified_context def build_model(model='resnet101', model_stage=3): cnn = getattr(torchvision.models, model)(pretrained=True) layers = [ cnn.conv1, cnn.bn1, cnn.relu, cnn.maxpool, ] for i in range(model_stage): name = 'layer%d' % (i + 1) layers.append(getattr(cnn, name)) model = torch.nn.Sequential(*layers) # model.cuda() model.eval() return model def run_image(img, model): mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) image = np.concatenate([img], 0).astype(np.float32) image = (image / 255.0 - mean) / std image = torch.FloatTensor(image) image = torch.autograd.Variable(image, volatile=True) feats = model(image) feats = feats.data.cpu().clone().numpy() return feats def get_img_feat(cnn_model, img, image_height=224, image_width=224): img_size = (image_height, image_width) img = np.array(Image.fromarray(np.uint8(img)).resize(img_size)) img = img.transpose(2, 0, 1)[None] feats = run_image(img, cnn_model) _, C, H, W = feats.shape feat_dset = feats.reshape(1, C, H, W) return feat_dset