Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,447 Bytes
11e6f7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""2D convolution with optional up/downsampling."""
import torch
from .. import misc
from . import conv2d_gradfix
from . import upfirdn2d
from .upfirdn2d import _parse_padding
from .upfirdn2d import _get_filter_size
#----------------------------------------------------------------------------
def _get_weight_shape(w):
with misc.suppress_tracer_warnings(
): # this value will be treated as a constant
shape = [int(sz) for sz in w.shape]
misc.assert_shape(w, shape)
return shape
#----------------------------------------------------------------------------
def _conv2d_wrapper(x,
w,
stride=1,
padding=0,
groups=1,
transpose=False,
flip_weight=True):
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
"""
_out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
# Flip weight if requested.
# Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
if not flip_weight and (kw > 1 or kh > 1):
w = w.flip([2, 3])
# Execute using conv2d_gradfix.
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
return op(x, w, stride=stride, padding=padding, groups=groups)
#----------------------------------------------------------------------------
@misc.profiled_function
def conv2d_resample(x,
w,
f=None,
up=1,
down=1,
padding=0,
groups=1,
flip_weight=True,
flip_filter=False):
r"""2D convolution with optional up/downsampling.
Padding is performed only once at the beginning, not between the operations.
Args:
x: Input tensor of shape
`[batch_size, in_channels, in_height, in_width]`.
w: Weight tensor of shape
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
calling upfirdn2d.setup_filter(). None = identity (default).
up: Integer upsampling factor (default: 1).
down: Integer downsampling factor (default: 1).
padding: Padding with respect to the upsampled image. Can be a single number
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
(default: 0).
groups: Split input channels into N groups (default: 1).
flip_weight: False = convolution, True = correlation (default: True).
flip_filter: False = convolution, True = correlation (default: False).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
"""
# Validate arguments.
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype
== x.dtype)
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]
and f.dtype == torch.float32)
assert isinstance(up, int) and (up >= 1)
assert isinstance(down, int) and (down >= 1)
assert isinstance(groups, int) and (groups >= 1)
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
fw, fh = _get_filter_size(f)
px0, px1, py0, py1 = _parse_padding(padding)
# Adjust padding to account for up/downsampling.
if up > 1:
px0 += (fw + up - 1) // 2
px1 += (fw - up) // 2
py0 += (fh + up - 1) // 2
py1 += (fh - up) // 2
if down > 1:
px0 += (fw - down + 1) // 2
px1 += (fw - down) // 2
py0 += (fh - down + 1) // 2
py1 += (fh - down) // 2
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
if kw == 1 and kh == 1 and (down > 1 and up == 1):
x = upfirdn2d.upfirdn2d(x=x,
f=f,
down=down,
padding=[px0, px1, py0, py1],
flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
return x
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
if kw == 1 and kh == 1 and (up > 1 and down == 1):
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
x = upfirdn2d.upfirdn2d(x=x,
f=f,
up=up,
padding=[px0, px1, py0, py1],
gain=up**2,
flip_filter=flip_filter)
return x
# Fast path: downsampling only => use strided convolution.
if down > 1 and up == 1:
x = upfirdn2d.upfirdn2d(x=x,
f=f,
padding=[px0, px1, py0, py1],
flip_filter=flip_filter)
x = _conv2d_wrapper(x=x,
w=w,
stride=down,
groups=groups,
flip_weight=flip_weight)
return x
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
if up > 1:
if groups == 1:
w = w.transpose(0, 1)
else:
w = w.reshape(groups, out_channels // groups,
in_channels_per_group, kh, kw)
w = w.transpose(1, 2)
w = w.reshape(groups * in_channels_per_group,
out_channels // groups, kh, kw)
px0 -= kw - 1
px1 -= kw - up
py0 -= kh - 1
py1 -= kh - up
pxt = max(min(-px0, -px1), 0)
pyt = max(min(-py0, -py1), 0)
x = _conv2d_wrapper(x=x,
w=w,
stride=up,
padding=[pyt, pxt],
groups=groups,
transpose=True,
flip_weight=(not flip_weight))
x = upfirdn2d.upfirdn2d(
x=x,
f=f,
padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
gain=up**2,
flip_filter=flip_filter)
if down > 1:
x = upfirdn2d.upfirdn2d(x=x,
f=f,
down=down,
flip_filter=flip_filter)
return x
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
if up == 1 and down == 1:
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
return _conv2d_wrapper(x=x,
w=w,
padding=[py0, px0],
groups=groups,
flip_weight=flip_weight)
# Fallback: Generic reference implementation.
x = upfirdn2d.upfirdn2d(x=x,
f=(f if up > 1 else None),
up=up,
padding=[px0, px1, py0, py1],
gain=up**2,
flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
if down > 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
return x
#----------------------------------------------------------------------------
|