Spaces:
Runtime error
Runtime error
File size: 12,916 Bytes
d47f0a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import numpy as np
import torch
import warnings
from .. import custom_ops
from .. import misc
from . import upfirdn2d
from . import bias_act
#----------------------------------------------------------------------------
_plugin = None
def _init():
global _plugin
if _plugin is None:
_plugin = custom_ops.get_plugin(
module_name='filtered_lrelu_plugin',
sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
source_dir=os.path.dirname(__file__),
extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
)
return True
def _get_filter_size(f):
if f is None:
return 1, 1
assert isinstance(f, torch.Tensor)
assert 1 <= f.ndim <= 2
return f.shape[-1], f.shape[0] # width, height
def _parse_padding(padding):
if isinstance(padding, int):
padding = [padding, padding]
assert isinstance(padding, (list, tuple))
assert all(isinstance(x, (int, np.integer)) for x in padding)
padding = [int(x) for x in padding]
if len(padding) == 2:
px, py = padding
padding = [px, px, py, py]
px0, px1, py0, py1 = padding
return px0, px1, py0, py1
#----------------------------------------------------------------------------
def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
r"""Filtered leaky ReLU for a batch of 2D images.
Performs the following sequence of operations for each channel:
1. Add channel-specific bias if provided (`b`).
2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
3. Pad the image with the specified number of zeros on each side (`padding`).
Negative padding corresponds to cropping the image.
4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
so that the footprint of all output pixels lies within the input image.
5. Multiply each value by the provided gain factor (`gain`).
6. Apply leaky ReLU activation function to each value.
7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
it so that the footprint of all output pixels lies within the input image.
9. Downsample the image by keeping every Nth pixel (`down`).
The fused op is considerably more efficient than performing the same calculation
using standard PyTorch ops. It supports gradients of arbitrary order.
Args:
x: Float32/float16/float64 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
fu: Float32 upsampling FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
fd: Float32 downsampling FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
as `x`. The length of vector must must match the channel dimension of `x`.
up: Integer upsampling factor (default: 1).
down: Integer downsampling factor. (default: 1).
padding: Padding with respect to the upsampled image. Can be a single number
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
(default: 0).
gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
slope: Slope on the negative side of leaky ReLU (default: 0.2).
clamp: Maximum magnitude for leaky ReLU output (default: None).
flip_filter: False = convolution, True = correlation (default: False).
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
"""
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
if impl == 'cuda' and x.device.type == 'cuda' and _init():
return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
#----------------------------------------------------------------------------
@misc.profiled_function
def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
"""Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
existing `upfirdn2n()` and `bias_act()` ops.
"""
assert isinstance(x, torch.Tensor) and x.ndim == 4
fu_w, fu_h = _get_filter_size(fu)
fd_w, fd_h = _get_filter_size(fd)
if b is not None:
assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
misc.assert_shape(b, [x.shape[1]])
assert isinstance(up, int) and up >= 1
assert isinstance(down, int) and down >= 1
px0, px1, py0, py1 = _parse_padding(padding)
assert gain == float(gain) and gain > 0
assert slope == float(slope) and slope >= 0
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
# Calculate output size.
batch_size, channels, in_h, in_w = x.shape
in_dtype = x.dtype
out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
# Compute using existing ops.
x = bias_act.bias_act(x=x, b=b) # Apply bias.
x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
# Check output shape & dtype.
misc.assert_shape(x, [batch_size, channels, out_h, out_w])
assert x.dtype == in_dtype
return x
#----------------------------------------------------------------------------
_filtered_lrelu_cuda_cache = dict()
def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
"""Fast CUDA implementation of `filtered_lrelu()` using custom ops.
"""
assert isinstance(up, int) and up >= 1
assert isinstance(down, int) and down >= 1
px0, px1, py0, py1 = _parse_padding(padding)
assert gain == float(gain) and gain > 0
gain = float(gain)
assert slope == float(slope) and slope >= 0
slope = float(slope)
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
clamp = float(clamp if clamp is not None else 'inf')
# Lookup from cache.
key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
if key in _filtered_lrelu_cuda_cache:
return _filtered_lrelu_cuda_cache[key]
# Forward op.
class FilteredLReluCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
assert isinstance(x, torch.Tensor) and x.ndim == 4
# Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
if fu is None:
fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
if fd is None:
fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
assert 1 <= fu.ndim <= 2
assert 1 <= fd.ndim <= 2
# Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
fu = fu.square()[None]
if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
fd = fd.square()[None]
# Missing sign input tensor.
if si is None:
si = torch.empty([0])
# Missing bias tensor.
if b is None:
b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
# Construct internal sign tensor only if gradients are needed.
write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
# Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
if any(a < b for a, b in zip(strides[:-1], strides[1:])):
warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
# Call C++/Cuda plugin if datatype is supported.
if x.dtype in [torch.float16, torch.float32]:
if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
else:
return_code = -1
# No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
# only the bit-packed sign tensor is retained for gradient computation.
if return_code < 0:
warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
# Prepare for gradient computation.
ctx.save_for_backward(fu, fd, (si if si.numel() else so))
ctx.x_shape = x.shape
ctx.y_shape = y.shape
ctx.s_ofs = sx, sy
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
fu, fd, si = ctx.saved_tensors
_, _, xh, xw = ctx.x_shape
_, _, yh, yw = ctx.y_shape
sx, sy = ctx.s_ofs
dx = None # 0
dfu = None; assert not ctx.needs_input_grad[1]
dfd = None; assert not ctx.needs_input_grad[2]
db = None # 3
dsi = None; assert not ctx.needs_input_grad[4]
dsx = None; assert not ctx.needs_input_grad[5]
dsy = None; assert not ctx.needs_input_grad[6]
if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
pp = [
(fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
xw * up - yw * down + px0 - (up - 1),
(fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
xh * up - yh * down + py0 - (up - 1),
]
gg = gain * (up ** 2) / (down ** 2)
ff = (not flip_filter)
sx = sx - (fu.shape[-1] - 1) + px0
sy = sy - (fu.shape[0] - 1) + py0
dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
if ctx.needs_input_grad[3]:
db = dx.sum([0, 2, 3])
return dx, dfu, dfd, db, dsi, dsx, dsy
# Add to cache.
_filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
return FilteredLReluCuda
#----------------------------------------------------------------------------
|