i72sijia commited on
Commit
c0068b4
1 Parent(s): e06425e

Upload bias_act.py

Browse files
Files changed (1) hide show
  1. torch_utils/ops/bias_act.py +212 -0
torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ import traceback
17
+
18
+ from .. import custom_ops
19
+ from .. import misc
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ activation_funcs = {
24
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33
+ }
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ _inited = False
38
+ _plugin = None
39
+ _null_tensor = torch.empty([0])
40
+
41
+ def _init():
42
+ global _inited, _plugin
43
+ if not _inited:
44
+ _inited = True
45
+ sources = ['bias_act.cpp', 'bias_act.cu']
46
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47
+ try:
48
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49
+ except:
50
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51
+ return _plugin is not None
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
56
+ r"""Fused bias and activation function.
57
+
58
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
59
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
60
+ the fused op is considerably more efficient than performing the same calculation
61
+ using standard PyTorch ops. It supports first and second order gradients,
62
+ but not third order gradients.
63
+
64
+ Args:
65
+ x: Input activation tensor. Can be of any shape.
66
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
67
+ as `x`. The shape must be known, and it must match the dimension of `x`
68
+ corresponding to `dim`.
69
+ dim: The dimension in `x` corresponding to the elements of `b`.
70
+ The value of `dim` is ignored if `b` is not specified.
71
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
72
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
73
+ See `activation_funcs` for a full list. `None` is not allowed.
74
+ alpha: Shape parameter for the activation function, or `None` to use the default.
75
+ gain: Scaling factor for the output tensor, or `None` to use default.
76
+ See `activation_funcs` for the default scaling of each activation function.
77
+ If unsure, consider specifying 1.
78
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
79
+ the clamping (default).
80
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
81
+
82
+ Returns:
83
+ Tensor of the same shape and datatype as `x`.
84
+ """
85
+ assert isinstance(x, torch.Tensor)
86
+ assert impl in ['ref', 'cuda']
87
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
88
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90
+
91
+ #----------------------------------------------------------------------------
92
+
93
+ @misc.profiled_function
94
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
95
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
96
+ """
97
+ assert isinstance(x, torch.Tensor)
98
+ assert clamp is None or clamp >= 0
99
+ spec = activation_funcs[act]
100
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
101
+ gain = float(gain if gain is not None else spec.def_gain)
102
+ clamp = float(clamp if clamp is not None else -1)
103
+
104
+ # Add bias.
105
+ if b is not None:
106
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
107
+ assert 0 <= dim < x.ndim
108
+ assert b.shape[0] == x.shape[dim]
109
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
110
+
111
+ # Evaluate activation function.
112
+ alpha = float(alpha)
113
+ x = spec.func(x, alpha=alpha)
114
+
115
+ # Scale by gain.
116
+ gain = float(gain)
117
+ if gain != 1:
118
+ x = x * gain
119
+
120
+ # Clamp.
121
+ if clamp >= 0:
122
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
123
+ return x
124
+
125
+ #----------------------------------------------------------------------------
126
+
127
+ _bias_act_cuda_cache = dict()
128
+
129
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
130
+ """Fast CUDA implementation of `bias_act()` using custom ops.
131
+ """
132
+ # Parse arguments.
133
+ assert clamp is None or clamp >= 0
134
+ spec = activation_funcs[act]
135
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
136
+ gain = float(gain if gain is not None else spec.def_gain)
137
+ clamp = float(clamp if clamp is not None else -1)
138
+
139
+ # Lookup from cache.
140
+ key = (dim, act, alpha, gain, clamp)
141
+ if key in _bias_act_cuda_cache:
142
+ return _bias_act_cuda_cache[key]
143
+
144
+ # Forward op.
145
+ class BiasActCuda(torch.autograd.Function):
146
+ @staticmethod
147
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
148
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
149
+ x = x.contiguous(memory_format=ctx.memory_format)
150
+ b = b.contiguous() if b is not None else _null_tensor
151
+ y = x
152
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
153
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
154
+ ctx.save_for_backward(
155
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157
+ y if 'y' in spec.ref else _null_tensor)
158
+ return y
159
+
160
+ @staticmethod
161
+ def backward(ctx, dy): # pylint: disable=arguments-differ
162
+ dy = dy.contiguous(memory_format=ctx.memory_format)
163
+ x, b, y = ctx.saved_tensors
164
+ dx = None
165
+ db = None
166
+
167
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
168
+ dx = dy
169
+ if act != 'linear' or gain != 1 or clamp >= 0:
170
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
171
+
172
+ if ctx.needs_input_grad[1]:
173
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
174
+
175
+ return dx, db
176
+
177
+ # Backward op.
178
+ class BiasActCudaGrad(torch.autograd.Function):
179
+ @staticmethod
180
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
181
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
182
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
183
+ ctx.save_for_backward(
184
+ dy if spec.has_2nd_grad else _null_tensor,
185
+ x, b, y)
186
+ return dx
187
+
188
+ @staticmethod
189
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
190
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
191
+ dy, x, b, y = ctx.saved_tensors
192
+ d_dy = None
193
+ d_x = None
194
+ d_b = None
195
+ d_y = None
196
+
197
+ if ctx.needs_input_grad[0]:
198
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
199
+
200
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
201
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
202
+
203
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
204
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
205
+
206
+ return d_dy, d_x, d_b, d_y
207
+
208
+ # Add to cache.
209
+ _bias_act_cuda_cache[key] = BiasActCuda
210
+ return BiasActCuda
211
+
212
+ #----------------------------------------------------------------------------