Object Detection
YOLOP / lib /utils /plot.py
Riser's picture
First model version
67bb36a
## 处理pred结果的.json文件,画图
import matplotlib.pyplot as plt
import cv2
import numpy as np
import random
def plot_img_and_mask(img, mask, index,epoch,save_dir):
classes = mask.shape[2] if len(mask.shape) > 2 else 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image')
ax[0].imshow(img)
if classes > 1:
for i in range(classes):
ax[i+1].set_title(f'Output mask (class {i+1})')
ax[i+1].imshow(mask[:, :, i])
else:
ax[1].set_title(f'Output mask')
ax[1].imshow(mask)
plt.xticks([]), plt.yticks([])
# plt.show()
plt.savefig(save_dir+"/batch_{}_{}_seg.png".format(epoch,index))
def show_seg_result(img, result, index, epoch, save_dir=None, is_ll=False,palette=None,is_demo=False,is_gt=False):
# img = mmcv.imread(img)
# img = img.copy()
# seg = result[0]
if palette is None:
palette = np.random.randint(
0, 255, size=(3, 3))
palette[0] = [0, 0, 0]
palette[1] = [0, 255, 0]
palette[2] = [255, 0, 0]
palette = np.array(palette)
assert palette.shape[0] == 3 # len(classes)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
if not is_demo:
color_seg = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[result == label, :] = color
else:
color_area = np.zeros((result[0].shape[0], result[0].shape[1], 3), dtype=np.uint8)
# for label, color in enumerate(palette):
# color_area[result[0] == label, :] = color
color_area[result[0] == 1] = [0, 255, 0]
color_area[result[1] ==1] = [255, 0, 0]
color_seg = color_area
# convert to BGR
color_seg = color_seg[..., ::-1]
# print(color_seg.shape)
color_mask = np.mean(color_seg, 2)
img[color_mask != 0] = img[color_mask != 0] * 0.5 + color_seg[color_mask != 0] * 0.5
# img = img * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
img = cv2.resize(img, (1280,720), interpolation=cv2.INTER_LINEAR)
if not is_demo:
if not is_gt:
if not is_ll:
cv2.imwrite(save_dir+"/batch_{}_{}_da_segresult.png".format(epoch,index), img)
else:
cv2.imwrite(save_dir+"/batch_{}_{}_ll_segresult.png".format(epoch,index), img)
else:
if not is_ll:
cv2.imwrite(save_dir+"/batch_{}_{}_da_seg_gt.png".format(epoch,index), img)
else:
cv2.imwrite(save_dir+"/batch_{}_{}_ll_seg_gt.png".format(epoch,index), img)
return img
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
# Plots one bounding box on image img
tl = line_thickness or round(0.0001 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
# if label:
# tf = max(tl - 1, 1) # font thickness
# t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
# c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
# cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
# print(label)
# cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
if __name__ == "__main__":
pass
# def plot():
# cudnn.benchmark = cfg.CUDNN.BENCHMARK
# torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
# torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
# device = select_device(logger, batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU) if not cfg.DEBUG \
# else select_device(logger, 'cpu')
# if args.local_rank != -1:
# assert torch.cuda.device_count() > args.local_rank
# torch.cuda.set_device(args.local_rank)
# device = torch.device('cuda', args.local_rank)
# dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
# model = get_net(cfg).to(device)
# model_file = '/home/zwt/DaChuang/weights/epoch--2.pth'
# checkpoint = torch.load(model_file)
# model.load_state_dict(checkpoint['state_dict'])
# if rank == -1 and torch.cuda.device_count() > 1:
# model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
# if rank != -1:
# model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)