i72sijia commited on
Commit
4f48b17
1 Parent(s): 2b83341

Upload conv2d_resample.py

Browse files
Files changed (1) hide show
  1. torch_utils/ops/conv2d_resample.py +156 -0
torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ w = w.flip([2, 3])
37
+
38
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42
+ if out_channels <= 4 and groups == 1:
43
+ in_shape = x.shape
44
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46
+ else:
47
+ x = x.to(memory_format=torch.contiguous_format)
48
+ w = w.to(memory_format=torch.contiguous_format)
49
+ x = conv2d_gradfix.conv2d(x, w, groups=groups)
50
+ return x.to(memory_format=torch.channels_last)
51
+
52
+ # Otherwise => execute using conv2d_gradfix.
53
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54
+ return op(x, w, stride=stride, padding=padding, groups=groups)
55
+
56
+ #----------------------------------------------------------------------------
57
+
58
+ @misc.profiled_function
59
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60
+ r"""2D convolution with optional up/downsampling.
61
+
62
+ Padding is performed only once at the beginning, not between the operations.
63
+
64
+ Args:
65
+ x: Input tensor of shape
66
+ `[batch_size, in_channels, in_height, in_width]`.
67
+ w: Weight tensor of shape
68
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70
+ calling upfirdn2d.setup_filter(). None = identity (default).
71
+ up: Integer upsampling factor (default: 1).
72
+ down: Integer downsampling factor (default: 1).
73
+ padding: Padding with respect to the upsampled image. Can be a single number
74
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75
+ (default: 0).
76
+ groups: Split input channels into N groups (default: 1).
77
+ flip_weight: False = convolution, True = correlation (default: True).
78
+ flip_filter: False = convolution, True = correlation (default: False).
79
+
80
+ Returns:
81
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82
+ """
83
+ # Validate arguments.
84
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87
+ assert isinstance(up, int) and (up >= 1)
88
+ assert isinstance(down, int) and (down >= 1)
89
+ assert isinstance(groups, int) and (groups >= 1)
90
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91
+ fw, fh = _get_filter_size(f)
92
+ px0, px1, py0, py1 = _parse_padding(padding)
93
+
94
+ # Adjust padding to account for up/downsampling.
95
+ if up > 1:
96
+ px0 += (fw + up - 1) // 2
97
+ px1 += (fw - up) // 2
98
+ py0 += (fh + up - 1) // 2
99
+ py1 += (fh - up) // 2
100
+ if down > 1:
101
+ px0 += (fw - down + 1) // 2
102
+ px1 += (fw - down) // 2
103
+ py0 += (fh - down + 1) // 2
104
+ py1 += (fh - down) // 2
105
+
106
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
108
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110
+ return x
111
+
112
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
114
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116
+ return x
117
+
118
+ # Fast path: downsampling only => use strided convolution.
119
+ if down > 1 and up == 1:
120
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122
+ return x
123
+
124
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125
+ if up > 1:
126
+ if groups == 1:
127
+ w = w.transpose(0, 1)
128
+ else:
129
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130
+ w = w.transpose(1, 2)
131
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132
+ px0 -= kw - 1
133
+ px1 -= kw - up
134
+ py0 -= kh - 1
135
+ py1 -= kh - up
136
+ pxt = max(min(-px0, -px1), 0)
137
+ pyt = max(min(-py0, -py1), 0)
138
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140
+ if down > 1:
141
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142
+ return x
143
+
144
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145
+ if up == 1 and down == 1:
146
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148
+
149
+ # Fallback: Generic reference implementation.
150
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152
+ if down > 1:
153
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154
+ return x
155
+
156
+ #----------------------------------------------------------------------------