import torch import numpy as np # stdlib import os # 3p from skimage import io import cv2 xyz_from_rgb = np.array( [[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]] ) rgb_from_xyz = np.array( [[3.24048134, -0.96925495, 0.05564664], [-1.53715152, 1.87599, -0.20404134], [-0.49853633, 0.04155593, 1.05731107]] ) def tensor_lab2rgb(input): """ n * 3* h *w """ input_trans = input.transpose(1, 2).transpose(2, 3) # n * h * w * 3 L, a, b = input_trans[:, :, :, 0:1], input_trans[:, :, :, 1:2], input_trans[:, :, :, 2:] y = (L + 16.0) / 116.0 x = (a / 500.0) + y z = y - (b / 200.0) neg_mask = z.data < 0 z[neg_mask] = 0 xyz = torch.cat((x, y, z), dim=3) mask = xyz.data > 0.2068966 mask_xyz = xyz.clone() mask_xyz[mask] = torch.pow(xyz[mask], 3.0) mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787 mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047 mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883 rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view( input.size(0), input.size(2), input.size(3), 3 ) rgb = rgb_trans.transpose(2, 3).transpose(1, 2) mask = rgb > 0.0031308 mask_rgb = rgb.clone() mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055 mask_rgb[~mask] = rgb[~mask] * 12.92 neg_mask = mask_rgb.data < 0 large_mask = mask_rgb.data > 1 mask_rgb[neg_mask] = 0 mask_rgb[large_mask] = 1 return mask_rgb def get_files(img_dir): imgs, masks, xmls = list_files(img_dir) return imgs, masks, xmls def list_files(in_path): img_files = [] mask_files = [] gt_files = [] for (dirpath, dirnames, filenames) in os.walk(in_path): for file in filenames: filename, ext = os.path.splitext(file) ext = str.lower(ext) if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm': img_files.append(os.path.join(dirpath, file)) elif ext == '.bmp': mask_files.append(os.path.join(dirpath, file)) elif ext == '.xml' or ext == '.gt' or ext == '.txt': gt_files.append(os.path.join(dirpath, file)) elif ext == '.zip': continue return img_files, mask_files, gt_files def load_image(img_file): img = io.imread(img_file) # RGB order if img.shape[0] == 2: img = img[0] if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) if img.shape[2] == 4: img = img[:, :, :3] img = np.array(img) return img