Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from ..metrics import ap_per_class | |
| def fitness(x): | |
| # Model fitness as a weighted combination of metrics | |
| w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9, 0.1, 0.9] | |
| return (x[:, :len(w)] * w).sum(1) | |
| def ap_per_class_box_and_mask( | |
| tp_m, | |
| tp_b, | |
| conf, | |
| pred_cls, | |
| target_cls, | |
| plot=False, | |
| save_dir=".", | |
| names=(), | |
| ): | |
| """ | |
| Args: | |
| tp_b: tp of boxes. | |
| tp_m: tp of masks. | |
| other arguments see `func: ap_per_class`. | |
| """ | |
| results_boxes = ap_per_class(tp_b, | |
| conf, | |
| pred_cls, | |
| target_cls, | |
| plot=plot, | |
| save_dir=save_dir, | |
| names=names, | |
| prefix="Box")[2:] | |
| results_masks = ap_per_class(tp_m, | |
| conf, | |
| pred_cls, | |
| target_cls, | |
| plot=plot, | |
| save_dir=save_dir, | |
| names=names, | |
| prefix="Mask")[2:] | |
| results = { | |
| "boxes": { | |
| "p": results_boxes[0], | |
| "r": results_boxes[1], | |
| "ap": results_boxes[3], | |
| "f1": results_boxes[2], | |
| "ap_class": results_boxes[4]}, | |
| "masks": { | |
| "p": results_masks[0], | |
| "r": results_masks[1], | |
| "ap": results_masks[3], | |
| "f1": results_masks[2], | |
| "ap_class": results_masks[4]}} | |
| return results | |
| class Metric: | |
| def __init__(self) -> None: | |
| self.p = [] # (nc, ) | |
| self.r = [] # (nc, ) | |
| self.f1 = [] # (nc, ) | |
| self.all_ap = [] # (nc, 10) | |
| self.ap_class_index = [] # (nc, ) | |
| def ap50(self): | |
| """AP@0.5 of all classes. | |
| Return: | |
| (nc, ) or []. | |
| """ | |
| return self.all_ap[:, 0] if len(self.all_ap) else [] | |
| def ap(self): | |
| """AP@0.5:0.95 | |
| Return: | |
| (nc, ) or []. | |
| """ | |
| return self.all_ap.mean(1) if len(self.all_ap) else [] | |
| def mp(self): | |
| """mean precision of all classes. | |
| Return: | |
| float. | |
| """ | |
| return self.p.mean() if len(self.p) else 0.0 | |
| def mr(self): | |
| """mean recall of all classes. | |
| Return: | |
| float. | |
| """ | |
| return self.r.mean() if len(self.r) else 0.0 | |
| def map50(self): | |
| """Mean AP@0.5 of all classes. | |
| Return: | |
| float. | |
| """ | |
| return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 | |
| def map(self): | |
| """Mean AP@0.5:0.95 of all classes. | |
| Return: | |
| float. | |
| """ | |
| return self.all_ap.mean() if len(self.all_ap) else 0.0 | |
| def mean_results(self): | |
| """Mean of results, return mp, mr, map50, map""" | |
| return (self.mp, self.mr, self.map50, self.map) | |
| def class_result(self, i): | |
| """class-aware result, return p[i], r[i], ap50[i], ap[i]""" | |
| return (self.p[i], self.r[i], self.ap50[i], self.ap[i]) | |
| def get_maps(self, nc): | |
| maps = np.zeros(nc) + self.map | |
| for i, c in enumerate(self.ap_class_index): | |
| maps[c] = self.ap[i] | |
| return maps | |
| def update(self, results): | |
| """ | |
| Args: | |
| results: tuple(p, r, ap, f1, ap_class) | |
| """ | |
| p, r, all_ap, f1, ap_class_index = results | |
| self.p = p | |
| self.r = r | |
| self.all_ap = all_ap | |
| self.f1 = f1 | |
| self.ap_class_index = ap_class_index | |
| class Metrics: | |
| """Metric for boxes and masks.""" | |
| def __init__(self) -> None: | |
| self.metric_box = Metric() | |
| self.metric_mask = Metric() | |
| def update(self, results): | |
| """ | |
| Args: | |
| results: Dict{'boxes': Dict{}, 'masks': Dict{}} | |
| """ | |
| self.metric_box.update(list(results["boxes"].values())) | |
| self.metric_mask.update(list(results["masks"].values())) | |
| def mean_results(self): | |
| return self.metric_box.mean_results() + self.metric_mask.mean_results() | |
| def class_result(self, i): | |
| return self.metric_box.class_result(i) + self.metric_mask.class_result(i) | |
| def get_maps(self, nc): | |
| return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc) | |
| def ap_class_index(self): | |
| # boxes and masks have the same ap_class_index | |
| return self.metric_box.ap_class_index | |
| class Semantic_Metrics: | |
| def __init__(self, nc, device): | |
| self.nc = nc # number of classes | |
| self.device = device | |
| self.iou = [] | |
| self.c_bit_counts = torch.zeros(nc, dtype = torch.long).to(device) | |
| self.c_intersection_counts = torch.zeros(nc, dtype = torch.long).to(device) | |
| self.c_union_counts = torch.zeros(nc, dtype = torch.long).to(device) | |
| def update(self, pred_masks, target_masks): | |
| nb, nc, h, w = pred_masks.shape | |
| device = pred_masks.device | |
| for b in range(nb): | |
| onehot_mask = pred_masks[b].to(device) | |
| # convert predict mask to one hot | |
| semantic_mask = torch.flatten(onehot_mask, start_dim = 1).permute(1, 0) # class x h x w -> (h x w) x class | |
| max_idx = semantic_mask.argmax(1) | |
| output_masks = (torch.zeros(semantic_mask.shape).to(self.device)).scatter(1, max_idx.unsqueeze(1), 1.0) # one hot: (h x w) x class | |
| output_masks = torch.reshape(output_masks.permute(1, 0), (nc, h, w)) # (h x w) x class -> class x h x w | |
| onehot_mask = output_masks.int() | |
| for c in range(self.nc): | |
| pred_mask = onehot_mask[c].to(device) | |
| target_mask = target_masks[b, c].to(device) | |
| # calculate IoU | |
| intersection = (torch.logical_and(pred_mask, target_mask).sum()).item() | |
| union = (torch.logical_or(pred_mask, target_mask).sum()).item() | |
| iou = 0. if (0 == union) else (intersection / union) | |
| # record class pixel counts, intersection counts, union counts | |
| self.c_bit_counts[c] += target_mask.int().sum() | |
| self.c_intersection_counts[c] += intersection | |
| self.c_union_counts[c] += union | |
| self.iou.append(iou) | |
| def results(self): | |
| # Mean IoU | |
| miou = 0. if (0 == len(self.iou)) else np.sum(self.iou) / (len(self.iou) * self.nc) | |
| # Frequency Weighted IoU | |
| c_iou = self.c_intersection_counts / (self.c_union_counts + 1) # add smooth | |
| # c_bit_counts = self.c_bit_counts.astype(int) | |
| total_c_bit_counts = self.c_bit_counts.sum() | |
| freq_ious = torch.zeros(1, dtype = torch.long).to(self.device) if (0 == total_c_bit_counts) else (self.c_bit_counts / total_c_bit_counts) * c_iou | |
| fwiou = (freq_ious.sum()).item() | |
| return (miou, fwiou) | |
| def reset(self): | |
| self.iou = [] | |
| self.c_bit_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device) | |
| self.c_intersection_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device) | |
| self.c_union_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device) | |
| KEYS = [ | |
| "train/box_loss", | |
| "train/seg_loss", # train loss | |
| "train/cls_loss", | |
| "train/dfl_loss", | |
| "train/fcl_loss", | |
| "train/dic_loss", | |
| "metrics/precision(B)", | |
| "metrics/recall(B)", | |
| "metrics/mAP_0.5(B)", | |
| "metrics/mAP_0.5:0.95(B)", # metrics | |
| "metrics/precision(M)", | |
| "metrics/recall(M)", | |
| "metrics/mAP_0.5(M)", | |
| "metrics/mAP_0.5:0.95(M)", # metrics | |
| "metrics/MIOUS(S)", | |
| "metrics/FWIOUS(S)", # metrics | |
| "val/box_loss", | |
| "val/seg_loss", # val loss | |
| "val/cls_loss", | |
| "val/dfl_loss", | |
| "val/fcl_loss", | |
| "val/dic_loss", | |
| "x/lr0", | |
| "x/lr1", | |
| "x/lr2",] | |
| BEST_KEYS = [ | |
| "best/epoch", | |
| "best/precision(B)", | |
| "best/recall(B)", | |
| "best/mAP_0.5(B)", | |
| "best/mAP_0.5:0.95(B)", | |
| "best/precision(M)", | |
| "best/recall(M)", | |
| "best/mAP_0.5(M)", | |
| "best/mAP_0.5:0.95(M)", | |
| "best/MIOUS(S)", | |
| "best/FWIOUS(S)",] | |