Spaces:
Running
on
Zero
Running
on
Zero
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" | |
import torch | |
#---------------------------------------------------------------------------- | |
def fma(a, b, c): # => a * b + c | |
return _FusedMultiplyAdd.apply(a, b, c) | |
#---------------------------------------------------------------------------- | |
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c | |
def forward(ctx, a, b, c): # pylint: disable=arguments-differ | |
out = torch.addcmul(c, a, b) | |
ctx.save_for_backward(a, b) | |
ctx.c_shape = c.shape | |
return out | |
def backward(ctx, dout): # pylint: disable=arguments-differ | |
a, b = ctx.saved_tensors | |
c_shape = ctx.c_shape | |
da = None | |
db = None | |
dc = None | |
if ctx.needs_input_grad[0]: | |
da = _unbroadcast(dout * b, a.shape) | |
if ctx.needs_input_grad[1]: | |
db = _unbroadcast(dout * a, b.shape) | |
if ctx.needs_input_grad[2]: | |
dc = _unbroadcast(dout, c_shape) | |
return da, db, dc | |
#---------------------------------------------------------------------------- | |
def _unbroadcast(x, shape): | |
extra_dims = x.ndim - len(shape) | |
assert extra_dims >= 0 | |
dim = [ | |
i for i in range(x.ndim) | |
if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) | |
] | |
if len(dim): | |
x = x.sum(dim=dim, keepdim=True) | |
if extra_dims: | |
x = x.reshape(-1, *x.shape[extra_dims + 1:]) | |
assert x.shape == shape | |
return x | |
#---------------------------------------------------------------------------- | |