# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch import torchvision.transforms as transforms import torch.utils.data as data import os import pickle import numpy as np import nltk from PIL import Image from build_vocab import Vocabulary import random import json import lmdb class Recipe1MDataset(data.Dataset): def __init__(self, data_dir, aux_data_dir, split, maxseqlen, maxnuminstrs, maxnumlabels, maxnumims, transform=None, max_num_samples=-1, use_lmdb=False, suff=''): self.ingrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_ingrs.pkl'), 'rb')) self.instrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_toks.pkl'), 'rb')) self.dataset = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_'+split+'.pkl'), 'rb')) self.label2word = self.get_ingrs_vocab() self.use_lmdb = use_lmdb if use_lmdb: self.image_file = lmdb.open(os.path.join(aux_data_dir, 'lmdb_' + split), max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) self.ids = [] self.split = split for i, entry in enumerate(self.dataset): if len(entry['images']) == 0: continue self.ids.append(i) self.root = os.path.join(data_dir, 'images', split) self.transform = transform self.max_num_labels = maxnumlabels self.maxseqlen = maxseqlen self.max_num_instrs = maxnuminstrs self.maxseqlen = maxseqlen*maxnuminstrs self.maxnumims = maxnumims if max_num_samples != -1: random.shuffle(self.ids) self.ids = self.ids[:max_num_samples] def get_instrs_vocab(self): return self.instrs_vocab def get_instrs_vocab_size(self): return len(self.instrs_vocab) def get_ingrs_vocab(self): return [min(w, key=len) if not isinstance(w, str) else w for w in self.ingrs_vocab.idx2word.values()] # includes 'pad' ingredient def get_ingrs_vocab_size(self): return len(self.ingrs_vocab) def __getitem__(self, index): """Returns one data pair (image and caption).""" sample = self.dataset[self.ids[index]] img_id = sample['id'] captions = sample['tokenized'] paths = sample['images'][0:self.maxnumims] idx = index labels = self.dataset[self.ids[idx]]['ingredients'] title = sample['title'] tokens = [] tokens.extend(title) # add fake token to separate title from recipe tokens.append('') for c in captions: tokens.extend(c) tokens.append('') ilabels_gt = np.ones(self.max_num_labels) * self.ingrs_vocab('') pos = 0 true_ingr_idxs = [] for i in range(len(labels)): true_ingr_idxs.append(self.ingrs_vocab(labels[i])) for i in range(self.max_num_labels): if i >= len(labels): label = '' else: label = labels[i] label_idx = self.ingrs_vocab(label) if label_idx not in ilabels_gt: ilabels_gt[pos] = label_idx pos += 1 ilabels_gt[pos] = self.ingrs_vocab('') ingrs_gt = torch.from_numpy(ilabels_gt).long() if len(paths) == 0: path = None image_input = torch.zeros((3, 224, 224)) else: if self.split == 'train': img_idx = np.random.randint(0, len(paths)) else: img_idx = 0 path = paths[img_idx] if self.use_lmdb: try: with self.image_file.begin(write=False) as txn: image = txn.get(path.encode()) image = np.fromstring(image, dtype=np.uint8) image = np.reshape(image, (256, 256, 3)) image = Image.fromarray(image.astype('uint8'), 'RGB') except: print ("Image id not found in lmdb. Loading jpeg file...") image = Image.open(os.path.join(self.root, path[0], path[1], path[2], path[3], path)).convert('RGB') else: image = Image.open(os.path.join(self.root, path[0], path[1], path[2], path[3], path)).convert('RGB') if self.transform is not None: image = self.transform(image) image_input = image # Convert caption (string) to word ids. caption = [] caption = self.caption_to_idxs(tokens, caption) caption.append(self.instrs_vocab('')) caption = caption[0:self.maxseqlen] target = torch.Tensor(caption) return image_input, target, ingrs_gt, img_id, path, self.instrs_vocab('') def __len__(self): return len(self.ids) def caption_to_idxs(self, tokens, caption): caption.append(self.instrs_vocab('')) for token in tokens: caption.append(self.instrs_vocab(token)) return caption def collate_fn(data): # Sort a data list by caption length (descending order). # data.sort(key=lambda x: len(x[2]), reverse=True) image_input, captions, ingrs_gt, img_id, path, pad_value = zip(*data) # Merge images (from tuple of 3D tensor to 4D tensor). image_input = torch.stack(image_input, 0) ingrs_gt = torch.stack(ingrs_gt, 0) # Merge captions (from tuple of 1D tensor to 2D tensor). lengths = [len(cap) for cap in captions] targets = torch.ones(len(captions), max(lengths)).long()*pad_value[0] for i, cap in enumerate(captions): end = lengths[i] targets[i, :end] = cap[:end] return image_input, targets, ingrs_gt, img_id, path def get_loader(data_dir, aux_data_dir, split, maxseqlen, maxnuminstrs, maxnumlabels, maxnumims, transform, batch_size, shuffle, num_workers, drop_last=False, max_num_samples=-1, use_lmdb=False, suff=''): dataset = Recipe1MDataset(data_dir=data_dir, aux_data_dir=aux_data_dir, split=split, maxseqlen=maxseqlen, maxnumlabels=maxnumlabels, maxnuminstrs=maxnuminstrs, maxnumims=maxnumims, transform=transform, max_num_samples=max_num_samples, use_lmdb=use_lmdb, suff=suff) data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last, collate_fn=collate_fn, pin_memory=True) return data_loader, dataset