Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Function | |
| from torch.autograd.function import once_differentiable | |
| from ..utils import ext_loader | |
| ext_module = ext_loader.load_ext('_ext', [ | |
| 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward', | |
| 'softmax_focal_loss_forward', 'softmax_focal_loss_backward' | |
| ]) | |
| class SigmoidFocalLossFunction(Function): | |
| def symbolic(g, input, target, gamma, alpha, weight, reduction): | |
| return g.op( | |
| 'mmcv::MMCVSigmoidFocalLoss', | |
| input, | |
| target, | |
| gamma_f=gamma, | |
| alpha_f=alpha, | |
| weight_f=weight, | |
| reduction_s=reduction) | |
| def forward(ctx, | |
| input, | |
| target, | |
| gamma=2.0, | |
| alpha=0.25, | |
| weight=None, | |
| reduction='mean'): | |
| assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) | |
| assert input.dim() == 2 | |
| assert target.dim() == 1 | |
| assert input.size(0) == target.size(0) | |
| if weight is None: | |
| weight = input.new_empty(0) | |
| else: | |
| assert weight.dim() == 1 | |
| assert input.size(1) == weight.size(0) | |
| ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} | |
| assert reduction in ctx.reduction_dict.keys() | |
| ctx.gamma = float(gamma) | |
| ctx.alpha = float(alpha) | |
| ctx.reduction = ctx.reduction_dict[reduction] | |
| output = input.new_zeros(input.size()) | |
| ext_module.sigmoid_focal_loss_forward( | |
| input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha) | |
| if ctx.reduction == ctx.reduction_dict['mean']: | |
| output = output.sum() / input.size(0) | |
| elif ctx.reduction == ctx.reduction_dict['sum']: | |
| output = output.sum() | |
| ctx.save_for_backward(input, target, weight) | |
| return output | |
| def backward(ctx, grad_output): | |
| input, target, weight = ctx.saved_tensors | |
| grad_input = input.new_zeros(input.size()) | |
| ext_module.sigmoid_focal_loss_backward( | |
| input, | |
| target, | |
| weight, | |
| grad_input, | |
| gamma=ctx.gamma, | |
| alpha=ctx.alpha) | |
| grad_input *= grad_output | |
| if ctx.reduction == ctx.reduction_dict['mean']: | |
| grad_input /= input.size(0) | |
| return grad_input, None, None, None, None, None | |
| sigmoid_focal_loss = SigmoidFocalLossFunction.apply | |
| class SigmoidFocalLoss(nn.Module): | |
| def __init__(self, gamma, alpha, weight=None, reduction='mean'): | |
| super(SigmoidFocalLoss, self).__init__() | |
| self.gamma = gamma | |
| self.alpha = alpha | |
| self.register_buffer('weight', weight) | |
| self.reduction = reduction | |
| def forward(self, input, target): | |
| return sigmoid_focal_loss(input, target, self.gamma, self.alpha, | |
| self.weight, self.reduction) | |
| def __repr__(self): | |
| s = self.__class__.__name__ | |
| s += f'(gamma={self.gamma}, ' | |
| s += f'alpha={self.alpha}, ' | |
| s += f'reduction={self.reduction})' | |
| return s | |
| class SoftmaxFocalLossFunction(Function): | |
| def symbolic(g, input, target, gamma, alpha, weight, reduction): | |
| return g.op( | |
| 'mmcv::MMCVSoftmaxFocalLoss', | |
| input, | |
| target, | |
| gamma_f=gamma, | |
| alpha_f=alpha, | |
| weight_f=weight, | |
| reduction_s=reduction) | |
| def forward(ctx, | |
| input, | |
| target, | |
| gamma=2.0, | |
| alpha=0.25, | |
| weight=None, | |
| reduction='mean'): | |
| assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) | |
| assert input.dim() == 2 | |
| assert target.dim() == 1 | |
| assert input.size(0) == target.size(0) | |
| if weight is None: | |
| weight = input.new_empty(0) | |
| else: | |
| assert weight.dim() == 1 | |
| assert input.size(1) == weight.size(0) | |
| ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} | |
| assert reduction in ctx.reduction_dict.keys() | |
| ctx.gamma = float(gamma) | |
| ctx.alpha = float(alpha) | |
| ctx.reduction = ctx.reduction_dict[reduction] | |
| channel_stats, _ = torch.max(input, dim=1) | |
| input_softmax = input - channel_stats.unsqueeze(1).expand_as(input) | |
| input_softmax.exp_() | |
| channel_stats = input_softmax.sum(dim=1) | |
| input_softmax /= channel_stats.unsqueeze(1).expand_as(input) | |
| output = input.new_zeros(input.size(0)) | |
| ext_module.softmax_focal_loss_forward( | |
| input_softmax, | |
| target, | |
| weight, | |
| output, | |
| gamma=ctx.gamma, | |
| alpha=ctx.alpha) | |
| if ctx.reduction == ctx.reduction_dict['mean']: | |
| output = output.sum() / input.size(0) | |
| elif ctx.reduction == ctx.reduction_dict['sum']: | |
| output = output.sum() | |
| ctx.save_for_backward(input_softmax, target, weight) | |
| return output | |
| def backward(ctx, grad_output): | |
| input_softmax, target, weight = ctx.saved_tensors | |
| buff = input_softmax.new_zeros(input_softmax.size(0)) | |
| grad_input = input_softmax.new_zeros(input_softmax.size()) | |
| ext_module.softmax_focal_loss_backward( | |
| input_softmax, | |
| target, | |
| weight, | |
| buff, | |
| grad_input, | |
| gamma=ctx.gamma, | |
| alpha=ctx.alpha) | |
| grad_input *= grad_output | |
| if ctx.reduction == ctx.reduction_dict['mean']: | |
| grad_input /= input_softmax.size(0) | |
| return grad_input, None, None, None, None, None | |
| softmax_focal_loss = SoftmaxFocalLossFunction.apply | |
| class SoftmaxFocalLoss(nn.Module): | |
| def __init__(self, gamma, alpha, weight=None, reduction='mean'): | |
| super(SoftmaxFocalLoss, self).__init__() | |
| self.gamma = gamma | |
| self.alpha = alpha | |
| self.register_buffer('weight', weight) | |
| self.reduction = reduction | |
| def forward(self, input, target): | |
| return softmax_focal_loss(input, target, self.gamma, self.alpha, | |
| self.weight, self.reduction) | |
| def __repr__(self): | |
| s = self.__class__.__name__ | |
| s += f'(gamma={self.gamma}, ' | |
| s += f'alpha={self.alpha}, ' | |
| s += f'reduction={self.reduction})' | |
| return s | |
 
			

