Spaces:
Runtime error
Runtime error
import argparse, os, json | |
import numpy as np | |
from imageio import imread | |
from PIL import Image | |
import torch | |
import torchvision | |
import ssl | |
ssl._create_default_https_context = ssl._create_unverified_context | |
def build_model(model='resnet101', model_stage=3): | |
cnn = getattr(torchvision.models, model)(pretrained=True) | |
layers = [ | |
cnn.conv1, | |
cnn.bn1, | |
cnn.relu, | |
cnn.maxpool, | |
] | |
for i in range(model_stage): | |
name = 'layer%d' % (i + 1) | |
layers.append(getattr(cnn, name)) | |
model = torch.nn.Sequential(*layers) | |
# model.cuda() | |
model.eval() | |
return model | |
def run_image(img, model): | |
mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) | |
std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) | |
image = np.concatenate([img], 0).astype(np.float32) | |
image = (image / 255.0 - mean) / std | |
image = torch.FloatTensor(image) | |
image = torch.autograd.Variable(image, volatile=True) | |
feats = model(image) | |
feats = feats.data.cpu().clone().numpy() | |
return feats | |
def get_img_feat(cnn_model, img, image_height=224, image_width=224): | |
img_size = (image_height, image_width) | |
img = np.array(Image.fromarray(np.uint8(img)).resize(img_size)) | |
img = img.transpose(2, 0, 1)[None] | |
feats = run_image(img, cnn_model) | |
_, C, H, W = feats.shape | |
feat_dset = feats.reshape(1, C, H, W) | |
return feat_dset | |