import torch import numpy as np from isegm.model.metrics import _compute_iou from .brs_losses import BRSMaskLoss class BaseOptimizer: def __init__(self, optimizer_params, prob_thresh=0.49, reg_weight=1e-3, min_iou_diff=0.01, brs_loss=BRSMaskLoss(), with_flip=False, flip_average=False, **kwargs): self.brs_loss = brs_loss self.optimizer_params = optimizer_params self.prob_thresh = prob_thresh self.reg_weight = reg_weight self.min_iou_diff = min_iou_diff self.with_flip = with_flip self.flip_average = flip_average self.best_prediction = None self._get_prediction_logits = None self._opt_shape = None self._best_loss = None self._click_masks = None self._last_mask = None self.device = None def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None): self.best_prediction = None self._get_prediction_logits = get_prediction_logits self._click_masks = (pos_mask, neg_mask) self._opt_shape = shape self._last_mask = None self.device = device def __call__(self, x): opt_params = torch.from_numpy(x).float().to(self.device) opt_params.requires_grad_(True) with torch.enable_grad(): opt_vars, reg_loss = self.unpack_opt_params(opt_params) result_before_sigmoid = self._get_prediction_logits(*opt_vars) result = torch.sigmoid(result_before_sigmoid) pos_mask, neg_mask = self._click_masks if self.with_flip and self.flip_average: result, result_flipped = torch.chunk(result, 2, dim=0) result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) loss = loss + reg_loss f_val = loss.detach().cpu().numpy() if self.best_prediction is None or f_val < self._best_loss: self.best_prediction = result_before_sigmoid.detach() self._best_loss = f_val if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh: return [f_val, np.zeros_like(x)] current_mask = result > self.prob_thresh if self._last_mask is not None and self.min_iou_diff > 0: diff_iou = _compute_iou(current_mask, self._last_mask) if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff: return [f_val, np.zeros_like(x)] self._last_mask = current_mask loss.backward() f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float) return [f_val, f_grad] def unpack_opt_params(self, opt_params): raise NotImplementedError class InputOptimizer(BaseOptimizer): def unpack_opt_params(self, opt_params): opt_params = opt_params.view(self._opt_shape) if self.with_flip: opt_params_flipped = torch.flip(opt_params, dims=[3]) opt_params = torch.cat([opt_params, opt_params_flipped], dim=0) reg_loss = self.reg_weight * torch.sum(opt_params**2) return (opt_params,), reg_loss class ScaleBiasOptimizer(BaseOptimizer): def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs): super().__init__(*args, **kwargs) self.scale_act = scale_act self.reg_bias_weight = reg_bias_weight def unpack_opt_params(self, opt_params): scale, bias = torch.chunk(opt_params, 2, dim=0) reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) if self.scale_act == 'tanh': scale = torch.tanh(scale) elif self.scale_act == 'sin': scale = torch.sin(scale) return (1 + scale, bias), reg_loss