Jyothirmai's picture
Upload 13 files
d290c84 verified
raw
history blame
5.14 kB
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import json
from utils.build_vocab import Vocabulary, JsonReader
import numpy as np
from torchvision import transforms
import pickle
class ChestXrayDataSet(Dataset):
def __init__(self,
image_dir,
caption_json,
file_list,
vocabulary,
s_max=10,
n_max=50,
transforms=None):
self.image_dir = image_dir
self.caption = JsonReader(caption_json)
self.file_names, self.labels = self.__load_label_list(file_list)
self.vocab = vocabulary
self.transform = transforms
self.s_max = s_max
self.n_max = n_max
def __load_label_list(self, file_list):
labels = []
filename_list = []
with open(file_list, 'r') as f:
for line in f:
items = line.split()
image_name = items[0]
label = items[1:]
label = [int(i) for i in label]
image_name = '{}.png'.format(image_name)
filename_list.append(image_name)
labels.append(label)
return filename_list, labels
def __getitem__(self, index):
image_name = self.file_names[index]
image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB')
label = self.labels[index]
if self.transform is not None:
image = self.transform(image)
try:
text = self.caption[image_name]
except Exception as err:
text = 'normal. '
target = list()
max_word_num = 0
for i, sentence in enumerate(text.split('. ')):
if i >= self.s_max:
break
sentence = sentence.split()
if len(sentence) == 0 or len(sentence) == 1 or len(sentence) > self.n_max:
continue
tokens = list()
tokens.append(self.vocab('<start>'))
tokens.extend([self.vocab(token) for token in sentence])
tokens.append(self.vocab('<end>'))
if max_word_num < len(tokens):
max_word_num = len(tokens)
target.append(tokens)
sentence_num = len(target)
return image, image_name, list(label / np.sum(label)), target, sentence_num, max_word_num
def __len__(self):
return len(self.file_names)
def collate_fn(data):
images, image_id, label, captions, sentence_num, max_word_num = zip(*data)
images = torch.stack(images, 0)
max_sentence_num = max(sentence_num)
max_word_num = max(max_word_num)
targets = np.zeros((len(captions), max_sentence_num + 1, max_word_num))
prob = np.zeros((len(captions), max_sentence_num + 1))
for i, caption in enumerate(captions):
for j, sentence in enumerate(caption):
targets[i, j, :len(sentence)] = sentence[:]
prob[i][j] = len(sentence) > 0
return images, image_id, torch.Tensor(label), targets, prob
def get_loader(image_dir,
caption_json,
file_list,
vocabulary,
transform,
batch_size,
s_max=10,
n_max=50,
shuffle=False):
dataset = ChestXrayDataSet(image_dir=image_dir,
caption_json=caption_json,
file_list=file_list,
vocabulary=vocabulary,
s_max=s_max,
n_max=n_max,
transforms=transform)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn)
return data_loader
if __name__ == '__main__':
vocab_path = '../data/vocab.pkl'
image_dir = '../data/images'
caption_json = '../data/debugging_captions.json'
file_list = '../data/debugging.txt'
batch_size = 6
resize = 256
crop_size = 224
transform = transforms.Compose([
transforms.Resize(resize),
transforms.RandomCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
with open(vocab_path, 'rb') as f:
vocab = pickle.load(f)
data_loader = get_loader(image_dir=image_dir,
caption_json=caption_json,
file_list=file_list,
vocabulary=vocab,
transform=transform,
batch_size=batch_size,
shuffle=False)
for i, (image, image_id, label, target, prob) in enumerate(data_loader):
print(image.shape)
print(image_id)
print(label)
print(target)
print(prob)
break