KyanChen's picture
add model
ab01e4a
raw
history blame
23.7 kB
import yaml
import torch
import random
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from einops import repeat
import cv2
import time
import torch.nn.functional as F
__all__ = ["decode_mask_to_onehot",
"encode_onehot_to_mask",
'Logger',
'get_coords_grid',
'get_coords_grid_float',
'draw_bboxes',
'Infos',
'inv_normalize_img',
'make_numpy_img',
'get_metrics'
]
class Infos(object):
def __init__(self, phase, class_names=None):
assert phase in ['od'], "Error in Infos"
self.phase = phase
self.class_names = class_names
self.register()
self.pattern = 'train'
self.epoch_id = 0
self.max_epoch = 0
self.batch_id = 0
self.batch_num = 0
self.lr = 0
self.fps_data_load = 0
self.fps = 0
self.val_metric = 0
# 'running_acc': {'loss': [], 'mIoU': [], 'OA': [], 'F1_score': []},
# 'epoch_metrics': {'loss': 1e10, 'mIoU': 0, 'OA': 0, 'F1_score': 0},
# 'best_val_metrics': {'epoch_id': 0, 'loss': 1e10, 'mIoU': 0, 'OA': 0, 'F1_score': 0},
def set_epoch_training_time(self, data):
self.epoch_training_time = data
def set_pattern(self, data):
self.pattern = data
def set_epoch_id(self, data):
self.epoch_id = data
def set_max_epoch(self, data):
self.max_epoch = data
def set_batch_id(self, data):
self.batch_id = data
def set_batch_num(self, data):
self.batch_num = data
def set_lr(self, data):
self.lr = data
def set_fps_data_load(self, data):
self.fps_data_load = data
def set_fps(self, data):
self.fps = data
def clear_cache(self):
self.register()
def get_val_metric(self):
return self.val_metric
def cal_metrics(self):
if self.phase == 'od':
coco_api_gt = COCO()
coco_api_gt.dataset['images'] = []
coco_api_gt.dataset['annotations'] = []
ann_id = 0
for i, targets_per_image in enumerate(self.result_all['target_all']):
for j in range(targets_per_image.shape[0]):
coco_api_gt.dataset['images'].append({'id': i})
coco_api_gt.dataset['annotations'].append({
'image_id': i,
"category_id": int(targets_per_image[j, 0]),
"bbox": np.hstack([targets_per_image[j, 1:3], targets_per_image[j, 3:5] - targets_per_image[j, 1:3]]),
"area": np.prod(targets_per_image[j, 3:5] - targets_per_image[j, 1:3]),
"id": ann_id,
"iscrowd": 0
})
ann_id += 1
coco_api_gt.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in
enumerate(self.class_names)]
coco_api_gt.createIndex()
coco_api_pred = COCO()
coco_api_pred.dataset['images'] = []
coco_api_pred.dataset['annotations'] = []
ann_id = 0
for i, preds_per_image in enumerate(self.result_all['pred_all']):
for j in range(preds_per_image.shape[0]):
coco_api_pred.dataset['images'].append({'id': i})
coco_api_pred.dataset['annotations'].append({
'image_id': i,
"category_id": int(preds_per_image[j, 0]),
'score': preds_per_image[j, 1],
"bbox": np.hstack(
[preds_per_image[j, 2:4], preds_per_image[j, 4:6] - preds_per_image[j, 2:4]]),
"area": np.prod(preds_per_image[j, 4:6] - preds_per_image[j, 2:4]),
"id": ann_id,
"iscrowd": 0
})
ann_id += 1
coco_api_pred.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in
enumerate(self.class_names)]
coco_api_pred.createIndex()
coco_eval = COCOeval(coco_api_gt, coco_api_pred, "bbox")
coco_eval.params.imgIds = coco_api_gt.getImgIds()
coco_eval.evaluate()
coco_eval.accumulate()
self.metrics = coco_eval.summarize()
self.val_metric = self.metrics[1]
def print_epoch_state_infos(self, logger):
infos_str = 'Pattern: %s Epoch [%d,%d], time: %d loss: %.4f' % \
(self.pattern, self.epoch_id, self.max_epoch, self.epoch_training_time, np.mean(self.loss_all['loss']))
logger.write(infos_str + '\n')
time_start = time.time()
self.cal_metrics()
time_end = time.time()
logger.write('Pattern: %s Epoch Eval_time: %d\n' % (self.pattern, (time_end - time_start)))
if self.phase == 'od':
titleStr = 6 * ['Average Precision'] + 6 * ['Average Recall']
typeStr = 6 * ['(AP)'] + 6 * ['(AR)']
iouStr = 12 * ['0.50:0.95']
iouStr[1] = '0.50'
iouStr[2] = '0.75'
areaRng = 3 * ['all'] + ['small', 'medium', 'large'] + 3 * ['all'] + ['small', 'medium', 'large']
maxDets = 6 * [100] + [1, 10, 100] + 3 * [100]
for i in range(12):
infos_str = '{:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}\n'
logger.write(infos_str.format(titleStr[i], typeStr[i], iouStr[i], areaRng[i], maxDets[i], self.metrics[i]))
def save_epoch_state_infos(self, writer):
iter = self.epoch_id
keys = [
'AP_m_all_100',
'AP_50_all_100',
'AP_75_all_100',
'AP_m_small_100',
'AP_m_medium_100',
'AP_m_large_100',
'AR_m_all_1',
'AR_m_all_10',
'AR_m_all_100',
'AR_m_small_100',
'AR_m_medium_100',
'AR_m_large_100',
]
for i, key in enumerate(keys):
writer.add_scalar(f'%s/epoch/%s' % (self.pattern, key), self.metrics[i], iter)
def print_batch_state_infos(self, logger):
infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % \
(self.pattern, self.epoch_id, self.max_epoch, self.batch_id,
self.batch_num, self.lr, self.fps_data_load, self.fps)
# add loss
infos_str += ', loss: %.4f' % self.loss_all['loss'][-1]
logger.write(infos_str + '\n')
def save_batch_state_infos(self, writer):
iter = self.epoch_id * self.batch_num + self.batch_id
writer.add_scalar('%s/lr' % self.pattern, self.lr, iter)
for key, value in self.loss_all.items():
writer.add_scalar(f'%s/%s' % (self.pattern, key), value[-1], iter)
def save_results(self, img_batch, prior_mean, prior_std, vis_dir, *args, **kwargs):
batch_size = img_batch.size(0)
k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size)
ids = np.random.choice(range(batch_size), k, replace=False)
for img_id in ids:
img = img_batch[img_id].detach().cpu()
pred = self.result_all['pred_all'][img_id - batch_size]
target = self.result_all['target_all'][img_id - batch_size]
img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std))
pred_draw = draw_bboxes(img, pred, self.class_names, (255, 0, 0))
target_draw = draw_bboxes(img, target, self.class_names, (0, 255, 0))
# target = make_numpy_img(encode_onehot_to_mask(target))
# pred = make_numpy_img(pred_label[img_id])
vis = np.concatenate([img/255., pred_draw/255., target_draw/255.], axis=0)
vis = np.clip(vis, a_min=0, a_max=1)
file_name = os.path.join(vis_dir, self.pattern, f'{self.epoch_id}_{self.batch_id}_{img_id}.png')
plt.imsave(file_name, vis)
def register(self):
self.is_registered_result = False
self.result_all = {}
self.is_registered_loss = False
self.loss_all = {}
def register_result(self, data: dict):
for key in data.keys():
self.result_all[key] = []
self.is_registered_result = True
def append_result(self, data: dict):
if not self.is_registered_result:
self.register_result(data)
for key, value in data.items():
self.result_all[key] += value
def register_loss(self, data: dict):
for key in data.keys():
self.loss_all[key] = []
self.is_registered_loss = True
def append_loss(self, data: dict):
if not self.is_registered_loss:
self.register_loss(data)
for key, value in data.items():
self.loss_all[key].append(value.detach().cpu().numpy())
# draw bboxes on image, bboxes with classID
def draw_bboxes(img, bboxes, color=(255, 0, 0), class_names=None, is_show_score=True):
'''
Args:
img:
bboxes: [n, 5], class_idx, l, t, r, b
[n, 6], class_idx, score, l, t, r, b
Returns:
'''
assert img is not None, "In draw_bboxes, img is None"
if torch.is_tensor(img):
img = img.cpu().numpy()
img = img.astype(np.uint8).copy()
if torch.is_tensor(bboxes):
bboxes = bboxes.cpu().numpy()
for bbox in bboxes:
if class_names:
class_name = class_names[int(bbox[0])]
bbox_coordinate = bbox[1:]
if len(bbox) == 6:
score = bbox[1]
bbox_coordinate = bbox[2:]
bbox_coordinate = bbox_coordinate.astype(np.int)
if is_show_score:
cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2] - np.array([2, 15])),
pt2=tuple(bbox_coordinate[0:2] + np.array([15, 1])), color=(0, 0, 255), thickness=-1)
if len(bbox) == 6:
cv2.putText(img, text='%s:%.2f' % (class_name, score),
org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.2, color=(255, 255, 255), thickness=1)
else:
cv2.putText(img, text='%s' % class_name,
org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.2, color=(255, 255, 255), thickness=1)
cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2]), pt2=tuple(bbox_coordinate[2:4]), color=color, thickness=2)
return img
def get_coords_grid(h_end, w_end, h_start=0, w_start=0, h_steps=None, w_steps=None, is_normalize=False):
if h_steps is None:
h_steps = int(h_end - h_start) + 1
if w_steps is None:
w_steps = int(w_end - w_start) + 1
y = torch.linspace(h_start, h_end, h_steps)
x = torch.linspace(w_start, w_end, w_steps)
if is_normalize:
y = y / h_end
x = x / w_end
coords = torch.meshgrid(y, x)
coords = torch.stack(coords[::-1], dim=0)
return coords
def get_coords_grid_float(ht, wd, scale, is_normalize=False):
y = torch.linspace(0, scale, ht + 2)
x = torch.linspace(0, scale, wd + 2)
if is_normalize:
y = y/scale
x = x/scale
coords = torch.meshgrid(y[1:-1], x[1:-1])
coords = torch.stack(coords[::-1], dim=0)
return coords
def get_coords_vector_float(len, scale, is_normalize=False):
x = torch.linspace(0, scale, len+2)
if is_normalize:
x = x/scale
coords = torch.meshgrid(x[1:-1], torch.tensor([0.]))
coords = torch.stack(coords[::-1], dim=0)
return coords
class Logger(object):
def __init__(self, filename="Default.log", is_terminal_show=True):
self.is_terminal_show = is_terminal_show
if self.is_terminal_show:
self.terminal = sys.stdout
self.log = open(filename, "a")
def write(self, message):
if self.is_terminal_show:
self.terminal.write(message)
self.log.write(message)
self.flush()
def flush(self):
if self.is_terminal_show:
self.terminal.flush()
self.log.flush()
class ParamsParser:
def __init__(self, project_file):
self.params = yaml.safe_load(open(project_file).read())
def __getattr__(self, item):
return self.params.get(item, None)
def get_all_dict(dict_infos: dict) -> dict:
return_dict = {}
for key, value in dict_infos.items():
if not isinstance(value, dict):
return_dict[key] = value
else:
return_dict = dict(return_dict.items(), **get_all_dict(value))
return return_dict
def make_numpy_img(tensor_data):
if len(tensor_data.shape) == 2:
tensor_data = tensor_data.unsqueeze(2)
tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2)
elif tensor_data.size(0) == 1:
tensor_data = tensor_data.permute((1, 2, 0))
tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2)
elif tensor_data.size(0) == 3:
tensor_data = tensor_data.permute((1, 2, 0))
elif tensor_data.size(2) == 3:
pass
else:
raise Exception('tensor_data apply to make_numpy_img error')
vis_img = tensor_data.detach().cpu().numpy()
return vis_img
def print_infos(logger, writer, infos: dict):
keys = list(infos.keys())
values = list(infos.values())
infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % tuple(values[:8])
if len(values) > 8:
extra_infos = [f', {x}: {y:.4f}' for x, y in zip(keys[8:], values[8:])]
infos_str = infos_str + ''.join(extra_infos)
logger.write(infos_str + '\n')
writer.add_scalar('%s/lr' % infos['pattern'], infos['lr'],
infos['epoch_id'] * infos['batch_num'] + infos['batch_id'])
for key, value in zip(keys[8:], values[8:]):
writer.add_scalar(f'%s/%s' % (infos['pattern'], key), value,
infos['epoch_id'] * infos['batch_num'] + infos['batch_id'])
def invert_affine(origin_imgs, preds, pattern='train'):
if pattern == 'val':
for i in range(len(preds)):
if len(preds[i]['rois']) == 0:
continue
else:
old_h, old_w, _ = origin_imgs[i].shape
preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (512 / old_w)
preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (512 / old_h)
return preds
def save_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id):
flows, pf1s, pf2s = output
k = np.clip(int(0.2 * len(flows[0])), a_min=2, a_max=len(flows[0]))
ids = np.random.choice(range(len(flows[0])), k, replace=False)
for img_id in ids:
img1, img2 = input['ori_img1'][img_id:img_id+1].to(flows[0].device), input['ori_img2'][img_id:img_id+1].to(flows[0].device)
# call the network with image pair batches and actions
flow = flows[0][img_id:img_id+1]
warps = flow_to_warp(flow)
warped_img2 = resample(img2, warps)
ori_img1 = make_numpy_img(img1[0]) / 255.
ori_img2 = make_numpy_img(img2[0]) / 255.
warped_img2 = make_numpy_img(warped_img2[0]) / 255.
flow_amplitude = torch.sqrt(flow[0, 0:1, ...] ** 2 + flow[0, 1:2, ...] ** 2)
flow_amplitude = make_numpy_img(flow_amplitude)
flow_amplitude = (flow_amplitude - np.min(flow_amplitude)) / (np.max(flow_amplitude) - np.min(flow_amplitude) + 1e-10)
u = make_numpy_img(flow[0, 0:1, ...])
v = make_numpy_img(flow[0, 1:2, ...])
vis = np.concatenate([ori_img1, ori_img2, warped_img2, flow_amplitude], axis=0)
vis = np.clip(vis, a_min=0, a_max=1)
file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg')
plt.imsave(file_name, vis)
def inv_normalize_img(img, prior_mean=[0, 0, 0], prior_std=[1, 1, 1]):
prior_mean = torch.tensor(prior_mean, dtype=torch.float).to(img.device).view(img.size(0), 1, 1)
prior_std = torch.tensor(prior_std, dtype=torch.float).to(img.device).view(img.size(0), 1, 1)
img = img * prior_std + prior_mean
img = img * 255.
img = torch.clamp(img, min=0, max=255)
return img
def save_seg_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id, prior_mean, prior_std):
pred_label = torch.argmax(output, 1)
k = np.clip(int(0.2 * len(pred_label)), a_min=1, a_max=len(pred_label[0]))
ids = np.random.choice(range(len(pred_label)), k, replace=False)
for img_id in ids:
img = input['img'][img_id].to(pred_label.device)
target = input['label'][img_id].to(pred_label.device)
img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) / 255.
target = make_numpy_img(encode_onehot_to_mask(target))
pred = make_numpy_img(pred_label[img_id])
vis = np.concatenate([img, pred, target], axis=0)
vis = np.clip(vis, a_min=0, a_max=1)
file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg')
plt.imsave(file_name, vis)
def set_requires_grad(nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def boolean_string(s):
if s not in {'False', 'True'}:
raise ValueError('Not a valid boolean string')
return s == 'True'
def cpt_pxl_cls_acc(pred_idx, target):
pred_idx = torch.reshape(pred_idx, [-1])
target = torch.reshape(target, [-1])
return torch.mean((pred_idx.int() == target.int()).float())
def cpt_batch_psnr(img, img_gt, PIXEL_MAX):
mse = torch.mean((img - img_gt) ** 2, dim=[1, 2, 3])
psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse))
return torch.mean(psnr)
def cpt_psnr(img, img_gt, PIXEL_MAX):
mse = np.mean((img - img_gt) ** 2)
psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse))
return psnr
def cpt_rgb_ssim(img, img_gt):
img = clip_01(img)
img_gt = clip_01(img_gt)
SSIM = 0
for i in range(3):
tmp = img[:, :, i]
tmp_gt = img_gt[:, :, i]
ssim = sk_cpt_ssim(tmp, tmp_gt)
SSIM = SSIM + ssim
return SSIM / 3.0
def cpt_ssim(img, img_gt):
img = clip_01(img)
img_gt = clip_01(img_gt)
return sk_cpt_ssim(img, img_gt)
def decode_mask_to_onehot(mask, n_class):
'''
mask : BxWxH or WxH
n_class : n
return : BxnxWxH or nxWxH
'''
assert len(mask.shape) in [2, 3], "decode_mask_to_onehot error!"
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
onehot = torch.zeros((mask.size(0), n_class, mask.size(1), mask.size(2))).to(mask.device)
for i in range(n_class):
onehot[:, i, ...] = mask == i
if len(mask.shape) == 2:
onehot = onehot.squeeze(0)
return onehot
def encode_onehot_to_mask(onehot):
'''
onehot: tensor, BxnxWxH or nxWxH
output: tensor, BxWxH or WxH
'''
assert len(onehot.shape) in [3, 4], "encode_onehot_to_mask error!"
mask = torch.argmax(onehot, dim=len(onehot.shape)-3)
return mask
def decode(pred, target=None, *args, **kwargs):
"""
Args:
phase: 'od'
pred: big_cls_1(0), big_reg_1, small_cls_1(2), small_reg_1, big_cls_2(4), big_reg_2, small_cls_2(6), small_reg_2
target: [[n,5], [n,5]] list of tensor
Returns:
"""
phase = kwargs['phase']
img_size = kwargs['img_size']
if phase == 'od':
prior_box_wh = kwargs['prior_box_wh']
conf_thres = kwargs['conf_thres']
iou_thres = kwargs['iou_thres']
conf_type = kwargs['conf_type']
pred_conf_32_2 = F.softmax(pred[4], dim=1)[:, 1, ...] # B H W
pred_conf_64_2 = F.softmax(pred[6], dim=1)[:, 1, ...] # B H W
obj_mask_32_2 = pred_conf_32_2 > conf_thres # B H W
obj_mask_64_2 = pred_conf_64_2 > conf_thres # B H W
pre_loc_32_2 = pred[1] + pred[5] # B 4 H W
pre_loc_32_2[:, 0::2, ...] *= prior_box_wh[0]
pre_loc_32_2[:, 1::2, ...] *= prior_box_wh[1]
x_y_grid = get_coords_grid(31, 31, 0, 0)
x_y_grid *= 8
x_y_grid = torch.cat([x_y_grid, x_y_grid], dim=0)
pre_loc_32_2 += x_y_grid.to(pre_loc_32_2.device)
pre_loc_64_2 = pred[3] + pred[7] # B 4 H W
pre_loc_64_2[:, 0::2, ...] *= prior_box_wh[0]
pre_loc_64_2[:, 1::2, ...] *= prior_box_wh[1]
x_y_grid_2 = get_coords_grid(63, 63, 0, 0)
x_y_grid_2 *= 4
x_y_grid_2 = torch.cat([x_y_grid_2, x_y_grid_2], dim=0)
pre_loc_64_2 += x_y_grid_2.to(pre_loc_32_2.device)
pred_all = []
for i in range(pre_loc_32_2.size(0)):
score_32 = pred_conf_32_2[i][obj_mask_32_2[i]] # N
score_64 = pred_conf_64_2[i][obj_mask_64_2[i]] # M
loc_32 = pre_loc_32_2[i].permute((1, 2, 0))[obj_mask_32_2[i]] # Nx4
loc_64 = pre_loc_64_2[i].permute((1, 2, 0))[obj_mask_64_2[i]] # Mx4
score_list = torch.cat((score_32, score_64), dim=0).detach().cpu().numpy()
boxes_list = torch.cat((loc_32, loc_64), dim=0).detach().cpu().numpy()
boxes_list[:, 0::2] /= img_size[0]
boxes_list[:, 1::2] /= img_size[1]
label_list = np.ones_like(score_list)
# 目标预设150
boxes_list = boxes_list[:150, :]
score_list = score_list[:150]
label_list = label_list[:150]
boxes, scores, labels = weighted_boxes_fusion([boxes_list], [score_list], [label_list], weights=None,
iou_thr=iou_thres, conf_type=conf_type)
boxes[:, 0::2] *= img_size[0]
boxes[:, 1::2] *= img_size[1]
pred_boxes = np.concatenate((labels.reshape(-1, 1), scores.reshape(-1, 1), boxes), axis=1)
pred_all.append(pred_boxes)
if target is not None:
target_all = [x.cpu().numpy() for x in target]
else:
target_all = None
return {"pred_all": pred_all, "target_all": target_all}
def get_metrics(phase, pred, target):
'''
pred: logits, tensor, nBatch*nClass*W*H
target: labels, tensor, nBatch*nClass*W*H
'''
if phase == 'seg':
pred = torch.argmax(pred.detach(), dim=1)
pred = decode_mask_to_onehot(pred, target.size(1))
# positive samples in ground truth
gt_pos_sum = torch.sum(target == 1, dim=(0, 2, 3))
# positive prediction in predict mask
pred_pos_sum = torch.sum(pred == 1, dim=(0, 2, 3))
# cal true positive sample
true_pos_sum = torch.sum((target == 1) * (pred == 1), dim=(0, 2, 3))
# Precision
precision = true_pos_sum / (pred_pos_sum + 1e-15)
# Recall
recall = true_pos_sum / (gt_pos_sum + 1e-15)
# IoU
IoU = true_pos_sum / (pred_pos_sum + gt_pos_sum - true_pos_sum + 1e-15)
# OA
OA = 1 - (pred_pos_sum + gt_pos_sum - 2 * true_pos_sum) / torch.sum(target >= 0, dim=(0, 2, 3))
# F1-score
F1_score = 2 * precision * recall / (precision + recall + 1e-15)
return IoU, OA, F1_score