Spaces:
Runtime error
Runtime error
Update TEED/utils/img_processing.py
Browse files- TEED/utils/img_processing.py +307 -307
TEED/utils/img_processing.py
CHANGED
|
@@ -1,307 +1,307 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
import cv2
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
import kornia as kn
|
| 7 |
-
import cv2_ext
|
| 8 |
-
|
| 9 |
-
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio
|
| 10 |
-
from sklearn.metrics import mean_absolute_error
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def image_normalization(img, img_min=0, img_max=255,
|
| 14 |
-
epsilon=1e-12):
|
| 15 |
-
"""This is a typical image normalization function
|
| 16 |
-
where the minimum and maximum of the image is needed
|
| 17 |
-
source: https://en.wikipedia.org/wiki/Normalization_(image_processing)
|
| 18 |
-
|
| 19 |
-
:param img: an image could be gray scale or color
|
| 20 |
-
:param img_min: for default is 0
|
| 21 |
-
:param img_max: for default is 255
|
| 22 |
-
|
| 23 |
-
:return: a normalized image, if max is 255 the dtype is uint8
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
img = np.float32(img)
|
| 27 |
-
# whenever an inconsistent image
|
| 28 |
-
img = (img - np.min(img)) * (img_max - img_min) / \
|
| 29 |
-
((np.max(img) - np.min(img)) + epsilon) + img_min
|
| 30 |
-
return img
|
| 31 |
-
|
| 32 |
-
def count_parameters(model=None):
|
| 33 |
-
if model is not None:
|
| 34 |
-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 35 |
-
else:
|
| 36 |
-
print("Error counting model parameters line 32 img_processing.py")
|
| 37 |
-
raise NotImplementedError
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def save_image_batch_to_disk(tensor, output_dir, file_names, img_shape=None, arg=None, is_inchannel=False):
|
| 41 |
-
|
| 42 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 43 |
-
predict_all = arg.predict_all
|
| 44 |
-
if not arg.is_testing:
|
| 45 |
-
assert len(tensor.shape) == 4, tensor.shape
|
| 46 |
-
img_height,img_width = img_shape[0].item(),img_shape[1].item()
|
| 47 |
-
|
| 48 |
-
for tensor_image, file_name in zip(tensor, file_names):
|
| 49 |
-
image_vis = kn.utils.tensor_to_image(
|
| 50 |
-
torch.sigmoid(tensor_image))#[..., 0]
|
| 51 |
-
image_vis = (255.0*(1.0 - image_vis)).astype(np.uint8)
|
| 52 |
-
output_file_name = os.path.join(output_dir, file_name)
|
| 53 |
-
# print('image vis size', image_vis.shape)
|
| 54 |
-
image_vis =cv2.resize(image_vis, (img_width, img_height))
|
| 55 |
-
assert cv2_ext.imwrite(output_file_name, image_vis)
|
| 56 |
-
assert cv2_ext.imwrite('checkpoints/current_res/'+file_name, image_vis)
|
| 57 |
-
# print(f"Image saved in {output_file_name}")
|
| 58 |
-
else:
|
| 59 |
-
if is_inchannel:
|
| 60 |
-
|
| 61 |
-
tensor, tensor2 = tensor
|
| 62 |
-
fuse_name = 'fusedCH'
|
| 63 |
-
av_name='avgCH'
|
| 64 |
-
is_2tensors=True
|
| 65 |
-
edge_maps2 = []
|
| 66 |
-
for i in tensor2:
|
| 67 |
-
tmp = torch.sigmoid(i).cpu().detach().numpy()
|
| 68 |
-
edge_maps2.append(tmp)
|
| 69 |
-
tensor2 = np.array(edge_maps2)
|
| 70 |
-
else:
|
| 71 |
-
fuse_name = 'fused'
|
| 72 |
-
# av_name = 'avg'
|
| 73 |
-
tensor2=None
|
| 74 |
-
tmp_img2 = None
|
| 75 |
-
|
| 76 |
-
# output_dir_f = os.path.join(output_dir, fuse_name)# normal execution
|
| 77 |
-
output_dir_f = output_dir# for DMRIR
|
| 78 |
-
# output_dir_a = os.path.join(output_dir, av_name)
|
| 79 |
-
os.makedirs(output_dir_f, exist_ok=True)
|
| 80 |
-
# os.makedirs(output_dir_a, exist_ok=True)
|
| 81 |
-
if predict_all:
|
| 82 |
-
all_data_dir = os.path.join(output_dir, "all_edges")
|
| 83 |
-
os.makedirs(all_data_dir, exist_ok=True)
|
| 84 |
-
out1_dir = os.path.join(all_data_dir,"o1")
|
| 85 |
-
out2_dir = os.path.join(all_data_dir,"o2")
|
| 86 |
-
out3_dir = os.path.join(all_data_dir,"o3")# TEED =output 3
|
| 87 |
-
out4_dir = os.path.join(all_data_dir,"o4") # TEED = average
|
| 88 |
-
out5_dir = os.path.join(all_data_dir,"o5")# fusion # TEED
|
| 89 |
-
out6_dir = os.path.join(all_data_dir,"o6") # fusion
|
| 90 |
-
os.makedirs(out1_dir, exist_ok=True)
|
| 91 |
-
os.makedirs(out2_dir, exist_ok=True)
|
| 92 |
-
os.makedirs(out3_dir, exist_ok=True)
|
| 93 |
-
os.makedirs(out4_dir, exist_ok=True)
|
| 94 |
-
os.makedirs(out5_dir, exist_ok=True)
|
| 95 |
-
os.makedirs(out6_dir, exist_ok=True)
|
| 96 |
-
|
| 97 |
-
# 255.0 * (1.0 - em_a)
|
| 98 |
-
edge_maps = []
|
| 99 |
-
for i in tensor:
|
| 100 |
-
tmp = torch.sigmoid(i).cpu().detach().numpy()
|
| 101 |
-
edge_maps.append(tmp)
|
| 102 |
-
tensor = np.array(edge_maps)
|
| 103 |
-
# print(f"tensor shape: {tensor.shape}")
|
| 104 |
-
|
| 105 |
-
image_shape = [x.cpu().detach().numpy() for x in img_shape]
|
| 106 |
-
# (H, W) -> (W, H)
|
| 107 |
-
image_shape = [[y, x] for x, y in zip(image_shape[0], image_shape[1])]
|
| 108 |
-
|
| 109 |
-
assert len(image_shape) == len(file_names)
|
| 110 |
-
|
| 111 |
-
idx = 0
|
| 112 |
-
for i_shape, file_name in zip(image_shape, file_names):
|
| 113 |
-
tmp = tensor[:, idx, ...]
|
| 114 |
-
tmp2 = tensor2[:, idx, ...] if tensor2 is not None else None
|
| 115 |
-
# tmp = np.transpose(np.squeeze(tmp), [0, 1, 2])
|
| 116 |
-
tmp = np.squeeze(tmp)
|
| 117 |
-
tmp2 = np.squeeze(tmp2) if tensor2 is not None else None
|
| 118 |
-
|
| 119 |
-
# Iterate our all 7 NN outputs for a particular image
|
| 120 |
-
preds = []
|
| 121 |
-
fuse_num = tmp.shape[0]-1
|
| 122 |
-
for i in range(tmp.shape[0]):
|
| 123 |
-
tmp_img = tmp[i]
|
| 124 |
-
tmp_img = np.uint8(image_normalization(tmp_img))
|
| 125 |
-
tmp_img = cv2.bitwise_not(tmp_img)
|
| 126 |
-
# tmp_img[tmp_img < 0.0] = 0.0
|
| 127 |
-
# tmp_img = 255.0 * (1.0 - tmp_img)
|
| 128 |
-
if tmp2 is not None:
|
| 129 |
-
tmp_img2 = tmp2[i]
|
| 130 |
-
tmp_img2 = np.uint8(image_normalization(tmp_img2))
|
| 131 |
-
tmp_img2 = cv2.bitwise_not(tmp_img2)
|
| 132 |
-
|
| 133 |
-
# Resize prediction to match input image size
|
| 134 |
-
if not tmp_img.shape[1] == i_shape[0] or not tmp_img.shape[0] == i_shape[1]:
|
| 135 |
-
tmp_img = cv2.resize(tmp_img, (i_shape[0], i_shape[1]))
|
| 136 |
-
tmp_img2 = cv2.resize(tmp_img2, (i_shape[0], i_shape[1])) if tmp2 is not None else None
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
if tmp2 is not None:
|
| 140 |
-
tmp_mask = np.logical_and(tmp_img>128,tmp_img2<128)
|
| 141 |
-
tmp_img= np.where(tmp_mask, tmp_img2, tmp_img)
|
| 142 |
-
preds.append(tmp_img)
|
| 143 |
-
|
| 144 |
-
else:
|
| 145 |
-
preds.append(tmp_img)
|
| 146 |
-
|
| 147 |
-
if i == fuse_num:
|
| 148 |
-
# print('fuse num',tmp.shape[0], fuse_num, i)
|
| 149 |
-
fuse = tmp_img
|
| 150 |
-
fuse = fuse.astype(np.uint8)
|
| 151 |
-
if tmp_img2 is not None:
|
| 152 |
-
fuse2 = tmp_img2
|
| 153 |
-
fuse2 = fuse2.astype(np.uint8)
|
| 154 |
-
# fuse = fuse-fuse2
|
| 155 |
-
fuse_mask=np.logical_and(fuse>128,fuse2<128)
|
| 156 |
-
fuse = np.where(fuse_mask,fuse2, fuse)
|
| 157 |
-
|
| 158 |
-
# print(fuse.shape, fuse_mask.shape)
|
| 159 |
-
|
| 160 |
-
# Save predicted edge maps
|
| 161 |
-
average = np.array(preds, dtype=np.float32)
|
| 162 |
-
average = np.uint8(np.mean(average, axis=0))
|
| 163 |
-
output_file_name_f = os.path.join(output_dir_f, file_name)
|
| 164 |
-
# output_file_name_a = os.path.join(output_dir_a, file_name)
|
| 165 |
-
cv2.imwrite(output_file_name_f, fuse)
|
| 166 |
-
# cv2_ext.imwrite(output_file_name_a, average)
|
| 167 |
-
if predict_all:
|
| 168 |
-
cv2.imwrite(os.path.join(out1_dir,file_name),preds[0])
|
| 169 |
-
cv2.imwrite(os.path.join(out2_dir,file_name),preds[1])
|
| 170 |
-
cv2.imwrite(os.path.join(out3_dir,file_name),preds[2])
|
| 171 |
-
cv2.imwrite(os.path.join(out4_dir,file_name),average)
|
| 172 |
-
cv2.imwrite(os.path.join(out5_dir,file_name),fuse)
|
| 173 |
-
cv2.imwrite(os.path.join(out6_dir,file_name),fuse)
|
| 174 |
-
|
| 175 |
-
idx += 1
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def restore_rgb(config, I, restore_rgb=False):
|
| 179 |
-
"""
|
| 180 |
-
:param config: [args.channel_swap, args.mean_pixel_value]
|
| 181 |
-
:param I: and image or a set of images
|
| 182 |
-
:return: an image or a set of images restored
|
| 183 |
-
"""
|
| 184 |
-
|
| 185 |
-
if len(I) > 3 and not type(I) == np.ndarray:
|
| 186 |
-
I = np.array(I)
|
| 187 |
-
I = I[:, :, :, 0:3]
|
| 188 |
-
n = I.shape[0]
|
| 189 |
-
for i in range(n):
|
| 190 |
-
x = I[i, ...]
|
| 191 |
-
x = np.array(x, dtype=np.float32)
|
| 192 |
-
x += config[1]
|
| 193 |
-
if restore_rgb:
|
| 194 |
-
x = x[:, :, config[0]]
|
| 195 |
-
x = image_normalization(x)
|
| 196 |
-
I[i, :, :, :] = x
|
| 197 |
-
elif len(I.shape) == 3 and I.shape[-1] == 3:
|
| 198 |
-
I = np.array(I, dtype=np.float32)
|
| 199 |
-
I += config[1]
|
| 200 |
-
if restore_rgb:
|
| 201 |
-
I = I[:, :, config[0]]
|
| 202 |
-
I = image_normalization(I)
|
| 203 |
-
else:
|
| 204 |
-
print("Sorry the input data size is out of our configuration")
|
| 205 |
-
return I
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def visualize_result(imgs_list, arg):
|
| 209 |
-
"""
|
| 210 |
-
data 2 image in one matrix
|
| 211 |
-
:param imgs_list: a list of prediction, gt and input data
|
| 212 |
-
:param arg:
|
| 213 |
-
:return: one image with the whole of imgs_list data
|
| 214 |
-
"""
|
| 215 |
-
n_imgs = len(imgs_list)
|
| 216 |
-
data_list = []
|
| 217 |
-
for i in range(n_imgs):
|
| 218 |
-
tmp = imgs_list[i]
|
| 219 |
-
# print(tmp.shape)
|
| 220 |
-
if tmp.shape[0] == 3:
|
| 221 |
-
tmp = np.transpose(tmp, [1, 2, 0])
|
| 222 |
-
tmp = restore_rgb([
|
| 223 |
-
arg.channel_swap,
|
| 224 |
-
arg.mean_train[:3]
|
| 225 |
-
], tmp)
|
| 226 |
-
tmp = np.uint8(image_normalization(tmp))
|
| 227 |
-
else:
|
| 228 |
-
tmp = np.squeeze(tmp)
|
| 229 |
-
if len(tmp.shape) == 2:
|
| 230 |
-
tmp = np.uint8(image_normalization(tmp))
|
| 231 |
-
tmp = cv2.bitwise_not(tmp)
|
| 232 |
-
tmp = cv2.cvtColor(tmp, cv2.COLOR_GRAY2BGR)
|
| 233 |
-
else:
|
| 234 |
-
tmp = np.uint8(image_normalization(tmp))
|
| 235 |
-
data_list.append(tmp)
|
| 236 |
-
# print(i,tmp.shape)
|
| 237 |
-
img = data_list[0]
|
| 238 |
-
if n_imgs % 2 == 0:
|
| 239 |
-
imgs = np.zeros((img.shape[0] * 2 + 10, img.shape[1]
|
| 240 |
-
* (n_imgs // 2) + ((n_imgs // 2 - 1) * 5), 3))
|
| 241 |
-
else:
|
| 242 |
-
imgs = np.zeros((img.shape[0] * 2 + 10, img.shape[1]
|
| 243 |
-
* ((1 + n_imgs) // 2) + ((n_imgs // 2) * 5), 3))
|
| 244 |
-
n_imgs += 1
|
| 245 |
-
|
| 246 |
-
k = 0
|
| 247 |
-
imgs = np.uint8(imgs)
|
| 248 |
-
i_step = img.shape[0] + 10
|
| 249 |
-
j_step = img.shape[1] + 5
|
| 250 |
-
for i in range(2):
|
| 251 |
-
for j in range(n_imgs // 2):
|
| 252 |
-
if k < len(data_list):
|
| 253 |
-
imgs[i * i_step:i * i_step+img.shape[0],
|
| 254 |
-
j * j_step:j * j_step+img.shape[1],
|
| 255 |
-
:] = data_list[k]
|
| 256 |
-
k += 1
|
| 257 |
-
else:
|
| 258 |
-
pass
|
| 259 |
-
return imgs
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
if __name__ == '__main__':
|
| 264 |
-
|
| 265 |
-
img_base_dir='tmp_edge'
|
| 266 |
-
gt_base_dir='C:/Users/xavysp/dataset/BIPED/edges/edge_maps/test/rgbr'
|
| 267 |
-
# gt_base_dir='C:/Users/xavysp/dataset/BRIND/test_edges'
|
| 268 |
-
# gt_base_dir='C:/Users/xavysp/dataset/UDED/gt'
|
| 269 |
-
vers = 'TEED model in BIPED'
|
| 270 |
-
list_img = os.listdir(img_base_dir)
|
| 271 |
-
list_gt = os.listdir(gt_base_dir)
|
| 272 |
-
mse_list=[]
|
| 273 |
-
psnr_list=[]
|
| 274 |
-
mae_list=[]
|
| 275 |
-
|
| 276 |
-
for img_name, gt_name in zip(list_img,list_gt):
|
| 277 |
-
|
| 278 |
-
# print(img_name, ' ', gt_name)
|
| 279 |
-
tmp_img = cv2.imread(os.path.join(img_base_dir,img_name),0)
|
| 280 |
-
tmp_img = cv2.bitwise_not(tmp_img) # if the image's background
|
| 281 |
-
# is white uncomment this line
|
| 282 |
-
tmp_gt = cv2.imread(os.path.join(gt_base_dir,gt_name),0)
|
| 283 |
-
# print(f"image {img_name} {tmp_img.shape}")
|
| 284 |
-
# print(f"gt {gt_name} {tmp_gt.shape}")
|
| 285 |
-
a = tmp_img.copy()
|
| 286 |
-
tmp_img = image_normalization(tmp_img, img_max=1.)
|
| 287 |
-
tmp_gt = image_normalization(tmp_gt, img_max=1.)
|
| 288 |
-
psnr = peak_signal_noise_ratio(tmp_gt, tmp_img)
|
| 289 |
-
mse = mean_squared_error(tmp_gt, tmp_img)
|
| 290 |
-
mae = mean_absolute_error(tmp_gt, tmp_img)
|
| 291 |
-
# a = cv2.bitwise_not(a) # save data
|
| 292 |
-
# cv2_ext.imwrite(os.path.join("tmp_res",img_name), a) # save data
|
| 293 |
-
|
| 294 |
-
psnr_list.append(psnr)
|
| 295 |
-
mse_list.append(mse)
|
| 296 |
-
mae_list.append(mae)
|
| 297 |
-
print(f"PSNR= {psnr} in {img_name}")
|
| 298 |
-
|
| 299 |
-
av_psnr =np.array(psnr_list).mean()
|
| 300 |
-
av_mse =np.array(mse_list).mean()
|
| 301 |
-
av_mae =np.array(mae_list).mean()
|
| 302 |
-
print(" MSE results: mean ", av_mse)
|
| 303 |
-
print(" MAE results: mean ", av_mae)
|
| 304 |
-
# print(mse_list)
|
| 305 |
-
print(" PSNR results: mean", av_psnr)
|
| 306 |
-
# print(psnr_list)
|
| 307 |
-
print('version: ',vers)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import kornia as kn
|
| 7 |
+
#import cv2_ext
|
| 8 |
+
|
| 9 |
+
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio
|
| 10 |
+
from sklearn.metrics import mean_absolute_error
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def image_normalization(img, img_min=0, img_max=255,
|
| 14 |
+
epsilon=1e-12):
|
| 15 |
+
"""This is a typical image normalization function
|
| 16 |
+
where the minimum and maximum of the image is needed
|
| 17 |
+
source: https://en.wikipedia.org/wiki/Normalization_(image_processing)
|
| 18 |
+
|
| 19 |
+
:param img: an image could be gray scale or color
|
| 20 |
+
:param img_min: for default is 0
|
| 21 |
+
:param img_max: for default is 255
|
| 22 |
+
|
| 23 |
+
:return: a normalized image, if max is 255 the dtype is uint8
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
img = np.float32(img)
|
| 27 |
+
# whenever an inconsistent image
|
| 28 |
+
img = (img - np.min(img)) * (img_max - img_min) / \
|
| 29 |
+
((np.max(img) - np.min(img)) + epsilon) + img_min
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
+
def count_parameters(model=None):
|
| 33 |
+
if model is not None:
|
| 34 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 35 |
+
else:
|
| 36 |
+
print("Error counting model parameters line 32 img_processing.py")
|
| 37 |
+
raise NotImplementedError
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def save_image_batch_to_disk(tensor, output_dir, file_names, img_shape=None, arg=None, is_inchannel=False):
|
| 41 |
+
|
| 42 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 43 |
+
predict_all = arg.predict_all
|
| 44 |
+
if not arg.is_testing:
|
| 45 |
+
assert len(tensor.shape) == 4, tensor.shape
|
| 46 |
+
img_height,img_width = img_shape[0].item(),img_shape[1].item()
|
| 47 |
+
|
| 48 |
+
for tensor_image, file_name in zip(tensor, file_names):
|
| 49 |
+
image_vis = kn.utils.tensor_to_image(
|
| 50 |
+
torch.sigmoid(tensor_image))#[..., 0]
|
| 51 |
+
image_vis = (255.0*(1.0 - image_vis)).astype(np.uint8)
|
| 52 |
+
output_file_name = os.path.join(output_dir, file_name)
|
| 53 |
+
# print('image vis size', image_vis.shape)
|
| 54 |
+
image_vis =cv2.resize(image_vis, (img_width, img_height))
|
| 55 |
+
assert cv2_ext.imwrite(output_file_name, image_vis)
|
| 56 |
+
assert cv2_ext.imwrite('checkpoints/current_res/'+file_name, image_vis)
|
| 57 |
+
# print(f"Image saved in {output_file_name}")
|
| 58 |
+
else:
|
| 59 |
+
if is_inchannel:
|
| 60 |
+
|
| 61 |
+
tensor, tensor2 = tensor
|
| 62 |
+
fuse_name = 'fusedCH'
|
| 63 |
+
av_name='avgCH'
|
| 64 |
+
is_2tensors=True
|
| 65 |
+
edge_maps2 = []
|
| 66 |
+
for i in tensor2:
|
| 67 |
+
tmp = torch.sigmoid(i).cpu().detach().numpy()
|
| 68 |
+
edge_maps2.append(tmp)
|
| 69 |
+
tensor2 = np.array(edge_maps2)
|
| 70 |
+
else:
|
| 71 |
+
fuse_name = 'fused'
|
| 72 |
+
# av_name = 'avg'
|
| 73 |
+
tensor2=None
|
| 74 |
+
tmp_img2 = None
|
| 75 |
+
|
| 76 |
+
# output_dir_f = os.path.join(output_dir, fuse_name)# normal execution
|
| 77 |
+
output_dir_f = output_dir# for DMRIR
|
| 78 |
+
# output_dir_a = os.path.join(output_dir, av_name)
|
| 79 |
+
os.makedirs(output_dir_f, exist_ok=True)
|
| 80 |
+
# os.makedirs(output_dir_a, exist_ok=True)
|
| 81 |
+
if predict_all:
|
| 82 |
+
all_data_dir = os.path.join(output_dir, "all_edges")
|
| 83 |
+
os.makedirs(all_data_dir, exist_ok=True)
|
| 84 |
+
out1_dir = os.path.join(all_data_dir,"o1")
|
| 85 |
+
out2_dir = os.path.join(all_data_dir,"o2")
|
| 86 |
+
out3_dir = os.path.join(all_data_dir,"o3")# TEED =output 3
|
| 87 |
+
out4_dir = os.path.join(all_data_dir,"o4") # TEED = average
|
| 88 |
+
out5_dir = os.path.join(all_data_dir,"o5")# fusion # TEED
|
| 89 |
+
out6_dir = os.path.join(all_data_dir,"o6") # fusion
|
| 90 |
+
os.makedirs(out1_dir, exist_ok=True)
|
| 91 |
+
os.makedirs(out2_dir, exist_ok=True)
|
| 92 |
+
os.makedirs(out3_dir, exist_ok=True)
|
| 93 |
+
os.makedirs(out4_dir, exist_ok=True)
|
| 94 |
+
os.makedirs(out5_dir, exist_ok=True)
|
| 95 |
+
os.makedirs(out6_dir, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
# 255.0 * (1.0 - em_a)
|
| 98 |
+
edge_maps = []
|
| 99 |
+
for i in tensor:
|
| 100 |
+
tmp = torch.sigmoid(i).cpu().detach().numpy()
|
| 101 |
+
edge_maps.append(tmp)
|
| 102 |
+
tensor = np.array(edge_maps)
|
| 103 |
+
# print(f"tensor shape: {tensor.shape}")
|
| 104 |
+
|
| 105 |
+
image_shape = [x.cpu().detach().numpy() for x in img_shape]
|
| 106 |
+
# (H, W) -> (W, H)
|
| 107 |
+
image_shape = [[y, x] for x, y in zip(image_shape[0], image_shape[1])]
|
| 108 |
+
|
| 109 |
+
assert len(image_shape) == len(file_names)
|
| 110 |
+
|
| 111 |
+
idx = 0
|
| 112 |
+
for i_shape, file_name in zip(image_shape, file_names):
|
| 113 |
+
tmp = tensor[:, idx, ...]
|
| 114 |
+
tmp2 = tensor2[:, idx, ...] if tensor2 is not None else None
|
| 115 |
+
# tmp = np.transpose(np.squeeze(tmp), [0, 1, 2])
|
| 116 |
+
tmp = np.squeeze(tmp)
|
| 117 |
+
tmp2 = np.squeeze(tmp2) if tensor2 is not None else None
|
| 118 |
+
|
| 119 |
+
# Iterate our all 7 NN outputs for a particular image
|
| 120 |
+
preds = []
|
| 121 |
+
fuse_num = tmp.shape[0]-1
|
| 122 |
+
for i in range(tmp.shape[0]):
|
| 123 |
+
tmp_img = tmp[i]
|
| 124 |
+
tmp_img = np.uint8(image_normalization(tmp_img))
|
| 125 |
+
tmp_img = cv2.bitwise_not(tmp_img)
|
| 126 |
+
# tmp_img[tmp_img < 0.0] = 0.0
|
| 127 |
+
# tmp_img = 255.0 * (1.0 - tmp_img)
|
| 128 |
+
if tmp2 is not None:
|
| 129 |
+
tmp_img2 = tmp2[i]
|
| 130 |
+
tmp_img2 = np.uint8(image_normalization(tmp_img2))
|
| 131 |
+
tmp_img2 = cv2.bitwise_not(tmp_img2)
|
| 132 |
+
|
| 133 |
+
# Resize prediction to match input image size
|
| 134 |
+
if not tmp_img.shape[1] == i_shape[0] or not tmp_img.shape[0] == i_shape[1]:
|
| 135 |
+
tmp_img = cv2.resize(tmp_img, (i_shape[0], i_shape[1]))
|
| 136 |
+
tmp_img2 = cv2.resize(tmp_img2, (i_shape[0], i_shape[1])) if tmp2 is not None else None
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
if tmp2 is not None:
|
| 140 |
+
tmp_mask = np.logical_and(tmp_img>128,tmp_img2<128)
|
| 141 |
+
tmp_img= np.where(tmp_mask, tmp_img2, tmp_img)
|
| 142 |
+
preds.append(tmp_img)
|
| 143 |
+
|
| 144 |
+
else:
|
| 145 |
+
preds.append(tmp_img)
|
| 146 |
+
|
| 147 |
+
if i == fuse_num:
|
| 148 |
+
# print('fuse num',tmp.shape[0], fuse_num, i)
|
| 149 |
+
fuse = tmp_img
|
| 150 |
+
fuse = fuse.astype(np.uint8)
|
| 151 |
+
if tmp_img2 is not None:
|
| 152 |
+
fuse2 = tmp_img2
|
| 153 |
+
fuse2 = fuse2.astype(np.uint8)
|
| 154 |
+
# fuse = fuse-fuse2
|
| 155 |
+
fuse_mask=np.logical_and(fuse>128,fuse2<128)
|
| 156 |
+
fuse = np.where(fuse_mask,fuse2, fuse)
|
| 157 |
+
|
| 158 |
+
# print(fuse.shape, fuse_mask.shape)
|
| 159 |
+
|
| 160 |
+
# Save predicted edge maps
|
| 161 |
+
average = np.array(preds, dtype=np.float32)
|
| 162 |
+
average = np.uint8(np.mean(average, axis=0))
|
| 163 |
+
output_file_name_f = os.path.join(output_dir_f, file_name)
|
| 164 |
+
# output_file_name_a = os.path.join(output_dir_a, file_name)
|
| 165 |
+
cv2.imwrite(output_file_name_f, fuse)
|
| 166 |
+
# cv2_ext.imwrite(output_file_name_a, average)
|
| 167 |
+
if predict_all:
|
| 168 |
+
cv2.imwrite(os.path.join(out1_dir,file_name),preds[0])
|
| 169 |
+
cv2.imwrite(os.path.join(out2_dir,file_name),preds[1])
|
| 170 |
+
cv2.imwrite(os.path.join(out3_dir,file_name),preds[2])
|
| 171 |
+
cv2.imwrite(os.path.join(out4_dir,file_name),average)
|
| 172 |
+
cv2.imwrite(os.path.join(out5_dir,file_name),fuse)
|
| 173 |
+
cv2.imwrite(os.path.join(out6_dir,file_name),fuse)
|
| 174 |
+
|
| 175 |
+
idx += 1
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def restore_rgb(config, I, restore_rgb=False):
|
| 179 |
+
"""
|
| 180 |
+
:param config: [args.channel_swap, args.mean_pixel_value]
|
| 181 |
+
:param I: and image or a set of images
|
| 182 |
+
:return: an image or a set of images restored
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
if len(I) > 3 and not type(I) == np.ndarray:
|
| 186 |
+
I = np.array(I)
|
| 187 |
+
I = I[:, :, :, 0:3]
|
| 188 |
+
n = I.shape[0]
|
| 189 |
+
for i in range(n):
|
| 190 |
+
x = I[i, ...]
|
| 191 |
+
x = np.array(x, dtype=np.float32)
|
| 192 |
+
x += config[1]
|
| 193 |
+
if restore_rgb:
|
| 194 |
+
x = x[:, :, config[0]]
|
| 195 |
+
x = image_normalization(x)
|
| 196 |
+
I[i, :, :, :] = x
|
| 197 |
+
elif len(I.shape) == 3 and I.shape[-1] == 3:
|
| 198 |
+
I = np.array(I, dtype=np.float32)
|
| 199 |
+
I += config[1]
|
| 200 |
+
if restore_rgb:
|
| 201 |
+
I = I[:, :, config[0]]
|
| 202 |
+
I = image_normalization(I)
|
| 203 |
+
else:
|
| 204 |
+
print("Sorry the input data size is out of our configuration")
|
| 205 |
+
return I
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def visualize_result(imgs_list, arg):
|
| 209 |
+
"""
|
| 210 |
+
data 2 image in one matrix
|
| 211 |
+
:param imgs_list: a list of prediction, gt and input data
|
| 212 |
+
:param arg:
|
| 213 |
+
:return: one image with the whole of imgs_list data
|
| 214 |
+
"""
|
| 215 |
+
n_imgs = len(imgs_list)
|
| 216 |
+
data_list = []
|
| 217 |
+
for i in range(n_imgs):
|
| 218 |
+
tmp = imgs_list[i]
|
| 219 |
+
# print(tmp.shape)
|
| 220 |
+
if tmp.shape[0] == 3:
|
| 221 |
+
tmp = np.transpose(tmp, [1, 2, 0])
|
| 222 |
+
tmp = restore_rgb([
|
| 223 |
+
arg.channel_swap,
|
| 224 |
+
arg.mean_train[:3]
|
| 225 |
+
], tmp)
|
| 226 |
+
tmp = np.uint8(image_normalization(tmp))
|
| 227 |
+
else:
|
| 228 |
+
tmp = np.squeeze(tmp)
|
| 229 |
+
if len(tmp.shape) == 2:
|
| 230 |
+
tmp = np.uint8(image_normalization(tmp))
|
| 231 |
+
tmp = cv2.bitwise_not(tmp)
|
| 232 |
+
tmp = cv2.cvtColor(tmp, cv2.COLOR_GRAY2BGR)
|
| 233 |
+
else:
|
| 234 |
+
tmp = np.uint8(image_normalization(tmp))
|
| 235 |
+
data_list.append(tmp)
|
| 236 |
+
# print(i,tmp.shape)
|
| 237 |
+
img = data_list[0]
|
| 238 |
+
if n_imgs % 2 == 0:
|
| 239 |
+
imgs = np.zeros((img.shape[0] * 2 + 10, img.shape[1]
|
| 240 |
+
* (n_imgs // 2) + ((n_imgs // 2 - 1) * 5), 3))
|
| 241 |
+
else:
|
| 242 |
+
imgs = np.zeros((img.shape[0] * 2 + 10, img.shape[1]
|
| 243 |
+
* ((1 + n_imgs) // 2) + ((n_imgs // 2) * 5), 3))
|
| 244 |
+
n_imgs += 1
|
| 245 |
+
|
| 246 |
+
k = 0
|
| 247 |
+
imgs = np.uint8(imgs)
|
| 248 |
+
i_step = img.shape[0] + 10
|
| 249 |
+
j_step = img.shape[1] + 5
|
| 250 |
+
for i in range(2):
|
| 251 |
+
for j in range(n_imgs // 2):
|
| 252 |
+
if k < len(data_list):
|
| 253 |
+
imgs[i * i_step:i * i_step+img.shape[0],
|
| 254 |
+
j * j_step:j * j_step+img.shape[1],
|
| 255 |
+
:] = data_list[k]
|
| 256 |
+
k += 1
|
| 257 |
+
else:
|
| 258 |
+
pass
|
| 259 |
+
return imgs
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == '__main__':
|
| 264 |
+
|
| 265 |
+
img_base_dir='tmp_edge'
|
| 266 |
+
gt_base_dir='C:/Users/xavysp/dataset/BIPED/edges/edge_maps/test/rgbr'
|
| 267 |
+
# gt_base_dir='C:/Users/xavysp/dataset/BRIND/test_edges'
|
| 268 |
+
# gt_base_dir='C:/Users/xavysp/dataset/UDED/gt'
|
| 269 |
+
vers = 'TEED model in BIPED'
|
| 270 |
+
list_img = os.listdir(img_base_dir)
|
| 271 |
+
list_gt = os.listdir(gt_base_dir)
|
| 272 |
+
mse_list=[]
|
| 273 |
+
psnr_list=[]
|
| 274 |
+
mae_list=[]
|
| 275 |
+
|
| 276 |
+
for img_name, gt_name in zip(list_img,list_gt):
|
| 277 |
+
|
| 278 |
+
# print(img_name, ' ', gt_name)
|
| 279 |
+
tmp_img = cv2.imread(os.path.join(img_base_dir,img_name),0)
|
| 280 |
+
tmp_img = cv2.bitwise_not(tmp_img) # if the image's background
|
| 281 |
+
# is white uncomment this line
|
| 282 |
+
tmp_gt = cv2.imread(os.path.join(gt_base_dir,gt_name),0)
|
| 283 |
+
# print(f"image {img_name} {tmp_img.shape}")
|
| 284 |
+
# print(f"gt {gt_name} {tmp_gt.shape}")
|
| 285 |
+
a = tmp_img.copy()
|
| 286 |
+
tmp_img = image_normalization(tmp_img, img_max=1.)
|
| 287 |
+
tmp_gt = image_normalization(tmp_gt, img_max=1.)
|
| 288 |
+
psnr = peak_signal_noise_ratio(tmp_gt, tmp_img)
|
| 289 |
+
mse = mean_squared_error(tmp_gt, tmp_img)
|
| 290 |
+
mae = mean_absolute_error(tmp_gt, tmp_img)
|
| 291 |
+
# a = cv2.bitwise_not(a) # save data
|
| 292 |
+
# cv2_ext.imwrite(os.path.join("tmp_res",img_name), a) # save data
|
| 293 |
+
|
| 294 |
+
psnr_list.append(psnr)
|
| 295 |
+
mse_list.append(mse)
|
| 296 |
+
mae_list.append(mae)
|
| 297 |
+
print(f"PSNR= {psnr} in {img_name}")
|
| 298 |
+
|
| 299 |
+
av_psnr =np.array(psnr_list).mean()
|
| 300 |
+
av_mse =np.array(mse_list).mean()
|
| 301 |
+
av_mae =np.array(mae_list).mean()
|
| 302 |
+
print(" MSE results: mean ", av_mse)
|
| 303 |
+
print(" MAE results: mean ", av_mae)
|
| 304 |
+
# print(mse_list)
|
| 305 |
+
print(" PSNR results: mean", av_psnr)
|
| 306 |
+
# print(psnr_list)
|
| 307 |
+
print('version: ',vers)
|