|
import torch |
|
import numpy as np |
|
|
|
import os |
|
|
|
from skimage import io |
|
import cv2 |
|
xyz_from_rgb = np.array( |
|
[[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]] |
|
) |
|
rgb_from_xyz = np.array( |
|
[[3.24048134, -0.96925495, 0.05564664], [-1.53715152, 1.87599, -0.20404134], [-0.49853633, 0.04155593, 1.05731107]] |
|
) |
|
|
|
|
|
def tensor_lab2rgb(input): |
|
""" |
|
n * 3* h *w |
|
""" |
|
input_trans = input.transpose(1, 2).transpose(2, 3) |
|
L, a, b = input_trans[:, :, :, 0:1], input_trans[:, :, :, 1:2], input_trans[:, :, :, 2:] |
|
y = (L + 16.0) / 116.0 |
|
x = (a / 500.0) + y |
|
z = y - (b / 200.0) |
|
|
|
neg_mask = z.data < 0 |
|
z[neg_mask] = 0 |
|
xyz = torch.cat((x, y, z), dim=3) |
|
|
|
mask = xyz.data > 0.2068966 |
|
mask_xyz = xyz.clone() |
|
mask_xyz[mask] = torch.pow(xyz[mask], 3.0) |
|
mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787 |
|
mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047 |
|
mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883 |
|
|
|
rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view( |
|
input.size(0), input.size(2), input.size(3), 3 |
|
) |
|
rgb = rgb_trans.transpose(2, 3).transpose(1, 2) |
|
|
|
mask = rgb > 0.0031308 |
|
mask_rgb = rgb.clone() |
|
mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055 |
|
mask_rgb[~mask] = rgb[~mask] * 12.92 |
|
|
|
neg_mask = mask_rgb.data < 0 |
|
large_mask = mask_rgb.data > 1 |
|
mask_rgb[neg_mask] = 0 |
|
mask_rgb[large_mask] = 1 |
|
return mask_rgb |
|
|
|
def get_files(img_dir): |
|
imgs, masks, xmls = list_files(img_dir) |
|
return imgs, masks, xmls |
|
|
|
|
|
def list_files(in_path): |
|
img_files = [] |
|
mask_files = [] |
|
gt_files = [] |
|
for (dirpath, dirnames, filenames) in os.walk(in_path): |
|
for file in filenames: |
|
filename, ext = os.path.splitext(file) |
|
ext = str.lower(ext) |
|
if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm': |
|
img_files.append(os.path.join(dirpath, file)) |
|
elif ext == '.bmp': |
|
mask_files.append(os.path.join(dirpath, file)) |
|
elif ext == '.xml' or ext == '.gt' or ext == '.txt': |
|
gt_files.append(os.path.join(dirpath, file)) |
|
elif ext == '.zip': |
|
continue |
|
return img_files, mask_files, gt_files |
|
|
|
|
|
def load_image(img_file): |
|
img = io.imread(img_file) |
|
if img.shape[0] == 2: |
|
img = img[0] |
|
if len(img.shape) == 2: |
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
if img.shape[2] == 4: |
|
img = img[:, :, :3] |
|
img = np.array(img) |
|
|
|
return img |