Spaces:
Sleeping
Sleeping
| import torch | |
| import copy | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.nn import BCELoss | |
| from utils import box_ops | |
| class ObjectNormalizedL2Loss(nn.Module): | |
| def __init__(self): | |
| super(ObjectNormalizedL2Loss, self).__init__() | |
| def forward(self, output, dmap, num_objects): | |
| return ((output - dmap) ** 2).sum() / num_objects | |
| class Detection_criterion(nn.Module): | |
| def __init__( | |
| self, sizes, iou_loss_type, center_sample, fpn_strides, pos_radius, aux=False | |
| ): | |
| super().__init__() | |
| self.sizes = sizes | |
| self.box_loss = IOULoss(iou_loss_type) | |
| self.aux = aux | |
| self.center_sample = center_sample | |
| self.strides = fpn_strides | |
| self.radius = pos_radius | |
| def prepare_target(self, points, targets): | |
| ex_size_of_interest = [] | |
| for i, point_per_level in enumerate(points): | |
| size_of_interest_per_level = point_per_level.new_tensor(self.sizes[i]) | |
| ex_size_of_interest.append( | |
| size_of_interest_per_level[None].expand(len(point_per_level), -1) | |
| ) | |
| ex_size_of_interest = torch.cat(ex_size_of_interest, 0) | |
| n_point_per_level = [len(point_per_level) for point_per_level in points] | |
| point_all = torch.cat(points, dim=0) | |
| label, box_target = self.compute_target_for_location( | |
| point_all, targets, ex_size_of_interest, n_point_per_level | |
| ) | |
| for i in range(len(label)): | |
| label[i] = torch.split(label[i], n_point_per_level, 0) | |
| box_target[i] = torch.split(box_target[i], n_point_per_level, 0) | |
| label_level_first = [] | |
| box_target_level_first = [] | |
| for level in range(len(points)): | |
| label_level_first.append( | |
| torch.cat([label_per_img[level] for label_per_img in label], 0).to(points[0].device) | |
| ) | |
| box_target_level_first.append( | |
| torch.cat( | |
| [box_target_per_img[level] for box_target_per_img in box_target], 0 | |
| ) | |
| ) | |
| return label_level_first, box_target_level_first | |
| def get_sample_region(self, gt, strides, n_point_per_level, xs, ys, radius=1): | |
| n_gt = gt.shape[0] | |
| n_loc = len(xs) | |
| gt = gt[None].expand(n_loc, n_gt, 4) | |
| center_x = (gt[..., 0] + gt[..., 2]) / 2 | |
| center_y = (gt[..., 1] + gt[..., 3]) / 2 | |
| # y_stride = torch.min((gt[..., 3] - gt[..., 1]) / 2)/2 | |
| # x_stride = torch.min((gt[..., 2] - gt[..., 0]) / 2)/2 | |
| if center_x[..., 0].sum() == 0: | |
| return xs.new_zeros(xs.shape, dtype=torch.uint8) | |
| begin = 0 | |
| center_gt = gt.new_zeros(gt.shape) | |
| for level, n_p in enumerate(n_point_per_level): | |
| end = begin + n_p | |
| stride = strides[level] * radius | |
| x_min = center_x[begin:end] - stride | |
| y_min = center_y[begin:end] - stride | |
| x_max = center_x[begin:end] + stride | |
| y_max = center_y[begin:end] + stride | |
| center_gt[begin:end, :, 0] = torch.where( | |
| x_min > gt[begin:end, :, 0], x_min, gt[begin:end, :, 0] | |
| ) | |
| center_gt[begin:end, :, 1] = torch.where( | |
| y_min > gt[begin:end, :, 1], y_min, gt[begin:end, :, 1] | |
| ) | |
| center_gt[begin:end, :, 2] = torch.where( | |
| x_max > gt[begin:end, :, 2], gt[begin:end, :, 2], x_max | |
| ) | |
| center_gt[begin:end, :, 3] = torch.where( | |
| y_max > gt[begin:end, :, 3], gt[begin:end, :, 3], y_max | |
| ) | |
| begin = end | |
| left = xs[:, None] - center_gt[..., 0] | |
| right = center_gt[..., 2] - xs[:, None] | |
| top = ys[:, None] - center_gt[..., 1] | |
| bottom = center_gt[..., 3] - ys[:, None] | |
| center_bbox = torch.stack((left, top, right, bottom), -1) | |
| is_in_boxes = center_bbox.min(-1)[0] > 0 | |
| return is_in_boxes | |
| def compute_target_for_location( | |
| self, locations, targets, sizes_of_interest, n_point_per_level | |
| ): | |
| labels = [] | |
| box_targets = [] | |
| xs, ys = locations[:, 0], locations[:, 1] | |
| for i in range(len(targets)): | |
| targets_per_img = targets[i] | |
| targets_per_img=targets_per_img.clip(remove_empty=True) | |
| assert targets_per_img.mode == 'xyxy' | |
| targets_per_img = targets_per_img[:50] | |
| bboxes = targets_per_img.box | |
| labels_per_img = torch.tensor([1]*len(bboxes)).to(locations.device) | |
| area = targets_per_img.area() | |
| l = xs[:, None] - bboxes[:, 0][None] | |
| t = ys[:, None] - bboxes[:, 1][None] | |
| r = bboxes[:, 2][None] - xs[:, None] | |
| b = bboxes[:, 3][None] - ys[:, None] | |
| box_targets_per_img = torch.stack([l, t, r, b], 2) | |
| if self.center_sample: | |
| is_in_boxes = self.get_sample_region( | |
| bboxes, self.strides, n_point_per_level, xs, ys, radius=self.radius | |
| ) | |
| else: | |
| is_in_boxes = box_targets_per_img.min(2)[0] > 0 | |
| max_box_targets_per_img = box_targets_per_img.max(2)[0] | |
| is_cared_in_level = ( | |
| max_box_targets_per_img >= sizes_of_interest[:, [0]] | |
| ) & (max_box_targets_per_img <= sizes_of_interest[:, [1]]) | |
| locations_to_gt_area = area[None].repeat(len(locations), 1) | |
| locations_to_gt_area[is_in_boxes == 0] = INF | |
| locations_to_gt_area[is_cared_in_level == 0] = INF | |
| locations_to_min_area, locations_to_gt_id = locations_to_gt_area.min(1) | |
| box_targets_per_img = box_targets_per_img[ | |
| range(len(locations)), locations_to_gt_id | |
| ] | |
| labels_per_img = labels_per_img.to(locations_to_gt_id.device)[locations_to_gt_id] | |
| labels_per_img[locations_to_min_area == INF] = 0 | |
| labels.append(labels_per_img) | |
| box_targets.append(box_targets_per_img) | |
| return labels, box_targets | |
| def compute_centerness_targets(self, box_targets): | |
| left_right = box_targets[:, [0, 2]] | |
| top_bottom = box_targets[:, [1, 3]] | |
| centerness = (left_right.min(-1)[0] / left_right.max(-1)[0]) * ( | |
| top_bottom.min(-1)[0] / top_bottom.max(-1)[0] | |
| ) | |
| return torch.sqrt(centerness) | |
| def forward(self, locations, box_pred, targets): | |
| batch = box_pred[0].shape[0] | |
| labels, box_targets = self.prepare_target(locations, targets) | |
| box_flat = [] | |
| labels_flat = [] | |
| box_targets_flat = [] | |
| for i in range(len(labels)): | |
| box_flat.append(box_pred.permute(0, 2, 3, 1).reshape(-1, 4)) | |
| labels_flat.append(labels[i].reshape(-1)) | |
| box_targets_flat.append(box_targets[i].reshape(-1, 4)) | |
| box_flat = torch.cat(box_flat, 0) | |
| labels_flat = torch.cat(labels_flat, 0) | |
| box_targets_flat = torch.cat(box_targets_flat, 0) | |
| pos_id = torch.nonzero(labels_flat > 0).squeeze(1) | |
| box_flat = box_flat[pos_id] | |
| box_targets_flat = box_targets_flat[pos_id] | |
| if pos_id.numel() > 0: | |
| center_targets = self.compute_centerness_targets(box_targets_flat) | |
| box_loss = self.box_loss(box_flat, box_targets_flat, center_targets) | |
| else: | |
| box_loss = box_flat.sum() | |
| return box_loss | |
| INF = 100000000 | |
| class IOULoss(nn.Module): | |
| def __init__(self, loc_loss_type): | |
| super().__init__() | |
| self.loc_loss_type = loc_loss_type | |
| def forward(self, out, target, weight=None): | |
| pred_left, pred_top, pred_right, pred_bottom = out.unbind(1) | |
| target_left, target_top, target_right, target_bottom = target.unbind(1) | |
| target_area = (target_left + target_right) * (target_top + target_bottom) | |
| pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) | |
| w_intersect = torch.min(pred_left, target_left) + torch.min( | |
| pred_right, target_right | |
| ) | |
| h_intersect = torch.min(pred_bottom, target_bottom) + torch.min( | |
| pred_top, target_top | |
| ) | |
| area_intersect = w_intersect * h_intersect | |
| area_union = target_area + pred_area - area_intersect | |
| ious = (area_intersect + 1) / (area_union + 1) | |
| if self.loc_loss_type == 'iou': | |
| loss = -torch.log(ious) | |
| elif self.loc_loss_type == 'giou': | |
| g_w_intersect = torch.max(pred_left, target_left) + torch.max( | |
| pred_right, target_right | |
| ) | |
| g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max( | |
| pred_top, target_top | |
| ) | |
| g_intersect = g_w_intersect * g_h_intersect + 1e-7 | |
| gious = ious - (g_intersect - area_union) / g_intersect | |
| loss = 1 - gious | |
| if weight is not None and weight.sum() > 0: | |
| return (loss * weight).sum() / weight.sum() | |
| else: | |
| return loss.mean() | |
| class SetCriterion(nn.Module): | |
| """ This class computes the loss for DETR. | |
| The process happens in two steps: | |
| 1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
| 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
| """ | |
| def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25): | |
| """ Create the criterion. | |
| Parameters: | |
| num_classes: number of object categories, omitting the special no-object category | |
| matcher: module able to compute a matching between targets and proposals | |
| weight_dict: dict containing as key the names of the losses and as values their relative weight. | |
| losses: list of all the losses to be applied. See get_loss for list of available losses. | |
| focal_alpha: alpha in Focal Loss | |
| """ | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.matcher = matcher | |
| self.weight_dict = weight_dict | |
| self.losses = losses | |
| self.focal_alpha = focal_alpha | |
| self.cross_entropy = BCELoss() | |
| def loss_boxes(self, outputs, targets, indices, num_boxes, centerness, centerness_gt,mask): | |
| """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss | |
| targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] | |
| The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. | |
| """ | |
| assert 'pred_boxes' in outputs | |
| idx = self._get_src_permutation_idx(indices) | |
| src_boxes = outputs['pred_boxes'][idx] | |
| target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
| loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') | |
| losses = {} | |
| losses['loss_bbox'] = loss_bbox.sum() / num_boxes | |
| loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( | |
| (src_boxes), | |
| (target_boxes))) | |
| losses['loss_giou'] = loss_giou.sum() / num_boxes | |
| return losses | |
| def ce_loss(self, outputs, targets, indices, num_boxes, centerness, centerness_gt, mask): | |
| l2 = ((centerness[mask > 0] - centerness_gt[mask > 0]) ** 2) | |
| losses = {} | |
| losses['loss_ce'] = l2.sum() / num_boxes | |
| return losses | |
| def _get_src_permutation_idx(self, indices): | |
| # permute predictions following indices | |
| batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
| src_idx = torch.cat([src for (src, _) in indices]) | |
| return batch_idx, src_idx | |
| def _get_tgt_permutation_idx(self, indices): | |
| # permute targets following indices | |
| batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
| tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
| return batch_idx, tgt_idx | |
| def get_loss(self, loss, outputs, targets, indices, num_boxes, centerness, centerness_gt, mask, **kwargs): | |
| loss_map = { | |
| 'bboxes': self.loss_boxes, | |
| 'ce': self.ce_loss | |
| } | |
| assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
| return loss_map[loss](outputs, targets, indices, num_boxes,centerness, centerness_gt, mask, **kwargs) | |
| # def generate_centerness_gt(self, indices, FN_idx, FP_idx, outputs, targets, centerness, ref_points): | |
| # # TP_bboxes = outputs['pred_boxes'][0][indices[0][0]] * centerness.shape[1] | |
| # # FP_bboxes = outputs['pred_boxes'][0][FP_idx] * centerness.shape[1] | |
| # FN_bboxes = targets[0]['boxes'][FN_idx] * centerness.shape[1] | |
| # centerness_gt = torch.zeros_like(centerness) | |
| # mask = torch.ones_like(centerness) | |
| # # FP -> Non-matched PRED bboxes get 0 in the reference point, so 1 in mask | |
| # FP_locs = ref_points.permute(1, 0)[FP_idx] | |
| # mask[0][FP_locs[:, 0], FP_locs[:, 1]] = 1 | |
| # bounding_boxes = (targets[0]['boxes'] * centerness.shape[1]).type(torch.int64) | |
| # for box in bounding_boxes: | |
| # x_min, y_min, x_max, y_max = box | |
| # mask[:, y_min:y_max, x_min:x_max] = 0 | |
| # # FN -> Non-matched GT bboxes get 1 in center of bbox | |
| # if len(FN_bboxes) > 0: | |
| # FN_y_loc = torch.clamp(((FN_bboxes[:, 3] + FN_bboxes[:, 1]) / 2).int(), min=0, max=centerness.shape[1]-1) | |
| # FN_x_loc = torch.clamp(((FN_bboxes[:, 2] + FN_bboxes[:, 0]) / 2).int(), min=0, max=centerness.shape[1]-1) | |
| # centerness_gt[0][FN_y_loc, FN_x_loc] = 1 | |
| # mask[0][FN_y_loc, FN_x_loc] = 1 | |
| # # TP -> Matched PRED bboxes get 1 in the reference point | |
| # TP_locs = ref_points.permute(1, 0)[indices[0][0]] | |
| # centerness_gt[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 | |
| # mask[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 | |
| # return centerness_gt, mask | |
| def generate_centerness_gt(self, indices, FN_idx, FP_idx, outputs, targets, centerness, ref_points): | |
| FN_bboxes = targets[0]['boxes'][FN_idx] * centerness.shape[1] | |
| centerness_gt = torch.zeros_like(centerness) | |
| mask = torch.zeros_like(centerness) | |
| # FP -> Non-matched PRED bboxes get 0 in the reference point, so 1 in mask | |
| FP_locs = ref_points.permute(1, 0)[FP_idx] | |
| mask[0][FP_locs[:, 0], FP_locs[:, 1]] = 1 | |
| # bounding_boxes = (targets[0]['boxes'] * centerness.shape[1]).type(torch.int64) | |
| # for box in bounding_boxes: | |
| # x_min, y_min, x_max, y_max = box | |
| # mask[:, y_min:y_max, x_min:x_max] = 0 | |
| # FN -> Non-matched GT bboxes get 1 in center of bbox | |
| if len(FN_bboxes) > 0: | |
| FN_y_loc = torch.clamp(((FN_bboxes[:, 3] + FN_bboxes[:, 1]) / 2).int(), min=0, max=centerness.shape[1]-1) | |
| FN_x_loc = torch.clamp(((FN_bboxes[:, 2] + FN_bboxes[:, 0]) / 2).int(), min=0, max=centerness.shape[1]-1) | |
| centerness_gt[0][FN_y_loc, FN_x_loc] = 1 | |
| mask[0][FN_y_loc, FN_x_loc] = 1 | |
| # TP -> Matched PRED bboxes get 1 in the reference point | |
| TP_locs = ref_points.permute(1, 0)[indices[0][0]] | |
| centerness_gt[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 | |
| mask[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 | |
| if centerness_gt.sum() < targets[0]['boxes'].shape[0]: | |
| centerness_gt = torch.zeros_like(centerness) | |
| FN_bboxes = targets[0]['boxes'] * centerness.shape[1] | |
| FN_y_loc = torch.clamp(((FN_bboxes[:, 3] + FN_bboxes[:, 1]) / 2).int(), min=0, max=centerness.shape[1]-1) | |
| FN_x_loc = torch.clamp(((FN_bboxes[:, 2] + FN_bboxes[:, 0]) / 2).int(), min=0, max=centerness.shape[1]-1) | |
| centerness_gt[0][FN_y_loc, FN_x_loc] = 1 | |
| mask = torch.ones_like(centerness) | |
| return centerness_gt, mask | |
| # def generate_centerness_gt(self, indices, FN_idx, FP_idx, outputs, targets, centerness, ref_points): | |
| # # TP_bboxes = outputs['pred_boxes'][0][indices[0][0]] * centerness.shape[1] | |
| # # FP_bboxes = outputs['pred_boxes'][0][FP_idx] * centerness.shape[1] | |
| # FN_bboxes = targets[0]['boxes'][FN_idx] * centerness.shape[1] | |
| # centerness_gt = torch.zeros_like(centerness) | |
| # mask = torch.zeros_like(centerness) | |
| # # FN -> Non-matched GT bboxes get 1 in center of bbox | |
| # if len(FN_bboxes) > 0: | |
| # FN_y_loc = ((FN_bboxes[:, 3] + FN_bboxes[:, 1]) / 2 ).int() | |
| # FN_x_loc = ((FN_bboxes[:, 2] + FN_bboxes[:, 0]) / 2 ).int() | |
| # centerness_gt[0][FN_y_loc, FN_x_loc] = 1 | |
| # # mask[0][FN_y_loc, FN_x_loc] = 1 | |
| # # FP -> Non-matched PRED bboxes get 0 in the reference point, so 1 in mask | |
| # FP_locs = ref_points.permute(1, 0)[FP_idx] | |
| # # mask[0][FP_locs[:, 0], FP_locs[:, 1]] = 1 | |
| # # TP -> Matched PRED bboxes get 1 in the reference point | |
| # print(indices[0][0]) | |
| # TP_locs = ref_points.permute(1, 0)[indices[0][0]] | |
| # centerness_gt[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 | |
| # # mask[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 | |
| # return centerness_gt, mask | |
| # # from matplotlib import pyplot as plt | |
| # # plt.clf() | |
| # # plt.imshow(centerness_gt.cpu()[0], cmap='jet') | |
| # # plt.imshow(mask.cpu()[0], cmap='jet', alpha=0.3) | |
| # # for i in range(TP_bboxes.shape[0]): | |
| # # box = TP_bboxes[i].cpu() | |
| # # plt.plot([box[0], box[0], box[2], box[2], box[0]], | |
| # # [box[1], box[3], box[3], box[1], box[1]], color='g') | |
| # # | |
| # # for i in range(FP_bboxes.shape[0]): | |
| # # box = FP_bboxes[i].cpu() | |
| # # plt.plot([box[0], box[0], box[2], box[2], box[0]], | |
| # # [box[1], box[3], box[3], box[1], box[1]], color='orange') | |
| # # | |
| # # for i in range(FN_bboxes.shape[0]): | |
| # # box = FN_bboxes[i].cpu() | |
| # # plt.plot([box[0], box[0], box[2], box[2], box[0]], | |
| # # [box[1], box[3], box[3], box[1], box[1]], color='red') | |
| # # plt.savefig("T") | |
| def forward(self, outputs, targets, centerness, ref_points): | |
| """ This performs the loss computation. | |
| Parameters: | |
| outputs: dict of tensors, see the output specification of the model for the format | |
| targets: list of dicts, such that len(targets) == batch_size. | |
| The expected keys in each dict depends on the losses applied, see each loss' doc | |
| """ | |
| # Retrieve the matching between the outputs of the last layer and the targets | |
| indices, FN_idx, FP_idx = self.matcher(outputs, targets) | |
| # Compute the average number of target boxes accross all nodes, for normalization purposes | |
| num_boxes = sum(len(t["labels"]) for t in targets) | |
| num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) | |
| num_boxes = torch.clamp(num_boxes, min=1).item() | |
| centerness_gt, mask = self.generate_centerness_gt(indices, FN_idx, FP_idx, outputs, targets, centerness, ref_points) | |
| # Compute all the requested losses | |
| losses = {} | |
| for loss in self.losses: | |
| kwargs = {} | |
| losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, centerness, centerness_gt, mask, **kwargs)) | |
| # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
| if 'aux_outputs' in outputs: | |
| for i, aux_outputs in enumerate(outputs['aux_outputs']): | |
| indices = self.matcher(aux_outputs, targets) | |
| for loss in self.losses: | |
| if loss == 'masks': | |
| # Intermediate masks losses are too costly to compute, we ignore them. | |
| continue | |
| kwargs = {} | |
| if loss == 'labels': | |
| # Logging is enabled only for the last layer | |
| kwargs['log'] = False | |
| l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) | |
| l_dict = {k + f'_{i}': v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| if 'enc_outputs' in outputs: | |
| enc_outputs = outputs['enc_outputs'] | |
| bin_targets = copy.deepcopy(targets) | |
| for bt in bin_targets: | |
| bt['labels'] = torch.zeros_like(bt['labels']) | |
| indices = self.matcher(enc_outputs, bin_targets) | |
| for loss in self.losses: | |
| if loss == 'masks': | |
| # Intermediate masks losses are too costly to compute, we ignore them. | |
| continue | |
| kwargs = {} | |
| if loss == 'labels': | |
| # Logging is enabled only for the last layer | |
| kwargs['log'] = False | |
| l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) | |
| l_dict = {k + f'_enc': v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| return losses | |