Spaces:
Runtime error
Runtime error
import yaml | |
import torch | |
import random | |
import numpy as np | |
import os | |
import sys | |
import matplotlib.pyplot as plt | |
from einops import repeat | |
import cv2 | |
import time | |
import torch.nn.functional as F | |
__all__ = ["decode_mask_to_onehot", | |
"encode_onehot_to_mask", | |
'Logger', | |
'get_coords_grid', | |
'get_coords_grid_float', | |
'draw_bboxes', | |
'Infos', | |
'inv_normalize_img', | |
'make_numpy_img', | |
'get_metrics' | |
] | |
class Infos(object): | |
def __init__(self, phase, class_names=None): | |
assert phase in ['od'], "Error in Infos" | |
self.phase = phase | |
self.class_names = class_names | |
self.register() | |
self.pattern = 'train' | |
self.epoch_id = 0 | |
self.max_epoch = 0 | |
self.batch_id = 0 | |
self.batch_num = 0 | |
self.lr = 0 | |
self.fps_data_load = 0 | |
self.fps = 0 | |
self.val_metric = 0 | |
# 'running_acc': {'loss': [], 'mIoU': [], 'OA': [], 'F1_score': []}, | |
# 'epoch_metrics': {'loss': 1e10, 'mIoU': 0, 'OA': 0, 'F1_score': 0}, | |
# 'best_val_metrics': {'epoch_id': 0, 'loss': 1e10, 'mIoU': 0, 'OA': 0, 'F1_score': 0}, | |
def set_epoch_training_time(self, data): | |
self.epoch_training_time = data | |
def set_pattern(self, data): | |
self.pattern = data | |
def set_epoch_id(self, data): | |
self.epoch_id = data | |
def set_max_epoch(self, data): | |
self.max_epoch = data | |
def set_batch_id(self, data): | |
self.batch_id = data | |
def set_batch_num(self, data): | |
self.batch_num = data | |
def set_lr(self, data): | |
self.lr = data | |
def set_fps_data_load(self, data): | |
self.fps_data_load = data | |
def set_fps(self, data): | |
self.fps = data | |
def clear_cache(self): | |
self.register() | |
def get_val_metric(self): | |
return self.val_metric | |
def cal_metrics(self): | |
if self.phase == 'od': | |
coco_api_gt = COCO() | |
coco_api_gt.dataset['images'] = [] | |
coco_api_gt.dataset['annotations'] = [] | |
ann_id = 0 | |
for i, targets_per_image in enumerate(self.result_all['target_all']): | |
for j in range(targets_per_image.shape[0]): | |
coco_api_gt.dataset['images'].append({'id': i}) | |
coco_api_gt.dataset['annotations'].append({ | |
'image_id': i, | |
"category_id": int(targets_per_image[j, 0]), | |
"bbox": np.hstack([targets_per_image[j, 1:3], targets_per_image[j, 3:5] - targets_per_image[j, 1:3]]), | |
"area": np.prod(targets_per_image[j, 3:5] - targets_per_image[j, 1:3]), | |
"id": ann_id, | |
"iscrowd": 0 | |
}) | |
ann_id += 1 | |
coco_api_gt.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in | |
enumerate(self.class_names)] | |
coco_api_gt.createIndex() | |
coco_api_pred = COCO() | |
coco_api_pred.dataset['images'] = [] | |
coco_api_pred.dataset['annotations'] = [] | |
ann_id = 0 | |
for i, preds_per_image in enumerate(self.result_all['pred_all']): | |
for j in range(preds_per_image.shape[0]): | |
coco_api_pred.dataset['images'].append({'id': i}) | |
coco_api_pred.dataset['annotations'].append({ | |
'image_id': i, | |
"category_id": int(preds_per_image[j, 0]), | |
'score': preds_per_image[j, 1], | |
"bbox": np.hstack( | |
[preds_per_image[j, 2:4], preds_per_image[j, 4:6] - preds_per_image[j, 2:4]]), | |
"area": np.prod(preds_per_image[j, 4:6] - preds_per_image[j, 2:4]), | |
"id": ann_id, | |
"iscrowd": 0 | |
}) | |
ann_id += 1 | |
coco_api_pred.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in | |
enumerate(self.class_names)] | |
coco_api_pred.createIndex() | |
coco_eval = COCOeval(coco_api_gt, coco_api_pred, "bbox") | |
coco_eval.params.imgIds = coco_api_gt.getImgIds() | |
coco_eval.evaluate() | |
coco_eval.accumulate() | |
self.metrics = coco_eval.summarize() | |
self.val_metric = self.metrics[1] | |
def print_epoch_state_infos(self, logger): | |
infos_str = 'Pattern: %s Epoch [%d,%d], time: %d loss: %.4f' % \ | |
(self.pattern, self.epoch_id, self.max_epoch, self.epoch_training_time, np.mean(self.loss_all['loss'])) | |
logger.write(infos_str + '\n') | |
time_start = time.time() | |
self.cal_metrics() | |
time_end = time.time() | |
logger.write('Pattern: %s Epoch Eval_time: %d\n' % (self.pattern, (time_end - time_start))) | |
if self.phase == 'od': | |
titleStr = 6 * ['Average Precision'] + 6 * ['Average Recall'] | |
typeStr = 6 * ['(AP)'] + 6 * ['(AR)'] | |
iouStr = 12 * ['0.50:0.95'] | |
iouStr[1] = '0.50' | |
iouStr[2] = '0.75' | |
areaRng = 3 * ['all'] + ['small', 'medium', 'large'] + 3 * ['all'] + ['small', 'medium', 'large'] | |
maxDets = 6 * [100] + [1, 10, 100] + 3 * [100] | |
for i in range(12): | |
infos_str = '{:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}\n' | |
logger.write(infos_str.format(titleStr[i], typeStr[i], iouStr[i], areaRng[i], maxDets[i], self.metrics[i])) | |
def save_epoch_state_infos(self, writer): | |
iter = self.epoch_id | |
keys = [ | |
'AP_m_all_100', | |
'AP_50_all_100', | |
'AP_75_all_100', | |
'AP_m_small_100', | |
'AP_m_medium_100', | |
'AP_m_large_100', | |
'AR_m_all_1', | |
'AR_m_all_10', | |
'AR_m_all_100', | |
'AR_m_small_100', | |
'AR_m_medium_100', | |
'AR_m_large_100', | |
] | |
for i, key in enumerate(keys): | |
writer.add_scalar(f'%s/epoch/%s' % (self.pattern, key), self.metrics[i], iter) | |
def print_batch_state_infos(self, logger): | |
infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % \ | |
(self.pattern, self.epoch_id, self.max_epoch, self.batch_id, | |
self.batch_num, self.lr, self.fps_data_load, self.fps) | |
# add loss | |
infos_str += ', loss: %.4f' % self.loss_all['loss'][-1] | |
logger.write(infos_str + '\n') | |
def save_batch_state_infos(self, writer): | |
iter = self.epoch_id * self.batch_num + self.batch_id | |
writer.add_scalar('%s/lr' % self.pattern, self.lr, iter) | |
for key, value in self.loss_all.items(): | |
writer.add_scalar(f'%s/%s' % (self.pattern, key), value[-1], iter) | |
def save_results(self, img_batch, prior_mean, prior_std, vis_dir, *args, **kwargs): | |
batch_size = img_batch.size(0) | |
k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size) | |
ids = np.random.choice(range(batch_size), k, replace=False) | |
for img_id in ids: | |
img = img_batch[img_id].detach().cpu() | |
pred = self.result_all['pred_all'][img_id - batch_size] | |
target = self.result_all['target_all'][img_id - batch_size] | |
img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) | |
pred_draw = draw_bboxes(img, pred, self.class_names, (255, 0, 0)) | |
target_draw = draw_bboxes(img, target, self.class_names, (0, 255, 0)) | |
# target = make_numpy_img(encode_onehot_to_mask(target)) | |
# pred = make_numpy_img(pred_label[img_id]) | |
vis = np.concatenate([img/255., pred_draw/255., target_draw/255.], axis=0) | |
vis = np.clip(vis, a_min=0, a_max=1) | |
file_name = os.path.join(vis_dir, self.pattern, f'{self.epoch_id}_{self.batch_id}_{img_id}.png') | |
plt.imsave(file_name, vis) | |
def register(self): | |
self.is_registered_result = False | |
self.result_all = {} | |
self.is_registered_loss = False | |
self.loss_all = {} | |
def register_result(self, data: dict): | |
for key in data.keys(): | |
self.result_all[key] = [] | |
self.is_registered_result = True | |
def append_result(self, data: dict): | |
if not self.is_registered_result: | |
self.register_result(data) | |
for key, value in data.items(): | |
self.result_all[key] += value | |
def register_loss(self, data: dict): | |
for key in data.keys(): | |
self.loss_all[key] = [] | |
self.is_registered_loss = True | |
def append_loss(self, data: dict): | |
if not self.is_registered_loss: | |
self.register_loss(data) | |
for key, value in data.items(): | |
self.loss_all[key].append(value.detach().cpu().numpy()) | |
# draw bboxes on image, bboxes with classID | |
def draw_bboxes(img, bboxes, color=(255, 0, 0), class_names=None, is_show_score=True): | |
''' | |
Args: | |
img: | |
bboxes: [n, 5], class_idx, l, t, r, b | |
[n, 6], class_idx, score, l, t, r, b | |
Returns: | |
''' | |
assert img is not None, "In draw_bboxes, img is None" | |
if torch.is_tensor(img): | |
img = img.cpu().numpy() | |
img = img.astype(np.uint8).copy() | |
if torch.is_tensor(bboxes): | |
bboxes = bboxes.cpu().numpy() | |
for bbox in bboxes: | |
if class_names: | |
class_name = class_names[int(bbox[0])] | |
bbox_coordinate = bbox[1:] | |
if len(bbox) == 6: | |
score = bbox[1] | |
bbox_coordinate = bbox[2:] | |
bbox_coordinate = bbox_coordinate.astype(np.int) | |
if is_show_score: | |
cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2] - np.array([2, 15])), | |
pt2=tuple(bbox_coordinate[0:2] + np.array([15, 1])), color=(0, 0, 255), thickness=-1) | |
if len(bbox) == 6: | |
cv2.putText(img, text='%s:%.2f' % (class_name, score), | |
org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
fontScale=0.2, color=(255, 255, 255), thickness=1) | |
else: | |
cv2.putText(img, text='%s' % class_name, | |
org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
fontScale=0.2, color=(255, 255, 255), thickness=1) | |
cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2]), pt2=tuple(bbox_coordinate[2:4]), color=color, thickness=2) | |
return img | |
def get_coords_grid(h_end, w_end, h_start=0, w_start=0, h_steps=None, w_steps=None, is_normalize=False): | |
if h_steps is None: | |
h_steps = int(h_end - h_start) + 1 | |
if w_steps is None: | |
w_steps = int(w_end - w_start) + 1 | |
y = torch.linspace(h_start, h_end, h_steps) | |
x = torch.linspace(w_start, w_end, w_steps) | |
if is_normalize: | |
y = y / h_end | |
x = x / w_end | |
coords = torch.meshgrid(y, x) | |
coords = torch.stack(coords[::-1], dim=0) | |
return coords | |
def get_coords_grid_float(ht, wd, scale, is_normalize=False): | |
y = torch.linspace(0, scale, ht + 2) | |
x = torch.linspace(0, scale, wd + 2) | |
if is_normalize: | |
y = y/scale | |
x = x/scale | |
coords = torch.meshgrid(y[1:-1], x[1:-1]) | |
coords = torch.stack(coords[::-1], dim=0) | |
return coords | |
def get_coords_vector_float(len, scale, is_normalize=False): | |
x = torch.linspace(0, scale, len+2) | |
if is_normalize: | |
x = x/scale | |
coords = torch.meshgrid(x[1:-1], torch.tensor([0.])) | |
coords = torch.stack(coords[::-1], dim=0) | |
return coords | |
class Logger(object): | |
def __init__(self, filename="Default.log", is_terminal_show=True): | |
self.is_terminal_show = is_terminal_show | |
if self.is_terminal_show: | |
self.terminal = sys.stdout | |
self.log = open(filename, "a") | |
def write(self, message): | |
if self.is_terminal_show: | |
self.terminal.write(message) | |
self.log.write(message) | |
self.flush() | |
def flush(self): | |
if self.is_terminal_show: | |
self.terminal.flush() | |
self.log.flush() | |
class ParamsParser: | |
def __init__(self, project_file): | |
self.params = yaml.safe_load(open(project_file).read()) | |
def __getattr__(self, item): | |
return self.params.get(item, None) | |
def get_all_dict(dict_infos: dict) -> dict: | |
return_dict = {} | |
for key, value in dict_infos.items(): | |
if not isinstance(value, dict): | |
return_dict[key] = value | |
else: | |
return_dict = dict(return_dict.items(), **get_all_dict(value)) | |
return return_dict | |
def make_numpy_img(tensor_data): | |
if len(tensor_data.shape) == 2: | |
tensor_data = tensor_data.unsqueeze(2) | |
tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) | |
elif tensor_data.size(0) == 1: | |
tensor_data = tensor_data.permute((1, 2, 0)) | |
tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) | |
elif tensor_data.size(0) == 3: | |
tensor_data = tensor_data.permute((1, 2, 0)) | |
elif tensor_data.size(2) == 3: | |
pass | |
else: | |
raise Exception('tensor_data apply to make_numpy_img error') | |
vis_img = tensor_data.detach().cpu().numpy() | |
return vis_img | |
def print_infos(logger, writer, infos: dict): | |
keys = list(infos.keys()) | |
values = list(infos.values()) | |
infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % tuple(values[:8]) | |
if len(values) > 8: | |
extra_infos = [f', {x}: {y:.4f}' for x, y in zip(keys[8:], values[8:])] | |
infos_str = infos_str + ''.join(extra_infos) | |
logger.write(infos_str + '\n') | |
writer.add_scalar('%s/lr' % infos['pattern'], infos['lr'], | |
infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) | |
for key, value in zip(keys[8:], values[8:]): | |
writer.add_scalar(f'%s/%s' % (infos['pattern'], key), value, | |
infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) | |
def invert_affine(origin_imgs, preds, pattern='train'): | |
if pattern == 'val': | |
for i in range(len(preds)): | |
if len(preds[i]['rois']) == 0: | |
continue | |
else: | |
old_h, old_w, _ = origin_imgs[i].shape | |
preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (512 / old_w) | |
preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (512 / old_h) | |
return preds | |
def save_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id): | |
flows, pf1s, pf2s = output | |
k = np.clip(int(0.2 * len(flows[0])), a_min=2, a_max=len(flows[0])) | |
ids = np.random.choice(range(len(flows[0])), k, replace=False) | |
for img_id in ids: | |
img1, img2 = input['ori_img1'][img_id:img_id+1].to(flows[0].device), input['ori_img2'][img_id:img_id+1].to(flows[0].device) | |
# call the network with image pair batches and actions | |
flow = flows[0][img_id:img_id+1] | |
warps = flow_to_warp(flow) | |
warped_img2 = resample(img2, warps) | |
ori_img1 = make_numpy_img(img1[0]) / 255. | |
ori_img2 = make_numpy_img(img2[0]) / 255. | |
warped_img2 = make_numpy_img(warped_img2[0]) / 255. | |
flow_amplitude = torch.sqrt(flow[0, 0:1, ...] ** 2 + flow[0, 1:2, ...] ** 2) | |
flow_amplitude = make_numpy_img(flow_amplitude) | |
flow_amplitude = (flow_amplitude - np.min(flow_amplitude)) / (np.max(flow_amplitude) - np.min(flow_amplitude) + 1e-10) | |
u = make_numpy_img(flow[0, 0:1, ...]) | |
v = make_numpy_img(flow[0, 1:2, ...]) | |
vis = np.concatenate([ori_img1, ori_img2, warped_img2, flow_amplitude], axis=0) | |
vis = np.clip(vis, a_min=0, a_max=1) | |
file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') | |
plt.imsave(file_name, vis) | |
def inv_normalize_img(img, prior_mean=[0, 0, 0], prior_std=[1, 1, 1]): | |
prior_mean = torch.tensor(prior_mean, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) | |
prior_std = torch.tensor(prior_std, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) | |
img = img * prior_std + prior_mean | |
img = img * 255. | |
img = torch.clamp(img, min=0, max=255) | |
return img | |
def save_seg_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id, prior_mean, prior_std): | |
pred_label = torch.argmax(output, 1) | |
k = np.clip(int(0.2 * len(pred_label)), a_min=1, a_max=len(pred_label[0])) | |
ids = np.random.choice(range(len(pred_label)), k, replace=False) | |
for img_id in ids: | |
img = input['img'][img_id].to(pred_label.device) | |
target = input['label'][img_id].to(pred_label.device) | |
img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) / 255. | |
target = make_numpy_img(encode_onehot_to_mask(target)) | |
pred = make_numpy_img(pred_label[img_id]) | |
vis = np.concatenate([img, pred, target], axis=0) | |
vis = np.clip(vis, a_min=0, a_max=1) | |
file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') | |
plt.imsave(file_name, vis) | |
def set_requires_grad(nets, requires_grad=False): | |
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations | |
Parameters: | |
nets (network list) -- a list of networks | |
requires_grad (bool) -- whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
def boolean_string(s): | |
if s not in {'False', 'True'}: | |
raise ValueError('Not a valid boolean string') | |
return s == 'True' | |
def cpt_pxl_cls_acc(pred_idx, target): | |
pred_idx = torch.reshape(pred_idx, [-1]) | |
target = torch.reshape(target, [-1]) | |
return torch.mean((pred_idx.int() == target.int()).float()) | |
def cpt_batch_psnr(img, img_gt, PIXEL_MAX): | |
mse = torch.mean((img - img_gt) ** 2, dim=[1, 2, 3]) | |
psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse)) | |
return torch.mean(psnr) | |
def cpt_psnr(img, img_gt, PIXEL_MAX): | |
mse = np.mean((img - img_gt) ** 2) | |
psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) | |
return psnr | |
def cpt_rgb_ssim(img, img_gt): | |
img = clip_01(img) | |
img_gt = clip_01(img_gt) | |
SSIM = 0 | |
for i in range(3): | |
tmp = img[:, :, i] | |
tmp_gt = img_gt[:, :, i] | |
ssim = sk_cpt_ssim(tmp, tmp_gt) | |
SSIM = SSIM + ssim | |
return SSIM / 3.0 | |
def cpt_ssim(img, img_gt): | |
img = clip_01(img) | |
img_gt = clip_01(img_gt) | |
return sk_cpt_ssim(img, img_gt) | |
def decode_mask_to_onehot(mask, n_class): | |
''' | |
mask : BxWxH or WxH | |
n_class : n | |
return : BxnxWxH or nxWxH | |
''' | |
assert len(mask.shape) in [2, 3], "decode_mask_to_onehot error!" | |
if len(mask.shape) == 2: | |
mask = mask.unsqueeze(0) | |
onehot = torch.zeros((mask.size(0), n_class, mask.size(1), mask.size(2))).to(mask.device) | |
for i in range(n_class): | |
onehot[:, i, ...] = mask == i | |
if len(mask.shape) == 2: | |
onehot = onehot.squeeze(0) | |
return onehot | |
def encode_onehot_to_mask(onehot): | |
''' | |
onehot: tensor, BxnxWxH or nxWxH | |
output: tensor, BxWxH or WxH | |
''' | |
assert len(onehot.shape) in [3, 4], "encode_onehot_to_mask error!" | |
mask = torch.argmax(onehot, dim=len(onehot.shape)-3) | |
return mask | |
def decode(pred, target=None, *args, **kwargs): | |
""" | |
Args: | |
phase: 'od' | |
pred: big_cls_1(0), big_reg_1, small_cls_1(2), small_reg_1, big_cls_2(4), big_reg_2, small_cls_2(6), small_reg_2 | |
target: [[n,5], [n,5]] list of tensor | |
Returns: | |
""" | |
phase = kwargs['phase'] | |
img_size = kwargs['img_size'] | |
if phase == 'od': | |
prior_box_wh = kwargs['prior_box_wh'] | |
conf_thres = kwargs['conf_thres'] | |
iou_thres = kwargs['iou_thres'] | |
conf_type = kwargs['conf_type'] | |
pred_conf_32_2 = F.softmax(pred[4], dim=1)[:, 1, ...] # B H W | |
pred_conf_64_2 = F.softmax(pred[6], dim=1)[:, 1, ...] # B H W | |
obj_mask_32_2 = pred_conf_32_2 > conf_thres # B H W | |
obj_mask_64_2 = pred_conf_64_2 > conf_thres # B H W | |
pre_loc_32_2 = pred[1] + pred[5] # B 4 H W | |
pre_loc_32_2[:, 0::2, ...] *= prior_box_wh[0] | |
pre_loc_32_2[:, 1::2, ...] *= prior_box_wh[1] | |
x_y_grid = get_coords_grid(31, 31, 0, 0) | |
x_y_grid *= 8 | |
x_y_grid = torch.cat([x_y_grid, x_y_grid], dim=0) | |
pre_loc_32_2 += x_y_grid.to(pre_loc_32_2.device) | |
pre_loc_64_2 = pred[3] + pred[7] # B 4 H W | |
pre_loc_64_2[:, 0::2, ...] *= prior_box_wh[0] | |
pre_loc_64_2[:, 1::2, ...] *= prior_box_wh[1] | |
x_y_grid_2 = get_coords_grid(63, 63, 0, 0) | |
x_y_grid_2 *= 4 | |
x_y_grid_2 = torch.cat([x_y_grid_2, x_y_grid_2], dim=0) | |
pre_loc_64_2 += x_y_grid_2.to(pre_loc_32_2.device) | |
pred_all = [] | |
for i in range(pre_loc_32_2.size(0)): | |
score_32 = pred_conf_32_2[i][obj_mask_32_2[i]] # N | |
score_64 = pred_conf_64_2[i][obj_mask_64_2[i]] # M | |
loc_32 = pre_loc_32_2[i].permute((1, 2, 0))[obj_mask_32_2[i]] # Nx4 | |
loc_64 = pre_loc_64_2[i].permute((1, 2, 0))[obj_mask_64_2[i]] # Mx4 | |
score_list = torch.cat((score_32, score_64), dim=0).detach().cpu().numpy() | |
boxes_list = torch.cat((loc_32, loc_64), dim=0).detach().cpu().numpy() | |
boxes_list[:, 0::2] /= img_size[0] | |
boxes_list[:, 1::2] /= img_size[1] | |
label_list = np.ones_like(score_list) | |
# 目标预设150 | |
boxes_list = boxes_list[:150, :] | |
score_list = score_list[:150] | |
label_list = label_list[:150] | |
boxes, scores, labels = weighted_boxes_fusion([boxes_list], [score_list], [label_list], weights=None, | |
iou_thr=iou_thres, conf_type=conf_type) | |
boxes[:, 0::2] *= img_size[0] | |
boxes[:, 1::2] *= img_size[1] | |
pred_boxes = np.concatenate((labels.reshape(-1, 1), scores.reshape(-1, 1), boxes), axis=1) | |
pred_all.append(pred_boxes) | |
if target is not None: | |
target_all = [x.cpu().numpy() for x in target] | |
else: | |
target_all = None | |
return {"pred_all": pred_all, "target_all": target_all} | |
def get_metrics(phase, pred, target): | |
''' | |
pred: logits, tensor, nBatch*nClass*W*H | |
target: labels, tensor, nBatch*nClass*W*H | |
''' | |
if phase == 'seg': | |
pred = torch.argmax(pred.detach(), dim=1) | |
pred = decode_mask_to_onehot(pred, target.size(1)) | |
# positive samples in ground truth | |
gt_pos_sum = torch.sum(target == 1, dim=(0, 2, 3)) | |
# positive prediction in predict mask | |
pred_pos_sum = torch.sum(pred == 1, dim=(0, 2, 3)) | |
# cal true positive sample | |
true_pos_sum = torch.sum((target == 1) * (pred == 1), dim=(0, 2, 3)) | |
# Precision | |
precision = true_pos_sum / (pred_pos_sum + 1e-15) | |
# Recall | |
recall = true_pos_sum / (gt_pos_sum + 1e-15) | |
# IoU | |
IoU = true_pos_sum / (pred_pos_sum + gt_pos_sum - true_pos_sum + 1e-15) | |
# OA | |
OA = 1 - (pred_pos_sum + gt_pos_sum - 2 * true_pos_sum) / torch.sum(target >= 0, dim=(0, 2, 3)) | |
# F1-score | |
F1_score = 2 * precision * recall / (precision + recall + 1e-15) | |
return IoU, OA, F1_score | |