# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math import random from typing import Optional from typing import Tuple from typing import Union import torch import torch.nn as nn from torch import Tensor class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) This is a definition, originally motivated by its close numerical similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). Memory-efficient derivative computation: double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). Now, s'(x) = s(x) * (1-s(x)). double_swish'(x) = x * s'(x) + s(x). = x * s(x) * (1-s(x)) + s(x). = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad x_dtype = x.dtype if x.dtype == torch.float16: x = x.to(torch.float32) s = torch.sigmoid(x - 1.0) y = x * s if requires_grad: deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( deriv ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 assert d_scaled.max() < 256.0 d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): y = y.to(torch.float16) return y @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 d = d * ((ceil - floor) / 255.0) + floor return y_grad * d class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ if torch.jit.is_scripting() or torch.jit.is_tracing(): return x * torch.sigmoid(x - 1.0) return DoubleSwishFunction.apply(x) class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( ctx, x: Tensor, scale_factor: Tensor, sign_factor: Optional[Tensor], channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x @staticmethod def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): scale_factor = scale_factor.unsqueeze(-1) sign_factor = sign_factor.unsqueeze(-1) factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) else: xgt0, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor return ( x_grad - neg_delta_grad, None, None, None, ) def _compute_scale_factor( x: Tensor, channel_dim: int, min_abs: float, max_abs: float, gain_factor: float, max_factor: float, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) if min_abs == 0.0: below_threshold = 0.0 else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( min=0, max=max_factor ) above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( min=0, max=max_factor ) return below_threshold - above_threshold def _compute_sign_factor( x: Tensor, channel_dim: int, min_positive: float, max_positive: float, gain_factor: float, max_factor: float, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. factor1 = ( (min_positive - proportion_positive) * (gain_factor / min_positive) ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. factor2 = ( (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for each channel, that it is positive at least a proportion `threshold` of the time. It does this by multiplying negative derivative values by up to (1+max_factor), and positive derivative values by up to (1-max_factor), interpolated from 1 at the threshold to those extremal values when none of the inputs are positive. Args: num_channels: the number of channels channel_dim: the dimension/axis corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. min_positive: the minimum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_positive: the maximum, per channel, of the proportion of the time that (x > 0), above which we start to modify the derivatives. max_factor: the maximum factor by which we modify the derivatives for either the sign constraint or the magnitude constraint; e.g. with max_factor=0.02, the the derivatives would be multiplied by values in the range [0.98..1.02]. sign_gain_factor: determines the 'gain' with which we increase the change in gradient once the constraints on min_positive and max_positive are violated. scale_gain_factor: determines the 'gain' with which we increase the change in gradient once the constraints on min_abs and max_abs are violated. min_abs: the minimum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. max_abs: the maximum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. min_prob: determines the minimum probability with which we modify the gradients for the {min,max}_positive and {min,max}_abs constraints, on each forward(). This is done randomly to prevent all layers from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ def __init__( self, num_channels: int, channel_dim: int, min_positive: float = 0.05, max_positive: float = 0.95, max_factor: float = 0.04, sign_gain_factor: float = 0.01, scale_gain_factor: float = 0.02, min_abs: float = 0.2, max_abs: float = 100.0, min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.min_positive = min_positive self.max_positive = max_positive self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs self.min_prob = min_prob self.sign_gain_factor = sign_gain_factor self.scale_gain_factor = scale_gain_factor # count measures how many times the forward() function has been called. # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): return _no_op(x) count = self.cpu_count self.cpu_count += 1 if random.random() < 0.01: # Occasionally sync self.cpu_count with self.count. # count affects the decay of 'prob'. don't do this on every iter, # because syncing with the GPU is slow. self.cpu_count = max(self.cpu_count, self.count.item()) self.count.fill_(self.cpu_count) # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: sign_factor = _compute_sign_factor( x, self.channel_dim, self.min_positive, self.max_positive, gain_factor=self.sign_gain_factor / prob, max_factor=self.max_factor, ) else: sign_factor = None scale_factor = _compute_scale_factor( x.detach(), self.channel_dim, min_abs=self.min_abs, max_abs=self.max_abs, gain_factor=self.scale_gain_factor / prob, max_factor=self.max_factor, ) return ActivationBalancerFunction.apply( x, scale_factor, sign_factor, self.channel_dim, ) else: return _no_op(x) def BalancedDoubleSwish( d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25 ) -> nn.Sequential: """ ActivationBalancer -> DoubleSwish """ balancer = ActivationBalancer( d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob ) return nn.Sequential( balancer, DoubleSwish(), )