Keiser41's picture
Upload 98 files
22d8ab7
import torch
import numpy as np
# stdlib
import os
# 3p
from skimage import io
import cv2
xyz_from_rgb = np.array(
[[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]]
)
rgb_from_xyz = np.array(
[[3.24048134, -0.96925495, 0.05564664], [-1.53715152, 1.87599, -0.20404134], [-0.49853633, 0.04155593, 1.05731107]]
)
def tensor_lab2rgb(input):
"""
n * 3* h *w
"""
input_trans = input.transpose(1, 2).transpose(2, 3) # n * h * w * 3
L, a, b = input_trans[:, :, :, 0:1], input_trans[:, :, :, 1:2], input_trans[:, :, :, 2:]
y = (L + 16.0) / 116.0
x = (a / 500.0) + y
z = y - (b / 200.0)
neg_mask = z.data < 0
z[neg_mask] = 0
xyz = torch.cat((x, y, z), dim=3)
mask = xyz.data > 0.2068966
mask_xyz = xyz.clone()
mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787
mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047
mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883
rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(
input.size(0), input.size(2), input.size(3), 3
)
rgb = rgb_trans.transpose(2, 3).transpose(1, 2)
mask = rgb > 0.0031308
mask_rgb = rgb.clone()
mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
mask_rgb[~mask] = rgb[~mask] * 12.92
neg_mask = mask_rgb.data < 0
large_mask = mask_rgb.data > 1
mask_rgb[neg_mask] = 0
mask_rgb[large_mask] = 1
return mask_rgb
def get_files(img_dir):
imgs, masks, xmls = list_files(img_dir)
return imgs, masks, xmls
def list_files(in_path):
img_files = []
mask_files = []
gt_files = []
for (dirpath, dirnames, filenames) in os.walk(in_path):
for file in filenames:
filename, ext = os.path.splitext(file)
ext = str.lower(ext)
if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm':
img_files.append(os.path.join(dirpath, file))
elif ext == '.bmp':
mask_files.append(os.path.join(dirpath, file))
elif ext == '.xml' or ext == '.gt' or ext == '.txt':
gt_files.append(os.path.join(dirpath, file))
elif ext == '.zip':
continue
return img_files, mask_files, gt_files
def load_image(img_file):
img = io.imread(img_file) # RGB order
if img.shape[0] == 2:
img = img[0]
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
if img.shape[2] == 4:
img = img[:, :, :3]
img = np.array(img)
return img