|
|
|
import matplotlib.pyplot as plt
|
|
import cv2
|
|
import numpy as np
|
|
import random
|
|
|
|
|
|
def plot_img_and_mask(img, mask, index,epoch,save_dir):
|
|
classes = mask.shape[2] if len(mask.shape) > 2 else 1
|
|
fig, ax = plt.subplots(1, classes + 1)
|
|
ax[0].set_title('Input image')
|
|
ax[0].imshow(img)
|
|
if classes > 1:
|
|
for i in range(classes):
|
|
ax[i+1].set_title(f'Output mask (class {i+1})')
|
|
ax[i+1].imshow(mask[:, :, i])
|
|
else:
|
|
ax[1].set_title(f'Output mask')
|
|
ax[1].imshow(mask)
|
|
plt.xticks([]), plt.yticks([])
|
|
|
|
plt.savefig(save_dir+"/batch_{}_{}_seg.png".format(epoch,index))
|
|
|
|
def show_seg_result(img, result, index, epoch, save_dir=None, is_ll=False,palette=None,is_demo=False,is_gt=False):
|
|
|
|
|
|
|
|
if palette is None:
|
|
palette = np.random.randint(
|
|
0, 255, size=(3, 3))
|
|
palette[0] = [0, 0, 0]
|
|
palette[1] = [0, 255, 0]
|
|
palette[2] = [255, 0, 0]
|
|
palette = np.array(palette)
|
|
assert palette.shape[0] == 3
|
|
assert palette.shape[1] == 3
|
|
assert len(palette.shape) == 2
|
|
|
|
if not is_demo:
|
|
color_seg = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8)
|
|
for label, color in enumerate(palette):
|
|
color_seg[result == label, :] = color
|
|
else:
|
|
color_area = np.zeros((result[0].shape[0], result[0].shape[1], 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
color_area[result[0] == 1] = [0, 255, 0]
|
|
color_area[result[1] ==1] = [255, 0, 0]
|
|
color_seg = color_area
|
|
|
|
|
|
color_seg = color_seg[..., ::-1]
|
|
|
|
color_mask = np.mean(color_seg, 2)
|
|
img[color_mask != 0] = img[color_mask != 0] * 0.5 + color_seg[color_mask != 0] * 0.5
|
|
|
|
img = img.astype(np.uint8)
|
|
img = cv2.resize(img, (1280,720), interpolation=cv2.INTER_LINEAR)
|
|
|
|
if not is_demo:
|
|
if not is_gt:
|
|
if not is_ll:
|
|
cv2.imwrite(save_dir+"/batch_{}_{}_da_segresult.png".format(epoch,index), img)
|
|
else:
|
|
cv2.imwrite(save_dir+"/batch_{}_{}_ll_segresult.png".format(epoch,index), img)
|
|
else:
|
|
if not is_ll:
|
|
cv2.imwrite(save_dir+"/batch_{}_{}_da_seg_gt.png".format(epoch,index), img)
|
|
else:
|
|
cv2.imwrite(save_dir+"/batch_{}_{}_ll_seg_gt.png".format(epoch,index), img)
|
|
return img
|
|
|
|
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
|
|
|
|
tl = line_thickness or round(0.0001 * (img.shape[0] + img.shape[1]) / 2) + 1
|
|
color = color or [random.randint(0, 255) for _ in range(3)]
|
|
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
|
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|