mmir_usersim / captioning /data /dataloaderraw.py
yashonwu
add captioning
9bf9e42
raw
history blame contribute delete
No virus
5.03 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import h5py
import os
import numpy as np
import random
import torch
import skimage
import skimage.io
import scipy.misc
from torchvision import transforms as trn
preprocess = trn.Compose([
#trn.ToTensor(),
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# from ..utils.resnet_utils import myResnet
# from ..utils import resnet
from captioning.utils.resnet_utils import myResnet
from captioning.utils import resnet
class DataLoaderRaw():
def __init__(self, opt):
self.opt = opt
self.coco_json = opt.get('coco_json', '')
self.folder_path = opt.get('folder_path', '')
self.batch_size = opt.get('batch_size', 1)
self.seq_per_img = 1
# Load resnet
self.cnn_model = opt.get('cnn_model', 'resnet101')
self.my_resnet = getattr(resnet, self.cnn_model)()
self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth'))
self.my_resnet = myResnet(self.my_resnet)
self.my_resnet.cuda()
self.my_resnet.eval()
# load the json file which contains additional information about the dataset
print('DataLoaderRaw loading images from folder: ', self.folder_path)
self.files = []
self.ids = []
print(len(self.coco_json))
if len(self.coco_json) > 0:
print('reading from ' + opt.coco_json)
# read in filenames from the coco-style json file
self.coco_annotation = json.load(open(self.coco_json))
for k,v in enumerate(self.coco_annotation['images']):
fullpath = os.path.join(self.folder_path, v['file_name'])
self.files.append(fullpath)
self.ids.append(v['id'])
else:
# read in all the filenames from the folder
print('listing all images in directory ' + self.folder_path)
def isImage(f):
supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM']
for ext in supportedExt:
start_idx = f.rfind(ext)
if start_idx >= 0 and start_idx + len(ext) == len(f):
return True
return False
n = 1
for root, dirs, files in os.walk(self.folder_path, topdown=False):
for file in files:
fullpath = os.path.join(self.folder_path, file)
if isImage(fullpath):
self.files.append(fullpath)
self.ids.append(str(n)) # just order them sequentially
n = n + 1
self.N = len(self.files)
print('DataLoaderRaw found ', self.N, ' images')
self.iterator = 0
# Nasty
self.dataset = self # to fix the bug in eval
def get_batch(self, split, batch_size=None):
batch_size = batch_size or self.batch_size
# pick an index of the datapoint to load next
fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32')
att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32')
max_index = self.N
wrapped = False
infos = []
for i in range(batch_size):
ri = self.iterator
ri_next = ri + 1
if ri_next >= max_index:
ri_next = 0
wrapped = True
# wrap back around
self.iterator = ri_next
img = skimage.io.imread(self.files[ri])
if len(img.shape) == 2:
img = img[:,:,np.newaxis]
img = np.concatenate((img, img, img), axis=2)
img = img[:,:,:3].astype('float32')/255.0
img = torch.from_numpy(img.transpose([2,0,1])).cuda()
img = preprocess(img)
with torch.no_grad():
tmp_fc, tmp_att = self.my_resnet(img)
fc_batch[i] = tmp_fc.data.cpu().float().numpy()
att_batch[i] = tmp_att.data.cpu().float().numpy()
info_struct = {}
info_struct['id'] = self.ids[ri]
info_struct['file_path'] = self.files[ri]
infos.append(info_struct)
data = {}
data['fc_feats'] = fc_batch
data['att_feats'] = att_batch.reshape(batch_size, -1, 2048)
data['labels'] = np.zeros([batch_size, 0])
data['masks'] = None
data['att_masks'] = None
data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped}
data['infos'] = infos
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
return data
def reset_iterator(self, split):
self.iterator = 0
def get_vocab_size(self):
return len(self.ix_to_word)
def get_vocab(self):
return self.ix_to_word