menghanxia's picture
created the space
6e70c4a
raw
history blame
1.2 kB
import torch
import torch.utils.data as data
import cv2
import numpy as np
from os.path import join
class HalftoneVOC2012(data.Dataset):
# data range is [-1,1], color image is in BGR format
def __init__(self, data_list):
super(HalftoneVOC2012, self).__init__()
self.inputs = [join('Data', x) for x in data_list['inputs']]
self.labels = [join('Data', x) for x in data_list['labels']]
@staticmethod
def load_input(name):
img = cv2.imread(name, flags=cv2.IMREAD_COLOR)
# transpose data
img = img.transpose((2, 0, 1))
# to Tensor
img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0)
return img
@staticmethod
def load_label(name):
img = cv2.imread(name, flags=cv2.IMREAD_GRAYSCALE)
# transpose data
img = img[np.newaxis, :, :]
# to Tensor
img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0)
return img
def __getitem__(self, index):
input_data = self.load_input(self.inputs[index])
label_data = self.load_label(self.labels[index])
return input_data, label_data
def __len__(self):
return len(self.inputs)