# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #import logging import math import random from typing import Optional from typing import Tuple from typing import Union import torch import torch.nn as nn from torch import Tensor class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) This is a definition, originally motivated by its close numerical similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). Memory-efficient derivative computation: double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). Now, s'(x) = s(x) * (1-s(x)). double_swish'(x) = x * s'(x) + s(x). = x * s(x) * (1-s(x)) + s(x). = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad x_dtype = x.dtype if x.dtype == torch.float16: x = x.to(torch.float32) s = torch.sigmoid(x - 1.0) y = x * s if requires_grad: deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( deriv ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 assert d_scaled.max() < 256.0 d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): y = y.to(torch.float16) return y @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 d = d * ((ceil - floor) / 255.0) + floor return y_grad * d class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ if torch.jit.is_scripting() or torch.jit.is_tracing(): return x * torch.sigmoid(x - 1.0) return DoubleSwishFunction.apply(x) class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( ctx, x: Tensor, scale_factor: Tensor, sign_factor: Optional[Tensor], channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x @staticmethod def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): scale_factor = scale_factor.unsqueeze(-1) sign_factor = sign_factor.unsqueeze(-1) factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) else: xgt0, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor return ( x_grad - neg_delta_grad, None, None, None, ) def _compute_scale_factor( x: Tensor, channel_dim: int, min_abs: float, max_abs: float, gain_factor: float, max_factor: float, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) if min_abs == 0.0: below_threshold = 0.0 else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( min=0, max=max_factor ) above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( min=0, max=max_factor ) return below_threshold - above_threshold def _compute_sign_factor( x: Tensor, channel_dim: int, min_positive: float, max_positive: float, gain_factor: float, max_factor: float, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. factor1 = ( (min_positive - proportion_positive) * (gain_factor / min_positive) ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. factor2 = ( (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for each channel, that it is positive at least a proportion `threshold` of the time. It does this by multiplying negative derivative values by up to (1+max_factor), and positive derivative values by up to (1-max_factor), interpolated from 1 at the threshold to those extremal values when none of the inputs are positive. Args: num_channels: the number of channels channel_dim: the dimension/axis corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. min_positive: the minimum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_positive: the maximum, per channel, of the proportion of the time that (x > 0), above which we start to modify the derivatives. max_factor: the maximum factor by which we modify the derivatives for either the sign constraint or the magnitude constraint; e.g. with max_factor=0.02, the the derivatives would be multiplied by values in the range [0.98..1.02]. sign_gain_factor: determines the 'gain' with which we increase the change in gradient once the constraints on min_positive and max_positive are violated. scale_gain_factor: determines the 'gain' with which we increase the change in gradient once the constraints on min_abs and max_abs are violated. min_abs: the minimum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. max_abs: the maximum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. min_prob: determines the minimum probability with which we modify the gradients for the {min,max}_positive and {min,max}_abs constraints, on each forward(). This is done randomly to prevent all layers from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ def __init__( self, num_channels: int, channel_dim: int, min_positive: float = 0.05, max_positive: float = 0.95, max_factor: float = 0.04, sign_gain_factor: float = 0.01, scale_gain_factor: float = 0.02, min_abs: float = 0.2, max_abs: float = 100.0, min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.min_positive = min_positive self.max_positive = max_positive self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs self.min_prob = min_prob self.sign_gain_factor = sign_gain_factor self.scale_gain_factor = scale_gain_factor # count measures how many times the forward() function has been called. # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): return _no_op(x) count = self.cpu_count self.cpu_count += 1 if random.random() < 0.01: # Occasionally sync self.cpu_count with self.count. # count affects the decay of 'prob'. don't do this on every iter, # because syncing with the GPU is slow. self.cpu_count = max(self.cpu_count, self.count.item()) self.count.fill_(self.cpu_count) # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: sign_factor = _compute_sign_factor( x, self.channel_dim, self.min_positive, self.max_positive, gain_factor=self.sign_gain_factor / prob, max_factor=self.max_factor, ) else: sign_factor = None scale_factor = _compute_scale_factor( x.detach(), self.channel_dim, min_abs=self.min_abs, max_abs=self.max_abs, gain_factor=self.scale_gain_factor / prob, max_factor=self.max_factor, ) return ActivationBalancerFunction.apply( x, scale_factor, sign_factor, self.channel_dim, ) else: return _no_op(x) def BalancedDoubleSwish( d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25 ) -> nn.Sequential: """ ActivationBalancer -> DoubleSwish """ balancer = ActivationBalancer( d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob ) return nn.Sequential( balancer, DoubleSwish(), )