Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .registry import CONV_LAYERS | |
| def conv_ws_2d(input, | |
| weight, | |
| bias=None, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| eps=1e-5): | |
| c_in = weight.size(0) | |
| weight_flat = weight.view(c_in, -1) | |
| mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) | |
| std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) | |
| weight = (weight - mean) / (std + eps) | |
| return F.conv2d(input, weight, bias, stride, padding, dilation, groups) | |
| class ConvWS2d(nn.Conv2d): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias=True, | |
| eps=1e-5): | |
| super(ConvWS2d, self).__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias) | |
| self.eps = eps | |
| def forward(self, x): | |
| return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, | |
| self.dilation, self.groups, self.eps) | |
| class ConvAWS2d(nn.Conv2d): | |
| """AWS (Adaptive Weight Standardization) | |
| This is a variant of Weight Standardization | |
| (https://arxiv.org/pdf/1903.10520.pdf) | |
| It is used in DetectoRS to avoid NaN | |
| (https://arxiv.org/pdf/2006.02334.pdf) | |
| Args: | |
| in_channels (int): Number of channels in the input image | |
| out_channels (int): Number of channels produced by the convolution | |
| kernel_size (int or tuple): Size of the conv kernel | |
| stride (int or tuple, optional): Stride of the convolution. Default: 1 | |
| padding (int or tuple, optional): Zero-padding added to both sides of | |
| the input. Default: 0 | |
| dilation (int or tuple, optional): Spacing between kernel elements. | |
| Default: 1 | |
| groups (int, optional): Number of blocked connections from input | |
| channels to output channels. Default: 1 | |
| bias (bool, optional): If set True, adds a learnable bias to the | |
| output. Default: True | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias=True): | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias) | |
| self.register_buffer('weight_gamma', | |
| torch.ones(self.out_channels, 1, 1, 1)) | |
| self.register_buffer('weight_beta', | |
| torch.zeros(self.out_channels, 1, 1, 1)) | |
| def _get_weight(self, weight): | |
| weight_flat = weight.view(weight.size(0), -1) | |
| mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) | |
| std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) | |
| weight = (weight - mean) / std | |
| weight = self.weight_gamma * weight + self.weight_beta | |
| return weight | |
| def forward(self, x): | |
| weight = self._get_weight(self.weight) | |
| return F.conv2d(x, weight, self.bias, self.stride, self.padding, | |
| self.dilation, self.groups) | |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
| missing_keys, unexpected_keys, error_msgs): | |
| """Override default load function. | |
| AWS overrides the function _load_from_state_dict to recover | |
| weight_gamma and weight_beta if they are missing. If weight_gamma and | |
| weight_beta are found in the checkpoint, this function will return | |
| after super()._load_from_state_dict. Otherwise, it will compute the | |
| mean and std of the pretrained weights and store them in weight_beta | |
| and weight_gamma. | |
| """ | |
| self.weight_gamma.data.fill_(-1) | |
| local_missing_keys = [] | |
| super()._load_from_state_dict(state_dict, prefix, local_metadata, | |
| strict, local_missing_keys, | |
| unexpected_keys, error_msgs) | |
| if self.weight_gamma.data.mean() > 0: | |
| for k in local_missing_keys: | |
| missing_keys.append(k) | |
| return | |
| weight = self.weight.data | |
| weight_flat = weight.view(weight.size(0), -1) | |
| mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) | |
| std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) | |
| self.weight_beta.data.copy_(mean) | |
| self.weight_gamma.data.copy_(std) | |
| missing_gamma_beta = [ | |
| k for k in local_missing_keys | |
| if k.endswith('weight_gamma') or k.endswith('weight_beta') | |
| ] | |
| for k in missing_gamma_beta: | |
| local_missing_keys.remove(k) | |
| for k in local_missing_keys: | |
| missing_keys.append(k) | |