Spaces:
Runtime error
Runtime error
import torch | |
import torch.utils.data as data | |
import cv2 | |
import numpy as np | |
from os.path import join | |
class HalftoneVOC2012(data.Dataset): | |
# data range is [-1,1], color image is in BGR format | |
def __init__(self, data_list): | |
super(HalftoneVOC2012, self).__init__() | |
self.inputs = [join('Data', x) for x in data_list['inputs']] | |
self.labels = [join('Data', x) for x in data_list['labels']] | |
def load_input(name): | |
img = cv2.imread(name, flags=cv2.IMREAD_COLOR) | |
# transpose data | |
img = img.transpose((2, 0, 1)) | |
# to Tensor | |
img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0) | |
return img | |
def load_label(name): | |
img = cv2.imread(name, flags=cv2.IMREAD_GRAYSCALE) | |
# transpose data | |
img = img[np.newaxis, :, :] | |
# to Tensor | |
img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0) | |
return img | |
def __getitem__(self, index): | |
input_data = self.load_input(self.inputs[index]) | |
label_data = self.load_label(self.labels[index]) | |
return input_data, label_data | |
def __len__(self): | |
return len(self.inputs) |