ayousanz commited on
Commit
413e7ca
·
verified ·
1 Parent(s): afada13

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/Lib/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-39.pyc +0 -0
  2. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc +0 -0
  3. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py +1 -0
  4. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
  5. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py +6 -0
  6. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  7. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc +0 -0
  8. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +6 -0
  9. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py +17 -0
  10. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  11. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc +0 -0
  12. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc +0 -0
  13. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc +0 -0
  14. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py +7 -0
  15. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py +8 -0
  16. .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py +6 -0
  17. .venv/Lib/site-packages/torch/nn/modules/__init__.py +334 -0
  18. .venv/Lib/site-packages/torch/nn/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  19. .venv/Lib/site-packages/torch/nn/modules/__pycache__/_functions.cpython-39.pyc +0 -0
  20. .venv/Lib/site-packages/torch/nn/modules/__pycache__/activation.cpython-39.pyc +0 -0
  21. .venv/Lib/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-39.pyc +0 -0
  22. .venv/Lib/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-39.pyc +0 -0
  23. .venv/Lib/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-39.pyc +0 -0
  24. .venv/Lib/site-packages/torch/nn/modules/__pycache__/container.cpython-39.pyc +0 -0
  25. .venv/Lib/site-packages/torch/nn/modules/__pycache__/conv.cpython-39.pyc +0 -0
  26. .venv/Lib/site-packages/torch/nn/modules/__pycache__/distance.cpython-39.pyc +0 -0
  27. .venv/Lib/site-packages/torch/nn/modules/__pycache__/dropout.cpython-39.pyc +0 -0
  28. .venv/Lib/site-packages/torch/nn/modules/__pycache__/flatten.cpython-39.pyc +0 -0
  29. .venv/Lib/site-packages/torch/nn/modules/__pycache__/fold.cpython-39.pyc +0 -0
  30. .venv/Lib/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-39.pyc +0 -0
  31. .venv/Lib/site-packages/torch/nn/modules/__pycache__/lazy.cpython-39.pyc +0 -0
  32. .venv/Lib/site-packages/torch/nn/modules/__pycache__/linear.cpython-39.pyc +0 -0
  33. .venv/Lib/site-packages/torch/nn/modules/__pycache__/loss.cpython-39.pyc +0 -0
  34. .venv/Lib/site-packages/torch/nn/modules/__pycache__/module.cpython-39.pyc +0 -0
  35. .venv/Lib/site-packages/torch/nn/modules/__pycache__/normalization.cpython-39.pyc +0 -0
  36. .venv/Lib/site-packages/torch/nn/modules/__pycache__/padding.cpython-39.pyc +0 -0
  37. .venv/Lib/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-39.pyc +0 -0
  38. .venv/Lib/site-packages/torch/nn/modules/__pycache__/pooling.cpython-39.pyc +0 -0
  39. .venv/Lib/site-packages/torch/nn/modules/__pycache__/rnn.cpython-39.pyc +0 -0
  40. .venv/Lib/site-packages/torch/nn/modules/__pycache__/sparse.cpython-39.pyc +0 -0
  41. .venv/Lib/site-packages/torch/nn/modules/__pycache__/transformer.cpython-39.pyc +0 -0
  42. .venv/Lib/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-39.pyc +0 -0
  43. .venv/Lib/site-packages/torch/nn/modules/__pycache__/utils.cpython-39.pyc +0 -0
  44. .venv/Lib/site-packages/torch/nn/modules/_functions.py +319 -0
  45. .venv/Lib/site-packages/torch/nn/modules/activation.py +1746 -0
  46. .venv/Lib/site-packages/torch/nn/modules/adaptive.py +330 -0
  47. .venv/Lib/site-packages/torch/nn/modules/batchnorm.py +883 -0
  48. .venv/Lib/site-packages/torch/nn/modules/channelshuffle.py +56 -0
  49. .venv/Lib/site-packages/torch/nn/modules/container.py +976 -0
  50. .venv/Lib/site-packages/torch/nn/modules/conv.py +1866 -0
.venv/Lib/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (736 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (420 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torch.nn.intrinsic.quantized.dynamic.modules import * # noqa: F403
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (269 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from torch.nn.intrinsic.quantized.dynamic.modules.linear_relu import LinearReLU
2
+
3
+
4
+ __all__ = [
5
+ "LinearReLU",
6
+ ]
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (331 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc ADDED
Binary file (317 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
2
+
3
+
4
+ __all__ = [
5
+ "LinearReLU",
6
+ ]
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.intrinsic.quantized.modules.bn_relu import BNReLU2d, BNReLU3d
2
+ from torch.nn.intrinsic.quantized.modules.conv_relu import (
3
+ ConvReLU1d,
4
+ ConvReLU2d,
5
+ ConvReLU3d,
6
+ )
7
+ from torch.nn.intrinsic.quantized.modules.linear_relu import LinearReLU
8
+
9
+
10
+ __all__ = [
11
+ "LinearReLU",
12
+ "ConvReLU1d",
13
+ "ConvReLU2d",
14
+ "ConvReLU3d",
15
+ "BNReLU2d",
16
+ "BNReLU3d",
17
+ ]
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (561 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc ADDED
Binary file (323 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc ADDED
Binary file (336 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc ADDED
Binary file (301 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic.quantized import BNReLU2d, BNReLU3d
2
+
3
+
4
+ __all__ = [
5
+ "BNReLU2d",
6
+ "BNReLU3d",
7
+ ]
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic.quantized import ConvReLU1d, ConvReLU2d, ConvReLU3d
2
+
3
+
4
+ __all__ = [
5
+ "ConvReLU1d",
6
+ "ConvReLU2d",
7
+ "ConvReLU3d",
8
+ ]
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic.quantized import LinearReLU
2
+
3
+
4
+ __all__ = [
5
+ "LinearReLU",
6
+ ]
.venv/Lib/site-packages/torch/nn/modules/__init__.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .module import Module # usort: skip
2
+ from .linear import Bilinear, Identity, LazyLinear, Linear # usort: skip
3
+ from .activation import (
4
+ CELU,
5
+ ELU,
6
+ GELU,
7
+ GLU,
8
+ Hardshrink,
9
+ Hardsigmoid,
10
+ Hardswish,
11
+ Hardtanh,
12
+ LeakyReLU,
13
+ LogSigmoid,
14
+ LogSoftmax,
15
+ Mish,
16
+ MultiheadAttention,
17
+ PReLU,
18
+ ReLU,
19
+ ReLU6,
20
+ RReLU,
21
+ SELU,
22
+ Sigmoid,
23
+ SiLU,
24
+ Softmax,
25
+ Softmax2d,
26
+ Softmin,
27
+ Softplus,
28
+ Softshrink,
29
+ Softsign,
30
+ Tanh,
31
+ Tanhshrink,
32
+ Threshold,
33
+ )
34
+ from .adaptive import AdaptiveLogSoftmaxWithLoss
35
+ from .batchnorm import (
36
+ BatchNorm1d,
37
+ BatchNorm2d,
38
+ BatchNorm3d,
39
+ LazyBatchNorm1d,
40
+ LazyBatchNorm2d,
41
+ LazyBatchNorm3d,
42
+ SyncBatchNorm,
43
+ )
44
+ from .channelshuffle import ChannelShuffle
45
+ from .container import (
46
+ Container,
47
+ ModuleDict,
48
+ ModuleList,
49
+ ParameterDict,
50
+ ParameterList,
51
+ Sequential,
52
+ )
53
+ from .conv import (
54
+ Conv1d,
55
+ Conv2d,
56
+ Conv3d,
57
+ ConvTranspose1d,
58
+ ConvTranspose2d,
59
+ ConvTranspose3d,
60
+ LazyConv1d,
61
+ LazyConv2d,
62
+ LazyConv3d,
63
+ LazyConvTranspose1d,
64
+ LazyConvTranspose2d,
65
+ LazyConvTranspose3d,
66
+ )
67
+ from .distance import CosineSimilarity, PairwiseDistance
68
+ from .dropout import (
69
+ AlphaDropout,
70
+ Dropout,
71
+ Dropout1d,
72
+ Dropout2d,
73
+ Dropout3d,
74
+ FeatureAlphaDropout,
75
+ )
76
+ from .flatten import Flatten, Unflatten
77
+ from .fold import Fold, Unfold
78
+ from .instancenorm import (
79
+ InstanceNorm1d,
80
+ InstanceNorm2d,
81
+ InstanceNorm3d,
82
+ LazyInstanceNorm1d,
83
+ LazyInstanceNorm2d,
84
+ LazyInstanceNorm3d,
85
+ )
86
+ from .loss import (
87
+ BCELoss,
88
+ BCEWithLogitsLoss,
89
+ CosineEmbeddingLoss,
90
+ CrossEntropyLoss,
91
+ CTCLoss,
92
+ GaussianNLLLoss,
93
+ HingeEmbeddingLoss,
94
+ HuberLoss,
95
+ KLDivLoss,
96
+ L1Loss,
97
+ MarginRankingLoss,
98
+ MSELoss,
99
+ MultiLabelMarginLoss,
100
+ MultiLabelSoftMarginLoss,
101
+ MultiMarginLoss,
102
+ NLLLoss,
103
+ NLLLoss2d,
104
+ PoissonNLLLoss,
105
+ SmoothL1Loss,
106
+ SoftMarginLoss,
107
+ TripletMarginLoss,
108
+ TripletMarginWithDistanceLoss,
109
+ )
110
+ from .normalization import (
111
+ CrossMapLRN2d,
112
+ GroupNorm,
113
+ LayerNorm,
114
+ LocalResponseNorm,
115
+ RMSNorm,
116
+ )
117
+ from .padding import (
118
+ CircularPad1d,
119
+ CircularPad2d,
120
+ CircularPad3d,
121
+ ConstantPad1d,
122
+ ConstantPad2d,
123
+ ConstantPad3d,
124
+ ReflectionPad1d,
125
+ ReflectionPad2d,
126
+ ReflectionPad3d,
127
+ ReplicationPad1d,
128
+ ReplicationPad2d,
129
+ ReplicationPad3d,
130
+ ZeroPad1d,
131
+ ZeroPad2d,
132
+ ZeroPad3d,
133
+ )
134
+ from .pixelshuffle import PixelShuffle, PixelUnshuffle
135
+ from .pooling import (
136
+ AdaptiveAvgPool1d,
137
+ AdaptiveAvgPool2d,
138
+ AdaptiveAvgPool3d,
139
+ AdaptiveMaxPool1d,
140
+ AdaptiveMaxPool2d,
141
+ AdaptiveMaxPool3d,
142
+ AvgPool1d,
143
+ AvgPool2d,
144
+ AvgPool3d,
145
+ FractionalMaxPool2d,
146
+ FractionalMaxPool3d,
147
+ LPPool1d,
148
+ LPPool2d,
149
+ LPPool3d,
150
+ MaxPool1d,
151
+ MaxPool2d,
152
+ MaxPool3d,
153
+ MaxUnpool1d,
154
+ MaxUnpool2d,
155
+ MaxUnpool3d,
156
+ )
157
+ from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNN, RNNBase, RNNCell, RNNCellBase
158
+ from .sparse import Embedding, EmbeddingBag
159
+ from .transformer import (
160
+ Transformer,
161
+ TransformerDecoder,
162
+ TransformerDecoderLayer,
163
+ TransformerEncoder,
164
+ TransformerEncoderLayer,
165
+ )
166
+ from .upsampling import Upsample, UpsamplingBilinear2d, UpsamplingNearest2d
167
+
168
+
169
+ __all__ = [
170
+ "AdaptiveAvgPool1d",
171
+ "AdaptiveAvgPool2d",
172
+ "AdaptiveAvgPool3d",
173
+ "AdaptiveLogSoftmaxWithLoss",
174
+ "AdaptiveMaxPool1d",
175
+ "AdaptiveMaxPool2d",
176
+ "AdaptiveMaxPool3d",
177
+ "AlphaDropout",
178
+ "AvgPool1d",
179
+ "AvgPool2d",
180
+ "AvgPool3d",
181
+ "BCELoss",
182
+ "BCEWithLogitsLoss",
183
+ "BatchNorm1d",
184
+ "BatchNorm2d",
185
+ "BatchNorm3d",
186
+ "Bilinear",
187
+ "CELU",
188
+ "CTCLoss",
189
+ "ChannelShuffle",
190
+ "CircularPad1d",
191
+ "CircularPad2d",
192
+ "CircularPad3d",
193
+ "ConstantPad1d",
194
+ "ConstantPad2d",
195
+ "ConstantPad3d",
196
+ "Container",
197
+ "Conv1d",
198
+ "Conv2d",
199
+ "Conv3d",
200
+ "ConvTranspose1d",
201
+ "ConvTranspose2d",
202
+ "ConvTranspose3d",
203
+ "CosineEmbeddingLoss",
204
+ "CosineSimilarity",
205
+ "CrossEntropyLoss",
206
+ "CrossMapLRN2d",
207
+ "Dropout",
208
+ "Dropout1d",
209
+ "Dropout2d",
210
+ "Dropout3d",
211
+ "ELU",
212
+ "Embedding",
213
+ "EmbeddingBag",
214
+ "FeatureAlphaDropout",
215
+ "Flatten",
216
+ "Fold",
217
+ "FractionalMaxPool2d",
218
+ "FractionalMaxPool3d",
219
+ "GELU",
220
+ "GLU",
221
+ "GRU",
222
+ "GRUCell",
223
+ "GaussianNLLLoss",
224
+ "GroupNorm",
225
+ "Hardshrink",
226
+ "Hardsigmoid",
227
+ "Hardswish",
228
+ "Hardtanh",
229
+ "HingeEmbeddingLoss",
230
+ "HuberLoss",
231
+ "Identity",
232
+ "InstanceNorm1d",
233
+ "InstanceNorm2d",
234
+ "InstanceNorm3d",
235
+ "KLDivLoss",
236
+ "L1Loss",
237
+ "LPPool1d",
238
+ "LPPool2d",
239
+ "LPPool3d",
240
+ "LSTM",
241
+ "LSTMCell",
242
+ "LayerNorm",
243
+ "LazyBatchNorm1d",
244
+ "LazyBatchNorm2d",
245
+ "LazyBatchNorm3d",
246
+ "LazyConv1d",
247
+ "LazyConv2d",
248
+ "LazyConv3d",
249
+ "LazyConvTranspose1d",
250
+ "LazyConvTranspose2d",
251
+ "LazyConvTranspose3d",
252
+ "LazyInstanceNorm1d",
253
+ "LazyInstanceNorm2d",
254
+ "LazyInstanceNorm3d",
255
+ "LazyLinear",
256
+ "LeakyReLU",
257
+ "Linear",
258
+ "LocalResponseNorm",
259
+ "LogSigmoid",
260
+ "LogSoftmax",
261
+ "MSELoss",
262
+ "MarginRankingLoss",
263
+ "MaxPool1d",
264
+ "MaxPool2d",
265
+ "MaxPool3d",
266
+ "MaxUnpool1d",
267
+ "MaxUnpool2d",
268
+ "MaxUnpool3d",
269
+ "Mish",
270
+ "Module",
271
+ "ModuleDict",
272
+ "ModuleList",
273
+ "MultiLabelMarginLoss",
274
+ "MultiLabelSoftMarginLoss",
275
+ "MultiMarginLoss",
276
+ "MultiheadAttention",
277
+ "NLLLoss",
278
+ "NLLLoss2d",
279
+ "PReLU",
280
+ "PairwiseDistance",
281
+ "ParameterDict",
282
+ "ParameterList",
283
+ "PixelShuffle",
284
+ "PixelUnshuffle",
285
+ "PoissonNLLLoss",
286
+ "RMSNorm",
287
+ "RNN",
288
+ "RNNBase",
289
+ "RNNCell",
290
+ "RNNCellBase",
291
+ "RReLU",
292
+ "ReLU",
293
+ "ReLU6",
294
+ "ReflectionPad1d",
295
+ "ReflectionPad2d",
296
+ "ReflectionPad3d",
297
+ "ReplicationPad1d",
298
+ "ReplicationPad2d",
299
+ "ReplicationPad3d",
300
+ "SELU",
301
+ "Sequential",
302
+ "SiLU",
303
+ "Sigmoid",
304
+ "SmoothL1Loss",
305
+ "SoftMarginLoss",
306
+ "Softmax",
307
+ "Softmax2d",
308
+ "Softmin",
309
+ "Softplus",
310
+ "Softshrink",
311
+ "Softsign",
312
+ "SyncBatchNorm",
313
+ "Tanh",
314
+ "Tanhshrink",
315
+ "Threshold",
316
+ "Transformer",
317
+ "TransformerDecoder",
318
+ "TransformerDecoderLayer",
319
+ "TransformerEncoder",
320
+ "TransformerEncoderLayer",
321
+ "TripletMarginLoss",
322
+ "TripletMarginWithDistanceLoss",
323
+ "Unflatten",
324
+ "Unfold",
325
+ "Upsample",
326
+ "UpsamplingBilinear2d",
327
+ "UpsamplingNearest2d",
328
+ "ZeroPad1d",
329
+ "ZeroPad2d",
330
+ "ZeroPad3d",
331
+ ]
332
+
333
+ # Please keep this list sorted
334
+ assert __all__ == sorted(__all__)
.venv/Lib/site-packages/torch/nn/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (5.16 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/_functions.cpython-39.pyc ADDED
Binary file (6.07 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/activation.cpython-39.pyc ADDED
Binary file (56.9 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-39.pyc ADDED
Binary file (32.2 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-39.pyc ADDED
Binary file (2.23 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/container.cpython-39.pyc ADDED
Binary file (35.2 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/conv.cpython-39.pyc ADDED
Binary file (61 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/distance.cpython-39.pyc ADDED
Binary file (4.11 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/dropout.cpython-39.pyc ADDED
Binary file (12.6 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/flatten.cpython-39.pyc ADDED
Binary file (5.99 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/fold.cpython-39.pyc ADDED
Binary file (13.1 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-39.pyc ADDED
Binary file (20.9 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/lazy.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/linear.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/loss.cpython-39.pyc ADDED
Binary file (94.7 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/module.cpython-39.pyc ADDED
Binary file (95.7 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/normalization.cpython-39.pyc ADDED
Binary file (15 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/padding.cpython-39.pyc ADDED
Binary file (34.2 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-39.pyc ADDED
Binary file (4.52 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/pooling.cpython-39.pyc ADDED
Binary file (58.6 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/rnn.cpython-39.pyc ADDED
Binary file (55.4 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/sparse.cpython-39.pyc ADDED
Binary file (21.5 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (37.2 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.74 kB). View file
 
.venv/Lib/site-packages/torch/nn/modules/_functions.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.autograd.function import Function
5
+
6
+
7
+ class SyncBatchNorm(Function):
8
+ @staticmethod
9
+ def forward(
10
+ self,
11
+ input,
12
+ weight,
13
+ bias,
14
+ running_mean,
15
+ running_var,
16
+ eps,
17
+ momentum,
18
+ process_group,
19
+ world_size,
20
+ ):
21
+ if not (
22
+ input.is_contiguous(memory_format=torch.channels_last)
23
+ or input.is_contiguous(memory_format=torch.channels_last_3d)
24
+ ):
25
+ input = input.contiguous()
26
+ if weight is not None:
27
+ weight = weight.contiguous()
28
+
29
+ size = int(input.numel() // input.size(1))
30
+ if size == 1 and world_size < 2:
31
+ raise ValueError(
32
+ f"Expected more than 1 value per channel when training, got input size {size}"
33
+ )
34
+
35
+ num_channels = input.shape[1]
36
+ if input.numel() > 0:
37
+ # calculate mean/invstd for input.
38
+ mean, invstd = torch.batch_norm_stats(input, eps)
39
+
40
+ count = torch.full(
41
+ (1,),
42
+ input.numel() // input.size(1),
43
+ dtype=mean.dtype,
44
+ device=mean.device,
45
+ )
46
+
47
+ # C, C, 1 -> (2C + 1)
48
+ combined = torch.cat([mean, invstd, count], dim=0)
49
+ else:
50
+ # for empty input, set stats and the count to zero. The stats with
51
+ # zero count will be filtered out later when computing global mean
52
+ # & invstd, but they still needs to participate the all_gather
53
+ # collective communication to unblock other peer processes.
54
+ combined = torch.zeros(
55
+ 2 * num_channels + 1, dtype=input.dtype, device=input.device
56
+ )
57
+
58
+ # Use allgather instead of allreduce because count could be different across
59
+ # ranks, simple all reduce op can not give correct results.
60
+ # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
61
+ # all gathered mean, invstd and count.
62
+ # for nccl backend, use the optimized version of all gather.
63
+ # The Gloo backend does not support `all_gather_into_tensor`.
64
+ if process_group._get_backend_name() != "gloo":
65
+ # world_size * (2C + 1)
66
+ combined_size = combined.numel()
67
+ combined_flat = torch.empty(
68
+ 1,
69
+ combined_size * world_size,
70
+ dtype=combined.dtype,
71
+ device=combined.device,
72
+ )
73
+ dist.all_gather_into_tensor(
74
+ combined_flat, combined, process_group, async_op=False
75
+ )
76
+ combined = torch.reshape(combined_flat, (world_size, combined_size))
77
+ # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
78
+ mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
79
+ else:
80
+ # world_size * (2C + 1)
81
+ combined_list = [torch.empty_like(combined) for _ in range(world_size)]
82
+ dist.all_gather(combined_list, combined, process_group, async_op=False)
83
+ combined = torch.stack(combined_list, dim=0)
84
+ # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
85
+ mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
86
+
87
+ if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
88
+ # The lines below force a synchronization between CUDA and CPU, because
89
+ # the shape of the result count_all depends on the values in mask tensor.
90
+ # Such synchronizations break CUDA Graph capturing.
91
+ # See https://github.com/pytorch/pytorch/issues/78549
92
+ # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
93
+ # a better longer-term solution.
94
+
95
+ # remove stats from empty inputs
96
+ mask = count_all.squeeze(-1) >= 1
97
+ count_all = count_all[mask]
98
+ mean_all = mean_all[mask]
99
+ invstd_all = invstd_all[mask]
100
+
101
+ # calculate global mean & invstd
102
+ counts = count_all.view(-1)
103
+ if running_mean is not None and counts.dtype != running_mean.dtype:
104
+ counts = counts.to(running_mean.dtype)
105
+ mean, invstd = torch.batch_norm_gather_stats_with_counts(
106
+ input,
107
+ mean_all,
108
+ invstd_all,
109
+ running_mean,
110
+ running_var,
111
+ momentum,
112
+ eps,
113
+ counts,
114
+ )
115
+
116
+ self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
117
+ self.process_group = process_group
118
+
119
+ # apply element-wise normalization
120
+ if input.numel() > 0:
121
+ return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
122
+ else:
123
+ return torch.empty_like(input)
124
+
125
+ @staticmethod
126
+ def backward(self, grad_output):
127
+ if not (
128
+ grad_output.is_contiguous(memory_format=torch.channels_last)
129
+ or grad_output.is_contiguous(memory_format=torch.channels_last_3d)
130
+ ):
131
+ grad_output = grad_output.contiguous()
132
+ saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
133
+ grad_input = grad_weight = grad_bias = None
134
+ process_group = self.process_group
135
+
136
+ if saved_input.numel() > 0:
137
+ # calculate local stats as well as grad_weight / grad_bias
138
+ (
139
+ sum_dy,
140
+ sum_dy_xmu,
141
+ grad_weight,
142
+ grad_bias,
143
+ ) = torch.batch_norm_backward_reduce(
144
+ grad_output,
145
+ saved_input,
146
+ mean,
147
+ invstd,
148
+ weight,
149
+ self.needs_input_grad[0],
150
+ self.needs_input_grad[1],
151
+ self.needs_input_grad[2],
152
+ )
153
+
154
+ if self.needs_input_grad[0]:
155
+ # synchronizing stats used to calculate input gradient.
156
+ num_channels = sum_dy.shape[0]
157
+ combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
158
+ torch.distributed.all_reduce(
159
+ combined,
160
+ torch.distributed.ReduceOp.SUM,
161
+ process_group,
162
+ async_op=False,
163
+ )
164
+ sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
165
+
166
+ # backward pass for gradient calculation
167
+ if weight is not None and weight.dtype != mean.dtype:
168
+ weight = weight.to(mean.dtype)
169
+ grad_input = torch.batch_norm_backward_elemt(
170
+ grad_output,
171
+ saved_input,
172
+ mean,
173
+ invstd,
174
+ weight,
175
+ sum_dy,
176
+ sum_dy_xmu,
177
+ count_tensor,
178
+ )
179
+ # synchronizing of grad_weight / grad_bias is not needed as distributed
180
+ # training would handle all reduce.
181
+ if weight is None or not self.needs_input_grad[1]:
182
+ grad_weight = None
183
+
184
+ if weight is None or not self.needs_input_grad[2]:
185
+ grad_bias = None
186
+ else:
187
+ # This process got an empty input tensor in the forward pass.
188
+ # Although this process can directly set grad_input as an empty
189
+ # tensor of zeros, it still needs to participate in the collective
190
+ # communication to unblock its peers, as other peer processes might
191
+ # have received non-empty inputs.
192
+ num_channels = saved_input.shape[1]
193
+ if self.needs_input_grad[0]:
194
+ # launch all_reduce to unblock other peer processes
195
+ combined = torch.zeros(
196
+ 2 * num_channels, dtype=saved_input.dtype, device=saved_input.device
197
+ )
198
+ torch.distributed.all_reduce(
199
+ combined,
200
+ torch.distributed.ReduceOp.SUM,
201
+ process_group,
202
+ async_op=False,
203
+ )
204
+
205
+ # Leave grad_input, grad_weight and grad_bias as None, which will be
206
+ # interpreted by the autograd engine as Tensors full of zeros.
207
+
208
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
209
+
210
+
211
+ class CrossMapLRN2d(Function):
212
+ @staticmethod
213
+ def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
214
+ ctx.size = size
215
+ ctx.alpha = alpha
216
+ ctx.beta = beta
217
+ ctx.k = k
218
+ ctx.scale = None
219
+
220
+ if input.dim() != 4:
221
+ raise ValueError(
222
+ f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead."
223
+ )
224
+
225
+ ctx.scale = ctx.scale or input.new()
226
+ output = input.new()
227
+
228
+ batch_size = input.size(0)
229
+ channels = input.size(1)
230
+ input_height = input.size(2)
231
+ input_width = input.size(3)
232
+
233
+ output.resize_as_(input)
234
+ ctx.scale.resize_as_(input)
235
+
236
+ # use output storage as temporary buffer
237
+ input_square = output
238
+ torch.pow(input, 2, out=input_square)
239
+
240
+ pre_pad = int((ctx.size - 1) / 2 + 1)
241
+ pre_pad_crop = min(pre_pad, channels)
242
+
243
+ scale_first = ctx.scale.select(1, 0)
244
+ scale_first.zero_()
245
+ # compute first feature map normalization
246
+ for c in range(pre_pad_crop):
247
+ scale_first.add_(input_square.select(1, c))
248
+
249
+ # reuse computations for next feature maps normalization
250
+ # by adding the next feature map and removing the previous
251
+ for c in range(1, channels):
252
+ scale_previous = ctx.scale.select(1, c - 1)
253
+ scale_current = ctx.scale.select(1, c)
254
+ scale_current.copy_(scale_previous)
255
+ if c < channels - pre_pad + 1:
256
+ square_next = input_square.select(1, c + pre_pad - 1)
257
+ scale_current.add_(square_next, alpha=1)
258
+
259
+ if c > pre_pad:
260
+ square_previous = input_square.select(1, c - pre_pad)
261
+ scale_current.add_(square_previous, alpha=-1)
262
+
263
+ ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
264
+
265
+ torch.pow(ctx.scale, -ctx.beta, out=output)
266
+ output.mul_(input)
267
+
268
+ ctx.save_for_backward(input, output)
269
+ return output
270
+
271
+ @staticmethod
272
+ def backward(ctx, grad_output):
273
+ input, output = ctx.saved_tensors
274
+ grad_input = grad_output.new()
275
+
276
+ batch_size = input.size(0)
277
+ channels = input.size(1)
278
+ input_height = input.size(2)
279
+ input_width = input.size(3)
280
+
281
+ paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width)
282
+ accum_ratio = input.new(input_height, input_width)
283
+
284
+ cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
285
+ inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
286
+
287
+ grad_input.resize_as_(input)
288
+ torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
289
+
290
+ paddded_ratio.zero_()
291
+ padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels)
292
+ for n in range(batch_size):
293
+ torch.mul(grad_output[n], output[n], out=padded_ratio_center)
294
+ padded_ratio_center.div_(ctx.scale[n])
295
+ torch.sum(
296
+ paddded_ratio.narrow(0, 0, ctx.size - 1),
297
+ 0,
298
+ keepdim=False,
299
+ out=accum_ratio,
300
+ )
301
+ for c in range(channels):
302
+ accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
303
+ grad_input[n][c].addcmul_(
304
+ input[n][c], accum_ratio, value=-cache_ratio_value
305
+ )
306
+ accum_ratio.add_(paddded_ratio[c], alpha=-1)
307
+
308
+ return grad_input, None, None, None, None
309
+
310
+
311
+ class BackwardHookFunction(torch.autograd.Function):
312
+ @staticmethod
313
+ def forward(ctx, *args):
314
+ ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
315
+ return args
316
+
317
+ @staticmethod
318
+ def backward(ctx, *args):
319
+ return args
.venv/Lib/site-packages/torch/nn/modules/activation.py ADDED
@@ -0,0 +1,1746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import warnings
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.parameter import Parameter
10
+
11
+ from .linear import NonDynamicallyQuantizableLinear
12
+ from .module import Module
13
+
14
+
15
+ __all__ = [
16
+ "Threshold",
17
+ "ReLU",
18
+ "RReLU",
19
+ "Hardtanh",
20
+ "ReLU6",
21
+ "Sigmoid",
22
+ "Hardsigmoid",
23
+ "Tanh",
24
+ "SiLU",
25
+ "Mish",
26
+ "Hardswish",
27
+ "ELU",
28
+ "CELU",
29
+ "SELU",
30
+ "GLU",
31
+ "GELU",
32
+ "Hardshrink",
33
+ "LeakyReLU",
34
+ "LogSigmoid",
35
+ "Softplus",
36
+ "Softshrink",
37
+ "MultiheadAttention",
38
+ "PReLU",
39
+ "Softsign",
40
+ "Tanhshrink",
41
+ "Softmin",
42
+ "Softmax",
43
+ "Softmax2d",
44
+ "LogSoftmax",
45
+ ]
46
+
47
+
48
+ class Threshold(Module):
49
+ r"""Thresholds each element of the input Tensor.
50
+
51
+ Threshold is defined as:
52
+
53
+ .. math::
54
+ y =
55
+ \begin{cases}
56
+ x, &\text{ if } x > \text{threshold} \\
57
+ \text{value}, &\text{ otherwise }
58
+ \end{cases}
59
+
60
+ Args:
61
+ threshold: The value to threshold at
62
+ value: The value to replace with
63
+ inplace: can optionally do the operation in-place. Default: ``False``
64
+
65
+ Shape:
66
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
67
+ - Output: :math:`(*)`, same shape as the input.
68
+
69
+ Examples::
70
+
71
+ >>> m = nn.Threshold(0.1, 20)
72
+ >>> input = torch.randn(2)
73
+ >>> output = m(input)
74
+ """
75
+
76
+ __constants__ = ["threshold", "value", "inplace"]
77
+
78
+ threshold: float
79
+ value: float
80
+ inplace: bool
81
+
82
+ def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
83
+ super().__init__()
84
+ self.threshold = threshold
85
+ self.value = value
86
+ self.inplace = inplace
87
+ # TODO: check in THNN (if inplace == True, then assert value <= threshold)
88
+
89
+ def forward(self, input: Tensor) -> Tensor:
90
+ return F.threshold(input, self.threshold, self.value, self.inplace)
91
+
92
+ def extra_repr(self):
93
+ inplace_str = ", inplace=True" if self.inplace else ""
94
+ return f"threshold={self.threshold}, value={self.value}{inplace_str}"
95
+
96
+
97
+ class ReLU(Module):
98
+ r"""Applies the rectified linear unit function element-wise.
99
+
100
+ :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
101
+
102
+ Args:
103
+ inplace: can optionally do the operation in-place. Default: ``False``
104
+
105
+ Shape:
106
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
107
+ - Output: :math:`(*)`, same shape as the input.
108
+
109
+ .. image:: ../scripts/activation_images/ReLU.png
110
+
111
+ Examples::
112
+
113
+ >>> m = nn.ReLU()
114
+ >>> input = torch.randn(2)
115
+ >>> output = m(input)
116
+
117
+
118
+ An implementation of CReLU - https://arxiv.org/abs/1603.05201
119
+
120
+ >>> m = nn.ReLU()
121
+ >>> input = torch.randn(2).unsqueeze(0)
122
+ >>> output = torch.cat((m(input), m(-input)))
123
+ """
124
+
125
+ __constants__ = ["inplace"]
126
+ inplace: bool
127
+
128
+ def __init__(self, inplace: bool = False):
129
+ super().__init__()
130
+ self.inplace = inplace
131
+
132
+ def forward(self, input: Tensor) -> Tensor:
133
+ return F.relu(input, inplace=self.inplace)
134
+
135
+ def extra_repr(self) -> str:
136
+ inplace_str = "inplace=True" if self.inplace else ""
137
+ return inplace_str
138
+
139
+
140
+ class RReLU(Module):
141
+ r"""Applies the randomized leaky rectified linear unit function, element-wise.
142
+
143
+ Method described in the paper:
144
+ `Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_.
145
+
146
+ The function is defined as:
147
+
148
+ .. math::
149
+ \text{RReLU}(x) =
150
+ \begin{cases}
151
+ x & \text{if } x \geq 0 \\
152
+ ax & \text{ otherwise }
153
+ \end{cases}
154
+
155
+ where :math:`a` is randomly sampled from uniform distribution
156
+ :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during
157
+ evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`.
158
+
159
+ Args:
160
+ lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
161
+ upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
162
+ inplace: can optionally do the operation in-place. Default: ``False``
163
+
164
+ Shape:
165
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
166
+ - Output: :math:`(*)`, same shape as the input.
167
+
168
+ .. image:: ../scripts/activation_images/RReLU.png
169
+
170
+ Examples::
171
+
172
+ >>> m = nn.RReLU(0.1, 0.3)
173
+ >>> input = torch.randn(2)
174
+ >>> output = m(input)
175
+
176
+ """
177
+
178
+ __constants__ = ["lower", "upper", "inplace"]
179
+
180
+ lower: float
181
+ upper: float
182
+ inplace: bool
183
+
184
+ def __init__(
185
+ self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False
186
+ ):
187
+ super().__init__()
188
+ self.lower = lower
189
+ self.upper = upper
190
+ self.inplace = inplace
191
+
192
+ def forward(self, input: Tensor) -> Tensor:
193
+ return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
194
+
195
+ def extra_repr(self):
196
+ inplace_str = ", inplace=True" if self.inplace else ""
197
+ return f"lower={self.lower}, upper={self.upper}{inplace_str}"
198
+
199
+
200
+ class Hardtanh(Module):
201
+ r"""Applies the HardTanh function element-wise.
202
+
203
+ HardTanh is defined as:
204
+
205
+ .. math::
206
+ \text{HardTanh}(x) = \begin{cases}
207
+ \text{max\_val} & \text{ if } x > \text{ max\_val } \\
208
+ \text{min\_val} & \text{ if } x < \text{ min\_val } \\
209
+ x & \text{ otherwise } \\
210
+ \end{cases}
211
+
212
+ Args:
213
+ min_val: minimum value of the linear region range. Default: -1
214
+ max_val: maximum value of the linear region range. Default: 1
215
+ inplace: can optionally do the operation in-place. Default: ``False``
216
+
217
+ Keyword arguments :attr:`min_value` and :attr:`max_value`
218
+ have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
219
+
220
+ Shape:
221
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
222
+ - Output: :math:`(*)`, same shape as the input.
223
+
224
+ .. image:: ../scripts/activation_images/Hardtanh.png
225
+
226
+ Examples::
227
+
228
+ >>> m = nn.Hardtanh(-2, 2)
229
+ >>> input = torch.randn(2)
230
+ >>> output = m(input)
231
+ """
232
+
233
+ __constants__ = ["min_val", "max_val", "inplace"]
234
+
235
+ min_val: float
236
+ max_val: float
237
+ inplace: bool
238
+
239
+ def __init__(
240
+ self,
241
+ min_val: float = -1.0,
242
+ max_val: float = 1.0,
243
+ inplace: bool = False,
244
+ min_value: Optional[float] = None,
245
+ max_value: Optional[float] = None,
246
+ ) -> None:
247
+ super().__init__()
248
+ if min_value is not None:
249
+ warnings.warn(
250
+ "keyword argument `min_value` is deprecated and rename to `min_val`",
251
+ FutureWarning,
252
+ stacklevel=2,
253
+ )
254
+ min_val = min_value
255
+ if max_value is not None:
256
+ warnings.warn(
257
+ "keyword argument `max_value` is deprecated and rename to `max_val`",
258
+ FutureWarning,
259
+ stacklevel=2,
260
+ )
261
+ max_val = max_value
262
+
263
+ self.min_val = min_val
264
+ self.max_val = max_val
265
+ self.inplace = inplace
266
+ assert self.max_val > self.min_val
267
+
268
+ def forward(self, input: Tensor) -> Tensor:
269
+ return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
270
+
271
+ def extra_repr(self) -> str:
272
+ inplace_str = ", inplace=True" if self.inplace else ""
273
+ return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}"
274
+
275
+
276
+ class ReLU6(Hardtanh):
277
+ r"""Applies the ReLU6 function element-wise.
278
+
279
+ .. math::
280
+ \text{ReLU6}(x) = \min(\max(0,x), 6)
281
+
282
+ Args:
283
+ inplace: can optionally do the operation in-place. Default: ``False``
284
+
285
+ Shape:
286
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
287
+ - Output: :math:`(*)`, same shape as the input.
288
+
289
+ .. image:: ../scripts/activation_images/ReLU6.png
290
+
291
+ Examples::
292
+
293
+ >>> m = nn.ReLU6()
294
+ >>> input = torch.randn(2)
295
+ >>> output = m(input)
296
+ """
297
+
298
+ def __init__(self, inplace: bool = False):
299
+ super().__init__(0.0, 6.0, inplace)
300
+
301
+ def extra_repr(self) -> str:
302
+ inplace_str = "inplace=True" if self.inplace else ""
303
+ return inplace_str
304
+
305
+
306
+ class Sigmoid(Module):
307
+ r"""Applies the Sigmoid function element-wise.
308
+
309
+ .. math::
310
+ \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
311
+
312
+
313
+ Shape:
314
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
315
+ - Output: :math:`(*)`, same shape as the input.
316
+
317
+ .. image:: ../scripts/activation_images/Sigmoid.png
318
+
319
+ Examples::
320
+
321
+ >>> m = nn.Sigmoid()
322
+ >>> input = torch.randn(2)
323
+ >>> output = m(input)
324
+ """
325
+
326
+ def forward(self, input: Tensor) -> Tensor:
327
+ return torch.sigmoid(input)
328
+
329
+
330
+ class Hardsigmoid(Module):
331
+ r"""Applies the Hardsigmoid function element-wise.
332
+
333
+ Hardsigmoid is defined as:
334
+
335
+ .. math::
336
+ \text{Hardsigmoid}(x) = \begin{cases}
337
+ 0 & \text{if~} x \le -3, \\
338
+ 1 & \text{if~} x \ge +3, \\
339
+ x / 6 + 1 / 2 & \text{otherwise}
340
+ \end{cases}
341
+
342
+ Args:
343
+ inplace: can optionally do the operation in-place. Default: ``False``
344
+
345
+ Shape:
346
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
347
+ - Output: :math:`(*)`, same shape as the input.
348
+
349
+ .. image:: ../scripts/activation_images/Hardsigmoid.png
350
+
351
+ Examples::
352
+
353
+ >>> m = nn.Hardsigmoid()
354
+ >>> input = torch.randn(2)
355
+ >>> output = m(input)
356
+ """
357
+
358
+ __constants__ = ["inplace"]
359
+
360
+ inplace: bool
361
+
362
+ def __init__(self, inplace: bool = False) -> None:
363
+ super().__init__()
364
+ self.inplace = inplace
365
+
366
+ def forward(self, input: Tensor) -> Tensor:
367
+ return F.hardsigmoid(input, self.inplace)
368
+
369
+
370
+ class Tanh(Module):
371
+ r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
372
+
373
+ Tanh is defined as:
374
+
375
+ .. math::
376
+ \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
377
+
378
+ Shape:
379
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
380
+ - Output: :math:`(*)`, same shape as the input.
381
+
382
+ .. image:: ../scripts/activation_images/Tanh.png
383
+
384
+ Examples::
385
+
386
+ >>> m = nn.Tanh()
387
+ >>> input = torch.randn(2)
388
+ >>> output = m(input)
389
+ """
390
+
391
+ def forward(self, input: Tensor) -> Tensor:
392
+ return torch.tanh(input)
393
+
394
+
395
+ class SiLU(Module):
396
+ r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
397
+
398
+ The SiLU function is also known as the swish function.
399
+
400
+ .. math::
401
+ \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
402
+
403
+ .. note::
404
+ See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
405
+ where the SiLU (Sigmoid Linear Unit) was originally coined, and see
406
+ `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
407
+ in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
408
+ a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
409
+ where the SiLU was experimented with later.
410
+
411
+ Shape:
412
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
413
+ - Output: :math:`(*)`, same shape as the input.
414
+
415
+ .. image:: ../scripts/activation_images/SiLU.png
416
+
417
+ Examples::
418
+
419
+ >>> m = nn.SiLU()
420
+ >>> input = torch.randn(2)
421
+ >>> output = m(input)
422
+ """
423
+
424
+ __constants__ = ["inplace"]
425
+ inplace: bool
426
+
427
+ def __init__(self, inplace: bool = False):
428
+ super().__init__()
429
+ self.inplace = inplace
430
+
431
+ def forward(self, input: Tensor) -> Tensor:
432
+ return F.silu(input, inplace=self.inplace)
433
+
434
+ def extra_repr(self) -> str:
435
+ inplace_str = "inplace=True" if self.inplace else ""
436
+ return inplace_str
437
+
438
+
439
+ class Mish(Module):
440
+ r"""Applies the Mish function, element-wise.
441
+
442
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
443
+
444
+ .. math::
445
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
446
+
447
+ .. note::
448
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
449
+
450
+ Shape:
451
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
452
+ - Output: :math:`(*)`, same shape as the input.
453
+
454
+ .. image:: ../scripts/activation_images/Mish.png
455
+
456
+ Examples::
457
+
458
+ >>> m = nn.Mish()
459
+ >>> input = torch.randn(2)
460
+ >>> output = m(input)
461
+ """
462
+
463
+ __constants__ = ["inplace"]
464
+ inplace: bool
465
+
466
+ def __init__(self, inplace: bool = False):
467
+ super().__init__()
468
+ self.inplace = inplace
469
+
470
+ def forward(self, input: Tensor) -> Tensor:
471
+ return F.mish(input, inplace=self.inplace)
472
+
473
+ def extra_repr(self) -> str:
474
+ inplace_str = "inplace=True" if self.inplace else ""
475
+ return inplace_str
476
+
477
+
478
+ class Hardswish(Module):
479
+ r"""Applies the Hardswish function, element-wise.
480
+
481
+ Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
482
+
483
+ Hardswish is defined as:
484
+
485
+ .. math::
486
+ \text{Hardswish}(x) = \begin{cases}
487
+ 0 & \text{if~} x \le -3, \\
488
+ x & \text{if~} x \ge +3, \\
489
+ x \cdot (x + 3) /6 & \text{otherwise}
490
+ \end{cases}
491
+
492
+ Args:
493
+ inplace: can optionally do the operation in-place. Default: ``False``
494
+
495
+ Shape:
496
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
497
+ - Output: :math:`(*)`, same shape as the input.
498
+
499
+ .. image:: ../scripts/activation_images/Hardswish.png
500
+
501
+ Examples::
502
+
503
+ >>> m = nn.Hardswish()
504
+ >>> input = torch.randn(2)
505
+ >>> output = m(input)
506
+ """
507
+
508
+ __constants__ = ["inplace"]
509
+
510
+ inplace: bool
511
+
512
+ def __init__(self, inplace: bool = False) -> None:
513
+ super().__init__()
514
+ self.inplace = inplace
515
+
516
+ def forward(self, input: Tensor) -> Tensor:
517
+ return F.hardswish(input, self.inplace)
518
+
519
+
520
+ class ELU(Module):
521
+ r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
522
+
523
+ Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
524
+ Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
525
+
526
+ ELU is defined as:
527
+
528
+ .. math::
529
+ \text{ELU}(x) = \begin{cases}
530
+ x, & \text{ if } x > 0\\
531
+ \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
532
+ \end{cases}
533
+
534
+ Args:
535
+ alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
536
+ inplace: can optionally do the operation in-place. Default: ``False``
537
+
538
+ Shape:
539
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
540
+ - Output: :math:`(*)`, same shape as the input.
541
+
542
+ .. image:: ../scripts/activation_images/ELU.png
543
+
544
+ Examples::
545
+
546
+ >>> m = nn.ELU()
547
+ >>> input = torch.randn(2)
548
+ >>> output = m(input)
549
+ """
550
+
551
+ __constants__ = ["alpha", "inplace"]
552
+ alpha: float
553
+ inplace: bool
554
+
555
+ def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None:
556
+ super().__init__()
557
+ self.alpha = alpha
558
+ self.inplace = inplace
559
+
560
+ def forward(self, input: Tensor) -> Tensor:
561
+ return F.elu(input, self.alpha, self.inplace)
562
+
563
+ def extra_repr(self) -> str:
564
+ inplace_str = ", inplace=True" if self.inplace else ""
565
+ return f"alpha={self.alpha}{inplace_str}"
566
+
567
+
568
+ class CELU(Module):
569
+ r"""Applies the CELU function element-wise.
570
+
571
+ .. math::
572
+ \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
573
+
574
+ More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
575
+
576
+ Args:
577
+ alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
578
+ inplace: can optionally do the operation in-place. Default: ``False``
579
+
580
+ Shape:
581
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
582
+ - Output: :math:`(*)`, same shape as the input.
583
+
584
+ .. image:: ../scripts/activation_images/CELU.png
585
+
586
+ Examples::
587
+
588
+ >>> m = nn.CELU()
589
+ >>> input = torch.randn(2)
590
+ >>> output = m(input)
591
+
592
+ .. _`Continuously Differentiable Exponential Linear Units`:
593
+ https://arxiv.org/abs/1704.07483
594
+ """
595
+
596
+ __constants__ = ["alpha", "inplace"]
597
+ alpha: float
598
+ inplace: bool
599
+
600
+ def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None:
601
+ super().__init__()
602
+ self.alpha = alpha
603
+ self.inplace = inplace
604
+
605
+ def forward(self, input: Tensor) -> Tensor:
606
+ return F.celu(input, self.alpha, self.inplace)
607
+
608
+ def extra_repr(self) -> str:
609
+ inplace_str = ", inplace=True" if self.inplace else ""
610
+ return f"alpha={self.alpha}{inplace_str}"
611
+
612
+
613
+ class SELU(Module):
614
+ r"""Applies the SELU function element-wise.
615
+
616
+ .. math::
617
+ \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
618
+
619
+ with :math:`\alpha = 1.6732632423543772848170429916717` and
620
+ :math:`\text{scale} = 1.0507009873554804934193349852946`.
621
+
622
+ .. warning::
623
+ When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
624
+ ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
625
+ in order to get `Self-Normalizing Neural Networks`_.
626
+ See :func:`torch.nn.init.calculate_gain` for more information.
627
+
628
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
629
+
630
+ Args:
631
+ inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
632
+
633
+ Shape:
634
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
635
+ - Output: :math:`(*)`, same shape as the input.
636
+
637
+ .. image:: ../scripts/activation_images/SELU.png
638
+
639
+ Examples::
640
+
641
+ >>> m = nn.SELU()
642
+ >>> input = torch.randn(2)
643
+ >>> output = m(input)
644
+
645
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
646
+ """
647
+
648
+ __constants__ = ["inplace"]
649
+ inplace: bool
650
+
651
+ def __init__(self, inplace: bool = False) -> None:
652
+ super().__init__()
653
+ self.inplace = inplace
654
+
655
+ def forward(self, input: Tensor) -> Tensor:
656
+ return F.selu(input, self.inplace)
657
+
658
+ def extra_repr(self) -> str:
659
+ inplace_str = "inplace=True" if self.inplace else ""
660
+ return inplace_str
661
+
662
+
663
+ class GLU(Module):
664
+ r"""Applies the gated linear unit function.
665
+
666
+ :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
667
+ of the input matrices and :math:`b` is the second half.
668
+
669
+ Args:
670
+ dim (int): the dimension on which to split the input. Default: -1
671
+
672
+ Shape:
673
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
674
+ dimensions
675
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
676
+
677
+ Examples::
678
+
679
+ >>> m = nn.GLU()
680
+ >>> input = torch.randn(4, 2)
681
+ >>> output = m(input)
682
+ """
683
+
684
+ __constants__ = ["dim"]
685
+ dim: int
686
+
687
+ def __init__(self, dim: int = -1) -> None:
688
+ super().__init__()
689
+ self.dim = dim
690
+
691
+ def forward(self, input: Tensor) -> Tensor:
692
+ return F.glu(input, self.dim)
693
+
694
+ def extra_repr(self) -> str:
695
+ return f"dim={self.dim}"
696
+
697
+
698
+ class GELU(Module):
699
+ r"""Applies the Gaussian Error Linear Units function.
700
+
701
+ .. math:: \text{GELU}(x) = x * \Phi(x)
702
+
703
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
704
+
705
+ When the approximate argument is 'tanh', Gelu is estimated with:
706
+
707
+ .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
708
+
709
+ Args:
710
+ approximate (str, optional): the gelu approximation algorithm to use:
711
+ ``'none'`` | ``'tanh'``. Default: ``'none'``
712
+
713
+ Shape:
714
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
715
+ - Output: :math:`(*)`, same shape as the input.
716
+
717
+ .. image:: ../scripts/activation_images/GELU.png
718
+
719
+ Examples::
720
+
721
+ >>> m = nn.GELU()
722
+ >>> input = torch.randn(2)
723
+ >>> output = m(input)
724
+ """
725
+
726
+ __constants__ = ["approximate"]
727
+ approximate: str
728
+
729
+ def __init__(self, approximate: str = "none") -> None:
730
+ super().__init__()
731
+ self.approximate = approximate
732
+
733
+ def forward(self, input: Tensor) -> Tensor:
734
+ return F.gelu(input, approximate=self.approximate)
735
+
736
+ def extra_repr(self) -> str:
737
+ return f"approximate={repr(self.approximate)}"
738
+
739
+
740
+ class Hardshrink(Module):
741
+ r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
742
+
743
+ Hardshrink is defined as:
744
+
745
+ .. math::
746
+ \text{HardShrink}(x) =
747
+ \begin{cases}
748
+ x, & \text{ if } x > \lambda \\
749
+ x, & \text{ if } x < -\lambda \\
750
+ 0, & \text{ otherwise }
751
+ \end{cases}
752
+
753
+ Args:
754
+ lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
755
+
756
+ Shape:
757
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
758
+ - Output: :math:`(*)`, same shape as the input.
759
+
760
+ .. image:: ../scripts/activation_images/Hardshrink.png
761
+
762
+ Examples::
763
+
764
+ >>> m = nn.Hardshrink()
765
+ >>> input = torch.randn(2)
766
+ >>> output = m(input)
767
+ """
768
+
769
+ __constants__ = ["lambd"]
770
+ lambd: float
771
+
772
+ def __init__(self, lambd: float = 0.5) -> None:
773
+ super().__init__()
774
+ self.lambd = lambd
775
+
776
+ def forward(self, input: Tensor) -> Tensor:
777
+ return F.hardshrink(input, self.lambd)
778
+
779
+ def extra_repr(self) -> str:
780
+ return f"{self.lambd}"
781
+
782
+
783
+ class LeakyReLU(Module):
784
+ r"""Applies the LeakyReLU function element-wise.
785
+
786
+ .. math::
787
+ \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
788
+
789
+
790
+ or
791
+
792
+ .. math::
793
+ \text{LeakyReLU}(x) =
794
+ \begin{cases}
795
+ x, & \text{ if } x \geq 0 \\
796
+ \text{negative\_slope} \times x, & \text{ otherwise }
797
+ \end{cases}
798
+
799
+ Args:
800
+ negative_slope: Controls the angle of the negative slope (which is used for
801
+ negative input values). Default: 1e-2
802
+ inplace: can optionally do the operation in-place. Default: ``False``
803
+
804
+ Shape:
805
+ - Input: :math:`(*)` where `*` means, any number of additional
806
+ dimensions
807
+ - Output: :math:`(*)`, same shape as the input
808
+
809
+ .. image:: ../scripts/activation_images/LeakyReLU.png
810
+
811
+ Examples::
812
+
813
+ >>> m = nn.LeakyReLU(0.1)
814
+ >>> input = torch.randn(2)
815
+ >>> output = m(input)
816
+ """
817
+
818
+ __constants__ = ["inplace", "negative_slope"]
819
+ inplace: bool
820
+ negative_slope: float
821
+
822
+ def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
823
+ super().__init__()
824
+ self.negative_slope = negative_slope
825
+ self.inplace = inplace
826
+
827
+ def forward(self, input: Tensor) -> Tensor:
828
+ return F.leaky_relu(input, self.negative_slope, self.inplace)
829
+
830
+ def extra_repr(self) -> str:
831
+ inplace_str = ", inplace=True" if self.inplace else ""
832
+ return f"negative_slope={self.negative_slope}{inplace_str}"
833
+
834
+
835
+ class LogSigmoid(Module):
836
+ r"""Applies the Logsigmoid function element-wise.
837
+
838
+ .. math::
839
+ \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
840
+
841
+ Shape:
842
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
843
+ - Output: :math:`(*)`, same shape as the input.
844
+
845
+ .. image:: ../scripts/activation_images/LogSigmoid.png
846
+
847
+ Examples::
848
+
849
+ >>> m = nn.LogSigmoid()
850
+ >>> input = torch.randn(2)
851
+ >>> output = m(input)
852
+ """
853
+
854
+ def forward(self, input: Tensor) -> Tensor:
855
+ return F.logsigmoid(input)
856
+
857
+
858
+ class Softplus(Module):
859
+ r"""Applies the Softplus function element-wise.
860
+
861
+ .. math::
862
+ \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
863
+
864
+ SoftPlus is a smooth approximation to the ReLU function and can be used
865
+ to constrain the output of a machine to always be positive.
866
+
867
+ For numerical stability the implementation reverts to the linear function
868
+ when :math:`input \times \beta > threshold`.
869
+
870
+ Args:
871
+ beta: the :math:`\beta` value for the Softplus formulation. Default: 1
872
+ threshold: values above this revert to a linear function. Default: 20
873
+
874
+ Shape:
875
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
876
+ - Output: :math:`(*)`, same shape as the input.
877
+
878
+ .. image:: ../scripts/activation_images/Softplus.png
879
+
880
+ Examples::
881
+
882
+ >>> m = nn.Softplus()
883
+ >>> input = torch.randn(2)
884
+ >>> output = m(input)
885
+ """
886
+
887
+ __constants__ = ["beta", "threshold"]
888
+ beta: float
889
+ threshold: float
890
+
891
+ def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None:
892
+ super().__init__()
893
+ self.beta = beta
894
+ self.threshold = threshold
895
+
896
+ def forward(self, input: Tensor) -> Tensor:
897
+ return F.softplus(input, self.beta, self.threshold)
898
+
899
+ def extra_repr(self) -> str:
900
+ return f"beta={self.beta}, threshold={self.threshold}"
901
+
902
+
903
+ class Softshrink(Module):
904
+ r"""Applies the soft shrinkage function element-wise.
905
+
906
+ .. math::
907
+ \text{SoftShrinkage}(x) =
908
+ \begin{cases}
909
+ x - \lambda, & \text{ if } x > \lambda \\
910
+ x + \lambda, & \text{ if } x < -\lambda \\
911
+ 0, & \text{ otherwise }
912
+ \end{cases}
913
+
914
+ Args:
915
+ lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
916
+
917
+ Shape:
918
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
919
+ - Output: :math:`(*)`, same shape as the input.
920
+
921
+ .. image:: ../scripts/activation_images/Softshrink.png
922
+
923
+ Examples::
924
+
925
+ >>> m = nn.Softshrink()
926
+ >>> input = torch.randn(2)
927
+ >>> output = m(input)
928
+ """
929
+
930
+ __constants__ = ["lambd"]
931
+ lambd: float
932
+
933
+ def __init__(self, lambd: float = 0.5) -> None:
934
+ super().__init__()
935
+ self.lambd = lambd
936
+
937
+ def forward(self, input: Tensor) -> Tensor:
938
+ return F.softshrink(input, self.lambd)
939
+
940
+ def extra_repr(self) -> str:
941
+ return str(self.lambd)
942
+
943
+
944
+ def _check_arg_device(x: Optional[torch.Tensor]) -> bool:
945
+ if x is not None:
946
+ return x.device.type in [
947
+ "cpu",
948
+ "cuda",
949
+ torch.utils.backend_registration._privateuse1_backend_name,
950
+ ]
951
+ return True
952
+
953
+
954
+ def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool:
955
+ if x is not None:
956
+ return x.requires_grad
957
+ return False
958
+
959
+
960
+ def _is_make_fx_tracing():
961
+ if not torch.jit.is_scripting():
962
+ torch_dispatch_mode_stack = (
963
+ torch.utils._python_dispatch._get_current_dispatch_mode_stack()
964
+ )
965
+ return any(
966
+ type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode
967
+ for x in torch_dispatch_mode_stack
968
+ )
969
+ else:
970
+ return False
971
+
972
+
973
+ class MultiheadAttention(Module):
974
+ r"""Allows the model to jointly attend to information from different representation subspaces.
975
+
976
+ Method described in the paper:
977
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
978
+
979
+ Multi-Head Attention is defined as:
980
+
981
+ .. math::
982
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
983
+
984
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
985
+
986
+ ``nn.MultiHeadAttention`` will use the optimized implementations of
987
+ ``scaled_dot_product_attention()`` when possible.
988
+
989
+ In addition to support for the new ``scaled_dot_product_attention()``
990
+ function, for speeding up Inference, MHA will use
991
+ fastpath inference with support for Nested Tensors, iff:
992
+
993
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
994
+ - inputs are batched (3D) with ``batch_first==True``
995
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
996
+ - training is disabled (using ``.eval()``)
997
+ - ``add_bias_kv`` is ``False``
998
+ - ``add_zero_attn`` is ``False``
999
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
1000
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
1001
+ nor ``attn_mask`` is passed
1002
+ - autocast is disabled
1003
+
1004
+ If the optimized inference fastpath implementation is in use, a
1005
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
1006
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
1007
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
1008
+ will be returned, and an additional speedup proportional to the fraction of the input
1009
+ that is padding can be expected.
1010
+
1011
+ Args:
1012
+ embed_dim: Total dimension of the model.
1013
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
1014
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
1015
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
1016
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
1017
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
1018
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
1019
+ Default: ``False``.
1020
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
1021
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
1022
+ batch_first: If ``True``, then the input and output tensors are provided
1023
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
1024
+
1025
+ Examples::
1026
+
1027
+ >>> # xdoctest: +SKIP
1028
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
1029
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
1030
+
1031
+ .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
1032
+ https://arxiv.org/abs/2205.14135
1033
+
1034
+ """
1035
+
1036
+ __constants__ = ["batch_first"]
1037
+ bias_k: Optional[torch.Tensor]
1038
+ bias_v: Optional[torch.Tensor]
1039
+
1040
+ def __init__(
1041
+ self,
1042
+ embed_dim,
1043
+ num_heads,
1044
+ dropout=0.0,
1045
+ bias=True,
1046
+ add_bias_kv=False,
1047
+ add_zero_attn=False,
1048
+ kdim=None,
1049
+ vdim=None,
1050
+ batch_first=False,
1051
+ device=None,
1052
+ dtype=None,
1053
+ ) -> None:
1054
+ if embed_dim <= 0 or num_heads <= 0:
1055
+ raise ValueError(
1056
+ f"embed_dim and num_heads must be greater than 0,"
1057
+ f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
1058
+ )
1059
+ factory_kwargs = {"device": device, "dtype": dtype}
1060
+ super().__init__()
1061
+ self.embed_dim = embed_dim
1062
+ self.kdim = kdim if kdim is not None else embed_dim
1063
+ self.vdim = vdim if vdim is not None else embed_dim
1064
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
1065
+
1066
+ self.num_heads = num_heads
1067
+ self.dropout = dropout
1068
+ self.batch_first = batch_first
1069
+ self.head_dim = embed_dim // num_heads
1070
+ assert (
1071
+ self.head_dim * num_heads == self.embed_dim
1072
+ ), "embed_dim must be divisible by num_heads"
1073
+
1074
+ if not self._qkv_same_embed_dim:
1075
+ self.q_proj_weight = Parameter(
1076
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
1077
+ )
1078
+ self.k_proj_weight = Parameter(
1079
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
1080
+ )
1081
+ self.v_proj_weight = Parameter(
1082
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
1083
+ )
1084
+ self.register_parameter("in_proj_weight", None)
1085
+ else:
1086
+ self.in_proj_weight = Parameter(
1087
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
1088
+ )
1089
+ self.register_parameter("q_proj_weight", None)
1090
+ self.register_parameter("k_proj_weight", None)
1091
+ self.register_parameter("v_proj_weight", None)
1092
+
1093
+ if bias:
1094
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
1095
+ else:
1096
+ self.register_parameter("in_proj_bias", None)
1097
+ self.out_proj = NonDynamicallyQuantizableLinear(
1098
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
1099
+ )
1100
+
1101
+ if add_bias_kv:
1102
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
1103
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
1104
+ else:
1105
+ self.bias_k = self.bias_v = None
1106
+
1107
+ self.add_zero_attn = add_zero_attn
1108
+
1109
+ self._reset_parameters()
1110
+
1111
+ def _reset_parameters(self):
1112
+ if self._qkv_same_embed_dim:
1113
+ xavier_uniform_(self.in_proj_weight)
1114
+ else:
1115
+ xavier_uniform_(self.q_proj_weight)
1116
+ xavier_uniform_(self.k_proj_weight)
1117
+ xavier_uniform_(self.v_proj_weight)
1118
+
1119
+ if self.in_proj_bias is not None:
1120
+ constant_(self.in_proj_bias, 0.0)
1121
+ constant_(self.out_proj.bias, 0.0)
1122
+ if self.bias_k is not None:
1123
+ xavier_normal_(self.bias_k)
1124
+ if self.bias_v is not None:
1125
+ xavier_normal_(self.bias_v)
1126
+
1127
+ def __setstate__(self, state):
1128
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
1129
+ if "_qkv_same_embed_dim" not in state:
1130
+ state["_qkv_same_embed_dim"] = True
1131
+
1132
+ super().__setstate__(state)
1133
+
1134
+ def forward(
1135
+ self,
1136
+ query: Tensor,
1137
+ key: Tensor,
1138
+ value: Tensor,
1139
+ key_padding_mask: Optional[Tensor] = None,
1140
+ need_weights: bool = True,
1141
+ attn_mask: Optional[Tensor] = None,
1142
+ average_attn_weights: bool = True,
1143
+ is_causal: bool = False,
1144
+ ) -> Tuple[Tensor, Optional[Tensor]]:
1145
+ r"""Compute attention outputs using query, key, and value embeddings.
1146
+
1147
+ Supports optional parameters for padding, masks and attention weights.
1148
+
1149
+ Args:
1150
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
1151
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
1152
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
1153
+ Queries are compared against key-value pairs to produce the output.
1154
+ See "Attention Is All You Need" for more details.
1155
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
1156
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
1157
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
1158
+ See "Attention Is All You Need" for more details.
1159
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
1160
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
1161
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
1162
+ See "Attention Is All You Need" for more details.
1163
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
1164
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
1165
+ Binary and float masks are supported.
1166
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
1167
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
1168
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
1169
+ Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
1170
+ and achieve the best performance for MHA.
1171
+ Default: ``True``.
1172
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
1173
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
1174
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
1175
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
1176
+ Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
1177
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
1178
+ the attention weight.
1179
+ If both attn_mask and key_padding_mask are supplied, their types should match.
1180
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
1181
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
1182
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
1183
+ is_causal: If specified, applies a causal mask as attention mask.
1184
+ Default: ``False``.
1185
+ Warning:
1186
+ ``is_causal`` provides a hint that ``attn_mask`` is the
1187
+ causal mask. Providing incorrect hints can result in
1188
+ incorrect execution, including forward and backward
1189
+ compatibility.
1190
+
1191
+ Outputs:
1192
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
1193
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
1194
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
1195
+ embedding dimension ``embed_dim``.
1196
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
1197
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
1198
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
1199
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
1200
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
1201
+
1202
+ .. note::
1203
+ `batch_first` argument is ignored for unbatched inputs.
1204
+ """ # noqa: B950
1205
+ why_not_fast_path = ""
1206
+ if (
1207
+ (attn_mask is not None and torch.is_floating_point(attn_mask))
1208
+ or (key_padding_mask is not None)
1209
+ and torch.is_floating_point(key_padding_mask)
1210
+ ):
1211
+ why_not_fast_path = "floating-point masks are not supported for fast path."
1212
+
1213
+ is_batched = query.dim() == 3
1214
+
1215
+ key_padding_mask = F._canonical_mask(
1216
+ mask=key_padding_mask,
1217
+ mask_name="key_padding_mask",
1218
+ other_type=F._none_or_dtype(attn_mask),
1219
+ other_name="attn_mask",
1220
+ target_type=query.dtype,
1221
+ )
1222
+
1223
+ attn_mask = F._canonical_mask(
1224
+ mask=attn_mask,
1225
+ mask_name="attn_mask",
1226
+ other_type=None,
1227
+ other_name="",
1228
+ target_type=query.dtype,
1229
+ check_other=False,
1230
+ )
1231
+
1232
+ is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
1233
+
1234
+ if not is_fastpath_enabled:
1235
+ why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
1236
+ elif not is_batched:
1237
+ why_not_fast_path = (
1238
+ f"input not batched; expected query.dim() of 3 but got {query.dim()}"
1239
+ )
1240
+ elif query is not key or key is not value:
1241
+ # When lifting this restriction, don't forget to either
1242
+ # enforce that the dtypes all match or test cases where
1243
+ # they don't!
1244
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
1245
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
1246
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
1247
+ elif self.in_proj_weight is None:
1248
+ why_not_fast_path = "in_proj_weight was None"
1249
+ elif query.dtype != self.in_proj_weight.dtype:
1250
+ # this case will fail anyway, but at least they'll get a useful error message.
1251
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
1252
+ elif self.training:
1253
+ why_not_fast_path = "training is enabled"
1254
+ elif (self.num_heads % 2) != 0:
1255
+ why_not_fast_path = "self.num_heads is not even"
1256
+ elif not self.batch_first:
1257
+ why_not_fast_path = "batch_first was not True"
1258
+ elif self.bias_k is not None:
1259
+ why_not_fast_path = "self.bias_k was not None"
1260
+ elif self.bias_v is not None:
1261
+ why_not_fast_path = "self.bias_v was not None"
1262
+ elif self.add_zero_attn:
1263
+ why_not_fast_path = "add_zero_attn was enabled"
1264
+ elif not self._qkv_same_embed_dim:
1265
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
1266
+ elif query.is_nested and (
1267
+ key_padding_mask is not None or attn_mask is not None
1268
+ ):
1269
+ why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
1270
+ is not supported with NestedTensor input"
1271
+ elif torch.is_autocast_enabled():
1272
+ why_not_fast_path = "autocast is enabled"
1273
+
1274
+ if not why_not_fast_path:
1275
+ tensor_args = (
1276
+ query,
1277
+ key,
1278
+ value,
1279
+ self.in_proj_weight,
1280
+ self.in_proj_bias,
1281
+ self.out_proj.weight,
1282
+ self.out_proj.bias,
1283
+ )
1284
+ # We have to use list comprehensions below because TorchScript does not support
1285
+ # generator expressions.
1286
+ if torch.overrides.has_torch_function(tensor_args):
1287
+ why_not_fast_path = "some Tensor argument has_torch_function"
1288
+ elif _is_make_fx_tracing():
1289
+ why_not_fast_path = "we are running make_fx tracing"
1290
+ elif not all(_check_arg_device(x) for x in tensor_args):
1291
+ why_not_fast_path = (
1292
+ "some Tensor argument's device is neither one of "
1293
+ f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
1294
+ )
1295
+ elif torch.is_grad_enabled() and any(
1296
+ _arg_requires_grad(x) for x in tensor_args
1297
+ ):
1298
+ why_not_fast_path = (
1299
+ "grad is enabled and at least one of query or the "
1300
+ "input/output projection weights or biases requires_grad"
1301
+ )
1302
+ if not why_not_fast_path:
1303
+ merged_mask, mask_type = self.merge_masks(
1304
+ attn_mask, key_padding_mask, query
1305
+ )
1306
+
1307
+ if self.in_proj_bias is not None and self.in_proj_weight is not None:
1308
+ return torch._native_multi_head_attention(
1309
+ query,
1310
+ key,
1311
+ value,
1312
+ self.embed_dim,
1313
+ self.num_heads,
1314
+ self.in_proj_weight,
1315
+ self.in_proj_bias,
1316
+ self.out_proj.weight,
1317
+ self.out_proj.bias,
1318
+ merged_mask,
1319
+ need_weights,
1320
+ average_attn_weights,
1321
+ mask_type,
1322
+ )
1323
+
1324
+ any_nested = query.is_nested or key.is_nested or value.is_nested
1325
+ assert not any_nested, (
1326
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
1327
+ + f"The fast path was not hit because {why_not_fast_path}"
1328
+ )
1329
+
1330
+ if self.batch_first and is_batched:
1331
+ # make sure that the transpose op does not affect the "is" property
1332
+ if key is value:
1333
+ if query is key:
1334
+ query = key = value = query.transpose(1, 0)
1335
+ else:
1336
+ query, key = (x.transpose(1, 0) for x in (query, key))
1337
+ value = key
1338
+ else:
1339
+ query, key, value = (x.transpose(1, 0) for x in (query, key, value))
1340
+
1341
+ if not self._qkv_same_embed_dim:
1342
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
1343
+ query,
1344
+ key,
1345
+ value,
1346
+ self.embed_dim,
1347
+ self.num_heads,
1348
+ self.in_proj_weight,
1349
+ self.in_proj_bias,
1350
+ self.bias_k,
1351
+ self.bias_v,
1352
+ self.add_zero_attn,
1353
+ self.dropout,
1354
+ self.out_proj.weight,
1355
+ self.out_proj.bias,
1356
+ training=self.training,
1357
+ key_padding_mask=key_padding_mask,
1358
+ need_weights=need_weights,
1359
+ attn_mask=attn_mask,
1360
+ use_separate_proj_weight=True,
1361
+ q_proj_weight=self.q_proj_weight,
1362
+ k_proj_weight=self.k_proj_weight,
1363
+ v_proj_weight=self.v_proj_weight,
1364
+ average_attn_weights=average_attn_weights,
1365
+ is_causal=is_causal,
1366
+ )
1367
+ else:
1368
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
1369
+ query,
1370
+ key,
1371
+ value,
1372
+ self.embed_dim,
1373
+ self.num_heads,
1374
+ self.in_proj_weight,
1375
+ self.in_proj_bias,
1376
+ self.bias_k,
1377
+ self.bias_v,
1378
+ self.add_zero_attn,
1379
+ self.dropout,
1380
+ self.out_proj.weight,
1381
+ self.out_proj.bias,
1382
+ training=self.training,
1383
+ key_padding_mask=key_padding_mask,
1384
+ need_weights=need_weights,
1385
+ attn_mask=attn_mask,
1386
+ average_attn_weights=average_attn_weights,
1387
+ is_causal=is_causal,
1388
+ )
1389
+ if self.batch_first and is_batched:
1390
+ return attn_output.transpose(1, 0), attn_output_weights
1391
+ else:
1392
+ return attn_output, attn_output_weights
1393
+
1394
+ def merge_masks(
1395
+ self,
1396
+ attn_mask: Optional[Tensor],
1397
+ key_padding_mask: Optional[Tensor],
1398
+ query: Tensor,
1399
+ ) -> Tuple[Optional[Tensor], Optional[int]]:
1400
+ r"""Determine mask type and combine masks if necessary.
1401
+
1402
+ If only one mask is provided, that mask
1403
+ and the corresponding mask type will be returned. If both masks are provided, they will be both
1404
+ expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
1405
+ and mask type 2 will be returned
1406
+ Args:
1407
+ attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
1408
+ key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
1409
+ query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
1410
+ Returns:
1411
+ merged_mask: merged mask
1412
+ mask_type: merged mask type (0, 1, or 2)
1413
+ """
1414
+ mask_type: Optional[int] = None
1415
+ merged_mask: Optional[Tensor] = None
1416
+
1417
+ if key_padding_mask is not None:
1418
+ mask_type = 1
1419
+ merged_mask = key_padding_mask
1420
+
1421
+ if attn_mask is not None:
1422
+ # In this branch query can't be a nested tensor, so it has a shape
1423
+ batch_size, seq_len, _ = query.shape
1424
+ mask_type = 2
1425
+
1426
+ # Always expands attn_mask to 4D
1427
+ if attn_mask.dim() == 3:
1428
+ attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
1429
+ else: # attn_mask.dim() == 2:
1430
+ attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(
1431
+ batch_size, self.num_heads, -1, -1
1432
+ )
1433
+ merged_mask = attn_mask_expanded
1434
+
1435
+ if key_padding_mask is not None:
1436
+ key_padding_mask_expanded = key_padding_mask.view(
1437
+ batch_size, 1, 1, seq_len
1438
+ ).expand(-1, self.num_heads, -1, -1)
1439
+ merged_mask = attn_mask_expanded + key_padding_mask_expanded
1440
+
1441
+ # no attn_mask and no key_padding_mask, returns None, None
1442
+ return merged_mask, mask_type
1443
+
1444
+
1445
+ class PReLU(Module):
1446
+ r"""Applies the element-wise PReLU function.
1447
+
1448
+ .. math::
1449
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
1450
+
1451
+ or
1452
+
1453
+ .. math::
1454
+ \text{PReLU}(x) =
1455
+ \begin{cases}
1456
+ x, & \text{ if } x \ge 0 \\
1457
+ ax, & \text{ otherwise }
1458
+ \end{cases}
1459
+
1460
+ Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
1461
+ parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
1462
+ a separate :math:`a` is used for each input channel.
1463
+
1464
+
1465
+ .. note::
1466
+ weight decay should not be used when learning :math:`a` for good performance.
1467
+
1468
+ .. note::
1469
+ Channel dim is the 2nd dim of input. When input has dims < 2, then there is
1470
+ no channel dim and the number of channels = 1.
1471
+
1472
+ Args:
1473
+ num_parameters (int): number of :math:`a` to learn.
1474
+ Although it takes an int as input, there is only two values are legitimate:
1475
+ 1, or the number of channels at input. Default: 1
1476
+ init (float): the initial value of :math:`a`. Default: 0.25
1477
+
1478
+ Shape:
1479
+ - Input: :math:`( *)` where `*` means, any number of additional
1480
+ dimensions.
1481
+ - Output: :math:`(*)`, same shape as the input.
1482
+
1483
+ Attributes:
1484
+ weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
1485
+
1486
+ .. image:: ../scripts/activation_images/PReLU.png
1487
+
1488
+ Examples::
1489
+
1490
+ >>> m = nn.PReLU()
1491
+ >>> input = torch.randn(2)
1492
+ >>> output = m(input)
1493
+ """
1494
+
1495
+ __constants__ = ["num_parameters"]
1496
+ num_parameters: int
1497
+
1498
+ def __init__(
1499
+ self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None
1500
+ ) -> None:
1501
+ factory_kwargs = {"device": device, "dtype": dtype}
1502
+ self.num_parameters = num_parameters
1503
+ super().__init__()
1504
+ self.init = init
1505
+ self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
1506
+ self.reset_parameters()
1507
+
1508
+ def reset_parameters(self):
1509
+ torch.nn.init.constant_(self.weight, self.init)
1510
+
1511
+ def forward(self, input: Tensor) -> Tensor:
1512
+ return F.prelu(input, self.weight)
1513
+
1514
+ def extra_repr(self) -> str:
1515
+ return f"num_parameters={self.num_parameters}"
1516
+
1517
+
1518
+ class Softsign(Module):
1519
+ r"""Applies the element-wise Softsign function.
1520
+
1521
+ .. math::
1522
+ \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
1523
+
1524
+ Shape:
1525
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1526
+ - Output: :math:`(*)`, same shape as the input.
1527
+
1528
+ .. image:: ../scripts/activation_images/Softsign.png
1529
+
1530
+ Examples::
1531
+
1532
+ >>> m = nn.Softsign()
1533
+ >>> input = torch.randn(2)
1534
+ >>> output = m(input)
1535
+ """
1536
+
1537
+ def forward(self, input: Tensor) -> Tensor:
1538
+ return F.softsign(input)
1539
+
1540
+
1541
+ class Tanhshrink(Module):
1542
+ r"""Applies the element-wise Tanhshrink function.
1543
+
1544
+ .. math::
1545
+ \text{Tanhshrink}(x) = x - \tanh(x)
1546
+
1547
+ Shape:
1548
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1549
+ - Output: :math:`(*)`, same shape as the input.
1550
+
1551
+ .. image:: ../scripts/activation_images/Tanhshrink.png
1552
+
1553
+ Examples::
1554
+
1555
+ >>> m = nn.Tanhshrink()
1556
+ >>> input = torch.randn(2)
1557
+ >>> output = m(input)
1558
+ """
1559
+
1560
+ def forward(self, input: Tensor) -> Tensor:
1561
+ return F.tanhshrink(input)
1562
+
1563
+
1564
+ class Softmin(Module):
1565
+ r"""Applies the Softmin function to an n-dimensional input Tensor.
1566
+
1567
+ Rescales them so that the elements of the n-dimensional output Tensor
1568
+ lie in the range `[0, 1]` and sum to 1.
1569
+
1570
+ Softmin is defined as:
1571
+
1572
+ .. math::
1573
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
1574
+
1575
+ Shape:
1576
+ - Input: :math:`(*)` where `*` means, any number of additional
1577
+ dimensions
1578
+ - Output: :math:`(*)`, same shape as the input
1579
+
1580
+ Args:
1581
+ dim (int): A dimension along which Softmin will be computed (so every slice
1582
+ along dim will sum to 1).
1583
+
1584
+ Returns:
1585
+ a Tensor of the same dimension and shape as the input, with
1586
+ values in the range [0, 1]
1587
+
1588
+ Examples::
1589
+
1590
+ >>> m = nn.Softmin(dim=1)
1591
+ >>> input = torch.randn(2, 3)
1592
+ >>> output = m(input)
1593
+ """
1594
+
1595
+ __constants__ = ["dim"]
1596
+ dim: Optional[int]
1597
+
1598
+ def __init__(self, dim: Optional[int] = None) -> None:
1599
+ super().__init__()
1600
+ self.dim = dim
1601
+
1602
+ def __setstate__(self, state):
1603
+ super().__setstate__(state)
1604
+ if not hasattr(self, "dim"):
1605
+ self.dim = None
1606
+
1607
+ def forward(self, input: Tensor) -> Tensor:
1608
+ return F.softmin(input, self.dim, _stacklevel=5)
1609
+
1610
+ def extra_repr(self):
1611
+ return f"dim={self.dim}"
1612
+
1613
+
1614
+ class Softmax(Module):
1615
+ r"""Applies the Softmax function to an n-dimensional input Tensor.
1616
+
1617
+ Rescales them so that the elements of the n-dimensional output Tensor
1618
+ lie in the range [0,1] and sum to 1.
1619
+
1620
+ Softmax is defined as:
1621
+
1622
+ .. math::
1623
+ \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
1624
+
1625
+ When the input Tensor is a sparse tensor then the unspecified
1626
+ values are treated as ``-inf``.
1627
+
1628
+ Shape:
1629
+ - Input: :math:`(*)` where `*` means, any number of additional
1630
+ dimensions
1631
+ - Output: :math:`(*)`, same shape as the input
1632
+
1633
+ Returns:
1634
+ a Tensor of the same dimension and shape as the input with
1635
+ values in the range [0, 1]
1636
+
1637
+ Args:
1638
+ dim (int): A dimension along which Softmax will be computed (so every slice
1639
+ along dim will sum to 1).
1640
+
1641
+ .. note::
1642
+ This module doesn't work directly with NLLLoss,
1643
+ which expects the Log to be computed between the Softmax and itself.
1644
+ Use `LogSoftmax` instead (it's faster and has better numerical properties).
1645
+
1646
+ Examples::
1647
+
1648
+ >>> m = nn.Softmax(dim=1)
1649
+ >>> input = torch.randn(2, 3)
1650
+ >>> output = m(input)
1651
+
1652
+ """
1653
+
1654
+ __constants__ = ["dim"]
1655
+ dim: Optional[int]
1656
+
1657
+ def __init__(self, dim: Optional[int] = None) -> None:
1658
+ super().__init__()
1659
+ self.dim = dim
1660
+
1661
+ def __setstate__(self, state):
1662
+ super().__setstate__(state)
1663
+ if not hasattr(self, "dim"):
1664
+ self.dim = None
1665
+
1666
+ def forward(self, input: Tensor) -> Tensor:
1667
+ return F.softmax(input, self.dim, _stacklevel=5)
1668
+
1669
+ def extra_repr(self) -> str:
1670
+ return f"dim={self.dim}"
1671
+
1672
+
1673
+ class Softmax2d(Module):
1674
+ r"""Applies SoftMax over features to each spatial location.
1675
+
1676
+ When given an image of ``Channels x Height x Width``, it will
1677
+ apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1678
+
1679
+ Shape:
1680
+ - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1681
+ - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1682
+
1683
+ Returns:
1684
+ a Tensor of the same dimension and shape as the input with
1685
+ values in the range [0, 1]
1686
+
1687
+ Examples::
1688
+
1689
+ >>> m = nn.Softmax2d()
1690
+ >>> # you softmax over the 2nd dimension
1691
+ >>> input = torch.randn(2, 3, 12, 13)
1692
+ >>> output = m(input)
1693
+ """
1694
+
1695
+ def forward(self, input: Tensor) -> Tensor:
1696
+ if input.dim() not in (3, 4):
1697
+ raise ValueError(
1698
+ f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead"
1699
+ )
1700
+ return F.softmax(input, -3, _stacklevel=5)
1701
+
1702
+
1703
+ class LogSoftmax(Module):
1704
+ r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
1705
+
1706
+ The LogSoftmax formulation can be simplified as:
1707
+
1708
+ .. math::
1709
+ \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1710
+
1711
+ Shape:
1712
+ - Input: :math:`(*)` where `*` means, any number of additional
1713
+ dimensions
1714
+ - Output: :math:`(*)`, same shape as the input
1715
+
1716
+ Args:
1717
+ dim (int): A dimension along which LogSoftmax will be computed.
1718
+
1719
+ Returns:
1720
+ a Tensor of the same dimension and shape as the input with
1721
+ values in the range [-inf, 0)
1722
+
1723
+ Examples::
1724
+
1725
+ >>> m = nn.LogSoftmax(dim=1)
1726
+ >>> input = torch.randn(2, 3)
1727
+ >>> output = m(input)
1728
+ """
1729
+
1730
+ __constants__ = ["dim"]
1731
+ dim: Optional[int]
1732
+
1733
+ def __init__(self, dim: Optional[int] = None) -> None:
1734
+ super().__init__()
1735
+ self.dim = dim
1736
+
1737
+ def __setstate__(self, state):
1738
+ super().__setstate__(state)
1739
+ if not hasattr(self, "dim"):
1740
+ self.dim = None
1741
+
1742
+ def forward(self, input: Tensor) -> Tensor:
1743
+ return F.log_softmax(input, self.dim, _stacklevel=5)
1744
+
1745
+ def extra_repr(self):
1746
+ return f"dim={self.dim}"
.venv/Lib/site-packages/torch/nn/modules/adaptive.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ from collections import namedtuple
4
+ from typing import List, Sequence
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+ from .container import ModuleList, Sequential
11
+ from .linear import Linear
12
+ from .module import Module
13
+
14
+
15
+ __all__ = ["AdaptiveLogSoftmaxWithLoss"]
16
+
17
+ _ASMoutput = namedtuple("_ASMoutput", ["output", "loss"])
18
+
19
+
20
+ class AdaptiveLogSoftmaxWithLoss(Module):
21
+ """Efficient softmax approximation.
22
+
23
+ As described in
24
+ `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
25
+ Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou
26
+ <https://arxiv.org/abs/1609.04309>`__.
27
+ """ r"""
28
+ Adaptive softmax is an approximate strategy for training models with large
29
+ output spaces. It is most effective when the label distribution is highly
30
+ imbalanced, for example in natural language modelling, where the word
31
+ frequency distribution approximately follows the `Zipf's law`_.
32
+
33
+ Adaptive softmax partitions the labels into several clusters, according to
34
+ their frequency. These clusters may contain different number of targets
35
+ each.
36
+ Additionally, clusters containing less frequent labels assign lower
37
+ dimensional embeddings to those labels, which speeds up the computation.
38
+ For each minibatch, only clusters for which at least one target is
39
+ present are evaluated.
40
+
41
+ The idea is that the clusters which are accessed frequently
42
+ (like the first one, containing most frequent labels), should also be cheap
43
+ to compute -- that is, contain a small number of assigned labels.
44
+
45
+ We highly recommend taking a look at the original paper for more details.
46
+
47
+ * :attr:`cutoffs` should be an ordered Sequence of integers sorted
48
+ in the increasing order.
49
+ It controls number of clusters and the partitioning of targets into
50
+ clusters. For example setting ``cutoffs = [10, 100, 1000]``
51
+ means that first `10` targets will be assigned
52
+ to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
53
+ assigned to the first cluster, and targets `101, 102, ..., 1000` will be
54
+ assigned to the second cluster, while targets
55
+ `1001, 1002, ..., n_classes - 1` will be assigned
56
+ to the last, third cluster.
57
+
58
+ * :attr:`div_value` is used to compute the size of each additional cluster,
59
+ which is given as
60
+ :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`,
61
+ where :math:`idx` is the cluster index (with clusters
62
+ for less frequent words having larger indices,
63
+ and indices starting from :math:`1`).
64
+
65
+ * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
66
+ adaptive softmax. See paper for details. Set to False in the official
67
+ implementation.
68
+
69
+ .. warning::
70
+ Labels passed as inputs to this module should be sorted according to
71
+ their frequency. This means that the most frequent label should be
72
+ represented by the index `0`, and the least frequent
73
+ label should be represented by the index `n_classes - 1`.
74
+
75
+ .. note::
76
+ This module returns a ``NamedTuple`` with ``output``
77
+ and ``loss`` fields. See further documentation for details.
78
+
79
+ .. note::
80
+ To compute log-probabilities for all classes, the ``log_prob``
81
+ method can be used.
82
+
83
+ Args:
84
+ in_features (int): Number of features in the input tensor
85
+ n_classes (int): Number of classes in the dataset
86
+ cutoffs (Sequence): Cutoffs used to assign targets to their buckets
87
+ div_value (float, optional): value used as an exponent to compute sizes
88
+ of the clusters. Default: 4.0
89
+ head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
90
+ adaptive softmax. Default: ``False``
91
+
92
+ Returns:
93
+ ``NamedTuple`` with ``output`` and ``loss`` fields:
94
+ * **output** is a Tensor of size ``N`` containing computed target
95
+ log probabilities for each example
96
+ * **loss** is a Scalar representing the computed negative
97
+ log likelihood loss
98
+
99
+ Shape:
100
+ - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})`
101
+ - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}`
102
+ - output1: :math:`(N)` or :math:`()`
103
+ - output2: ``Scalar``
104
+
105
+ .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
106
+ """
107
+
108
+ in_features: int
109
+ n_classes: int
110
+ cutoffs: List[int]
111
+ div_value: float
112
+ head_bias: bool
113
+ head: Linear
114
+ tail: ModuleList
115
+
116
+ def __init__(
117
+ self,
118
+ in_features: int,
119
+ n_classes: int,
120
+ cutoffs: Sequence[int],
121
+ div_value: float = 4.0,
122
+ head_bias: bool = False,
123
+ device=None,
124
+ dtype=None,
125
+ ) -> None:
126
+ factory_kwargs = {"device": device, "dtype": dtype}
127
+ super().__init__()
128
+
129
+ cutoffs = list(cutoffs)
130
+
131
+ if len(cutoffs) == 0:
132
+ raise ValueError("cutoffs should be a sequence of length larger than 0")
133
+
134
+ if (
135
+ (cutoffs != sorted(cutoffs))
136
+ or (min(cutoffs) <= 0)
137
+ or (max(cutoffs) > (n_classes - 1))
138
+ or (len(set(cutoffs)) != len(cutoffs))
139
+ or any(int(c) != c for c in cutoffs)
140
+ ):
141
+ raise ValueError(
142
+ "cutoffs should be a sequence of unique, positive "
143
+ "integers sorted in an increasing order, where "
144
+ "each value is between 1 and n_classes-1"
145
+ )
146
+
147
+ self.in_features = in_features
148
+ self.n_classes = n_classes
149
+ self.cutoffs = cutoffs + [n_classes]
150
+ self.div_value = div_value
151
+ self.head_bias = head_bias
152
+
153
+ self.shortlist_size = self.cutoffs[0]
154
+ self.n_clusters = len(self.cutoffs) - 1
155
+ self.head_size = self.shortlist_size + self.n_clusters
156
+
157
+ self.head = Linear(
158
+ self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs
159
+ )
160
+ self.tail = ModuleList()
161
+
162
+ for i in range(self.n_clusters):
163
+ hsz = int(self.in_features // (self.div_value ** (i + 1)))
164
+ osz = self.cutoffs[i + 1] - self.cutoffs[i]
165
+
166
+ projection = Sequential(
167
+ Linear(self.in_features, hsz, bias=False, **factory_kwargs),
168
+ Linear(hsz, osz, bias=False, **factory_kwargs),
169
+ )
170
+
171
+ self.tail.append(projection)
172
+
173
+ def reset_parameters(self) -> None:
174
+ self.head.reset_parameters()
175
+ for i2h, h2o in self.tail:
176
+ i2h.reset_parameters()
177
+ h2o.reset_parameters()
178
+
179
+ def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
180
+ targ_dim = target_.dim()
181
+
182
+ if targ_dim == 1:
183
+ if input_.size(0) != target_.size(0):
184
+ raise RuntimeError(
185
+ "Input and target should have the same size "
186
+ "in the batch dimension."
187
+ )
188
+ if input_.dim() != 2:
189
+ raise RuntimeError(
190
+ "1D target tensor expects 2D input tensors, "
191
+ "but found inputs with size",
192
+ input_.size(),
193
+ )
194
+ elif targ_dim == 0:
195
+ if input_.dim() != 1:
196
+ raise RuntimeError(
197
+ "0D target tensor expects 1D input tensors, "
198
+ "but found inputs with size",
199
+ input_.size(),
200
+ )
201
+ else:
202
+ raise RuntimeError(
203
+ "0D or 1D target tensor expected, " "multi-target not supported"
204
+ )
205
+
206
+ is_batched = targ_dim > 0
207
+ input = input_ if is_batched else input_.unsqueeze(0)
208
+ target = target_ if is_batched else target_.unsqueeze(0)
209
+
210
+ used_rows = 0
211
+ batch_size = target.size(0)
212
+
213
+ output = input.new_zeros(batch_size)
214
+ gather_inds = target.new_empty(batch_size)
215
+
216
+ cutoff_values = [0] + self.cutoffs
217
+ for i in range(len(cutoff_values) - 1):
218
+ low_idx = cutoff_values[i]
219
+ high_idx = cutoff_values[i + 1]
220
+
221
+ target_mask = (target >= low_idx) & (target < high_idx)
222
+ row_indices = target_mask.nonzero().squeeze()
223
+
224
+ if row_indices.numel() == 0:
225
+ continue
226
+
227
+ if i == 0:
228
+ gather_inds.index_copy_(0, row_indices, target[target_mask])
229
+
230
+ else:
231
+ relative_target = target[target_mask] - low_idx
232
+ input_subset = input.index_select(0, row_indices)
233
+
234
+ cluster_output = self.tail[i - 1](input_subset)
235
+ cluster_index = self.shortlist_size + i - 1
236
+
237
+ gather_inds.index_fill_(0, row_indices, cluster_index)
238
+ cluster_logprob = F.log_softmax(cluster_output, dim=1)
239
+ local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
240
+ output.index_copy_(0, row_indices, local_logprob.squeeze(1))
241
+
242
+ used_rows += row_indices.numel()
243
+
244
+ if used_rows != batch_size:
245
+ raise RuntimeError(
246
+ f"Target values should be in [0, {self.n_classes - 1}], "
247
+ f"but values in range [{target.min().item()}, {target.max().item()}] "
248
+ "were found. "
249
+ )
250
+
251
+ head_output = self.head(input)
252
+ head_logprob = F.log_softmax(head_output, dim=1)
253
+ output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
254
+ loss = (-output).mean()
255
+
256
+ if not is_batched:
257
+ output = output.squeeze(0)
258
+
259
+ return _ASMoutput(output, loss)
260
+
261
+ def _get_full_log_prob(self, input, head_output):
262
+ """Given input tensor, and output of ``self.head``, compute the log of the full distribution."""
263
+ out = input.new_empty((head_output.size(0), self.n_classes))
264
+ head_logprob = F.log_softmax(head_output, dim=1)
265
+
266
+ out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size]
267
+
268
+ for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
269
+ cluster_output = self.tail[i](input)
270
+ cluster_logprob = F.log_softmax(cluster_output, dim=1)
271
+ output_logprob = cluster_logprob + head_logprob[
272
+ :, self.shortlist_size + i
273
+ ].unsqueeze(1)
274
+
275
+ out[:, start_idx:stop_idx] = output_logprob
276
+
277
+ return out
278
+
279
+ def log_prob(self, input: Tensor) -> Tensor:
280
+ r"""Compute log probabilities for all :math:`\texttt{n\_classes}`.
281
+
282
+ Args:
283
+ input (Tensor): a minibatch of examples
284
+
285
+ Returns:
286
+ log-probabilities of for each class :math:`c`
287
+ in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
288
+ parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
289
+
290
+ Shape:
291
+ - Input: :math:`(N, \texttt{in\_features})`
292
+ - Output: :math:`(N, \texttt{n\_classes})`
293
+
294
+ """
295
+ head_output = self.head(input)
296
+ return self._get_full_log_prob(input, head_output)
297
+
298
+ def predict(self, input: Tensor) -> Tensor:
299
+ r"""Return the class with the highest probability for each example in the input minibatch.
300
+
301
+ This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases.
302
+
303
+ Args:
304
+ input (Tensor): a minibatch of examples
305
+
306
+ Returns:
307
+ output (Tensor): a class with the highest probability for each example
308
+
309
+ Shape:
310
+ - Input: :math:`(N, \texttt{in\_features})`
311
+ - Output: :math:`(N)`
312
+ """
313
+ head_output = self.head(input)
314
+ output = torch.argmax(head_output, dim=1)
315
+ not_in_shortlist = output >= self.shortlist_size
316
+ all_in_shortlist = not (not_in_shortlist.any())
317
+
318
+ if all_in_shortlist:
319
+ return output
320
+
321
+ elif not_in_shortlist.all():
322
+ log_prob = self._get_full_log_prob(input, head_output)
323
+ return torch.argmax(log_prob, dim=1)
324
+
325
+ else:
326
+ log_prob = self._get_full_log_prob(
327
+ input[not_in_shortlist], head_output[not_in_shortlist]
328
+ )
329
+ output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
330
+ return output
.venv/Lib/site-packages/torch/nn/modules/batchnorm.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import functional as F, init
7
+ from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter
8
+
9
+ from ._functions import SyncBatchNorm as sync_batch_norm
10
+ from .lazy import LazyModuleMixin
11
+ from .module import Module
12
+
13
+
14
+ __all__ = [
15
+ "BatchNorm1d",
16
+ "LazyBatchNorm1d",
17
+ "BatchNorm2d",
18
+ "LazyBatchNorm2d",
19
+ "BatchNorm3d",
20
+ "LazyBatchNorm3d",
21
+ "SyncBatchNorm",
22
+ ]
23
+
24
+
25
+ class _NormBase(Module):
26
+ """Common base of _InstanceNorm and _BatchNorm."""
27
+
28
+ _version = 2
29
+ __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
30
+ num_features: int
31
+ eps: float
32
+ momentum: Optional[float]
33
+ affine: bool
34
+ track_running_stats: bool
35
+ # WARNING: weight and bias purposely not defined here.
36
+ # See https://github.com/pytorch/pytorch/issues/39670
37
+
38
+ def __init__(
39
+ self,
40
+ num_features: int,
41
+ eps: float = 1e-5,
42
+ momentum: Optional[float] = 0.1,
43
+ affine: bool = True,
44
+ track_running_stats: bool = True,
45
+ device=None,
46
+ dtype=None,
47
+ ) -> None:
48
+ factory_kwargs = {"device": device, "dtype": dtype}
49
+ super().__init__()
50
+ self.num_features = num_features
51
+ self.eps = eps
52
+ self.momentum = momentum
53
+ self.affine = affine
54
+ self.track_running_stats = track_running_stats
55
+ if self.affine:
56
+ self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
57
+ self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
58
+ else:
59
+ self.register_parameter("weight", None)
60
+ self.register_parameter("bias", None)
61
+ if self.track_running_stats:
62
+ self.register_buffer(
63
+ "running_mean", torch.zeros(num_features, **factory_kwargs)
64
+ )
65
+ self.register_buffer(
66
+ "running_var", torch.ones(num_features, **factory_kwargs)
67
+ )
68
+ self.running_mean: Optional[Tensor]
69
+ self.running_var: Optional[Tensor]
70
+ self.register_buffer(
71
+ "num_batches_tracked",
72
+ torch.tensor(
73
+ 0,
74
+ dtype=torch.long,
75
+ **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
76
+ ),
77
+ )
78
+ self.num_batches_tracked: Optional[Tensor]
79
+ else:
80
+ self.register_buffer("running_mean", None)
81
+ self.register_buffer("running_var", None)
82
+ self.register_buffer("num_batches_tracked", None)
83
+ self.reset_parameters()
84
+
85
+ def reset_running_stats(self) -> None:
86
+ if self.track_running_stats:
87
+ # running_mean/running_var/num_batches... are registered at runtime depending
88
+ # if self.track_running_stats is on
89
+ self.running_mean.zero_() # type: ignore[union-attr]
90
+ self.running_var.fill_(1) # type: ignore[union-attr]
91
+ self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
92
+
93
+ def reset_parameters(self) -> None:
94
+ self.reset_running_stats()
95
+ if self.affine:
96
+ init.ones_(self.weight)
97
+ init.zeros_(self.bias)
98
+
99
+ def _check_input_dim(self, input):
100
+ raise NotImplementedError
101
+
102
+ def extra_repr(self):
103
+ return (
104
+ "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
105
+ "track_running_stats={track_running_stats}".format(**self.__dict__)
106
+ )
107
+
108
+ def _load_from_state_dict(
109
+ self,
110
+ state_dict,
111
+ prefix,
112
+ local_metadata,
113
+ strict,
114
+ missing_keys,
115
+ unexpected_keys,
116
+ error_msgs,
117
+ ):
118
+ version = local_metadata.get("version", None)
119
+
120
+ if (version is None or version < 2) and self.track_running_stats:
121
+ # at version 2: added num_batches_tracked buffer
122
+ # this should have a default value of 0
123
+ num_batches_tracked_key = prefix + "num_batches_tracked"
124
+ if num_batches_tracked_key not in state_dict:
125
+ state_dict[num_batches_tracked_key] = (
126
+ self.num_batches_tracked
127
+ if self.num_batches_tracked is not None
128
+ and self.num_batches_tracked.device != torch.device("meta")
129
+ else torch.tensor(0, dtype=torch.long)
130
+ )
131
+
132
+ super()._load_from_state_dict(
133
+ state_dict,
134
+ prefix,
135
+ local_metadata,
136
+ strict,
137
+ missing_keys,
138
+ unexpected_keys,
139
+ error_msgs,
140
+ )
141
+
142
+
143
+ class _BatchNorm(_NormBase):
144
+ def __init__(
145
+ self,
146
+ num_features: int,
147
+ eps: float = 1e-5,
148
+ momentum: Optional[float] = 0.1,
149
+ affine: bool = True,
150
+ track_running_stats: bool = True,
151
+ device=None,
152
+ dtype=None,
153
+ ) -> None:
154
+ factory_kwargs = {"device": device, "dtype": dtype}
155
+ super().__init__(
156
+ num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
157
+ )
158
+
159
+ def forward(self, input: Tensor) -> Tensor:
160
+ self._check_input_dim(input)
161
+
162
+ # exponential_average_factor is set to self.momentum
163
+ # (when it is available) only so that it gets updated
164
+ # in ONNX graph when this node is exported to ONNX.
165
+ if self.momentum is None:
166
+ exponential_average_factor = 0.0
167
+ else:
168
+ exponential_average_factor = self.momentum
169
+
170
+ if self.training and self.track_running_stats:
171
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
172
+ if self.num_batches_tracked is not None: # type: ignore[has-type]
173
+ self.num_batches_tracked.add_(1) # type: ignore[has-type]
174
+ if self.momentum is None: # use cumulative moving average
175
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
176
+ else: # use exponential moving average
177
+ exponential_average_factor = self.momentum
178
+
179
+ r"""
180
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
181
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
182
+ """
183
+ if self.training:
184
+ bn_training = True
185
+ else:
186
+ bn_training = (self.running_mean is None) and (self.running_var is None)
187
+
188
+ r"""
189
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
190
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
191
+ used for normalization (i.e. in eval mode when buffers are not None).
192
+ """
193
+ return F.batch_norm(
194
+ input,
195
+ # If buffers are not to be tracked, ensure that they won't be updated
196
+ self.running_mean
197
+ if not self.training or self.track_running_stats
198
+ else None,
199
+ self.running_var if not self.training or self.track_running_stats else None,
200
+ self.weight,
201
+ self.bias,
202
+ bn_training,
203
+ exponential_average_factor,
204
+ self.eps,
205
+ )
206
+
207
+
208
+ class _LazyNormBase(LazyModuleMixin, _NormBase):
209
+ weight: UninitializedParameter # type: ignore[assignment]
210
+ bias: UninitializedParameter # type: ignore[assignment]
211
+
212
+ def __init__(
213
+ self,
214
+ eps=1e-5,
215
+ momentum=0.1,
216
+ affine=True,
217
+ track_running_stats=True,
218
+ device=None,
219
+ dtype=None,
220
+ ) -> None:
221
+ factory_kwargs = {"device": device, "dtype": dtype}
222
+ super().__init__(
223
+ # affine and track_running_stats are hardcoded to False to
224
+ # avoid creating tensors that will soon be overwritten.
225
+ 0,
226
+ eps,
227
+ momentum,
228
+ False,
229
+ False,
230
+ **factory_kwargs,
231
+ )
232
+ self.affine = affine
233
+ self.track_running_stats = track_running_stats
234
+ if self.affine:
235
+ self.weight = UninitializedParameter(**factory_kwargs)
236
+ self.bias = UninitializedParameter(**factory_kwargs)
237
+ if self.track_running_stats:
238
+ self.running_mean = UninitializedBuffer(**factory_kwargs)
239
+ self.running_var = UninitializedBuffer(**factory_kwargs)
240
+ self.num_batches_tracked = torch.tensor(
241
+ 0,
242
+ dtype=torch.long,
243
+ **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
244
+ )
245
+
246
+ def reset_parameters(self) -> None:
247
+ if not self.has_uninitialized_params() and self.num_features != 0:
248
+ super().reset_parameters()
249
+
250
+ def initialize_parameters(self, input) -> None: # type: ignore[override]
251
+ if self.has_uninitialized_params():
252
+ self.num_features = input.shape[1]
253
+ if self.affine:
254
+ assert isinstance(self.weight, UninitializedParameter)
255
+ assert isinstance(self.bias, UninitializedParameter)
256
+ self.weight.materialize((self.num_features,))
257
+ self.bias.materialize((self.num_features,))
258
+ if self.track_running_stats:
259
+ self.running_mean.materialize( # type:ignore[union-attr]
260
+ (self.num_features,)
261
+ )
262
+ self.running_var.materialize( # type:ignore[union-attr]
263
+ (self.num_features,)
264
+ )
265
+ self.reset_parameters()
266
+
267
+
268
+ class BatchNorm1d(_BatchNorm):
269
+ r"""Applies Batch Normalization over a 2D or 3D input.
270
+
271
+ Method described in the paper
272
+ `Batch Normalization: Accelerating Deep Network Training by Reducing
273
+ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
274
+
275
+ .. math::
276
+
277
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
278
+
279
+ The mean and standard-deviation are calculated per-dimension over
280
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
281
+ of size `C` (where `C` is the number of features or channels of the input). By default, the
282
+ elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
283
+ At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
284
+ equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
285
+ moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to
286
+ ``torch.var(input, unbiased=True)``.
287
+
288
+ Also by default, during training this layer keeps running estimates of its
289
+ computed mean and variance, which are then used for normalization during
290
+ evaluation. The running estimates are kept with a default :attr:`momentum`
291
+ of 0.1.
292
+
293
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
294
+ keep running estimates, and batch statistics are instead used during
295
+ evaluation time as well.
296
+
297
+ .. note::
298
+ This :attr:`momentum` argument is different from one used in optimizer
299
+ classes and the conventional notion of momentum. Mathematically, the
300
+ update rule for running statistics here is
301
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
302
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
303
+ new observed value.
304
+
305
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
306
+ on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
307
+
308
+ Args:
309
+ num_features: number of features or channels :math:`C` of the input
310
+ eps: a value added to the denominator for numerical stability.
311
+ Default: 1e-5
312
+ momentum: the value used for the running_mean and running_var
313
+ computation. Can be set to ``None`` for cumulative moving average
314
+ (i.e. simple average). Default: 0.1
315
+ affine: a boolean value that when set to ``True``, this module has
316
+ learnable affine parameters. Default: ``True``
317
+ track_running_stats: a boolean value that when set to ``True``, this
318
+ module tracks the running mean and variance, and when set to ``False``,
319
+ this module does not track such statistics, and initializes statistics
320
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
321
+ When these buffers are ``None``, this module always uses batch statistics.
322
+ in both training and eval modes. Default: ``True``
323
+
324
+ Shape:
325
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
326
+ :math:`C` is the number of features or channels, and :math:`L` is the sequence length
327
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
328
+
329
+ Examples::
330
+
331
+ >>> # With Learnable Parameters
332
+ >>> m = nn.BatchNorm1d(100)
333
+ >>> # Without Learnable Parameters
334
+ >>> m = nn.BatchNorm1d(100, affine=False)
335
+ >>> input = torch.randn(20, 100)
336
+ >>> output = m(input)
337
+ """
338
+
339
+ def _check_input_dim(self, input):
340
+ if input.dim() != 2 and input.dim() != 3:
341
+ raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
342
+
343
+
344
+ class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
345
+ r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
346
+
347
+ Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
348
+ from the ``input.size(1)``.
349
+ The attributes that will be lazily initialized are `weight`, `bias`,
350
+ `running_mean` and `running_var`.
351
+
352
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
353
+ on lazy modules and their limitations.
354
+
355
+ Args:
356
+ eps: a value added to the denominator for numerical stability.
357
+ Default: 1e-5
358
+ momentum: the value used for the running_mean and running_var
359
+ computation. Can be set to ``None`` for cumulative moving average
360
+ (i.e. simple average). Default: 0.1
361
+ affine: a boolean value that when set to ``True``, this module has
362
+ learnable affine parameters. Default: ``True``
363
+ track_running_stats: a boolean value that when set to ``True``, this
364
+ module tracks the running mean and variance, and when set to ``False``,
365
+ this module does not track such statistics, and initializes statistics
366
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
367
+ When these buffers are ``None``, this module always uses batch statistics.
368
+ in both training and eval modes. Default: ``True``
369
+ """
370
+
371
+ cls_to_become = BatchNorm1d # type: ignore[assignment]
372
+
373
+ def _check_input_dim(self, input):
374
+ if input.dim() != 2 and input.dim() != 3:
375
+ raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
376
+
377
+
378
+ class BatchNorm2d(_BatchNorm):
379
+ r"""Applies Batch Normalization over a 4D input.
380
+
381
+ 4D is a mini-batch of 2D inputs
382
+ with additional channel dimension. Method described in the paper
383
+ `Batch Normalization: Accelerating Deep Network Training by Reducing
384
+ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
385
+
386
+ .. math::
387
+
388
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
389
+
390
+ The mean and standard-deviation are calculated per-dimension over
391
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
392
+ of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
393
+ to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
394
+ standard-deviation is calculated via the biased estimator, equivalent to
395
+ ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
396
+ standard-deviation is calculated via the unbiased estimator, equivalent to
397
+ ``torch.var(input, unbiased=True)``.
398
+
399
+ Also by default, during training this layer keeps running estimates of its
400
+ computed mean and variance, which are then used for normalization during
401
+ evaluation. The running estimates are kept with a default :attr:`momentum`
402
+ of 0.1.
403
+
404
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
405
+ keep running estimates, and batch statistics are instead used during
406
+ evaluation time as well.
407
+
408
+ .. note::
409
+ This :attr:`momentum` argument is different from one used in optimizer
410
+ classes and the conventional notion of momentum. Mathematically, the
411
+ update rule for running statistics here is
412
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
413
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
414
+ new observed value.
415
+
416
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
417
+ on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
418
+
419
+ Args:
420
+ num_features: :math:`C` from an expected input of size
421
+ :math:`(N, C, H, W)`
422
+ eps: a value added to the denominator for numerical stability.
423
+ Default: 1e-5
424
+ momentum: the value used for the running_mean and running_var
425
+ computation. Can be set to ``None`` for cumulative moving average
426
+ (i.e. simple average). Default: 0.1
427
+ affine: a boolean value that when set to ``True``, this module has
428
+ learnable affine parameters. Default: ``True``
429
+ track_running_stats: a boolean value that when set to ``True``, this
430
+ module tracks the running mean and variance, and when set to ``False``,
431
+ this module does not track such statistics, and initializes statistics
432
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
433
+ When these buffers are ``None``, this module always uses batch statistics.
434
+ in both training and eval modes. Default: ``True``
435
+
436
+ Shape:
437
+ - Input: :math:`(N, C, H, W)`
438
+ - Output: :math:`(N, C, H, W)` (same shape as input)
439
+
440
+ Examples::
441
+
442
+ >>> # With Learnable Parameters
443
+ >>> m = nn.BatchNorm2d(100)
444
+ >>> # Without Learnable Parameters
445
+ >>> m = nn.BatchNorm2d(100, affine=False)
446
+ >>> input = torch.randn(20, 100, 35, 45)
447
+ >>> output = m(input)
448
+ """
449
+
450
+ def _check_input_dim(self, input):
451
+ if input.dim() != 4:
452
+ raise ValueError(f"expected 4D input (got {input.dim()}D input)")
453
+
454
+
455
+ class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
456
+ r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
457
+
458
+ Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
459
+ from the ``input.size(1)``.
460
+ The attributes that will be lazily initialized are `weight`, `bias`,
461
+ `running_mean` and `running_var`.
462
+
463
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
464
+ on lazy modules and their limitations.
465
+
466
+ Args:
467
+ eps: a value added to the denominator for numerical stability.
468
+ Default: 1e-5
469
+ momentum: the value used for the running_mean and running_var
470
+ computation. Can be set to ``None`` for cumulative moving average
471
+ (i.e. simple average). Default: 0.1
472
+ affine: a boolean value that when set to ``True``, this module has
473
+ learnable affine parameters. Default: ``True``
474
+ track_running_stats: a boolean value that when set to ``True``, this
475
+ module tracks the running mean and variance, and when set to ``False``,
476
+ this module does not track such statistics, and initializes statistics
477
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
478
+ When these buffers are ``None``, this module always uses batch statistics.
479
+ in both training and eval modes. Default: ``True``
480
+ """
481
+
482
+ cls_to_become = BatchNorm2d # type: ignore[assignment]
483
+
484
+ def _check_input_dim(self, input):
485
+ if input.dim() != 4:
486
+ raise ValueError(f"expected 4D input (got {input.dim()}D input)")
487
+
488
+
489
+ class BatchNorm3d(_BatchNorm):
490
+ r"""Applies Batch Normalization over a 5D input.
491
+
492
+ 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
493
+ `Batch Normalization: Accelerating Deep Network Training by Reducing
494
+ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
495
+
496
+ .. math::
497
+
498
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
499
+
500
+ The mean and standard-deviation are calculated per-dimension over
501
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
502
+ of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
503
+ to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
504
+ standard-deviation is calculated via the biased estimator, equivalent to
505
+ ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
506
+ standard-deviation is calculated via the unbiased estimator, equivalent to
507
+ ``torch.var(input, unbiased=True)``.
508
+
509
+ Also by default, during training this layer keeps running estimates of its
510
+ computed mean and variance, which are then used for normalization during
511
+ evaluation. The running estimates are kept with a default :attr:`momentum`
512
+ of 0.1.
513
+
514
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
515
+ keep running estimates, and batch statistics are instead used during
516
+ evaluation time as well.
517
+
518
+ .. note::
519
+ This :attr:`momentum` argument is different from one used in optimizer
520
+ classes and the conventional notion of momentum. Mathematically, the
521
+ update rule for running statistics here is
522
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
523
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
524
+ new observed value.
525
+
526
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
527
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
528
+ or Spatio-temporal Batch Normalization.
529
+
530
+ Args:
531
+ num_features: :math:`C` from an expected input of size
532
+ :math:`(N, C, D, H, W)`
533
+ eps: a value added to the denominator for numerical stability.
534
+ Default: 1e-5
535
+ momentum: the value used for the running_mean and running_var
536
+ computation. Can be set to ``None`` for cumulative moving average
537
+ (i.e. simple average). Default: 0.1
538
+ affine: a boolean value that when set to ``True``, this module has
539
+ learnable affine parameters. Default: ``True``
540
+ track_running_stats: a boolean value that when set to ``True``, this
541
+ module tracks the running mean and variance, and when set to ``False``,
542
+ this module does not track such statistics, and initializes statistics
543
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
544
+ When these buffers are ``None``, this module always uses batch statistics.
545
+ in both training and eval modes. Default: ``True``
546
+
547
+ Shape:
548
+ - Input: :math:`(N, C, D, H, W)`
549
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
550
+
551
+ Examples::
552
+
553
+ >>> # With Learnable Parameters
554
+ >>> m = nn.BatchNorm3d(100)
555
+ >>> # Without Learnable Parameters
556
+ >>> m = nn.BatchNorm3d(100, affine=False)
557
+ >>> input = torch.randn(20, 100, 35, 45, 10)
558
+ >>> output = m(input)
559
+ """
560
+
561
+ def _check_input_dim(self, input):
562
+ if input.dim() != 5:
563
+ raise ValueError(f"expected 5D input (got {input.dim()}D input)")
564
+
565
+
566
+ class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
567
+ r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
568
+
569
+ Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
570
+ from the ``input.size(1)``.
571
+ The attributes that will be lazily initialized are `weight`, `bias`,
572
+ `running_mean` and `running_var`.
573
+
574
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
575
+ on lazy modules and their limitations.
576
+
577
+ Args:
578
+ eps: a value added to the denominator for numerical stability.
579
+ Default: 1e-5
580
+ momentum: the value used for the running_mean and running_var
581
+ computation. Can be set to ``None`` for cumulative moving average
582
+ (i.e. simple average). Default: 0.1
583
+ affine: a boolean value that when set to ``True``, this module has
584
+ learnable affine parameters. Default: ``True``
585
+ track_running_stats: a boolean value that when set to ``True``, this
586
+ module tracks the running mean and variance, and when set to ``False``,
587
+ this module does not track such statistics, and initializes statistics
588
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
589
+ When these buffers are ``None``, this module always uses batch statistics.
590
+ in both training and eval modes. Default: ``True``
591
+ """
592
+
593
+ cls_to_become = BatchNorm3d # type: ignore[assignment]
594
+
595
+ def _check_input_dim(self, input):
596
+ if input.dim() != 5:
597
+ raise ValueError(f"expected 5D input (got {input.dim()}D input)")
598
+
599
+
600
+ class SyncBatchNorm(_BatchNorm):
601
+ r"""Applies Batch Normalization over a N-Dimensional input.
602
+
603
+ The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
604
+ `Batch Normalization: Accelerating Deep Network Training by Reducing
605
+ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
606
+
607
+ .. math::
608
+
609
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
610
+
611
+ The mean and standard-deviation are calculated per-dimension over all
612
+ mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
613
+ are learnable parameter vectors of size `C` (where `C` is the input size).
614
+ By default, the elements of :math:`\gamma` are sampled from
615
+ :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
616
+ The standard-deviation is calculated via the biased estimator, equivalent to
617
+ `torch.var(input, unbiased=False)`.
618
+
619
+ Also by default, during training this layer keeps running estimates of its
620
+ computed mean and variance, which are then used for normalization during
621
+ evaluation. The running estimates are kept with a default :attr:`momentum`
622
+ of 0.1.
623
+
624
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
625
+ keep running estimates, and batch statistics are instead used during
626
+ evaluation time as well.
627
+
628
+ .. note::
629
+ This :attr:`momentum` argument is different from one used in optimizer
630
+ classes and the conventional notion of momentum. Mathematically, the
631
+ update rule for running statistics here is
632
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
633
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
634
+ new observed value.
635
+
636
+ Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
637
+ statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
638
+ Normalization or Spatio-temporal Batch Normalization.
639
+
640
+ Currently :class:`SyncBatchNorm` only supports
641
+ :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
642
+ :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
643
+ :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
644
+ Network with DDP.
645
+
646
+ Args:
647
+ num_features: :math:`C` from an expected input of size
648
+ :math:`(N, C, +)`
649
+ eps: a value added to the denominator for numerical stability.
650
+ Default: ``1e-5``
651
+ momentum: the value used for the running_mean and running_var
652
+ computation. Can be set to ``None`` for cumulative moving average
653
+ (i.e. simple average). Default: 0.1
654
+ affine: a boolean value that when set to ``True``, this module has
655
+ learnable affine parameters. Default: ``True``
656
+ track_running_stats: a boolean value that when set to ``True``, this
657
+ module tracks the running mean and variance, and when set to ``False``,
658
+ this module does not track such statistics, and initializes statistics
659
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
660
+ When these buffers are ``None``, this module always uses batch statistics.
661
+ in both training and eval modes. Default: ``True``
662
+ process_group: synchronization of stats happen within each process group
663
+ individually. Default behavior is synchronization across the whole
664
+ world
665
+
666
+ Shape:
667
+ - Input: :math:`(N, C, +)`
668
+ - Output: :math:`(N, C, +)` (same shape as input)
669
+
670
+ .. note::
671
+ Synchronization of batchnorm statistics occurs only while training, i.e.
672
+ synchronization is disabled when ``model.eval()`` is set or if
673
+ ``self.training`` is otherwise ``False``.
674
+
675
+ Examples::
676
+
677
+ >>> # xdoctest: +SKIP
678
+ >>> # With Learnable Parameters
679
+ >>> m = nn.SyncBatchNorm(100)
680
+ >>> # creating process group (optional)
681
+ >>> # ranks is a list of int identifying rank ids.
682
+ >>> ranks = list(range(8))
683
+ >>> r1, r2 = ranks[:4], ranks[4:]
684
+ >>> # Note: every rank calls into new_group for every
685
+ >>> # process group created, even if that rank is not
686
+ >>> # part of the group.
687
+ >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
688
+ >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
689
+ >>> # Without Learnable Parameters
690
+ >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
691
+ >>> input = torch.randn(20, 100, 35, 45, 10)
692
+ >>> output = m(input)
693
+
694
+ >>> # network is nn.BatchNorm layer
695
+ >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
696
+ >>> # only single gpu per process is currently supported
697
+ >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
698
+ >>> sync_bn_network,
699
+ >>> device_ids=[args.local_rank],
700
+ >>> output_device=args.local_rank)
701
+ """
702
+
703
+ def __init__(
704
+ self,
705
+ num_features: int,
706
+ eps: float = 1e-5,
707
+ momentum: Optional[float] = 0.1,
708
+ affine: bool = True,
709
+ track_running_stats: bool = True,
710
+ process_group: Optional[Any] = None,
711
+ device=None,
712
+ dtype=None,
713
+ ) -> None:
714
+ factory_kwargs = {"device": device, "dtype": dtype}
715
+ super().__init__(
716
+ num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
717
+ )
718
+ self.process_group = process_group
719
+
720
+ def _check_input_dim(self, input):
721
+ if input.dim() < 2:
722
+ raise ValueError(f"expected at least 2D input (got {input.dim()}D input)")
723
+
724
+ def _check_non_zero_input_channels(self, input):
725
+ if input.size(1) == 0:
726
+ raise ValueError(
727
+ "SyncBatchNorm number of input channels should be non-zero"
728
+ )
729
+
730
+ def forward(self, input: Tensor) -> Tensor:
731
+ self._check_input_dim(input)
732
+ self._check_non_zero_input_channels(input)
733
+
734
+ # exponential_average_factor is set to self.momentum
735
+ # (when it is available) only so that it gets updated
736
+ # in ONNX graph when this node is exported to ONNX.
737
+ if self.momentum is None:
738
+ exponential_average_factor = 0.0
739
+ else:
740
+ exponential_average_factor = self.momentum
741
+
742
+ if self.training and self.track_running_stats:
743
+ assert self.num_batches_tracked is not None
744
+ self.num_batches_tracked.add_(1)
745
+ if self.momentum is None: # use cumulative moving average
746
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
747
+ else: # use exponential moving average
748
+ exponential_average_factor = self.momentum
749
+
750
+ r"""
751
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
752
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
753
+ """
754
+ if self.training:
755
+ bn_training = True
756
+ else:
757
+ bn_training = (self.running_mean is None) and (self.running_var is None)
758
+
759
+ r"""
760
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
761
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
762
+ used for normalization (i.e. in eval mode when buffers are not None).
763
+ """
764
+ # If buffers are not to be tracked, ensure that they won't be updated
765
+ running_mean = (
766
+ self.running_mean if not self.training or self.track_running_stats else None
767
+ )
768
+ running_var = (
769
+ self.running_var if not self.training or self.track_running_stats else None
770
+ )
771
+
772
+ # Don't sync batchnorm stats in inference mode (model.eval()).
773
+ need_sync = (
774
+ bn_training
775
+ and self.training
776
+ and torch.distributed.is_available()
777
+ and torch.distributed.is_initialized()
778
+ )
779
+ if need_sync:
780
+ # currently only GPU/PrivateUse1 input is supported
781
+ if input.device.type not in [
782
+ "cuda",
783
+ torch._C._get_privateuse1_backend_name(),
784
+ ]:
785
+ raise ValueError(
786
+ "SyncBatchNorm expected input tensor to be on GPU or "
787
+ f"{torch._C._get_privateuse1_backend_name()}"
788
+ )
789
+
790
+ process_group = torch.distributed.group.WORLD
791
+ if self.process_group:
792
+ process_group = self.process_group
793
+ world_size = torch.distributed.get_world_size(process_group)
794
+ need_sync = world_size > 1
795
+
796
+ # fallback to framework BN when synchronization is not necessary
797
+ if not need_sync:
798
+ return F.batch_norm(
799
+ input,
800
+ running_mean,
801
+ running_var,
802
+ self.weight,
803
+ self.bias,
804
+ bn_training,
805
+ exponential_average_factor,
806
+ self.eps,
807
+ )
808
+ else:
809
+ assert bn_training
810
+ return sync_batch_norm.apply(
811
+ input,
812
+ self.weight,
813
+ self.bias,
814
+ running_mean,
815
+ running_var,
816
+ self.eps,
817
+ exponential_average_factor,
818
+ process_group, # type: ignore[possibly-undefined]
819
+ world_size, # type: ignore[possibly-undefined]
820
+ )
821
+
822
+ @classmethod
823
+ def convert_sync_batchnorm(cls, module, process_group=None):
824
+ r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
825
+
826
+ Args:
827
+ module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
828
+ process_group (optional): process group to scope synchronization,
829
+ default is the whole world
830
+
831
+ Returns:
832
+ The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
833
+ layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
834
+ a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
835
+ instead.
836
+
837
+ Example::
838
+
839
+ >>> # Network with nn.BatchNorm layer
840
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
841
+ >>> module = torch.nn.Sequential(
842
+ >>> torch.nn.Linear(20, 100),
843
+ >>> torch.nn.BatchNorm1d(100),
844
+ >>> ).cuda()
845
+ >>> # creating process group (optional)
846
+ >>> # ranks is a list of int identifying rank ids.
847
+ >>> ranks = list(range(8))
848
+ >>> r1, r2 = ranks[:4], ranks[4:]
849
+ >>> # Note: every rank calls into new_group for every
850
+ >>> # process group created, even if that rank is not
851
+ >>> # part of the group.
852
+ >>> # xdoctest: +SKIP("distributed")
853
+ >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
854
+ >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
855
+ >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
856
+
857
+ """
858
+ module_output = module
859
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
860
+ module_output = torch.nn.SyncBatchNorm(
861
+ module.num_features,
862
+ module.eps,
863
+ module.momentum,
864
+ module.affine,
865
+ module.track_running_stats,
866
+ process_group,
867
+ )
868
+ if module.affine:
869
+ with torch.no_grad():
870
+ module_output.weight = module.weight
871
+ module_output.bias = module.bias
872
+ module_output.running_mean = module.running_mean
873
+ module_output.running_var = module.running_var
874
+ module_output.num_batches_tracked = module.num_batches_tracked
875
+ module_output.training = module.training
876
+ if hasattr(module, "qconfig"):
877
+ module_output.qconfig = module.qconfig
878
+ for name, child in module.named_children():
879
+ module_output.add_module(
880
+ name, cls.convert_sync_batchnorm(child, process_group)
881
+ )
882
+ del module
883
+ return module_output
.venv/Lib/site-packages/torch/nn/modules/channelshuffle.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from torch import Tensor
3
+
4
+ from .module import Module
5
+
6
+
7
+ __all__ = ["ChannelShuffle"]
8
+
9
+
10
+ class ChannelShuffle(Module):
11
+ r"""Divides and rearranges the channels in a tensor.
12
+
13
+ This operation divides the channels in a tensor of shape :math:`(N, C, *)`
14
+ into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them,
15
+ while retaining the original tensor shape in the final output.
16
+
17
+ Args:
18
+ groups (int): number of groups to divide channels in.
19
+
20
+ Examples::
21
+
22
+ >>> channel_shuffle = nn.ChannelShuffle(2)
23
+ >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
24
+ >>> input
25
+ tensor([[[[ 1., 2.],
26
+ [ 3., 4.]],
27
+ [[ 5., 6.],
28
+ [ 7., 8.]],
29
+ [[ 9., 10.],
30
+ [11., 12.]],
31
+ [[13., 14.],
32
+ [15., 16.]]]])
33
+ >>> output = channel_shuffle(input)
34
+ >>> output
35
+ tensor([[[[ 1., 2.],
36
+ [ 3., 4.]],
37
+ [[ 9., 10.],
38
+ [11., 12.]],
39
+ [[ 5., 6.],
40
+ [ 7., 8.]],
41
+ [[13., 14.],
42
+ [15., 16.]]]])
43
+ """
44
+
45
+ __constants__ = ["groups"]
46
+ groups: int
47
+
48
+ def __init__(self, groups: int) -> None:
49
+ super().__init__()
50
+ self.groups = groups
51
+
52
+ def forward(self, input: Tensor) -> Tensor:
53
+ return F.channel_shuffle(input, self.groups)
54
+
55
+ def extra_repr(self) -> str:
56
+ return f"groups={self.groups}"
.venv/Lib/site-packages/torch/nn/modules/container.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import operator
4
+ from collections import abc as container_abcs, OrderedDict
5
+ from itertools import chain, islice
6
+ from typing import (
7
+ Any,
8
+ Dict,
9
+ Iterable,
10
+ Iterator,
11
+ Mapping,
12
+ Optional,
13
+ overload,
14
+ Tuple,
15
+ TypeVar,
16
+ Union,
17
+ )
18
+ from typing_extensions import deprecated, Self
19
+
20
+ import torch
21
+ from torch._jit_internal import _copy_to_script_wrapper
22
+ from torch.nn.parameter import Parameter
23
+
24
+ from .module import Module
25
+
26
+
27
+ __all__ = [
28
+ "Container",
29
+ "Sequential",
30
+ "ModuleList",
31
+ "ModuleDict",
32
+ "ParameterList",
33
+ "ParameterDict",
34
+ ]
35
+
36
+ T = TypeVar("T", bound=Module)
37
+
38
+
39
+ # Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
40
+ def _addindent(s_, numSpaces):
41
+ s = s_.split("\n")
42
+ # don't do anything for single-line stuff
43
+ if len(s) == 1:
44
+ return s_
45
+ first = s.pop(0)
46
+ s = [(numSpaces * " ") + line for line in s]
47
+ s = "\n".join(s)
48
+ s = first + "\n" + s
49
+ return s
50
+
51
+
52
+ @deprecated(
53
+ "`nn.Container` is deprecated. "
54
+ "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.",
55
+ category=FutureWarning,
56
+ )
57
+ class Container(Module):
58
+ def __init__(self, **kwargs: Any) -> None:
59
+ super().__init__()
60
+ for key, value in kwargs.items():
61
+ self.add_module(key, value)
62
+
63
+
64
+ class Sequential(Module):
65
+ r"""A sequential container.
66
+
67
+ Modules will be added to it in the order they are passed in the
68
+ constructor. Alternatively, an ``OrderedDict`` of modules can be
69
+ passed in. The ``forward()`` method of ``Sequential`` accepts any
70
+ input and forwards it to the first module it contains. It then
71
+ "chains" outputs to inputs sequentially for each subsequent module,
72
+ finally returning the output of the last module.
73
+
74
+ The value a ``Sequential`` provides over manually calling a sequence
75
+ of modules is that it allows treating the whole container as a
76
+ single module, such that performing a transformation on the
77
+ ``Sequential`` applies to each of the modules it stores (which are
78
+ each a registered submodule of the ``Sequential``).
79
+
80
+ What's the difference between a ``Sequential`` and a
81
+ :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
82
+ sounds like--a list for storing ``Module`` s! On the other hand,
83
+ the layers in a ``Sequential`` are connected in a cascading way.
84
+
85
+ Example::
86
+
87
+ # Using Sequential to create a small model. When `model` is run,
88
+ # input will first be passed to `Conv2d(1,20,5)`. The output of
89
+ # `Conv2d(1,20,5)` will be used as the input to the first
90
+ # `ReLU`; the output of the first `ReLU` will become the input
91
+ # for `Conv2d(20,64,5)`. Finally, the output of
92
+ # `Conv2d(20,64,5)` will be used as input to the second `ReLU`
93
+ model = nn.Sequential(
94
+ nn.Conv2d(1,20,5),
95
+ nn.ReLU(),
96
+ nn.Conv2d(20,64,5),
97
+ nn.ReLU()
98
+ )
99
+
100
+ # Using Sequential with OrderedDict. This is functionally the
101
+ # same as the above code
102
+ model = nn.Sequential(OrderedDict([
103
+ ('conv1', nn.Conv2d(1,20,5)),
104
+ ('relu1', nn.ReLU()),
105
+ ('conv2', nn.Conv2d(20,64,5)),
106
+ ('relu2', nn.ReLU())
107
+ ]))
108
+ """
109
+
110
+ _modules: Dict[str, Module] # type: ignore[assignment]
111
+
112
+ @overload
113
+ def __init__(self, *args: Module) -> None:
114
+ ...
115
+
116
+ @overload
117
+ def __init__(self, arg: "OrderedDict[str, Module]") -> None:
118
+ ...
119
+
120
+ def __init__(self, *args):
121
+ super().__init__()
122
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
123
+ for key, module in args[0].items():
124
+ self.add_module(key, module)
125
+ else:
126
+ for idx, module in enumerate(args):
127
+ self.add_module(str(idx), module)
128
+
129
+ def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
130
+ """Get the idx-th item of the iterator."""
131
+ size = len(self)
132
+ idx = operator.index(idx)
133
+ if not -size <= idx < size:
134
+ raise IndexError(f"index {idx} is out of range")
135
+ idx %= size
136
+ return next(islice(iterator, idx, None))
137
+
138
+ @_copy_to_script_wrapper
139
+ def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]:
140
+ if isinstance(idx, slice):
141
+ return self.__class__(OrderedDict(list(self._modules.items())[idx]))
142
+ else:
143
+ return self._get_item_by_idx(self._modules.values(), idx)
144
+
145
+ def __setitem__(self, idx: int, module: Module) -> None:
146
+ key: str = self._get_item_by_idx(self._modules.keys(), idx)
147
+ return setattr(self, key, module)
148
+
149
+ def __delitem__(self, idx: Union[slice, int]) -> None:
150
+ if isinstance(idx, slice):
151
+ for key in list(self._modules.keys())[idx]:
152
+ delattr(self, key)
153
+ else:
154
+ key = self._get_item_by_idx(self._modules.keys(), idx)
155
+ delattr(self, key)
156
+ # To preserve numbering
157
+ str_indices = [str(i) for i in range(len(self._modules))]
158
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
159
+
160
+ @_copy_to_script_wrapper
161
+ def __len__(self) -> int:
162
+ return len(self._modules)
163
+
164
+ def __add__(self, other) -> "Sequential":
165
+ if isinstance(other, Sequential):
166
+ ret = Sequential()
167
+ for layer in self:
168
+ ret.append(layer)
169
+ for layer in other:
170
+ ret.append(layer)
171
+ return ret
172
+ else:
173
+ raise ValueError(
174
+ "add operator supports only objects "
175
+ f"of Sequential class, but {str(type(other))} is given."
176
+ )
177
+
178
+ def pop(self, key: Union[int, slice]) -> Module:
179
+ v = self[key]
180
+ del self[key]
181
+ return v
182
+
183
+ def __iadd__(self, other) -> Self:
184
+ if isinstance(other, Sequential):
185
+ offset = len(self)
186
+ for i, module in enumerate(other):
187
+ self.add_module(str(i + offset), module)
188
+ return self
189
+ else:
190
+ raise ValueError(
191
+ "add operator supports only objects "
192
+ f"of Sequential class, but {str(type(other))} is given."
193
+ )
194
+
195
+ def __mul__(self, other: int) -> "Sequential":
196
+ if not isinstance(other, int):
197
+ raise TypeError(
198
+ f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
199
+ )
200
+ elif other <= 0:
201
+ raise ValueError(
202
+ f"Non-positive multiplication factor {other} for {type(self)}"
203
+ )
204
+ else:
205
+ combined = Sequential()
206
+ offset = 0
207
+ for _ in range(other):
208
+ for module in self:
209
+ combined.add_module(str(offset), module)
210
+ offset += 1
211
+ return combined
212
+
213
+ def __rmul__(self, other: int) -> "Sequential":
214
+ return self.__mul__(other)
215
+
216
+ def __imul__(self, other: int) -> Self:
217
+ if not isinstance(other, int):
218
+ raise TypeError(
219
+ f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
220
+ )
221
+ elif other <= 0:
222
+ raise ValueError(
223
+ f"Non-positive multiplication factor {other} for {type(self)}"
224
+ )
225
+ else:
226
+ len_original = len(self)
227
+ offset = len(self)
228
+ for _ in range(other - 1):
229
+ for i in range(len_original):
230
+ self.add_module(str(i + offset), self._modules[str(i)])
231
+ offset += len_original
232
+ return self
233
+
234
+ @_copy_to_script_wrapper
235
+ def __dir__(self):
236
+ keys = super().__dir__()
237
+ keys = [key for key in keys if not key.isdigit()]
238
+ return keys
239
+
240
+ @_copy_to_script_wrapper
241
+ def __iter__(self) -> Iterator[Module]:
242
+ return iter(self._modules.values())
243
+
244
+ # NB: We can't really type check this function as the type of input
245
+ # may change dynamically (as is tested in
246
+ # TestScript.test_sequential_intermediary_types). Cannot annotate
247
+ # with Any as TorchScript expects a more precise type
248
+ def forward(self, input):
249
+ for module in self:
250
+ input = module(input)
251
+ return input
252
+
253
+ def append(self, module: Module) -> "Sequential":
254
+ r"""Append a given module to the end.
255
+
256
+ Args:
257
+ module (nn.Module): module to append
258
+ """
259
+ self.add_module(str(len(self)), module)
260
+ return self
261
+
262
+ def insert(self, index: int, module: Module) -> "Sequential":
263
+ if not isinstance(module, Module):
264
+ raise AssertionError(f"module should be of type: {Module}")
265
+ n = len(self._modules)
266
+ if not (-n <= index <= n):
267
+ raise IndexError(f"Index out of range: {index}")
268
+ if index < 0:
269
+ index += n
270
+ for i in range(n, index, -1):
271
+ self._modules[str(i)] = self._modules[str(i - 1)]
272
+ self._modules[str(index)] = module
273
+ return self
274
+
275
+ def extend(self, sequential) -> "Sequential":
276
+ for layer in sequential:
277
+ self.append(layer)
278
+ return self
279
+
280
+
281
+ class ModuleList(Module):
282
+ r"""Holds submodules in a list.
283
+
284
+ :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
285
+ modules it contains are properly registered, and will be visible by all
286
+ :class:`~torch.nn.Module` methods.
287
+
288
+ Args:
289
+ modules (iterable, optional): an iterable of modules to add
290
+
291
+ Example::
292
+
293
+ class MyModule(nn.Module):
294
+ def __init__(self) -> None:
295
+ super().__init__()
296
+ self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
297
+
298
+ def forward(self, x):
299
+ # ModuleList can act as an iterable, or be indexed using ints
300
+ for i, l in enumerate(self.linears):
301
+ x = self.linears[i // 2](x) + l(x)
302
+ return x
303
+ """
304
+
305
+ _modules: Dict[str, Module] # type: ignore[assignment]
306
+
307
+ def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
308
+ super().__init__()
309
+ if modules is not None:
310
+ self += modules
311
+
312
+ def _get_abs_string_index(self, idx):
313
+ """Get the absolute index for the list of modules."""
314
+ idx = operator.index(idx)
315
+ if not (-len(self) <= idx < len(self)):
316
+ raise IndexError(f"index {idx} is out of range")
317
+ if idx < 0:
318
+ idx += len(self)
319
+ return str(idx)
320
+
321
+ @overload
322
+ def __getitem__(self, idx: slice) -> "ModuleList":
323
+ ...
324
+
325
+ @overload
326
+ def __getitem__(self, idx: int) -> Module:
327
+ ...
328
+
329
+ @_copy_to_script_wrapper
330
+ def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
331
+ if isinstance(idx, slice):
332
+ return self.__class__(list(self._modules.values())[idx])
333
+ else:
334
+ return self._modules[self._get_abs_string_index(idx)]
335
+
336
+ def __setitem__(self, idx: int, module: Module) -> None:
337
+ idx = self._get_abs_string_index(idx)
338
+ return setattr(self, str(idx), module)
339
+
340
+ def __delitem__(self, idx: Union[int, slice]) -> None:
341
+ if isinstance(idx, slice):
342
+ for k in range(len(self._modules))[idx]:
343
+ delattr(self, str(k))
344
+ else:
345
+ delattr(self, self._get_abs_string_index(idx))
346
+ # To preserve numbering, self._modules is being reconstructed with modules after deletion
347
+ str_indices = [str(i) for i in range(len(self._modules))]
348
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
349
+
350
+ @_copy_to_script_wrapper
351
+ def __len__(self) -> int:
352
+ return len(self._modules)
353
+
354
+ @_copy_to_script_wrapper
355
+ def __iter__(self) -> Iterator[Module]:
356
+ return iter(self._modules.values())
357
+
358
+ def __iadd__(self, modules: Iterable[Module]) -> Self:
359
+ return self.extend(modules)
360
+
361
+ def __add__(self, other: Iterable[Module]) -> "ModuleList":
362
+ combined = ModuleList()
363
+ for i, module in enumerate(chain(self, other)):
364
+ combined.add_module(str(i), module)
365
+ return combined
366
+
367
+ def __repr__(self):
368
+ """Return a custom repr for ModuleList that compresses repeated module representations."""
369
+ list_of_reprs = [repr(item) for item in self]
370
+ if len(list_of_reprs) == 0:
371
+ return self._get_name() + "()"
372
+
373
+ start_end_indices = [[0, 0]]
374
+ repeated_blocks = [list_of_reprs[0]]
375
+ for i, r in enumerate(list_of_reprs[1:], 1):
376
+ if r == repeated_blocks[-1]:
377
+ start_end_indices[-1][1] += 1
378
+ continue
379
+
380
+ start_end_indices.append([i, i])
381
+ repeated_blocks.append(r)
382
+
383
+ lines = []
384
+ main_str = self._get_name() + "("
385
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
386
+ local_repr = f"({start_id}): {b}" # default repr
387
+
388
+ if start_id != end_id:
389
+ n = end_id - start_id + 1
390
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
391
+
392
+ local_repr = _addindent(local_repr, 2)
393
+ lines.append(local_repr)
394
+
395
+ main_str += "\n " + "\n ".join(lines) + "\n"
396
+ main_str += ")"
397
+ return main_str
398
+
399
+ @_copy_to_script_wrapper
400
+ def __dir__(self):
401
+ keys = super().__dir__()
402
+ keys = [key for key in keys if not key.isdigit()]
403
+ return keys
404
+
405
+ def insert(self, index: int, module: Module) -> None:
406
+ r"""Insert a given module before a given index in the list.
407
+
408
+ Args:
409
+ index (int): index to insert.
410
+ module (nn.Module): module to insert
411
+ """
412
+ for i in range(len(self._modules), index, -1):
413
+ self._modules[str(i)] = self._modules[str(i - 1)]
414
+ self._modules[str(index)] = module
415
+
416
+ def append(self, module: Module) -> "ModuleList":
417
+ r"""Append a given module to the end of the list.
418
+
419
+ Args:
420
+ module (nn.Module): module to append
421
+ """
422
+ self.add_module(str(len(self)), module)
423
+ return self
424
+
425
+ def pop(self, key: Union[int, slice]) -> Module:
426
+ v = self[key]
427
+ del self[key]
428
+ return v
429
+
430
+ def extend(self, modules: Iterable[Module]) -> Self:
431
+ r"""Append modules from a Python iterable to the end of the list.
432
+
433
+ Args:
434
+ modules (iterable): iterable of modules to append
435
+ """
436
+ if not isinstance(modules, container_abcs.Iterable):
437
+ raise TypeError(
438
+ "ModuleList.extend should be called with an "
439
+ "iterable, but got " + type(modules).__name__
440
+ )
441
+ offset = len(self)
442
+ for i, module in enumerate(modules):
443
+ self.add_module(str(offset + i), module)
444
+ return self
445
+
446
+ # remove forward alltogether to fallback on Module's _forward_unimplemented
447
+
448
+
449
+ class ModuleDict(Module):
450
+ r"""Holds submodules in a dictionary.
451
+
452
+ :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
453
+ but modules it contains are properly registered, and will be visible by all
454
+ :class:`~torch.nn.Module` methods.
455
+
456
+ :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
457
+
458
+ * the order of insertion, and
459
+
460
+ * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
461
+ ``OrderedDict``, ``dict`` (started from Python 3.6) or another
462
+ :class:`~torch.nn.ModuleDict` (the argument to
463
+ :meth:`~torch.nn.ModuleDict.update`).
464
+
465
+ Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
466
+ types (e.g., Python's plain ``dict`` before Python version 3.6) does not
467
+ preserve the order of the merged mapping.
468
+
469
+ Args:
470
+ modules (iterable, optional): a mapping (dictionary) of (string: module)
471
+ or an iterable of key-value pairs of type (string, module)
472
+
473
+ Example::
474
+
475
+ class MyModule(nn.Module):
476
+ def __init__(self) -> None:
477
+ super().__init__()
478
+ self.choices = nn.ModuleDict({
479
+ 'conv': nn.Conv2d(10, 10, 3),
480
+ 'pool': nn.MaxPool2d(3)
481
+ })
482
+ self.activations = nn.ModuleDict([
483
+ ['lrelu', nn.LeakyReLU()],
484
+ ['prelu', nn.PReLU()]
485
+ ])
486
+
487
+ def forward(self, x, choice, act):
488
+ x = self.choices[choice](x)
489
+ x = self.activations[act](x)
490
+ return x
491
+ """
492
+
493
+ _modules: Dict[str, Module] # type: ignore[assignment]
494
+
495
+ def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
496
+ super().__init__()
497
+ if modules is not None:
498
+ self.update(modules)
499
+
500
+ @_copy_to_script_wrapper
501
+ def __getitem__(self, key: str) -> Module:
502
+ return self._modules[key]
503
+
504
+ def __setitem__(self, key: str, module: Module) -> None:
505
+ self.add_module(key, module)
506
+
507
+ def __delitem__(self, key: str) -> None:
508
+ del self._modules[key]
509
+
510
+ @_copy_to_script_wrapper
511
+ def __len__(self) -> int:
512
+ return len(self._modules)
513
+
514
+ @_copy_to_script_wrapper
515
+ def __iter__(self) -> Iterator[str]:
516
+ return iter(self._modules)
517
+
518
+ @_copy_to_script_wrapper
519
+ def __contains__(self, key: str) -> bool:
520
+ return key in self._modules
521
+
522
+ def clear(self) -> None:
523
+ """Remove all items from the ModuleDict."""
524
+ self._modules.clear()
525
+
526
+ def pop(self, key: str) -> Module:
527
+ r"""Remove key from the ModuleDict and return its module.
528
+
529
+ Args:
530
+ key (str): key to pop from the ModuleDict
531
+ """
532
+ v = self[key]
533
+ del self[key]
534
+ return v
535
+
536
+ @_copy_to_script_wrapper
537
+ def keys(self) -> Iterable[str]:
538
+ r"""Return an iterable of the ModuleDict keys."""
539
+ return self._modules.keys()
540
+
541
+ @_copy_to_script_wrapper
542
+ def items(self) -> Iterable[Tuple[str, Module]]:
543
+ r"""Return an iterable of the ModuleDict key/value pairs."""
544
+ return self._modules.items()
545
+
546
+ @_copy_to_script_wrapper
547
+ def values(self) -> Iterable[Module]:
548
+ r"""Return an iterable of the ModuleDict values."""
549
+ return self._modules.values()
550
+
551
+ def update(self, modules: Mapping[str, Module]) -> None:
552
+ r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
553
+
554
+ .. note::
555
+ If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
556
+ an iterable of key-value pairs, the order of new elements in it is preserved.
557
+
558
+ Args:
559
+ modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
560
+ or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
561
+ """
562
+ if not isinstance(modules, container_abcs.Iterable):
563
+ raise TypeError(
564
+ "ModuleDict.update should be called with an "
565
+ "iterable of key/value pairs, but got " + type(modules).__name__
566
+ )
567
+
568
+ if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
569
+ for key, module in modules.items():
570
+ self[key] = module
571
+ else:
572
+ # modules here can be a list with two items
573
+ for j, m in enumerate(modules):
574
+ if not isinstance(m, container_abcs.Iterable):
575
+ raise TypeError(
576
+ "ModuleDict update sequence element "
577
+ "#" + str(j) + " should be Iterable; is" + type(m).__name__
578
+ )
579
+ if not len(m) == 2:
580
+ raise ValueError(
581
+ "ModuleDict update sequence element "
582
+ "#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
583
+ )
584
+ # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
585
+ # that's too cumbersome to type correctly with overloads, so we add an ignore here
586
+ self[m[0]] = m[1] # type: ignore[assignment]
587
+
588
+ # remove forward alltogether to fallback on Module's _forward_unimplemented
589
+
590
+
591
+ class ParameterList(Module):
592
+ r"""Holds parameters in a list.
593
+
594
+ :class:`~torch.nn.ParameterList` can be used like a regular Python
595
+ list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
596
+ and will be visible by all :class:`~torch.nn.Module` methods.
597
+
598
+ Note that the constructor, assigning an element of the list, the
599
+ :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend`
600
+ method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
601
+
602
+ Args:
603
+ parameters (iterable, optional): an iterable of elements to add to the list.
604
+
605
+ Example::
606
+
607
+ class MyModule(nn.Module):
608
+ def __init__(self) -> None:
609
+ super().__init__()
610
+ self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
611
+
612
+ def forward(self, x):
613
+ # ParameterList can act as an iterable, or be indexed using ints
614
+ for i, p in enumerate(self.params):
615
+ x = self.params[i // 2].mm(x) + p.mm(x)
616
+ return x
617
+ """
618
+
619
+ def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
620
+ super().__init__()
621
+ self._size = 0
622
+ if values is not None:
623
+ self += values
624
+
625
+ def _get_abs_string_index(self, idx):
626
+ """Get the absolute index for the list of modules."""
627
+ idx = operator.index(idx)
628
+ if not (-len(self) <= idx < len(self)):
629
+ raise IndexError(f"index {idx} is out of range")
630
+ if idx < 0:
631
+ idx += len(self)
632
+ return str(idx)
633
+
634
+ @overload
635
+ def __getitem__(self, idx: int) -> Any:
636
+ ...
637
+
638
+ @overload
639
+ def __getitem__(self: T, idx: slice) -> T:
640
+ ...
641
+
642
+ def __getitem__(self, idx):
643
+ if isinstance(idx, slice):
644
+ start, stop, step = idx.indices(len(self))
645
+ out = self.__class__()
646
+ for i in range(start, stop, step):
647
+ out.append(self[i])
648
+ return out
649
+ else:
650
+ idx = self._get_abs_string_index(idx)
651
+ return getattr(self, str(idx))
652
+
653
+ def __setitem__(self, idx: int, param: Any) -> None:
654
+ # Note that all other function that add an entry to the list part of
655
+ # the ParameterList end up here. So this is the only place where we need
656
+ # to wrap things into Parameter if needed.
657
+ # Objects added via setattr() are not in the list part and thus won't
658
+ # call into this function.
659
+ idx = self._get_abs_string_index(idx)
660
+ if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
661
+ param = Parameter(param)
662
+ return setattr(self, str(idx), param)
663
+
664
+ def __len__(self) -> int:
665
+ return self._size
666
+
667
+ def __iter__(self) -> Iterator[Any]:
668
+ return iter(self[i] for i in range(len(self)))
669
+
670
+ def __iadd__(self, parameters: Iterable[Any]) -> Self:
671
+ return self.extend(parameters)
672
+
673
+ def __dir__(self):
674
+ keys = super().__dir__()
675
+ keys = [key for key in keys if not key.isdigit()]
676
+ return keys
677
+
678
+ def append(self, value: Any) -> "ParameterList":
679
+ """Append a given value at the end of the list.
680
+
681
+ Args:
682
+ value (Any): value to append
683
+ """
684
+ new_idx = len(self)
685
+ self._size += 1
686
+ self[new_idx] = value
687
+ return self
688
+
689
+ def extend(self, values: Iterable[Any]) -> Self:
690
+ """Append values from a Python iterable to the end of the list.
691
+
692
+ Args:
693
+ values (iterable): iterable of values to append
694
+ """
695
+ # Tensor is an iterable but we never want to unpack it here
696
+ if not isinstance(values, container_abcs.Iterable) or isinstance(
697
+ values, torch.Tensor
698
+ ):
699
+ raise TypeError(
700
+ "ParameterList.extend should be called with an "
701
+ "iterable, but got " + type(values).__name__
702
+ )
703
+ for value in values:
704
+ self.append(value)
705
+ return self
706
+
707
+ def extra_repr(self) -> str:
708
+ child_lines = []
709
+ for k, p in enumerate(self):
710
+ if isinstance(p, torch.Tensor):
711
+ size_str = "x".join(str(size) for size in p.size())
712
+ if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
713
+ device_str = f" ({p.device})"
714
+ else:
715
+ device_str = ""
716
+ parastr = "{} containing: [{} of size {}{}]".format(
717
+ "Parameter" if isinstance(p, Parameter) else "Tensor",
718
+ p.dtype,
719
+ size_str,
720
+ device_str,
721
+ )
722
+ child_lines.append(" (" + str(k) + "): " + parastr)
723
+ else:
724
+ child_lines.append(
725
+ " (" + str(k) + "): Object of type: " + type(p).__name__
726
+ )
727
+
728
+ tmpstr = "\n".join(child_lines)
729
+ return tmpstr
730
+
731
+ def __call__(self, *args, **kwargs):
732
+ raise RuntimeError("ParameterList should not be called.")
733
+
734
+
735
+ class ParameterDict(Module):
736
+ r"""Holds parameters in a dictionary.
737
+
738
+ ParameterDict can be indexed like a regular Python dictionary, but Parameters it
739
+ contains are properly registered, and will be visible by all Module methods.
740
+ Other objects are treated as would be done by a regular Python dictionary
741
+
742
+ :class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
743
+ :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
744
+ types (e.g., Python's plain ``dict``) does not preserve the order of the
745
+ merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
746
+ will preserve their ordering.
747
+
748
+ Note that the constructor, assigning an element of the dictionary and the
749
+ :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
750
+ :class:`~torch.nn.Parameter`.
751
+
752
+ Args:
753
+ values (iterable, optional): a mapping (dictionary) of
754
+ (string : Any) or an iterable of key-value pairs
755
+ of type (string, Any)
756
+
757
+ Example::
758
+
759
+ class MyModule(nn.Module):
760
+ def __init__(self) -> None:
761
+ super().__init__()
762
+ self.params = nn.ParameterDict({
763
+ 'left': nn.Parameter(torch.randn(5, 10)),
764
+ 'right': nn.Parameter(torch.randn(5, 10))
765
+ })
766
+
767
+ def forward(self, x, choice):
768
+ x = self.params[choice].mm(x)
769
+ return x
770
+ """
771
+
772
+ def __init__(self, parameters: Any = None) -> None:
773
+ super().__init__()
774
+ self._keys: Dict[str, None] = {}
775
+ if parameters is not None:
776
+ self.update(parameters)
777
+
778
+ def _key_to_attr(self, key: str) -> str:
779
+ if not isinstance(key, str):
780
+ raise TypeError(
781
+ "Index given to ParameterDict cannot be used as a key as it is "
782
+ f"not a string (type is '{type(key).__name__}'). Open an issue on "
783
+ "github if you need non-string keys."
784
+ )
785
+ else:
786
+ # Use the key as-is so that `.named_parameters()` returns the right thing
787
+ return key
788
+
789
+ def __getitem__(self, key: str) -> Any:
790
+ attr = self._key_to_attr(key)
791
+ return getattr(self, attr)
792
+
793
+ def __setitem__(self, key: str, value: Any) -> None:
794
+ # Note that all other function that add an entry to the dictionary part of
795
+ # the ParameterDict end up here. So this is the only place where we need
796
+ # to wrap things into Parameter if needed.
797
+ # Objects added via setattr() are not in the dictionary part and thus won't
798
+ # call into this function.
799
+ self._keys[key] = None
800
+ attr = self._key_to_attr(key)
801
+ if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
802
+ value = Parameter(value)
803
+ setattr(self, attr, value)
804
+
805
+ def __delitem__(self, key: str) -> None:
806
+ del self._keys[key]
807
+ attr = self._key_to_attr(key)
808
+ delattr(self, attr)
809
+
810
+ def __len__(self) -> int:
811
+ return len(self._keys)
812
+
813
+ def __iter__(self) -> Iterator[str]:
814
+ return iter(self._keys)
815
+
816
+ def __reversed__(self) -> Iterator[str]:
817
+ return reversed(list(self._keys))
818
+
819
+ def copy(self) -> "ParameterDict":
820
+ """Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
821
+ # We have to use an OrderedDict because the ParameterDict constructor
822
+ # behaves differently on plain dict vs OrderedDict
823
+ return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
824
+
825
+ def __contains__(self, key: str) -> bool:
826
+ return key in self._keys
827
+
828
+ def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
829
+ """Set the default for a key in the Parameterdict.
830
+
831
+ If key is in the ParameterDict, return its value.
832
+ If not, insert `key` with a parameter `default` and return `default`.
833
+ `default` defaults to `None`.
834
+
835
+ Args:
836
+ key (str): key to set default for
837
+ default (Any): the parameter set to the key
838
+ """
839
+ if key not in self:
840
+ self[key] = default
841
+ return self[key]
842
+
843
+ def clear(self) -> None:
844
+ """Remove all items from the ParameterDict."""
845
+ for k in self._keys.copy():
846
+ del self[k]
847
+
848
+ def pop(self, key: str) -> Any:
849
+ r"""Remove key from the ParameterDict and return its parameter.
850
+
851
+ Args:
852
+ key (str): key to pop from the ParameterDict
853
+ """
854
+ v = self[key]
855
+ del self[key]
856
+ return v
857
+
858
+ def popitem(self) -> Tuple[str, Any]:
859
+ """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
860
+ k, _ = self._keys.popitem()
861
+ # We need the key in the _keys to be able to access/del
862
+ self._keys[k] = None
863
+ val = self[k]
864
+ del self[k]
865
+ return k, val
866
+
867
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
868
+ r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
869
+
870
+ Args:
871
+ key (str): key to get from the ParameterDict
872
+ default (Parameter, optional): value to return if key not present
873
+ """
874
+ return self[key] if key in self else default
875
+
876
+ def fromkeys(
877
+ self, keys: Iterable[str], default: Optional[Any] = None
878
+ ) -> "ParameterDict":
879
+ r"""Return a new ParameterDict with the keys provided.
880
+
881
+ Args:
882
+ keys (iterable, string): keys to make the new ParameterDict from
883
+ default (Parameter, optional): value to set for all keys
884
+ """
885
+ return ParameterDict((k, default) for k in keys)
886
+
887
+ def keys(self) -> Iterable[str]:
888
+ r"""Return an iterable of the ParameterDict keys."""
889
+ return self._keys.keys()
890
+
891
+ def items(self) -> Iterable[Tuple[str, Any]]:
892
+ r"""Return an iterable of the ParameterDict key/value pairs."""
893
+ return ((k, self[k]) for k in self._keys)
894
+
895
+ def values(self) -> Iterable[Any]:
896
+ r"""Return an iterable of the ParameterDict values."""
897
+ return (self[k] for k in self._keys)
898
+
899
+ def update(self, parameters: Union[Mapping[str, Any], "ParameterDict"]) -> None:
900
+ r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
901
+
902
+ .. note::
903
+ If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
904
+ an iterable of key-value pairs, the order of new elements in it is preserved.
905
+
906
+ Args:
907
+ parameters (iterable): a mapping (dictionary) from string to
908
+ :class:`~torch.nn.Parameter`, or an iterable of
909
+ key-value pairs of type (string, :class:`~torch.nn.Parameter`)
910
+ """
911
+ if not isinstance(parameters, container_abcs.Iterable):
912
+ raise TypeError(
913
+ "ParametersDict.update should be called with an "
914
+ "iterable of key/value pairs, but got " + type(parameters).__name__
915
+ )
916
+
917
+ if isinstance(parameters, (OrderedDict, ParameterDict)):
918
+ for key, parameter in parameters.items():
919
+ self[key] = parameter
920
+ elif isinstance(parameters, container_abcs.Mapping):
921
+ for key, parameter in sorted(parameters.items()):
922
+ self[key] = parameter
923
+ else:
924
+ for j, p in enumerate(parameters):
925
+ if not isinstance(p, container_abcs.Iterable):
926
+ raise TypeError(
927
+ "ParameterDict update sequence element "
928
+ "#" + str(j) + " should be Iterable; is" + type(p).__name__
929
+ )
930
+ if not len(p) == 2:
931
+ raise ValueError(
932
+ "ParameterDict update sequence element "
933
+ "#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
934
+ )
935
+ # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
936
+ self[p[0]] = p[1] # type: ignore[assignment]
937
+
938
+ def extra_repr(self) -> str:
939
+ child_lines = []
940
+ for k, p in self.items():
941
+ if isinstance(p, torch.Tensor):
942
+ size_str = "x".join(str(size) for size in p.size())
943
+ if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
944
+ device_str = f" ({p.device})"
945
+ else:
946
+ device_str = ""
947
+ parastr = "{} containing: [{} of size {}{}]".format(
948
+ "Parameter" if isinstance(p, Parameter) else "Tensor",
949
+ torch.typename(p),
950
+ size_str,
951
+ device_str,
952
+ )
953
+ child_lines.append(" (" + str(k) + "): " + parastr)
954
+ else:
955
+ child_lines.append(
956
+ " (" + str(k) + "): Object of type: " + type(p).__name__
957
+ )
958
+ tmpstr = "\n".join(child_lines)
959
+ return tmpstr
960
+
961
+ def __call__(self, input):
962
+ raise RuntimeError("ParameterDict should not be called.")
963
+
964
+ def __or__(self, other: "ParameterDict") -> "ParameterDict":
965
+ copy = self.copy()
966
+ copy.update(other)
967
+ return copy
968
+
969
+ def __ror__(self, other: "ParameterDict") -> "ParameterDict":
970
+ copy = other.copy()
971
+ copy.update(self)
972
+ return copy
973
+
974
+ def __ior__(self, other: "ParameterDict") -> Self:
975
+ self.update(other)
976
+ return self
.venv/Lib/site-packages/torch/nn/modules/conv.py ADDED
@@ -0,0 +1,1866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import math
3
+ from typing import List, Optional, Tuple, Union
4
+ from typing_extensions import deprecated
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch._torch_docs import reproducibility_notes
9
+ from torch.nn import functional as F, init
10
+ from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
11
+ from torch.nn.parameter import Parameter, UninitializedParameter
12
+
13
+ from .lazy import LazyModuleMixin
14
+ from .module import Module
15
+ from .utils import _pair, _reverse_repeat_tuple, _single, _triple
16
+
17
+
18
+ __all__ = [
19
+ "Conv1d",
20
+ "Conv2d",
21
+ "Conv3d",
22
+ "ConvTranspose1d",
23
+ "ConvTranspose2d",
24
+ "ConvTranspose3d",
25
+ "LazyConv1d",
26
+ "LazyConv2d",
27
+ "LazyConv3d",
28
+ "LazyConvTranspose1d",
29
+ "LazyConvTranspose2d",
30
+ "LazyConvTranspose3d",
31
+ ]
32
+
33
+ convolution_notes = {
34
+ "groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs.
35
+ :attr:`in_channels` and :attr:`out_channels` must both be divisible by
36
+ :attr:`groups`. For example,
37
+
38
+ * At groups=1, all inputs are convolved to all outputs.
39
+ * At groups=2, the operation becomes equivalent to having two conv
40
+ layers side by side, each seeing half the input channels
41
+ and producing half the output channels, and both subsequently
42
+ concatenated.
43
+ * At groups= :attr:`in_channels`, each input channel is convolved with
44
+ its own set of filters (of size
45
+ :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""",
46
+ "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`,
47
+ where `K` is a positive integer, this operation is also known as a "depthwise convolution".
48
+
49
+ In other words, for an input of size :math:`(N, C_{in}, L_{in})`,
50
+ a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments
51
+ :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`.""",
52
+ } # noqa: B950
53
+
54
+
55
+ class _ConvNd(Module):
56
+ __constants__ = [
57
+ "stride",
58
+ "padding",
59
+ "dilation",
60
+ "groups",
61
+ "padding_mode",
62
+ "output_padding",
63
+ "in_channels",
64
+ "out_channels",
65
+ "kernel_size",
66
+ ]
67
+ __annotations__ = {"bias": Optional[torch.Tensor]}
68
+
69
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body]
70
+ ...
71
+
72
+ in_channels: int
73
+ _reversed_padding_repeated_twice: List[int]
74
+ out_channels: int
75
+ kernel_size: Tuple[int, ...]
76
+ stride: Tuple[int, ...]
77
+ padding: Union[str, Tuple[int, ...]]
78
+ dilation: Tuple[int, ...]
79
+ transposed: bool
80
+ output_padding: Tuple[int, ...]
81
+ groups: int
82
+ padding_mode: str
83
+ weight: Tensor
84
+ bias: Optional[Tensor]
85
+
86
+ def __init__(
87
+ self,
88
+ in_channels: int,
89
+ out_channels: int,
90
+ kernel_size: Tuple[int, ...],
91
+ stride: Tuple[int, ...],
92
+ padding: Tuple[int, ...],
93
+ dilation: Tuple[int, ...],
94
+ transposed: bool,
95
+ output_padding: Tuple[int, ...],
96
+ groups: int,
97
+ bias: bool,
98
+ padding_mode: str,
99
+ device=None,
100
+ dtype=None,
101
+ ) -> None:
102
+ factory_kwargs = {"device": device, "dtype": dtype}
103
+ super().__init__()
104
+ if groups <= 0:
105
+ raise ValueError("groups must be a positive integer")
106
+ if in_channels % groups != 0:
107
+ raise ValueError("in_channels must be divisible by groups")
108
+ if out_channels % groups != 0:
109
+ raise ValueError("out_channels must be divisible by groups")
110
+ valid_padding_strings = {"same", "valid"}
111
+ if isinstance(padding, str):
112
+ if padding not in valid_padding_strings:
113
+ raise ValueError(
114
+ f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}"
115
+ )
116
+ if padding == "same" and any(s != 1 for s in stride):
117
+ raise ValueError(
118
+ "padding='same' is not supported for strided convolutions"
119
+ )
120
+
121
+ valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
122
+ if padding_mode not in valid_padding_modes:
123
+ raise ValueError(
124
+ f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'"
125
+ )
126
+ self.in_channels = in_channels
127
+ self.out_channels = out_channels
128
+ self.kernel_size = kernel_size
129
+ self.stride = stride
130
+ self.padding = padding
131
+ self.dilation = dilation
132
+ self.transposed = transposed
133
+ self.output_padding = output_padding
134
+ self.groups = groups
135
+ self.padding_mode = padding_mode
136
+ # `_reversed_padding_repeated_twice` is the padding to be passed to
137
+ # `F.pad` if needed (e.g., for non-zero padding types that are
138
+ # implemented as two ops: padding + conv). `F.pad` accepts paddings in
139
+ # reverse order than the dimension.
140
+ if isinstance(self.padding, str):
141
+ self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
142
+ if padding == "same":
143
+ for d, k, i in zip(
144
+ dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)
145
+ ):
146
+ total_padding = d * (k - 1)
147
+ left_pad = total_padding // 2
148
+ self._reversed_padding_repeated_twice[2 * i] = left_pad
149
+ self._reversed_padding_repeated_twice[2 * i + 1] = (
150
+ total_padding - left_pad
151
+ )
152
+ else:
153
+ self._reversed_padding_repeated_twice = _reverse_repeat_tuple(
154
+ self.padding, 2
155
+ )
156
+
157
+ if transposed:
158
+ self.weight = Parameter(
159
+ torch.empty(
160
+ (in_channels, out_channels // groups, *kernel_size),
161
+ **factory_kwargs,
162
+ )
163
+ )
164
+ else:
165
+ self.weight = Parameter(
166
+ torch.empty(
167
+ (out_channels, in_channels // groups, *kernel_size),
168
+ **factory_kwargs,
169
+ )
170
+ )
171
+ if bias:
172
+ self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
173
+ else:
174
+ self.register_parameter("bias", None)
175
+
176
+ self.reset_parameters()
177
+
178
+ def reset_parameters(self) -> None:
179
+ # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
180
+ # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
181
+ # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
182
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
183
+ if self.bias is not None:
184
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
185
+ if fan_in != 0:
186
+ bound = 1 / math.sqrt(fan_in)
187
+ init.uniform_(self.bias, -bound, bound)
188
+
189
+ def extra_repr(self):
190
+ s = (
191
+ "{in_channels}, {out_channels}, kernel_size={kernel_size}"
192
+ ", stride={stride}"
193
+ )
194
+ if self.padding != (0,) * len(self.padding):
195
+ s += ", padding={padding}"
196
+ if self.dilation != (1,) * len(self.dilation):
197
+ s += ", dilation={dilation}"
198
+ if self.output_padding != (0,) * len(self.output_padding):
199
+ s += ", output_padding={output_padding}"
200
+ if self.groups != 1:
201
+ s += ", groups={groups}"
202
+ if self.bias is None:
203
+ s += ", bias=False"
204
+ if self.padding_mode != "zeros":
205
+ s += ", padding_mode={padding_mode}"
206
+ return s.format(**self.__dict__)
207
+
208
+ def __setstate__(self, state):
209
+ super().__setstate__(state)
210
+ if not hasattr(self, "padding_mode"):
211
+ self.padding_mode = "zeros"
212
+
213
+
214
+ class Conv1d(_ConvNd):
215
+ __doc__ = (
216
+ r"""Applies a 1D convolution over an input signal composed of several input
217
+ planes.
218
+
219
+ In the simplest case, the output value of the layer with input size
220
+ :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
221
+ precisely described as:
222
+
223
+ .. math::
224
+ \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
225
+ \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
226
+ \star \text{input}(N_i, k)
227
+
228
+ where :math:`\star` is the valid `cross-correlation`_ operator,
229
+ :math:`N` is a batch size, :math:`C` denotes a number of channels,
230
+ :math:`L` is a length of signal sequence.
231
+ """
232
+ + r"""
233
+
234
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
235
+
236
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
237
+
238
+ * :attr:`stride` controls the stride for the cross-correlation, a single
239
+ number or a one-element tuple.
240
+
241
+ * :attr:`padding` controls the amount of padding applied to the input. It
242
+ can be either a string {{'valid', 'same'}} or a tuple of ints giving the
243
+ amount of implicit padding applied on both sides.
244
+ """
245
+ """
246
+ * :attr:`dilation` controls the spacing between the kernel points; also
247
+ known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
248
+ has a nice visualization of what :attr:`dilation` does.
249
+ """
250
+ r"""
251
+ {groups_note}
252
+
253
+ Note:
254
+ {depthwise_separable_note}
255
+ Note:
256
+ {cudnn_reproducibility_note}
257
+
258
+ Note:
259
+ ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
260
+ the input so the output has the shape as the input. However, this mode
261
+ doesn't support any stride values other than 1.
262
+
263
+ Note:
264
+ This module supports complex data types i.e. ``complex32, complex64, complex128``.
265
+
266
+ Args:
267
+ in_channels (int): Number of channels in the input image
268
+ out_channels (int): Number of channels produced by the convolution
269
+ kernel_size (int or tuple): Size of the convolving kernel
270
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
271
+ padding (int, tuple or str, optional): Padding added to both sides of
272
+ the input. Default: 0
273
+ dilation (int or tuple, optional): Spacing between kernel
274
+ elements. Default: 1
275
+ groups (int, optional): Number of blocked connections from input
276
+ channels to output channels. Default: 1
277
+ bias (bool, optional): If ``True``, adds a learnable bias to the
278
+ output. Default: ``True``
279
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
280
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
281
+
282
+ """.format(
283
+ **reproducibility_notes, **convolution_notes
284
+ )
285
+ + r"""
286
+
287
+ Shape:
288
+ - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
289
+ - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
290
+
291
+ .. math::
292
+ L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
293
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
294
+
295
+ Attributes:
296
+ weight (Tensor): the learnable weights of the module of shape
297
+ :math:`(\text{out\_channels},
298
+ \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`.
299
+ The values of these weights are sampled from
300
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
301
+ :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
302
+ bias (Tensor): the learnable bias of the module of shape
303
+ (out_channels). If :attr:`bias` is ``True``, then the values of these weights are
304
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
305
+ :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
306
+
307
+ Examples::
308
+
309
+ >>> m = nn.Conv1d(16, 33, 3, stride=2)
310
+ >>> input = torch.randn(20, 16, 50)
311
+ >>> output = m(input)
312
+
313
+ .. _cross-correlation:
314
+ https://en.wikipedia.org/wiki/Cross-correlation
315
+
316
+ .. _link:
317
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
318
+ """
319
+ )
320
+
321
+ def __init__(
322
+ self,
323
+ in_channels: int,
324
+ out_channels: int,
325
+ kernel_size: _size_1_t,
326
+ stride: _size_1_t = 1,
327
+ padding: Union[str, _size_1_t] = 0,
328
+ dilation: _size_1_t = 1,
329
+ groups: int = 1,
330
+ bias: bool = True,
331
+ padding_mode: str = "zeros", # TODO: refine this type
332
+ device=None,
333
+ dtype=None,
334
+ ) -> None:
335
+ factory_kwargs = {"device": device, "dtype": dtype}
336
+ # we create new variables below to make mypy happy since kernel_size has
337
+ # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
338
+ kernel_size_ = _single(kernel_size)
339
+ stride_ = _single(stride)
340
+ padding_ = padding if isinstance(padding, str) else _single(padding)
341
+ dilation_ = _single(dilation)
342
+ super().__init__(
343
+ in_channels,
344
+ out_channels,
345
+ kernel_size_,
346
+ stride_,
347
+ padding_,
348
+ dilation_,
349
+ False,
350
+ _single(0),
351
+ groups,
352
+ bias,
353
+ padding_mode,
354
+ **factory_kwargs,
355
+ )
356
+
357
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
358
+ if self.padding_mode != "zeros":
359
+ return F.conv1d(
360
+ F.pad(
361
+ input, self._reversed_padding_repeated_twice, mode=self.padding_mode
362
+ ),
363
+ weight,
364
+ bias,
365
+ self.stride,
366
+ _single(0),
367
+ self.dilation,
368
+ self.groups,
369
+ )
370
+ return F.conv1d(
371
+ input, weight, bias, self.stride, self.padding, self.dilation, self.groups
372
+ )
373
+
374
+ def forward(self, input: Tensor) -> Tensor:
375
+ return self._conv_forward(input, self.weight, self.bias)
376
+
377
+
378
+ class Conv2d(_ConvNd):
379
+ __doc__ = (
380
+ r"""Applies a 2D convolution over an input signal composed of several input
381
+ planes.
382
+
383
+ In the simplest case, the output value of the layer with input size
384
+ :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
385
+ can be precisely described as:
386
+
387
+ .. math::
388
+ \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
389
+ \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
390
+
391
+
392
+ where :math:`\star` is the valid 2D `cross-correlation`_ operator,
393
+ :math:`N` is a batch size, :math:`C` denotes a number of channels,
394
+ :math:`H` is a height of input planes in pixels, and :math:`W` is
395
+ width in pixels.
396
+ """
397
+ + r"""
398
+
399
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
400
+
401
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
402
+
403
+ * :attr:`stride` controls the stride for the cross-correlation, a single
404
+ number or a tuple.
405
+
406
+ * :attr:`padding` controls the amount of padding applied to the input. It
407
+ can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the
408
+ amount of implicit padding applied on both sides.
409
+ """
410
+ """
411
+ * :attr:`dilation` controls the spacing between the kernel points; also
412
+ known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
413
+ has a nice visualization of what :attr:`dilation` does.
414
+ """
415
+ r"""
416
+
417
+ {groups_note}
418
+
419
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
420
+
421
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
422
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
423
+ and the second `int` for the width dimension
424
+
425
+ Note:
426
+ {depthwise_separable_note}
427
+
428
+ Note:
429
+ {cudnn_reproducibility_note}
430
+
431
+ Note:
432
+ ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
433
+ the input so the output has the shape as the input. However, this mode
434
+ doesn't support any stride values other than 1.
435
+
436
+ Note:
437
+ This module supports complex data types i.e. ``complex32, complex64, complex128``.
438
+
439
+ Args:
440
+ in_channels (int): Number of channels in the input image
441
+ out_channels (int): Number of channels produced by the convolution
442
+ kernel_size (int or tuple): Size of the convolving kernel
443
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
444
+ padding (int, tuple or str, optional): Padding added to all four sides of
445
+ the input. Default: 0
446
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
447
+ groups (int, optional): Number of blocked connections from input
448
+ channels to output channels. Default: 1
449
+ bias (bool, optional): If ``True``, adds a learnable bias to the
450
+ output. Default: ``True``
451
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
452
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
453
+ """.format(
454
+ **reproducibility_notes, **convolution_notes
455
+ )
456
+ + r"""
457
+
458
+ Shape:
459
+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
460
+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
461
+
462
+ .. math::
463
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
464
+ \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
465
+
466
+ .. math::
467
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
468
+ \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
469
+
470
+ Attributes:
471
+ weight (Tensor): the learnable weights of the module of shape
472
+ :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
473
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
474
+ The values of these weights are sampled from
475
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
476
+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
477
+ bias (Tensor): the learnable bias of the module of shape
478
+ (out_channels). If :attr:`bias` is ``True``,
479
+ then the values of these weights are
480
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
481
+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
482
+
483
+ Examples:
484
+
485
+ >>> # With square kernels and equal stride
486
+ >>> m = nn.Conv2d(16, 33, 3, stride=2)
487
+ >>> # non-square kernels and unequal stride and with padding
488
+ >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
489
+ >>> # non-square kernels and unequal stride and with padding and dilation
490
+ >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
491
+ >>> input = torch.randn(20, 16, 50, 100)
492
+ >>> output = m(input)
493
+
494
+ .. _cross-correlation:
495
+ https://en.wikipedia.org/wiki/Cross-correlation
496
+
497
+ .. _link:
498
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
499
+ """
500
+ )
501
+
502
+ def __init__(
503
+ self,
504
+ in_channels: int,
505
+ out_channels: int,
506
+ kernel_size: _size_2_t,
507
+ stride: _size_2_t = 1,
508
+ padding: Union[str, _size_2_t] = 0,
509
+ dilation: _size_2_t = 1,
510
+ groups: int = 1,
511
+ bias: bool = True,
512
+ padding_mode: str = "zeros", # TODO: refine this type
513
+ device=None,
514
+ dtype=None,
515
+ ) -> None:
516
+ factory_kwargs = {"device": device, "dtype": dtype}
517
+ kernel_size_ = _pair(kernel_size)
518
+ stride_ = _pair(stride)
519
+ padding_ = padding if isinstance(padding, str) else _pair(padding)
520
+ dilation_ = _pair(dilation)
521
+ super().__init__(
522
+ in_channels,
523
+ out_channels,
524
+ kernel_size_,
525
+ stride_,
526
+ padding_,
527
+ dilation_,
528
+ False,
529
+ _pair(0),
530
+ groups,
531
+ bias,
532
+ padding_mode,
533
+ **factory_kwargs,
534
+ )
535
+
536
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
537
+ if self.padding_mode != "zeros":
538
+ return F.conv2d(
539
+ F.pad(
540
+ input, self._reversed_padding_repeated_twice, mode=self.padding_mode
541
+ ),
542
+ weight,
543
+ bias,
544
+ self.stride,
545
+ _pair(0),
546
+ self.dilation,
547
+ self.groups,
548
+ )
549
+ return F.conv2d(
550
+ input, weight, bias, self.stride, self.padding, self.dilation, self.groups
551
+ )
552
+
553
+ def forward(self, input: Tensor) -> Tensor:
554
+ return self._conv_forward(input, self.weight, self.bias)
555
+
556
+
557
+ class Conv3d(_ConvNd):
558
+ __doc__ = (
559
+ r"""Applies a 3D convolution over an input signal composed of several input
560
+ planes.
561
+
562
+ In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`
563
+ and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
564
+
565
+ .. math::
566
+ out(N_i, C_{out_j}) = bias(C_{out_j}) +
567
+ \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k)
568
+
569
+ where :math:`\star` is the valid 3D `cross-correlation`_ operator
570
+ """
571
+ + r"""
572
+
573
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
574
+
575
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
576
+
577
+ * :attr:`stride` controls the stride for the cross-correlation.
578
+
579
+ * :attr:`padding` controls the amount of padding applied to the input. It
580
+ can be either a string {{'valid', 'same'}} or a tuple of ints giving the
581
+ amount of implicit padding applied on both sides.
582
+ """
583
+ """
584
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
585
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
586
+ """
587
+ r"""
588
+
589
+ {groups_note}
590
+
591
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
592
+
593
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
594
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
595
+ the second `int` for the height dimension and the third `int` for the width dimension
596
+
597
+ Note:
598
+ {depthwise_separable_note}
599
+
600
+ Note:
601
+ {cudnn_reproducibility_note}
602
+
603
+ Note:
604
+ ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
605
+ the input so the output has the shape as the input. However, this mode
606
+ doesn't support any stride values other than 1.
607
+
608
+ Note:
609
+ This module supports complex data types i.e. ``complex32, complex64, complex128``.
610
+
611
+ Args:
612
+ in_channels (int): Number of channels in the input image
613
+ out_channels (int): Number of channels produced by the convolution
614
+ kernel_size (int or tuple): Size of the convolving kernel
615
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
616
+ padding (int, tuple or str, optional): Padding added to all six sides of
617
+ the input. Default: 0
618
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
619
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
620
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
621
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
622
+ """.format(
623
+ **reproducibility_notes, **convolution_notes
624
+ )
625
+ + r"""
626
+
627
+ Shape:
628
+ - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
629
+ - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`,
630
+ where
631
+
632
+ .. math::
633
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
634
+ \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
635
+
636
+ .. math::
637
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
638
+ \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
639
+
640
+ .. math::
641
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2]
642
+ \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
643
+
644
+ Attributes:
645
+ weight (Tensor): the learnable weights of the module of shape
646
+ :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
647
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
648
+ The values of these weights are sampled from
649
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
650
+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
651
+ bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
652
+ then the values of these weights are
653
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
654
+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
655
+
656
+ Examples::
657
+
658
+ >>> # With square kernels and equal stride
659
+ >>> m = nn.Conv3d(16, 33, 3, stride=2)
660
+ >>> # non-square kernels and unequal stride and with padding
661
+ >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
662
+ >>> input = torch.randn(20, 16, 10, 50, 100)
663
+ >>> output = m(input)
664
+
665
+ .. _cross-correlation:
666
+ https://en.wikipedia.org/wiki/Cross-correlation
667
+
668
+ .. _link:
669
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
670
+ """
671
+ )
672
+
673
+ def __init__(
674
+ self,
675
+ in_channels: int,
676
+ out_channels: int,
677
+ kernel_size: _size_3_t,
678
+ stride: _size_3_t = 1,
679
+ padding: Union[str, _size_3_t] = 0,
680
+ dilation: _size_3_t = 1,
681
+ groups: int = 1,
682
+ bias: bool = True,
683
+ padding_mode: str = "zeros",
684
+ device=None,
685
+ dtype=None,
686
+ ) -> None:
687
+ factory_kwargs = {"device": device, "dtype": dtype}
688
+ kernel_size_ = _triple(kernel_size)
689
+ stride_ = _triple(stride)
690
+ padding_ = padding if isinstance(padding, str) else _triple(padding)
691
+ dilation_ = _triple(dilation)
692
+ super().__init__(
693
+ in_channels,
694
+ out_channels,
695
+ kernel_size_,
696
+ stride_,
697
+ padding_,
698
+ dilation_,
699
+ False,
700
+ _triple(0),
701
+ groups,
702
+ bias,
703
+ padding_mode,
704
+ **factory_kwargs,
705
+ )
706
+
707
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
708
+ if self.padding_mode != "zeros":
709
+ return F.conv3d(
710
+ F.pad(
711
+ input, self._reversed_padding_repeated_twice, mode=self.padding_mode
712
+ ),
713
+ weight,
714
+ bias,
715
+ self.stride,
716
+ _triple(0),
717
+ self.dilation,
718
+ self.groups,
719
+ )
720
+ return F.conv3d(
721
+ input, weight, bias, self.stride, self.padding, self.dilation, self.groups
722
+ )
723
+
724
+ def forward(self, input: Tensor) -> Tensor:
725
+ return self._conv_forward(input, self.weight, self.bias)
726
+
727
+
728
+ class _ConvTransposeNd(_ConvNd):
729
+ def __init__(
730
+ self,
731
+ in_channels,
732
+ out_channels,
733
+ kernel_size,
734
+ stride,
735
+ padding,
736
+ dilation,
737
+ transposed,
738
+ output_padding,
739
+ groups,
740
+ bias,
741
+ padding_mode,
742
+ device=None,
743
+ dtype=None,
744
+ ) -> None:
745
+ if padding_mode != "zeros":
746
+ raise ValueError(
747
+ f'Only "zeros" padding mode is supported for {self.__class__.__name__}'
748
+ )
749
+
750
+ factory_kwargs = {"device": device, "dtype": dtype}
751
+ super().__init__(
752
+ in_channels,
753
+ out_channels,
754
+ kernel_size,
755
+ stride,
756
+ padding,
757
+ dilation,
758
+ transposed,
759
+ output_padding,
760
+ groups,
761
+ bias,
762
+ padding_mode,
763
+ **factory_kwargs,
764
+ )
765
+
766
+ # dilation being an optional parameter is for backwards
767
+ # compatibility
768
+ def _output_padding(
769
+ self,
770
+ input: Tensor,
771
+ output_size: Optional[List[int]],
772
+ stride: List[int],
773
+ padding: List[int],
774
+ kernel_size: List[int],
775
+ num_spatial_dims: int,
776
+ dilation: Optional[List[int]] = None,
777
+ ) -> List[int]:
778
+ if output_size is None:
779
+ ret = _single(self.output_padding) # converting to list if was not already
780
+ else:
781
+ has_batch_dim = input.dim() == num_spatial_dims + 2
782
+ num_non_spatial_dims = 2 if has_batch_dim else 1
783
+ if len(output_size) == num_non_spatial_dims + num_spatial_dims:
784
+ output_size = output_size[num_non_spatial_dims:]
785
+ if len(output_size) != num_spatial_dims:
786
+ raise ValueError(
787
+ f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} "
788
+ f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
789
+ )
790
+
791
+ min_sizes = torch.jit.annotate(List[int], [])
792
+ max_sizes = torch.jit.annotate(List[int], [])
793
+ for d in range(num_spatial_dims):
794
+ dim_size = (
795
+ (input.size(d + num_non_spatial_dims) - 1) * stride[d]
796
+ - 2 * padding[d]
797
+ + (dilation[d] if dilation is not None else 1)
798
+ * (kernel_size[d] - 1)
799
+ + 1
800
+ )
801
+ min_sizes.append(dim_size)
802
+ max_sizes.append(min_sizes[d] + stride[d] - 1)
803
+
804
+ for i in range(len(output_size)):
805
+ size = output_size[i]
806
+ min_size = min_sizes[i]
807
+ max_size = max_sizes[i]
808
+ if size < min_size or size > max_size:
809
+ raise ValueError(
810
+ f"requested an output size of {output_size}, but valid sizes range "
811
+ f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
812
+ )
813
+
814
+ res = torch.jit.annotate(List[int], [])
815
+ for d in range(num_spatial_dims):
816
+ res.append(output_size[d] - min_sizes[d])
817
+
818
+ ret = res
819
+ return ret
820
+
821
+
822
+ class ConvTranspose1d(_ConvTransposeNd):
823
+ __doc__ = (
824
+ r"""Applies a 1D transposed convolution operator over an input image
825
+ composed of several input planes.
826
+
827
+ This module can be seen as the gradient of Conv1d with respect to its input.
828
+ It is also known as a fractionally-strided convolution or
829
+ a deconvolution (although it is not an actual deconvolution operation as it does
830
+ not compute a true inverse of convolution). For more information, see the visualizations
831
+ `here`_ and the `Deconvolutional Networks`_ paper.
832
+
833
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
834
+
835
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
836
+
837
+ * :attr:`stride` controls the stride for the cross-correlation.
838
+
839
+ * :attr:`padding` controls the amount of implicit zero padding on both
840
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
841
+ below for details.
842
+
843
+ * :attr:`output_padding` controls the additional size added to one side
844
+ of the output shape. See note below for details.
845
+ """
846
+ """
847
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
848
+ It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
849
+ """
850
+ r"""
851
+ {groups_note}
852
+
853
+ Note:
854
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
855
+ amount of zero padding to both sizes of the input. This is set so that
856
+ when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d`
857
+ are initialized with same parameters, they are inverses of each other in
858
+ regard to the input and output shapes. However, when ``stride > 1``,
859
+ :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output
860
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
861
+ effectively increasing the calculated output shape on one side. Note
862
+ that :attr:`output_padding` is only used to find output shape, but does
863
+ not actually add zero-padding to output.
864
+
865
+ Note:
866
+ In some circumstances when using the CUDA backend with CuDNN, this operator
867
+ may select a nondeterministic algorithm to increase performance. If this is
868
+ undesirable, you can try to make the operation deterministic (potentially at
869
+ a performance cost) by setting ``torch.backends.cudnn.deterministic =
870
+ True``.
871
+ Please see the notes on :doc:`/notes/randomness` for background.
872
+
873
+
874
+ Args:
875
+ in_channels (int): Number of channels in the input image
876
+ out_channels (int): Number of channels produced by the convolution
877
+ kernel_size (int or tuple): Size of the convolving kernel
878
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
879
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
880
+ will be added to both sides of the input. Default: 0
881
+ output_padding (int or tuple, optional): Additional size added to one side
882
+ of the output shape. Default: 0
883
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
884
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
885
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
886
+ """.format(
887
+ **reproducibility_notes, **convolution_notes
888
+ )
889
+ + r"""
890
+
891
+ Shape:
892
+ - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
893
+ - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
894
+
895
+ .. math::
896
+ L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation}
897
+ \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1
898
+
899
+ Attributes:
900
+ weight (Tensor): the learnable weights of the module of shape
901
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
902
+ :math:`\text{kernel\_size})`.
903
+ The values of these weights are sampled from
904
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
905
+ :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
906
+ bias (Tensor): the learnable bias of the module of shape (out_channels).
907
+ If :attr:`bias` is ``True``, then the values of these weights are
908
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
909
+ :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
910
+
911
+ .. _`here`:
912
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
913
+
914
+ .. _`Deconvolutional Networks`:
915
+ https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
916
+ """
917
+ )
918
+
919
+ def __init__(
920
+ self,
921
+ in_channels: int,
922
+ out_channels: int,
923
+ kernel_size: _size_1_t,
924
+ stride: _size_1_t = 1,
925
+ padding: _size_1_t = 0,
926
+ output_padding: _size_1_t = 0,
927
+ groups: int = 1,
928
+ bias: bool = True,
929
+ dilation: _size_1_t = 1,
930
+ padding_mode: str = "zeros",
931
+ device=None,
932
+ dtype=None,
933
+ ) -> None:
934
+ factory_kwargs = {"device": device, "dtype": dtype}
935
+ kernel_size = _single(kernel_size)
936
+ stride = _single(stride)
937
+ padding = _single(padding)
938
+ dilation = _single(dilation)
939
+ output_padding = _single(output_padding)
940
+ super().__init__(
941
+ in_channels,
942
+ out_channels,
943
+ kernel_size,
944
+ stride,
945
+ padding,
946
+ dilation,
947
+ True,
948
+ output_padding,
949
+ groups,
950
+ bias,
951
+ padding_mode,
952
+ **factory_kwargs,
953
+ )
954
+
955
+ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
956
+ if self.padding_mode != "zeros":
957
+ raise ValueError(
958
+ "Only `zeros` padding mode is supported for ConvTranspose1d"
959
+ )
960
+
961
+ assert isinstance(self.padding, tuple)
962
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
963
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
964
+ num_spatial_dims = 1
965
+ output_padding = self._output_padding(
966
+ input,
967
+ output_size,
968
+ self.stride, # type: ignore[arg-type]
969
+ self.padding, # type: ignore[arg-type]
970
+ self.kernel_size, # type: ignore[arg-type]
971
+ num_spatial_dims,
972
+ self.dilation, # type: ignore[arg-type]
973
+ )
974
+ return F.conv_transpose1d(
975
+ input,
976
+ self.weight,
977
+ self.bias,
978
+ self.stride,
979
+ self.padding,
980
+ output_padding,
981
+ self.groups,
982
+ self.dilation,
983
+ )
984
+
985
+
986
+ class ConvTranspose2d(_ConvTransposeNd):
987
+ __doc__ = (
988
+ r"""Applies a 2D transposed convolution operator over an input image
989
+ composed of several input planes.
990
+
991
+ This module can be seen as the gradient of Conv2d with respect to its input.
992
+ It is also known as a fractionally-strided convolution or
993
+ a deconvolution (although it is not an actual deconvolution operation as it does
994
+ not compute a true inverse of convolution). For more information, see the visualizations
995
+ `here`_ and the `Deconvolutional Networks`_ paper.
996
+
997
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
998
+
999
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
1000
+
1001
+ * :attr:`stride` controls the stride for the cross-correlation.
1002
+
1003
+ * :attr:`padding` controls the amount of implicit zero padding on both
1004
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
1005
+ below for details.
1006
+
1007
+ * :attr:`output_padding` controls the additional size added to one side
1008
+ of the output shape. See note below for details.
1009
+ """
1010
+ """
1011
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
1012
+ It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
1013
+ """
1014
+ r"""
1015
+ {groups_note}
1016
+
1017
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
1018
+ can either be:
1019
+
1020
+ - a single ``int`` -- in which case the same value is used for the height and width dimensions
1021
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
1022
+ and the second `int` for the width dimension
1023
+
1024
+ Note:
1025
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
1026
+ amount of zero padding to both sizes of the input. This is set so that
1027
+ when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d`
1028
+ are initialized with same parameters, they are inverses of each other in
1029
+ regard to the input and output shapes. However, when ``stride > 1``,
1030
+ :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output
1031
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
1032
+ effectively increasing the calculated output shape on one side. Note
1033
+ that :attr:`output_padding` is only used to find output shape, but does
1034
+ not actually add zero-padding to output.
1035
+
1036
+ Note:
1037
+ {cudnn_reproducibility_note}
1038
+
1039
+ Args:
1040
+ in_channels (int): Number of channels in the input image
1041
+ out_channels (int): Number of channels produced by the convolution
1042
+ kernel_size (int or tuple): Size of the convolving kernel
1043
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1044
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1045
+ will be added to both sides of each dimension in the input. Default: 0
1046
+ output_padding (int or tuple, optional): Additional size added to one side
1047
+ of each dimension in the output shape. Default: 0
1048
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1049
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1050
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1051
+ """.format(
1052
+ **reproducibility_notes, **convolution_notes
1053
+ )
1054
+ + r"""
1055
+
1056
+ Shape:
1057
+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
1058
+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
1059
+
1060
+ .. math::
1061
+ H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
1062
+ \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
1063
+ .. math::
1064
+ W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
1065
+ \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
1066
+
1067
+ Attributes:
1068
+ weight (Tensor): the learnable weights of the module of shape
1069
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
1070
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
1071
+ The values of these weights are sampled from
1072
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1073
+ :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
1074
+ bias (Tensor): the learnable bias of the module of shape (out_channels)
1075
+ If :attr:`bias` is ``True``, then the values of these weights are
1076
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1077
+ :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
1078
+
1079
+ Examples::
1080
+
1081
+ >>> # With square kernels and equal stride
1082
+ >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
1083
+ >>> # non-square kernels and unequal stride and with padding
1084
+ >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
1085
+ >>> input = torch.randn(20, 16, 50, 100)
1086
+ >>> output = m(input)
1087
+ >>> # exact output size can be also specified as an argument
1088
+ >>> input = torch.randn(1, 16, 12, 12)
1089
+ >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
1090
+ >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
1091
+ >>> h = downsample(input)
1092
+ >>> h.size()
1093
+ torch.Size([1, 16, 6, 6])
1094
+ >>> output = upsample(h, output_size=input.size())
1095
+ >>> output.size()
1096
+ torch.Size([1, 16, 12, 12])
1097
+
1098
+ .. _`here`:
1099
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
1100
+
1101
+ .. _`Deconvolutional Networks`:
1102
+ https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
1103
+ """
1104
+ )
1105
+
1106
+ def __init__(
1107
+ self,
1108
+ in_channels: int,
1109
+ out_channels: int,
1110
+ kernel_size: _size_2_t,
1111
+ stride: _size_2_t = 1,
1112
+ padding: _size_2_t = 0,
1113
+ output_padding: _size_2_t = 0,
1114
+ groups: int = 1,
1115
+ bias: bool = True,
1116
+ dilation: _size_2_t = 1,
1117
+ padding_mode: str = "zeros",
1118
+ device=None,
1119
+ dtype=None,
1120
+ ) -> None:
1121
+ factory_kwargs = {"device": device, "dtype": dtype}
1122
+ kernel_size = _pair(kernel_size)
1123
+ stride = _pair(stride)
1124
+ padding = _pair(padding)
1125
+ dilation = _pair(dilation)
1126
+ output_padding = _pair(output_padding)
1127
+ super().__init__(
1128
+ in_channels,
1129
+ out_channels,
1130
+ kernel_size,
1131
+ stride,
1132
+ padding,
1133
+ dilation,
1134
+ True,
1135
+ output_padding,
1136
+ groups,
1137
+ bias,
1138
+ padding_mode,
1139
+ **factory_kwargs,
1140
+ )
1141
+
1142
+ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
1143
+ if self.padding_mode != "zeros":
1144
+ raise ValueError(
1145
+ "Only `zeros` padding mode is supported for ConvTranspose2d"
1146
+ )
1147
+
1148
+ assert isinstance(self.padding, tuple)
1149
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
1150
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
1151
+ num_spatial_dims = 2
1152
+ output_padding = self._output_padding(
1153
+ input,
1154
+ output_size,
1155
+ self.stride, # type: ignore[arg-type]
1156
+ self.padding, # type: ignore[arg-type]
1157
+ self.kernel_size, # type: ignore[arg-type]
1158
+ num_spatial_dims,
1159
+ self.dilation, # type: ignore[arg-type]
1160
+ )
1161
+
1162
+ return F.conv_transpose2d(
1163
+ input,
1164
+ self.weight,
1165
+ self.bias,
1166
+ self.stride,
1167
+ self.padding,
1168
+ output_padding,
1169
+ self.groups,
1170
+ self.dilation,
1171
+ )
1172
+
1173
+
1174
+ class ConvTranspose3d(_ConvTransposeNd):
1175
+ __doc__ = (
1176
+ r"""Applies a 3D transposed convolution operator over an input image composed of several input
1177
+ planes.
1178
+ The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
1179
+ and sums over the outputs from all input feature planes.
1180
+
1181
+ This module can be seen as the gradient of Conv3d with respect to its input.
1182
+ It is also known as a fractionally-strided convolution or
1183
+ a deconvolution (although it is not an actual deconvolution operation as it does
1184
+ not compute a true inverse of convolution). For more information, see the visualizations
1185
+ `here`_ and the `Deconvolutional Networks`_ paper.
1186
+
1187
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
1188
+
1189
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
1190
+
1191
+ * :attr:`stride` controls the stride for the cross-correlation.
1192
+
1193
+ * :attr:`padding` controls the amount of implicit zero padding on both
1194
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
1195
+ below for details.
1196
+
1197
+ * :attr:`output_padding` controls the additional size added to one side
1198
+ of the output shape. See note below for details.
1199
+ """
1200
+ """
1201
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
1202
+ It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
1203
+ """
1204
+ r"""
1205
+ {groups_note}
1206
+
1207
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
1208
+ can either be:
1209
+
1210
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions
1211
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
1212
+ the second `int` for the height dimension and the third `int` for the width dimension
1213
+
1214
+ Note:
1215
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
1216
+ amount of zero padding to both sizes of the input. This is set so that
1217
+ when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d`
1218
+ are initialized with same parameters, they are inverses of each other in
1219
+ regard to the input and output shapes. However, when ``stride > 1``,
1220
+ :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output
1221
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
1222
+ effectively increasing the calculated output shape on one side. Note
1223
+ that :attr:`output_padding` is only used to find output shape, but does
1224
+ not actually add zero-padding to output.
1225
+
1226
+ Note:
1227
+ {cudnn_reproducibility_note}
1228
+
1229
+ Args:
1230
+ in_channels (int): Number of channels in the input image
1231
+ out_channels (int): Number of channels produced by the convolution
1232
+ kernel_size (int or tuple): Size of the convolving kernel
1233
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1234
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1235
+ will be added to both sides of each dimension in the input. Default: 0
1236
+ output_padding (int or tuple, optional): Additional size added to one side
1237
+ of each dimension in the output shape. Default: 0
1238
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1239
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1240
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1241
+ """.format(
1242
+ **reproducibility_notes, **convolution_notes
1243
+ )
1244
+ + r"""
1245
+
1246
+ Shape:
1247
+ - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
1248
+ - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or
1249
+ :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where
1250
+
1251
+ .. math::
1252
+ D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
1253
+ \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
1254
+ .. math::
1255
+ H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
1256
+ \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
1257
+ .. math::
1258
+ W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2]
1259
+ \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1
1260
+
1261
+
1262
+ Attributes:
1263
+ weight (Tensor): the learnable weights of the module of shape
1264
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
1265
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
1266
+ The values of these weights are sampled from
1267
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1268
+ :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
1269
+ bias (Tensor): the learnable bias of the module of shape (out_channels)
1270
+ If :attr:`bias` is ``True``, then the values of these weights are
1271
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1272
+ :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
1273
+
1274
+ Examples::
1275
+
1276
+ >>> # With square kernels and equal stride
1277
+ >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
1278
+ >>> # non-square kernels and unequal stride and with padding
1279
+ >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
1280
+ >>> input = torch.randn(20, 16, 10, 50, 100)
1281
+ >>> output = m(input)
1282
+
1283
+ .. _`here`:
1284
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
1285
+
1286
+ .. _`Deconvolutional Networks`:
1287
+ https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
1288
+ """
1289
+ )
1290
+
1291
+ def __init__(
1292
+ self,
1293
+ in_channels: int,
1294
+ out_channels: int,
1295
+ kernel_size: _size_3_t,
1296
+ stride: _size_3_t = 1,
1297
+ padding: _size_3_t = 0,
1298
+ output_padding: _size_3_t = 0,
1299
+ groups: int = 1,
1300
+ bias: bool = True,
1301
+ dilation: _size_3_t = 1,
1302
+ padding_mode: str = "zeros",
1303
+ device=None,
1304
+ dtype=None,
1305
+ ) -> None:
1306
+ factory_kwargs = {"device": device, "dtype": dtype}
1307
+ kernel_size = _triple(kernel_size)
1308
+ stride = _triple(stride)
1309
+ padding = _triple(padding)
1310
+ dilation = _triple(dilation)
1311
+ output_padding = _triple(output_padding)
1312
+ super().__init__(
1313
+ in_channels,
1314
+ out_channels,
1315
+ kernel_size,
1316
+ stride,
1317
+ padding,
1318
+ dilation,
1319
+ True,
1320
+ output_padding,
1321
+ groups,
1322
+ bias,
1323
+ padding_mode,
1324
+ **factory_kwargs,
1325
+ )
1326
+
1327
+ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
1328
+ if self.padding_mode != "zeros":
1329
+ raise ValueError(
1330
+ "Only `zeros` padding mode is supported for ConvTranspose3d"
1331
+ )
1332
+
1333
+ assert isinstance(self.padding, tuple)
1334
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
1335
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
1336
+ num_spatial_dims = 3
1337
+ output_padding = self._output_padding(
1338
+ input,
1339
+ output_size,
1340
+ self.stride, # type: ignore[arg-type]
1341
+ self.padding, # type: ignore[arg-type]
1342
+ self.kernel_size, # type: ignore[arg-type]
1343
+ num_spatial_dims,
1344
+ self.dilation, # type: ignore[arg-type]
1345
+ )
1346
+
1347
+ return F.conv_transpose3d(
1348
+ input,
1349
+ self.weight,
1350
+ self.bias,
1351
+ self.stride,
1352
+ self.padding,
1353
+ output_padding,
1354
+ self.groups,
1355
+ self.dilation,
1356
+ )
1357
+
1358
+
1359
+ # TODO: Deprecate and remove the following alias `_ConvTransposeMixin`.
1360
+ #
1361
+ # `_ConvTransposeMixin` was a mixin that was removed. It is meant to be used
1362
+ # with `_ConvNd` to construct actual module classes that implements conv
1363
+ # transpose ops:
1364
+ #
1365
+ # class MyConvTranspose(_ConvNd, _ConvTransposeMixin):
1366
+ # ...
1367
+ #
1368
+ # In PyTorch, it has been replaced by `_ConvTransposeNd`, which is a proper
1369
+ # subclass of `_ConvNd`. However, some user code in the wild still (incorrectly)
1370
+ # use the internal class `_ConvTransposeMixin`. Hence, we provide this alias
1371
+ # for BC, because it is cheap and easy for us to do so, even though that
1372
+ # `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as
1373
+ # above would still work).
1374
+ class _ConvTransposeMixin(_ConvTransposeNd):
1375
+ @deprecated(
1376
+ "`_ConvTransposeMixin` is a deprecated internal class. "
1377
+ "Please consider using public APIs.",
1378
+ category=FutureWarning,
1379
+ )
1380
+ def __init__(self, *args, **kwargs):
1381
+ super().__init__(*args, **kwargs)
1382
+
1383
+
1384
+ # TODO: Conv2dLocal
1385
+ # TODO: Conv2dMap
1386
+ # TODO: ConvTranspose2dMap
1387
+
1388
+
1389
+ class _LazyConvXdMixin(LazyModuleMixin):
1390
+ groups: int
1391
+ transposed: bool
1392
+ in_channels: int
1393
+ out_channels: int
1394
+ kernel_size: Tuple[int, ...]
1395
+ weight: UninitializedParameter
1396
+ bias: UninitializedParameter
1397
+
1398
+ def reset_parameters(self) -> None:
1399
+ # has_uninitialized_params is defined in parent class and it is using a protocol on self
1400
+ if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc]
1401
+ # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined
1402
+ # in super class. Turns out that it is defined in _ConvND which is inherited by any class
1403
+ # that also inherits _LazyConvXdMixin
1404
+ super().reset_parameters() # type: ignore[misc]
1405
+
1406
+ # Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin
1407
+ def initialize_parameters(self, input: Tensor, *args, **kwargs) -> None: # type: ignore[override]
1408
+ # defined by parent class but using a protocol
1409
+ if self.has_uninitialized_params(): # type: ignore[misc]
1410
+ self.in_channels = self._get_in_channels(input)
1411
+ if self.in_channels % self.groups != 0:
1412
+ raise ValueError("in_channels must be divisible by groups")
1413
+ assert isinstance(self.weight, UninitializedParameter)
1414
+ if self.transposed:
1415
+ self.weight.materialize(
1416
+ (
1417
+ self.in_channels,
1418
+ self.out_channels // self.groups,
1419
+ *self.kernel_size,
1420
+ )
1421
+ )
1422
+ else:
1423
+ self.weight.materialize(
1424
+ (
1425
+ self.out_channels,
1426
+ self.in_channels // self.groups,
1427
+ *self.kernel_size,
1428
+ )
1429
+ )
1430
+ if self.bias is not None:
1431
+ assert isinstance(self.bias, UninitializedParameter)
1432
+ self.bias.materialize((self.out_channels,))
1433
+ self.reset_parameters()
1434
+
1435
+ # Function to extract in_channels from first input.
1436
+ def _get_in_channels(self, input: Tensor) -> int:
1437
+ num_spatial_dims = self._get_num_spatial_dims()
1438
+ num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim
1439
+ num_dims_batch = num_dims_no_batch + 1
1440
+ if input.dim() not in (num_dims_no_batch, num_dims_batch):
1441
+ raise RuntimeError(
1442
+ f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input "
1443
+ f"to {self.__class__.__name__}, but "
1444
+ f"got input of size: {input.shape}"
1445
+ )
1446
+ return input.shape[1] if input.dim() == num_dims_batch else input.shape[0]
1447
+
1448
+ # Function to return the number of spatial dims expected for inputs to the module.
1449
+ # This is expected to be implemented by subclasses.
1450
+ def _get_num_spatial_dims(self) -> int:
1451
+ raise NotImplementedError
1452
+
1453
+
1454
+ # LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter
1455
+ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
1456
+ r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument.
1457
+
1458
+ The ``in_channels`` argument of the :class:`Conv1d` is inferred from the ``input.size(1)``.
1459
+ The attributes that will be lazily initialized are `weight` and `bias`.
1460
+
1461
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1462
+ on lazy modules and their limitations.
1463
+
1464
+ Args:
1465
+ out_channels (int): Number of channels produced by the convolution
1466
+ kernel_size (int or tuple): Size of the convolving kernel
1467
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1468
+ padding (int or tuple, optional): Zero-padding added to both sides of
1469
+ the input. Default: 0
1470
+ dilation (int or tuple, optional): Spacing between kernel
1471
+ elements. Default: 1
1472
+ groups (int, optional): Number of blocked connections from input
1473
+ channels to output channels. Default: 1
1474
+ bias (bool, optional): If ``True``, adds a learnable bias to the
1475
+ output. Default: ``True``
1476
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1477
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1478
+
1479
+ .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1480
+ """
1481
+
1482
+ # super class define this variable as None. "type: ignore[..] is required
1483
+ # since we are redefining the variable.
1484
+ cls_to_become = Conv1d # type: ignore[assignment]
1485
+
1486
+ def __init__(
1487
+ self,
1488
+ out_channels: int,
1489
+ kernel_size: _size_1_t,
1490
+ stride: _size_1_t = 1,
1491
+ padding: _size_1_t = 0,
1492
+ dilation: _size_1_t = 1,
1493
+ groups: int = 1,
1494
+ bias: bool = True,
1495
+ padding_mode: str = "zeros",
1496
+ device=None,
1497
+ dtype=None,
1498
+ ) -> None:
1499
+ factory_kwargs = {"device": device, "dtype": dtype}
1500
+ super().__init__(
1501
+ 0,
1502
+ 0,
1503
+ kernel_size,
1504
+ stride,
1505
+ padding,
1506
+ dilation,
1507
+ groups,
1508
+ # bias is hardcoded to False to avoid creating tensor
1509
+ # that will soon be overwritten.
1510
+ False,
1511
+ padding_mode,
1512
+ **factory_kwargs,
1513
+ )
1514
+ self.weight = UninitializedParameter(**factory_kwargs)
1515
+ self.out_channels = out_channels
1516
+ if bias:
1517
+ self.bias = UninitializedParameter(**factory_kwargs)
1518
+
1519
+ def _get_num_spatial_dims(self) -> int:
1520
+ return 1
1521
+
1522
+
1523
+ # LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
1524
+ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
1525
+ r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument.
1526
+
1527
+ The ``in_channels`` argument of the :class:`Conv2d` that is inferred from the ``input.size(1)``.
1528
+ The attributes that will be lazily initialized are `weight` and `bias`.
1529
+
1530
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1531
+ on lazy modules and their limitations.
1532
+
1533
+ Args:
1534
+ out_channels (int): Number of channels produced by the convolution
1535
+ kernel_size (int or tuple): Size of the convolving kernel
1536
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1537
+ padding (int or tuple, optional): Zero-padding added to both sides of
1538
+ the input. Default: 0
1539
+ dilation (int or tuple, optional): Spacing between kernel
1540
+ elements. Default: 1
1541
+ groups (int, optional): Number of blocked connections from input
1542
+ channels to output channels. Default: 1
1543
+ bias (bool, optional): If ``True``, adds a learnable bias to the
1544
+ output. Default: ``True``
1545
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1546
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1547
+
1548
+ .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1549
+ """
1550
+
1551
+ # super class define this variable as None. "type: ignore[..] is required
1552
+ # since we are redefining the variable.
1553
+ cls_to_become = Conv2d # type: ignore[assignment]
1554
+
1555
+ def __init__(
1556
+ self,
1557
+ out_channels: int,
1558
+ kernel_size: _size_2_t,
1559
+ stride: _size_2_t = 1,
1560
+ padding: _size_2_t = 0,
1561
+ dilation: _size_2_t = 1,
1562
+ groups: int = 1,
1563
+ bias: bool = True,
1564
+ padding_mode: str = "zeros", # TODO: refine this type
1565
+ device=None,
1566
+ dtype=None,
1567
+ ) -> None:
1568
+ factory_kwargs = {"device": device, "dtype": dtype}
1569
+ super().__init__(
1570
+ 0,
1571
+ 0,
1572
+ kernel_size,
1573
+ stride,
1574
+ padding,
1575
+ dilation,
1576
+ groups,
1577
+ # bias is hardcoded to False to avoid creating tensor
1578
+ # that will soon be overwritten.
1579
+ False,
1580
+ padding_mode,
1581
+ **factory_kwargs,
1582
+ )
1583
+ self.weight = UninitializedParameter(**factory_kwargs)
1584
+ self.out_channels = out_channels
1585
+ if bias:
1586
+ self.bias = UninitializedParameter(**factory_kwargs)
1587
+
1588
+ def _get_num_spatial_dims(self) -> int:
1589
+ return 2
1590
+
1591
+
1592
+ # LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
1593
+ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
1594
+ r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument.
1595
+
1596
+ The ``in_channels`` argument of the :class:`Conv3d` that is inferred from
1597
+ the ``input.size(1)``.
1598
+ The attributes that will be lazily initialized are `weight` and `bias`.
1599
+
1600
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1601
+ on lazy modules and their limitations.
1602
+
1603
+ Args:
1604
+ out_channels (int): Number of channels produced by the convolution
1605
+ kernel_size (int or tuple): Size of the convolving kernel
1606
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1607
+ padding (int or tuple, optional): Zero-padding added to both sides of
1608
+ the input. Default: 0
1609
+ dilation (int or tuple, optional): Spacing between kernel
1610
+ elements. Default: 1
1611
+ groups (int, optional): Number of blocked connections from input
1612
+ channels to output channels. Default: 1
1613
+ bias (bool, optional): If ``True``, adds a learnable bias to the
1614
+ output. Default: ``True``
1615
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1616
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1617
+
1618
+ .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1619
+ """
1620
+
1621
+ # super class define this variable as None. "type: ignore[..] is required
1622
+ # since we are redefining the variable.
1623
+ cls_to_become = Conv3d # type: ignore[assignment]
1624
+
1625
+ def __init__(
1626
+ self,
1627
+ out_channels: int,
1628
+ kernel_size: _size_3_t,
1629
+ stride: _size_3_t = 1,
1630
+ padding: _size_3_t = 0,
1631
+ dilation: _size_3_t = 1,
1632
+ groups: int = 1,
1633
+ bias: bool = True,
1634
+ padding_mode: str = "zeros",
1635
+ device=None,
1636
+ dtype=None,
1637
+ ) -> None:
1638
+ factory_kwargs = {"device": device, "dtype": dtype}
1639
+ super().__init__(
1640
+ 0,
1641
+ 0,
1642
+ kernel_size,
1643
+ stride,
1644
+ padding,
1645
+ dilation,
1646
+ groups,
1647
+ # bias is hardcoded to False to avoid creating tensor
1648
+ # that will soon be overwritten.
1649
+ False,
1650
+ padding_mode,
1651
+ **factory_kwargs,
1652
+ )
1653
+ self.weight = UninitializedParameter(**factory_kwargs)
1654
+ self.out_channels = out_channels
1655
+ if bias:
1656
+ self.bias = UninitializedParameter(**factory_kwargs)
1657
+
1658
+ def _get_num_spatial_dims(self) -> int:
1659
+ return 3
1660
+
1661
+
1662
+ # LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
1663
+ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc]
1664
+ r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument.
1665
+
1666
+ The ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from
1667
+ the ``input.size(1)``.
1668
+ The attributes that will be lazily initialized are `weight` and `bias`.
1669
+
1670
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1671
+ on lazy modules and their limitations.
1672
+
1673
+ Args:
1674
+ out_channels (int): Number of channels produced by the convolution
1675
+ kernel_size (int or tuple): Size of the convolving kernel
1676
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1677
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1678
+ will be added to both sides of the input. Default: 0
1679
+ output_padding (int or tuple, optional): Additional size added to one side
1680
+ of the output shape. Default: 0
1681
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1682
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1683
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1684
+
1685
+ .. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1686
+ """
1687
+
1688
+ # super class define this variable as None. "type: ignore[..] is required
1689
+ # since we are redefining the variable.
1690
+ cls_to_become = ConvTranspose1d # type: ignore[assignment]
1691
+
1692
+ def __init__(
1693
+ self,
1694
+ out_channels: int,
1695
+ kernel_size: _size_1_t,
1696
+ stride: _size_1_t = 1,
1697
+ padding: _size_1_t = 0,
1698
+ output_padding: _size_1_t = 0,
1699
+ groups: int = 1,
1700
+ bias: bool = True,
1701
+ dilation: _size_1_t = 1,
1702
+ padding_mode: str = "zeros",
1703
+ device=None,
1704
+ dtype=None,
1705
+ ) -> None:
1706
+ factory_kwargs = {"device": device, "dtype": dtype}
1707
+ super().__init__(
1708
+ 0,
1709
+ 0,
1710
+ kernel_size,
1711
+ stride,
1712
+ padding,
1713
+ output_padding,
1714
+ groups,
1715
+ # bias is hardcoded to False to avoid creating tensor
1716
+ # that will soon be overwritten.
1717
+ False,
1718
+ dilation,
1719
+ padding_mode,
1720
+ **factory_kwargs,
1721
+ )
1722
+ self.weight = UninitializedParameter(**factory_kwargs)
1723
+ self.out_channels = out_channels
1724
+ if bias:
1725
+ self.bias = UninitializedParameter(**factory_kwargs)
1726
+
1727
+ def _get_num_spatial_dims(self) -> int:
1728
+ return 1
1729
+
1730
+
1731
+ # LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
1732
+ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc]
1733
+ r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument.
1734
+
1735
+ The ``in_channels`` argument of the :class:`ConvTranspose2d` is inferred from
1736
+ the ``input.size(1)``.
1737
+ The attributes that will be lazily initialized are `weight` and `bias`.
1738
+
1739
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1740
+ on lazy modules and their limitations.
1741
+
1742
+ Args:
1743
+ out_channels (int): Number of channels produced by the convolution
1744
+ kernel_size (int or tuple): Size of the convolving kernel
1745
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1746
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1747
+ will be added to both sides of each dimension in the input. Default: 0
1748
+ output_padding (int or tuple, optional): Additional size added to one side
1749
+ of each dimension in the output shape. Default: 0
1750
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1751
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1752
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1753
+
1754
+ .. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1755
+ """
1756
+
1757
+ # super class define this variable as None. "type: ignore[..] is required
1758
+ # since we are redefining the variable.
1759
+ cls_to_become = ConvTranspose2d # type: ignore[assignment]
1760
+
1761
+ def __init__(
1762
+ self,
1763
+ out_channels: int,
1764
+ kernel_size: _size_2_t,
1765
+ stride: _size_2_t = 1,
1766
+ padding: _size_2_t = 0,
1767
+ output_padding: _size_2_t = 0,
1768
+ groups: int = 1,
1769
+ bias: bool = True,
1770
+ dilation: int = 1,
1771
+ padding_mode: str = "zeros",
1772
+ device=None,
1773
+ dtype=None,
1774
+ ) -> None:
1775
+ factory_kwargs = {"device": device, "dtype": dtype}
1776
+ super().__init__(
1777
+ 0,
1778
+ 0,
1779
+ kernel_size,
1780
+ stride,
1781
+ padding,
1782
+ output_padding,
1783
+ groups,
1784
+ # bias is hardcoded to False to avoid creating tensor
1785
+ # that will soon be overwritten.
1786
+ False,
1787
+ dilation,
1788
+ padding_mode,
1789
+ **factory_kwargs,
1790
+ )
1791
+ self.weight = UninitializedParameter(**factory_kwargs)
1792
+ self.out_channels = out_channels
1793
+ if bias:
1794
+ self.bias = UninitializedParameter(**factory_kwargs)
1795
+
1796
+ def _get_num_spatial_dims(self) -> int:
1797
+ return 2
1798
+
1799
+
1800
+ # LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
1801
+ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc]
1802
+ r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument.
1803
+
1804
+ The ``in_channels`` argument of the :class:`ConvTranspose3d` is inferred from
1805
+ the ``input.size(1)``.
1806
+ The attributes that will be lazily initialized are `weight` and `bias`.
1807
+
1808
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1809
+ on lazy modules and their limitations.
1810
+
1811
+ Args:
1812
+ out_channels (int): Number of channels produced by the convolution
1813
+ kernel_size (int or tuple): Size of the convolving kernel
1814
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1815
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1816
+ will be added to both sides of each dimension in the input. Default: 0
1817
+ output_padding (int or tuple, optional): Additional size added to one side
1818
+ of each dimension in the output shape. Default: 0
1819
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1820
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1821
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1822
+
1823
+ .. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1824
+ """
1825
+
1826
+ # super class define this variable as None. "type: ignore[..] is required
1827
+ # since we are redefining the variable.
1828
+ cls_to_become = ConvTranspose3d # type: ignore[assignment]
1829
+
1830
+ def __init__(
1831
+ self,
1832
+ out_channels: int,
1833
+ kernel_size: _size_3_t,
1834
+ stride: _size_3_t = 1,
1835
+ padding: _size_3_t = 0,
1836
+ output_padding: _size_3_t = 0,
1837
+ groups: int = 1,
1838
+ bias: bool = True,
1839
+ dilation: _size_3_t = 1,
1840
+ padding_mode: str = "zeros",
1841
+ device=None,
1842
+ dtype=None,
1843
+ ) -> None:
1844
+ factory_kwargs = {"device": device, "dtype": dtype}
1845
+ super().__init__(
1846
+ 0,
1847
+ 0,
1848
+ kernel_size,
1849
+ stride,
1850
+ padding,
1851
+ output_padding,
1852
+ groups,
1853
+ # bias is hardcoded to False to avoid creating tensor
1854
+ # that will soon be overwritten.
1855
+ False,
1856
+ dilation,
1857
+ padding_mode,
1858
+ **factory_kwargs,
1859
+ )
1860
+ self.weight = UninitializedParameter(**factory_kwargs)
1861
+ self.out_channels = out_channels
1862
+ if bias:
1863
+ self.bias = UninitializedParameter(**factory_kwargs)
1864
+
1865
+ def _get_num_spatial_dims(self) -> int:
1866
+ return 3