Spaces:
Runtime error
Runtime error
import torch | |
from torch.autograd import Variable as V | |
import torchvision.models as models | |
from torchvision import transforms as trn | |
from torch.nn import functional as F | |
import os | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
def recursion_change_bn(module): | |
if isinstance(module, torch.nn.BatchNorm2d): | |
module.track_running_stats = 1 | |
else: | |
for i, (name, module1) in enumerate(module._modules.items()): | |
module1 = recursion_change_bn(module1) | |
return module | |
def load_labels(): | |
# prepare all the labels | |
# scene category relevant | |
file_name_category = 'categories_places365.txt' | |
classes = list() | |
with open(file_name_category) as class_file: | |
for line in class_file: | |
classes.append(line.strip().split(' ')[0][3:]) | |
classes = tuple(classes) | |
# indoor and outdoor relevant | |
file_name_IO = 'IO_places365.txt' | |
with open(file_name_IO) as f: | |
lines = f.readlines() | |
labels_IO = [] | |
for line in lines: | |
items = line.rstrip().split() | |
labels_IO.append(int(items[-1]) -1) # 0 is indoor, 1 is outdoor | |
labels_IO = np.array(labels_IO) | |
# scene attribute relevant | |
file_name_attribute = 'labels_sunattribute.txt' | |
with open(file_name_attribute) as f: | |
lines = f.readlines() | |
labels_attribute = [item.rstrip() for item in lines] | |
file_name_W = 'W_sceneattribute_wideresnet18.npy' | |
W_attribute = np.load(file_name_W) | |
return classes, labels_IO, labels_attribute, W_attribute | |
def hook_feature(module, input, output): | |
return np.squeeze(output.data.cpu().numpy()) | |
def returnCAM(feature_conv, weight_softmax, class_idx): | |
# generate the class activation maps upsample to 256x256 | |
size_upsample = (256, 256) | |
nc, h, w = feature_conv.shape | |
output_cam = [] | |
for idx in class_idx: | |
cam = weight_softmax[class_idx].dot(feature_conv.reshape((nc, h*w))) | |
cam = cam.reshape(h, w) | |
cam = cam - np.min(cam) | |
cam_img = cam / np.max(cam) | |
cam_img = np.uint8(255 * cam_img) | |
output_cam.append(cv2.resize(cam_img, size_upsample)) | |
return output_cam | |
def returnTF(): | |
# load the image transformer | |
tf = trn.Compose([ | |
trn.Resize((224,224)), | |
trn.ToTensor(), | |
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
return tf | |
def load_model(): | |
# this model has a last conv feature map as 14x14 | |
model_file = 'wideresnet18_places365.pth.tar' | |
import wideresnet | |
model = wideresnet.resnet18(num_classes=365) | |
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage) | |
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()} | |
model.load_state_dict(state_dict) | |
# hacky way to deal with the upgraded batchnorm2D and avgpool layers... | |
for i, (name, module) in enumerate(model._modules.items()): | |
module = recursion_change_bn(model) | |
model.avgpool = torch.nn.AvgPool2d(kernel_size=14, stride=1, padding=0) | |
model.eval() | |
# hook the feature extractor | |
features_names = ['layer4','avgpool'] # this is the last conv layer of the resnet | |
for name in features_names: | |
model._modules.get(name).register_forward_hook(hook_feature) | |
return model | |
# load the labels | |
classes, labels_IO, labels_attribute, W_attribute = load_labels() | |
# load the model | |
features_blobs = [] | |
model = load_model() | |
# load the transformer | |
tf = returnTF() # image transformer | |
# get the softmax weight | |
params = list(model.parameters()) | |
weight_softmax = params[-2].data.numpy() | |
weight_softmax[weight_softmax<0] = 0 | |
def predict(img): | |
#img = Image.open('6.jpg') | |
input_img = V(tf(img).unsqueeze(0)) | |
logit = model.forward(input_img) | |
h_x = F.softmax(logit, 1).data.squeeze() | |
probs, idx = h_x.sort(0, True) | |
probs = probs.numpy() | |
idx = idx.numpy() | |
io_image = np.mean(labels_IO[idx[:10]]) # vote for the indoor or outdoor | |
env_image = [] | |
if io_image < 0.5: | |
env_image.append('Indoor') | |
#print('--TYPE OF ENVIRONMENT: indoor') | |
else: | |
env_image.append('Outdoor') | |
#print('--TYPE OF ENVIRONMENT: outdoor') | |
# output the prediction of scene category | |
#print('--SCENE CATEGORIES:') | |
scene_cat=[] | |
for i in range(0, 5): | |
scene_cat.append('{:.3f} -> {}'.format(probs[i], classes[idx[i]])) | |
#print('{:.3f} -> {}'.format(probs[i], classes[idx[i]])) | |
return env_image,scene_cat |