Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/Lib/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py +1 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py +6 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +6 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py +17 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py +7 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py +8 -0
- .venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py +6 -0
- .venv/Lib/site-packages/torch/nn/modules/__init__.py +334 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/_functions.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/activation.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/container.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/conv.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/distance.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/dropout.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/flatten.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/fold.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/lazy.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/linear.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/loss.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/module.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/normalization.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/padding.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/pooling.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/rnn.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/sparse.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/transformer.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/__pycache__/utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/modules/_functions.py +319 -0
- .venv/Lib/site-packages/torch/nn/modules/activation.py +1746 -0
- .venv/Lib/site-packages/torch/nn/modules/adaptive.py +330 -0
- .venv/Lib/site-packages/torch/nn/modules/batchnorm.py +883 -0
- .venv/Lib/site-packages/torch/nn/modules/channelshuffle.py +56 -0
- .venv/Lib/site-packages/torch/nn/modules/container.py +976 -0
- .venv/Lib/site-packages/torch/nn/modules/conv.py +1866 -0
.venv/Lib/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (736 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (420 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torch.nn.intrinsic.quantized.dynamic.modules import * # noqa: F403
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (269 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn.intrinsic.quantized.dynamic.modules.linear_relu import LinearReLU
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"LinearReLU",
|
| 6 |
+
]
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (331 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc
ADDED
|
Binary file (317 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"LinearReLU",
|
| 6 |
+
]
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn.intrinsic.quantized.modules.bn_relu import BNReLU2d, BNReLU3d
|
| 2 |
+
from torch.nn.intrinsic.quantized.modules.conv_relu import (
|
| 3 |
+
ConvReLU1d,
|
| 4 |
+
ConvReLU2d,
|
| 5 |
+
ConvReLU3d,
|
| 6 |
+
)
|
| 7 |
+
from torch.nn.intrinsic.quantized.modules.linear_relu import LinearReLU
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"LinearReLU",
|
| 12 |
+
"ConvReLU1d",
|
| 13 |
+
"ConvReLU2d",
|
| 14 |
+
"ConvReLU3d",
|
| 15 |
+
"BNReLU2d",
|
| 16 |
+
"BNReLU3d",
|
| 17 |
+
]
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (561 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc
ADDED
|
Binary file (323 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc
ADDED
|
Binary file (336 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc
ADDED
|
Binary file (301 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic.quantized import BNReLU2d, BNReLU3d
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"BNReLU2d",
|
| 6 |
+
"BNReLU3d",
|
| 7 |
+
]
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic.quantized import ConvReLU1d, ConvReLU2d, ConvReLU3d
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"ConvReLU1d",
|
| 6 |
+
"ConvReLU2d",
|
| 7 |
+
"ConvReLU3d",
|
| 8 |
+
]
|
.venv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic.quantized import LinearReLU
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"LinearReLU",
|
| 6 |
+
]
|
.venv/Lib/site-packages/torch/nn/modules/__init__.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module import Module # usort: skip
|
| 2 |
+
from .linear import Bilinear, Identity, LazyLinear, Linear # usort: skip
|
| 3 |
+
from .activation import (
|
| 4 |
+
CELU,
|
| 5 |
+
ELU,
|
| 6 |
+
GELU,
|
| 7 |
+
GLU,
|
| 8 |
+
Hardshrink,
|
| 9 |
+
Hardsigmoid,
|
| 10 |
+
Hardswish,
|
| 11 |
+
Hardtanh,
|
| 12 |
+
LeakyReLU,
|
| 13 |
+
LogSigmoid,
|
| 14 |
+
LogSoftmax,
|
| 15 |
+
Mish,
|
| 16 |
+
MultiheadAttention,
|
| 17 |
+
PReLU,
|
| 18 |
+
ReLU,
|
| 19 |
+
ReLU6,
|
| 20 |
+
RReLU,
|
| 21 |
+
SELU,
|
| 22 |
+
Sigmoid,
|
| 23 |
+
SiLU,
|
| 24 |
+
Softmax,
|
| 25 |
+
Softmax2d,
|
| 26 |
+
Softmin,
|
| 27 |
+
Softplus,
|
| 28 |
+
Softshrink,
|
| 29 |
+
Softsign,
|
| 30 |
+
Tanh,
|
| 31 |
+
Tanhshrink,
|
| 32 |
+
Threshold,
|
| 33 |
+
)
|
| 34 |
+
from .adaptive import AdaptiveLogSoftmaxWithLoss
|
| 35 |
+
from .batchnorm import (
|
| 36 |
+
BatchNorm1d,
|
| 37 |
+
BatchNorm2d,
|
| 38 |
+
BatchNorm3d,
|
| 39 |
+
LazyBatchNorm1d,
|
| 40 |
+
LazyBatchNorm2d,
|
| 41 |
+
LazyBatchNorm3d,
|
| 42 |
+
SyncBatchNorm,
|
| 43 |
+
)
|
| 44 |
+
from .channelshuffle import ChannelShuffle
|
| 45 |
+
from .container import (
|
| 46 |
+
Container,
|
| 47 |
+
ModuleDict,
|
| 48 |
+
ModuleList,
|
| 49 |
+
ParameterDict,
|
| 50 |
+
ParameterList,
|
| 51 |
+
Sequential,
|
| 52 |
+
)
|
| 53 |
+
from .conv import (
|
| 54 |
+
Conv1d,
|
| 55 |
+
Conv2d,
|
| 56 |
+
Conv3d,
|
| 57 |
+
ConvTranspose1d,
|
| 58 |
+
ConvTranspose2d,
|
| 59 |
+
ConvTranspose3d,
|
| 60 |
+
LazyConv1d,
|
| 61 |
+
LazyConv2d,
|
| 62 |
+
LazyConv3d,
|
| 63 |
+
LazyConvTranspose1d,
|
| 64 |
+
LazyConvTranspose2d,
|
| 65 |
+
LazyConvTranspose3d,
|
| 66 |
+
)
|
| 67 |
+
from .distance import CosineSimilarity, PairwiseDistance
|
| 68 |
+
from .dropout import (
|
| 69 |
+
AlphaDropout,
|
| 70 |
+
Dropout,
|
| 71 |
+
Dropout1d,
|
| 72 |
+
Dropout2d,
|
| 73 |
+
Dropout3d,
|
| 74 |
+
FeatureAlphaDropout,
|
| 75 |
+
)
|
| 76 |
+
from .flatten import Flatten, Unflatten
|
| 77 |
+
from .fold import Fold, Unfold
|
| 78 |
+
from .instancenorm import (
|
| 79 |
+
InstanceNorm1d,
|
| 80 |
+
InstanceNorm2d,
|
| 81 |
+
InstanceNorm3d,
|
| 82 |
+
LazyInstanceNorm1d,
|
| 83 |
+
LazyInstanceNorm2d,
|
| 84 |
+
LazyInstanceNorm3d,
|
| 85 |
+
)
|
| 86 |
+
from .loss import (
|
| 87 |
+
BCELoss,
|
| 88 |
+
BCEWithLogitsLoss,
|
| 89 |
+
CosineEmbeddingLoss,
|
| 90 |
+
CrossEntropyLoss,
|
| 91 |
+
CTCLoss,
|
| 92 |
+
GaussianNLLLoss,
|
| 93 |
+
HingeEmbeddingLoss,
|
| 94 |
+
HuberLoss,
|
| 95 |
+
KLDivLoss,
|
| 96 |
+
L1Loss,
|
| 97 |
+
MarginRankingLoss,
|
| 98 |
+
MSELoss,
|
| 99 |
+
MultiLabelMarginLoss,
|
| 100 |
+
MultiLabelSoftMarginLoss,
|
| 101 |
+
MultiMarginLoss,
|
| 102 |
+
NLLLoss,
|
| 103 |
+
NLLLoss2d,
|
| 104 |
+
PoissonNLLLoss,
|
| 105 |
+
SmoothL1Loss,
|
| 106 |
+
SoftMarginLoss,
|
| 107 |
+
TripletMarginLoss,
|
| 108 |
+
TripletMarginWithDistanceLoss,
|
| 109 |
+
)
|
| 110 |
+
from .normalization import (
|
| 111 |
+
CrossMapLRN2d,
|
| 112 |
+
GroupNorm,
|
| 113 |
+
LayerNorm,
|
| 114 |
+
LocalResponseNorm,
|
| 115 |
+
RMSNorm,
|
| 116 |
+
)
|
| 117 |
+
from .padding import (
|
| 118 |
+
CircularPad1d,
|
| 119 |
+
CircularPad2d,
|
| 120 |
+
CircularPad3d,
|
| 121 |
+
ConstantPad1d,
|
| 122 |
+
ConstantPad2d,
|
| 123 |
+
ConstantPad3d,
|
| 124 |
+
ReflectionPad1d,
|
| 125 |
+
ReflectionPad2d,
|
| 126 |
+
ReflectionPad3d,
|
| 127 |
+
ReplicationPad1d,
|
| 128 |
+
ReplicationPad2d,
|
| 129 |
+
ReplicationPad3d,
|
| 130 |
+
ZeroPad1d,
|
| 131 |
+
ZeroPad2d,
|
| 132 |
+
ZeroPad3d,
|
| 133 |
+
)
|
| 134 |
+
from .pixelshuffle import PixelShuffle, PixelUnshuffle
|
| 135 |
+
from .pooling import (
|
| 136 |
+
AdaptiveAvgPool1d,
|
| 137 |
+
AdaptiveAvgPool2d,
|
| 138 |
+
AdaptiveAvgPool3d,
|
| 139 |
+
AdaptiveMaxPool1d,
|
| 140 |
+
AdaptiveMaxPool2d,
|
| 141 |
+
AdaptiveMaxPool3d,
|
| 142 |
+
AvgPool1d,
|
| 143 |
+
AvgPool2d,
|
| 144 |
+
AvgPool3d,
|
| 145 |
+
FractionalMaxPool2d,
|
| 146 |
+
FractionalMaxPool3d,
|
| 147 |
+
LPPool1d,
|
| 148 |
+
LPPool2d,
|
| 149 |
+
LPPool3d,
|
| 150 |
+
MaxPool1d,
|
| 151 |
+
MaxPool2d,
|
| 152 |
+
MaxPool3d,
|
| 153 |
+
MaxUnpool1d,
|
| 154 |
+
MaxUnpool2d,
|
| 155 |
+
MaxUnpool3d,
|
| 156 |
+
)
|
| 157 |
+
from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNN, RNNBase, RNNCell, RNNCellBase
|
| 158 |
+
from .sparse import Embedding, EmbeddingBag
|
| 159 |
+
from .transformer import (
|
| 160 |
+
Transformer,
|
| 161 |
+
TransformerDecoder,
|
| 162 |
+
TransformerDecoderLayer,
|
| 163 |
+
TransformerEncoder,
|
| 164 |
+
TransformerEncoderLayer,
|
| 165 |
+
)
|
| 166 |
+
from .upsampling import Upsample, UpsamplingBilinear2d, UpsamplingNearest2d
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
__all__ = [
|
| 170 |
+
"AdaptiveAvgPool1d",
|
| 171 |
+
"AdaptiveAvgPool2d",
|
| 172 |
+
"AdaptiveAvgPool3d",
|
| 173 |
+
"AdaptiveLogSoftmaxWithLoss",
|
| 174 |
+
"AdaptiveMaxPool1d",
|
| 175 |
+
"AdaptiveMaxPool2d",
|
| 176 |
+
"AdaptiveMaxPool3d",
|
| 177 |
+
"AlphaDropout",
|
| 178 |
+
"AvgPool1d",
|
| 179 |
+
"AvgPool2d",
|
| 180 |
+
"AvgPool3d",
|
| 181 |
+
"BCELoss",
|
| 182 |
+
"BCEWithLogitsLoss",
|
| 183 |
+
"BatchNorm1d",
|
| 184 |
+
"BatchNorm2d",
|
| 185 |
+
"BatchNorm3d",
|
| 186 |
+
"Bilinear",
|
| 187 |
+
"CELU",
|
| 188 |
+
"CTCLoss",
|
| 189 |
+
"ChannelShuffle",
|
| 190 |
+
"CircularPad1d",
|
| 191 |
+
"CircularPad2d",
|
| 192 |
+
"CircularPad3d",
|
| 193 |
+
"ConstantPad1d",
|
| 194 |
+
"ConstantPad2d",
|
| 195 |
+
"ConstantPad3d",
|
| 196 |
+
"Container",
|
| 197 |
+
"Conv1d",
|
| 198 |
+
"Conv2d",
|
| 199 |
+
"Conv3d",
|
| 200 |
+
"ConvTranspose1d",
|
| 201 |
+
"ConvTranspose2d",
|
| 202 |
+
"ConvTranspose3d",
|
| 203 |
+
"CosineEmbeddingLoss",
|
| 204 |
+
"CosineSimilarity",
|
| 205 |
+
"CrossEntropyLoss",
|
| 206 |
+
"CrossMapLRN2d",
|
| 207 |
+
"Dropout",
|
| 208 |
+
"Dropout1d",
|
| 209 |
+
"Dropout2d",
|
| 210 |
+
"Dropout3d",
|
| 211 |
+
"ELU",
|
| 212 |
+
"Embedding",
|
| 213 |
+
"EmbeddingBag",
|
| 214 |
+
"FeatureAlphaDropout",
|
| 215 |
+
"Flatten",
|
| 216 |
+
"Fold",
|
| 217 |
+
"FractionalMaxPool2d",
|
| 218 |
+
"FractionalMaxPool3d",
|
| 219 |
+
"GELU",
|
| 220 |
+
"GLU",
|
| 221 |
+
"GRU",
|
| 222 |
+
"GRUCell",
|
| 223 |
+
"GaussianNLLLoss",
|
| 224 |
+
"GroupNorm",
|
| 225 |
+
"Hardshrink",
|
| 226 |
+
"Hardsigmoid",
|
| 227 |
+
"Hardswish",
|
| 228 |
+
"Hardtanh",
|
| 229 |
+
"HingeEmbeddingLoss",
|
| 230 |
+
"HuberLoss",
|
| 231 |
+
"Identity",
|
| 232 |
+
"InstanceNorm1d",
|
| 233 |
+
"InstanceNorm2d",
|
| 234 |
+
"InstanceNorm3d",
|
| 235 |
+
"KLDivLoss",
|
| 236 |
+
"L1Loss",
|
| 237 |
+
"LPPool1d",
|
| 238 |
+
"LPPool2d",
|
| 239 |
+
"LPPool3d",
|
| 240 |
+
"LSTM",
|
| 241 |
+
"LSTMCell",
|
| 242 |
+
"LayerNorm",
|
| 243 |
+
"LazyBatchNorm1d",
|
| 244 |
+
"LazyBatchNorm2d",
|
| 245 |
+
"LazyBatchNorm3d",
|
| 246 |
+
"LazyConv1d",
|
| 247 |
+
"LazyConv2d",
|
| 248 |
+
"LazyConv3d",
|
| 249 |
+
"LazyConvTranspose1d",
|
| 250 |
+
"LazyConvTranspose2d",
|
| 251 |
+
"LazyConvTranspose3d",
|
| 252 |
+
"LazyInstanceNorm1d",
|
| 253 |
+
"LazyInstanceNorm2d",
|
| 254 |
+
"LazyInstanceNorm3d",
|
| 255 |
+
"LazyLinear",
|
| 256 |
+
"LeakyReLU",
|
| 257 |
+
"Linear",
|
| 258 |
+
"LocalResponseNorm",
|
| 259 |
+
"LogSigmoid",
|
| 260 |
+
"LogSoftmax",
|
| 261 |
+
"MSELoss",
|
| 262 |
+
"MarginRankingLoss",
|
| 263 |
+
"MaxPool1d",
|
| 264 |
+
"MaxPool2d",
|
| 265 |
+
"MaxPool3d",
|
| 266 |
+
"MaxUnpool1d",
|
| 267 |
+
"MaxUnpool2d",
|
| 268 |
+
"MaxUnpool3d",
|
| 269 |
+
"Mish",
|
| 270 |
+
"Module",
|
| 271 |
+
"ModuleDict",
|
| 272 |
+
"ModuleList",
|
| 273 |
+
"MultiLabelMarginLoss",
|
| 274 |
+
"MultiLabelSoftMarginLoss",
|
| 275 |
+
"MultiMarginLoss",
|
| 276 |
+
"MultiheadAttention",
|
| 277 |
+
"NLLLoss",
|
| 278 |
+
"NLLLoss2d",
|
| 279 |
+
"PReLU",
|
| 280 |
+
"PairwiseDistance",
|
| 281 |
+
"ParameterDict",
|
| 282 |
+
"ParameterList",
|
| 283 |
+
"PixelShuffle",
|
| 284 |
+
"PixelUnshuffle",
|
| 285 |
+
"PoissonNLLLoss",
|
| 286 |
+
"RMSNorm",
|
| 287 |
+
"RNN",
|
| 288 |
+
"RNNBase",
|
| 289 |
+
"RNNCell",
|
| 290 |
+
"RNNCellBase",
|
| 291 |
+
"RReLU",
|
| 292 |
+
"ReLU",
|
| 293 |
+
"ReLU6",
|
| 294 |
+
"ReflectionPad1d",
|
| 295 |
+
"ReflectionPad2d",
|
| 296 |
+
"ReflectionPad3d",
|
| 297 |
+
"ReplicationPad1d",
|
| 298 |
+
"ReplicationPad2d",
|
| 299 |
+
"ReplicationPad3d",
|
| 300 |
+
"SELU",
|
| 301 |
+
"Sequential",
|
| 302 |
+
"SiLU",
|
| 303 |
+
"Sigmoid",
|
| 304 |
+
"SmoothL1Loss",
|
| 305 |
+
"SoftMarginLoss",
|
| 306 |
+
"Softmax",
|
| 307 |
+
"Softmax2d",
|
| 308 |
+
"Softmin",
|
| 309 |
+
"Softplus",
|
| 310 |
+
"Softshrink",
|
| 311 |
+
"Softsign",
|
| 312 |
+
"SyncBatchNorm",
|
| 313 |
+
"Tanh",
|
| 314 |
+
"Tanhshrink",
|
| 315 |
+
"Threshold",
|
| 316 |
+
"Transformer",
|
| 317 |
+
"TransformerDecoder",
|
| 318 |
+
"TransformerDecoderLayer",
|
| 319 |
+
"TransformerEncoder",
|
| 320 |
+
"TransformerEncoderLayer",
|
| 321 |
+
"TripletMarginLoss",
|
| 322 |
+
"TripletMarginWithDistanceLoss",
|
| 323 |
+
"Unflatten",
|
| 324 |
+
"Unfold",
|
| 325 |
+
"Upsample",
|
| 326 |
+
"UpsamplingBilinear2d",
|
| 327 |
+
"UpsamplingNearest2d",
|
| 328 |
+
"ZeroPad1d",
|
| 329 |
+
"ZeroPad2d",
|
| 330 |
+
"ZeroPad3d",
|
| 331 |
+
]
|
| 332 |
+
|
| 333 |
+
# Please keep this list sorted
|
| 334 |
+
assert __all__ == sorted(__all__)
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (5.16 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/_functions.cpython-39.pyc
ADDED
|
Binary file (6.07 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/activation.cpython-39.pyc
ADDED
|
Binary file (56.9 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-39.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-39.pyc
ADDED
|
Binary file (32.2 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-39.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/container.cpython-39.pyc
ADDED
|
Binary file (35.2 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/conv.cpython-39.pyc
ADDED
|
Binary file (61 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/distance.cpython-39.pyc
ADDED
|
Binary file (4.11 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/dropout.cpython-39.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/flatten.cpython-39.pyc
ADDED
|
Binary file (5.99 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/fold.cpython-39.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-39.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/lazy.cpython-39.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/linear.cpython-39.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/loss.cpython-39.pyc
ADDED
|
Binary file (94.7 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/module.cpython-39.pyc
ADDED
|
Binary file (95.7 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/normalization.cpython-39.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/padding.cpython-39.pyc
ADDED
|
Binary file (34.2 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-39.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/pooling.cpython-39.pyc
ADDED
|
Binary file (58.6 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/rnn.cpython-39.pyc
ADDED
|
Binary file (55.4 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/sparse.cpython-39.pyc
ADDED
|
Binary file (21.5 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/transformer.cpython-39.pyc
ADDED
|
Binary file (37.2 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-39.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (2.74 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/modules/_functions.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from torch.autograd.function import Function
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SyncBatchNorm(Function):
|
| 8 |
+
@staticmethod
|
| 9 |
+
def forward(
|
| 10 |
+
self,
|
| 11 |
+
input,
|
| 12 |
+
weight,
|
| 13 |
+
bias,
|
| 14 |
+
running_mean,
|
| 15 |
+
running_var,
|
| 16 |
+
eps,
|
| 17 |
+
momentum,
|
| 18 |
+
process_group,
|
| 19 |
+
world_size,
|
| 20 |
+
):
|
| 21 |
+
if not (
|
| 22 |
+
input.is_contiguous(memory_format=torch.channels_last)
|
| 23 |
+
or input.is_contiguous(memory_format=torch.channels_last_3d)
|
| 24 |
+
):
|
| 25 |
+
input = input.contiguous()
|
| 26 |
+
if weight is not None:
|
| 27 |
+
weight = weight.contiguous()
|
| 28 |
+
|
| 29 |
+
size = int(input.numel() // input.size(1))
|
| 30 |
+
if size == 1 and world_size < 2:
|
| 31 |
+
raise ValueError(
|
| 32 |
+
f"Expected more than 1 value per channel when training, got input size {size}"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
num_channels = input.shape[1]
|
| 36 |
+
if input.numel() > 0:
|
| 37 |
+
# calculate mean/invstd for input.
|
| 38 |
+
mean, invstd = torch.batch_norm_stats(input, eps)
|
| 39 |
+
|
| 40 |
+
count = torch.full(
|
| 41 |
+
(1,),
|
| 42 |
+
input.numel() // input.size(1),
|
| 43 |
+
dtype=mean.dtype,
|
| 44 |
+
device=mean.device,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# C, C, 1 -> (2C + 1)
|
| 48 |
+
combined = torch.cat([mean, invstd, count], dim=0)
|
| 49 |
+
else:
|
| 50 |
+
# for empty input, set stats and the count to zero. The stats with
|
| 51 |
+
# zero count will be filtered out later when computing global mean
|
| 52 |
+
# & invstd, but they still needs to participate the all_gather
|
| 53 |
+
# collective communication to unblock other peer processes.
|
| 54 |
+
combined = torch.zeros(
|
| 55 |
+
2 * num_channels + 1, dtype=input.dtype, device=input.device
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Use allgather instead of allreduce because count could be different across
|
| 59 |
+
# ranks, simple all reduce op can not give correct results.
|
| 60 |
+
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
| 61 |
+
# all gathered mean, invstd and count.
|
| 62 |
+
# for nccl backend, use the optimized version of all gather.
|
| 63 |
+
# The Gloo backend does not support `all_gather_into_tensor`.
|
| 64 |
+
if process_group._get_backend_name() != "gloo":
|
| 65 |
+
# world_size * (2C + 1)
|
| 66 |
+
combined_size = combined.numel()
|
| 67 |
+
combined_flat = torch.empty(
|
| 68 |
+
1,
|
| 69 |
+
combined_size * world_size,
|
| 70 |
+
dtype=combined.dtype,
|
| 71 |
+
device=combined.device,
|
| 72 |
+
)
|
| 73 |
+
dist.all_gather_into_tensor(
|
| 74 |
+
combined_flat, combined, process_group, async_op=False
|
| 75 |
+
)
|
| 76 |
+
combined = torch.reshape(combined_flat, (world_size, combined_size))
|
| 77 |
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
| 78 |
+
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
|
| 79 |
+
else:
|
| 80 |
+
# world_size * (2C + 1)
|
| 81 |
+
combined_list = [torch.empty_like(combined) for _ in range(world_size)]
|
| 82 |
+
dist.all_gather(combined_list, combined, process_group, async_op=False)
|
| 83 |
+
combined = torch.stack(combined_list, dim=0)
|
| 84 |
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
| 85 |
+
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
|
| 86 |
+
|
| 87 |
+
if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
|
| 88 |
+
# The lines below force a synchronization between CUDA and CPU, because
|
| 89 |
+
# the shape of the result count_all depends on the values in mask tensor.
|
| 90 |
+
# Such synchronizations break CUDA Graph capturing.
|
| 91 |
+
# See https://github.com/pytorch/pytorch/issues/78549
|
| 92 |
+
# FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
|
| 93 |
+
# a better longer-term solution.
|
| 94 |
+
|
| 95 |
+
# remove stats from empty inputs
|
| 96 |
+
mask = count_all.squeeze(-1) >= 1
|
| 97 |
+
count_all = count_all[mask]
|
| 98 |
+
mean_all = mean_all[mask]
|
| 99 |
+
invstd_all = invstd_all[mask]
|
| 100 |
+
|
| 101 |
+
# calculate global mean & invstd
|
| 102 |
+
counts = count_all.view(-1)
|
| 103 |
+
if running_mean is not None and counts.dtype != running_mean.dtype:
|
| 104 |
+
counts = counts.to(running_mean.dtype)
|
| 105 |
+
mean, invstd = torch.batch_norm_gather_stats_with_counts(
|
| 106 |
+
input,
|
| 107 |
+
mean_all,
|
| 108 |
+
invstd_all,
|
| 109 |
+
running_mean,
|
| 110 |
+
running_var,
|
| 111 |
+
momentum,
|
| 112 |
+
eps,
|
| 113 |
+
counts,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
|
| 117 |
+
self.process_group = process_group
|
| 118 |
+
|
| 119 |
+
# apply element-wise normalization
|
| 120 |
+
if input.numel() > 0:
|
| 121 |
+
return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
| 122 |
+
else:
|
| 123 |
+
return torch.empty_like(input)
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def backward(self, grad_output):
|
| 127 |
+
if not (
|
| 128 |
+
grad_output.is_contiguous(memory_format=torch.channels_last)
|
| 129 |
+
or grad_output.is_contiguous(memory_format=torch.channels_last_3d)
|
| 130 |
+
):
|
| 131 |
+
grad_output = grad_output.contiguous()
|
| 132 |
+
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
|
| 133 |
+
grad_input = grad_weight = grad_bias = None
|
| 134 |
+
process_group = self.process_group
|
| 135 |
+
|
| 136 |
+
if saved_input.numel() > 0:
|
| 137 |
+
# calculate local stats as well as grad_weight / grad_bias
|
| 138 |
+
(
|
| 139 |
+
sum_dy,
|
| 140 |
+
sum_dy_xmu,
|
| 141 |
+
grad_weight,
|
| 142 |
+
grad_bias,
|
| 143 |
+
) = torch.batch_norm_backward_reduce(
|
| 144 |
+
grad_output,
|
| 145 |
+
saved_input,
|
| 146 |
+
mean,
|
| 147 |
+
invstd,
|
| 148 |
+
weight,
|
| 149 |
+
self.needs_input_grad[0],
|
| 150 |
+
self.needs_input_grad[1],
|
| 151 |
+
self.needs_input_grad[2],
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if self.needs_input_grad[0]:
|
| 155 |
+
# synchronizing stats used to calculate input gradient.
|
| 156 |
+
num_channels = sum_dy.shape[0]
|
| 157 |
+
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
|
| 158 |
+
torch.distributed.all_reduce(
|
| 159 |
+
combined,
|
| 160 |
+
torch.distributed.ReduceOp.SUM,
|
| 161 |
+
process_group,
|
| 162 |
+
async_op=False,
|
| 163 |
+
)
|
| 164 |
+
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
|
| 165 |
+
|
| 166 |
+
# backward pass for gradient calculation
|
| 167 |
+
if weight is not None and weight.dtype != mean.dtype:
|
| 168 |
+
weight = weight.to(mean.dtype)
|
| 169 |
+
grad_input = torch.batch_norm_backward_elemt(
|
| 170 |
+
grad_output,
|
| 171 |
+
saved_input,
|
| 172 |
+
mean,
|
| 173 |
+
invstd,
|
| 174 |
+
weight,
|
| 175 |
+
sum_dy,
|
| 176 |
+
sum_dy_xmu,
|
| 177 |
+
count_tensor,
|
| 178 |
+
)
|
| 179 |
+
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
| 180 |
+
# training would handle all reduce.
|
| 181 |
+
if weight is None or not self.needs_input_grad[1]:
|
| 182 |
+
grad_weight = None
|
| 183 |
+
|
| 184 |
+
if weight is None or not self.needs_input_grad[2]:
|
| 185 |
+
grad_bias = None
|
| 186 |
+
else:
|
| 187 |
+
# This process got an empty input tensor in the forward pass.
|
| 188 |
+
# Although this process can directly set grad_input as an empty
|
| 189 |
+
# tensor of zeros, it still needs to participate in the collective
|
| 190 |
+
# communication to unblock its peers, as other peer processes might
|
| 191 |
+
# have received non-empty inputs.
|
| 192 |
+
num_channels = saved_input.shape[1]
|
| 193 |
+
if self.needs_input_grad[0]:
|
| 194 |
+
# launch all_reduce to unblock other peer processes
|
| 195 |
+
combined = torch.zeros(
|
| 196 |
+
2 * num_channels, dtype=saved_input.dtype, device=saved_input.device
|
| 197 |
+
)
|
| 198 |
+
torch.distributed.all_reduce(
|
| 199 |
+
combined,
|
| 200 |
+
torch.distributed.ReduceOp.SUM,
|
| 201 |
+
process_group,
|
| 202 |
+
async_op=False,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Leave grad_input, grad_weight and grad_bias as None, which will be
|
| 206 |
+
# interpreted by the autograd engine as Tensors full of zeros.
|
| 207 |
+
|
| 208 |
+
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class CrossMapLRN2d(Function):
|
| 212 |
+
@staticmethod
|
| 213 |
+
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
|
| 214 |
+
ctx.size = size
|
| 215 |
+
ctx.alpha = alpha
|
| 216 |
+
ctx.beta = beta
|
| 217 |
+
ctx.k = k
|
| 218 |
+
ctx.scale = None
|
| 219 |
+
|
| 220 |
+
if input.dim() != 4:
|
| 221 |
+
raise ValueError(
|
| 222 |
+
f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead."
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
ctx.scale = ctx.scale or input.new()
|
| 226 |
+
output = input.new()
|
| 227 |
+
|
| 228 |
+
batch_size = input.size(0)
|
| 229 |
+
channels = input.size(1)
|
| 230 |
+
input_height = input.size(2)
|
| 231 |
+
input_width = input.size(3)
|
| 232 |
+
|
| 233 |
+
output.resize_as_(input)
|
| 234 |
+
ctx.scale.resize_as_(input)
|
| 235 |
+
|
| 236 |
+
# use output storage as temporary buffer
|
| 237 |
+
input_square = output
|
| 238 |
+
torch.pow(input, 2, out=input_square)
|
| 239 |
+
|
| 240 |
+
pre_pad = int((ctx.size - 1) / 2 + 1)
|
| 241 |
+
pre_pad_crop = min(pre_pad, channels)
|
| 242 |
+
|
| 243 |
+
scale_first = ctx.scale.select(1, 0)
|
| 244 |
+
scale_first.zero_()
|
| 245 |
+
# compute first feature map normalization
|
| 246 |
+
for c in range(pre_pad_crop):
|
| 247 |
+
scale_first.add_(input_square.select(1, c))
|
| 248 |
+
|
| 249 |
+
# reuse computations for next feature maps normalization
|
| 250 |
+
# by adding the next feature map and removing the previous
|
| 251 |
+
for c in range(1, channels):
|
| 252 |
+
scale_previous = ctx.scale.select(1, c - 1)
|
| 253 |
+
scale_current = ctx.scale.select(1, c)
|
| 254 |
+
scale_current.copy_(scale_previous)
|
| 255 |
+
if c < channels - pre_pad + 1:
|
| 256 |
+
square_next = input_square.select(1, c + pre_pad - 1)
|
| 257 |
+
scale_current.add_(square_next, alpha=1)
|
| 258 |
+
|
| 259 |
+
if c > pre_pad:
|
| 260 |
+
square_previous = input_square.select(1, c - pre_pad)
|
| 261 |
+
scale_current.add_(square_previous, alpha=-1)
|
| 262 |
+
|
| 263 |
+
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
|
| 264 |
+
|
| 265 |
+
torch.pow(ctx.scale, -ctx.beta, out=output)
|
| 266 |
+
output.mul_(input)
|
| 267 |
+
|
| 268 |
+
ctx.save_for_backward(input, output)
|
| 269 |
+
return output
|
| 270 |
+
|
| 271 |
+
@staticmethod
|
| 272 |
+
def backward(ctx, grad_output):
|
| 273 |
+
input, output = ctx.saved_tensors
|
| 274 |
+
grad_input = grad_output.new()
|
| 275 |
+
|
| 276 |
+
batch_size = input.size(0)
|
| 277 |
+
channels = input.size(1)
|
| 278 |
+
input_height = input.size(2)
|
| 279 |
+
input_width = input.size(3)
|
| 280 |
+
|
| 281 |
+
paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width)
|
| 282 |
+
accum_ratio = input.new(input_height, input_width)
|
| 283 |
+
|
| 284 |
+
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
|
| 285 |
+
inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
|
| 286 |
+
|
| 287 |
+
grad_input.resize_as_(input)
|
| 288 |
+
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
|
| 289 |
+
|
| 290 |
+
paddded_ratio.zero_()
|
| 291 |
+
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels)
|
| 292 |
+
for n in range(batch_size):
|
| 293 |
+
torch.mul(grad_output[n], output[n], out=padded_ratio_center)
|
| 294 |
+
padded_ratio_center.div_(ctx.scale[n])
|
| 295 |
+
torch.sum(
|
| 296 |
+
paddded_ratio.narrow(0, 0, ctx.size - 1),
|
| 297 |
+
0,
|
| 298 |
+
keepdim=False,
|
| 299 |
+
out=accum_ratio,
|
| 300 |
+
)
|
| 301 |
+
for c in range(channels):
|
| 302 |
+
accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
|
| 303 |
+
grad_input[n][c].addcmul_(
|
| 304 |
+
input[n][c], accum_ratio, value=-cache_ratio_value
|
| 305 |
+
)
|
| 306 |
+
accum_ratio.add_(paddded_ratio[c], alpha=-1)
|
| 307 |
+
|
| 308 |
+
return grad_input, None, None, None, None
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class BackwardHookFunction(torch.autograd.Function):
|
| 312 |
+
@staticmethod
|
| 313 |
+
def forward(ctx, *args):
|
| 314 |
+
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
|
| 315 |
+
return args
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def backward(ctx, *args):
|
| 319 |
+
return args
|
.venv/Lib/site-packages/torch/nn/modules/activation.py
ADDED
|
@@ -0,0 +1,1746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import warnings
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
| 9 |
+
from torch.nn.parameter import Parameter
|
| 10 |
+
|
| 11 |
+
from .linear import NonDynamicallyQuantizableLinear
|
| 12 |
+
from .module import Module
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"Threshold",
|
| 17 |
+
"ReLU",
|
| 18 |
+
"RReLU",
|
| 19 |
+
"Hardtanh",
|
| 20 |
+
"ReLU6",
|
| 21 |
+
"Sigmoid",
|
| 22 |
+
"Hardsigmoid",
|
| 23 |
+
"Tanh",
|
| 24 |
+
"SiLU",
|
| 25 |
+
"Mish",
|
| 26 |
+
"Hardswish",
|
| 27 |
+
"ELU",
|
| 28 |
+
"CELU",
|
| 29 |
+
"SELU",
|
| 30 |
+
"GLU",
|
| 31 |
+
"GELU",
|
| 32 |
+
"Hardshrink",
|
| 33 |
+
"LeakyReLU",
|
| 34 |
+
"LogSigmoid",
|
| 35 |
+
"Softplus",
|
| 36 |
+
"Softshrink",
|
| 37 |
+
"MultiheadAttention",
|
| 38 |
+
"PReLU",
|
| 39 |
+
"Softsign",
|
| 40 |
+
"Tanhshrink",
|
| 41 |
+
"Softmin",
|
| 42 |
+
"Softmax",
|
| 43 |
+
"Softmax2d",
|
| 44 |
+
"LogSoftmax",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Threshold(Module):
|
| 49 |
+
r"""Thresholds each element of the input Tensor.
|
| 50 |
+
|
| 51 |
+
Threshold is defined as:
|
| 52 |
+
|
| 53 |
+
.. math::
|
| 54 |
+
y =
|
| 55 |
+
\begin{cases}
|
| 56 |
+
x, &\text{ if } x > \text{threshold} \\
|
| 57 |
+
\text{value}, &\text{ otherwise }
|
| 58 |
+
\end{cases}
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
threshold: The value to threshold at
|
| 62 |
+
value: The value to replace with
|
| 63 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 64 |
+
|
| 65 |
+
Shape:
|
| 66 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 67 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 68 |
+
|
| 69 |
+
Examples::
|
| 70 |
+
|
| 71 |
+
>>> m = nn.Threshold(0.1, 20)
|
| 72 |
+
>>> input = torch.randn(2)
|
| 73 |
+
>>> output = m(input)
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
__constants__ = ["threshold", "value", "inplace"]
|
| 77 |
+
|
| 78 |
+
threshold: float
|
| 79 |
+
value: float
|
| 80 |
+
inplace: bool
|
| 81 |
+
|
| 82 |
+
def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.threshold = threshold
|
| 85 |
+
self.value = value
|
| 86 |
+
self.inplace = inplace
|
| 87 |
+
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
|
| 88 |
+
|
| 89 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 90 |
+
return F.threshold(input, self.threshold, self.value, self.inplace)
|
| 91 |
+
|
| 92 |
+
def extra_repr(self):
|
| 93 |
+
inplace_str = ", inplace=True" if self.inplace else ""
|
| 94 |
+
return f"threshold={self.threshold}, value={self.value}{inplace_str}"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ReLU(Module):
|
| 98 |
+
r"""Applies the rectified linear unit function element-wise.
|
| 99 |
+
|
| 100 |
+
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 104 |
+
|
| 105 |
+
Shape:
|
| 106 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 107 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 108 |
+
|
| 109 |
+
.. image:: ../scripts/activation_images/ReLU.png
|
| 110 |
+
|
| 111 |
+
Examples::
|
| 112 |
+
|
| 113 |
+
>>> m = nn.ReLU()
|
| 114 |
+
>>> input = torch.randn(2)
|
| 115 |
+
>>> output = m(input)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
| 119 |
+
|
| 120 |
+
>>> m = nn.ReLU()
|
| 121 |
+
>>> input = torch.randn(2).unsqueeze(0)
|
| 122 |
+
>>> output = torch.cat((m(input), m(-input)))
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
__constants__ = ["inplace"]
|
| 126 |
+
inplace: bool
|
| 127 |
+
|
| 128 |
+
def __init__(self, inplace: bool = False):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.inplace = inplace
|
| 131 |
+
|
| 132 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 133 |
+
return F.relu(input, inplace=self.inplace)
|
| 134 |
+
|
| 135 |
+
def extra_repr(self) -> str:
|
| 136 |
+
inplace_str = "inplace=True" if self.inplace else ""
|
| 137 |
+
return inplace_str
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class RReLU(Module):
|
| 141 |
+
r"""Applies the randomized leaky rectified linear unit function, element-wise.
|
| 142 |
+
|
| 143 |
+
Method described in the paper:
|
| 144 |
+
`Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_.
|
| 145 |
+
|
| 146 |
+
The function is defined as:
|
| 147 |
+
|
| 148 |
+
.. math::
|
| 149 |
+
\text{RReLU}(x) =
|
| 150 |
+
\begin{cases}
|
| 151 |
+
x & \text{if } x \geq 0 \\
|
| 152 |
+
ax & \text{ otherwise }
|
| 153 |
+
\end{cases}
|
| 154 |
+
|
| 155 |
+
where :math:`a` is randomly sampled from uniform distribution
|
| 156 |
+
:math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during
|
| 157 |
+
evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
| 161 |
+
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
| 162 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 163 |
+
|
| 164 |
+
Shape:
|
| 165 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 166 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 167 |
+
|
| 168 |
+
.. image:: ../scripts/activation_images/RReLU.png
|
| 169 |
+
|
| 170 |
+
Examples::
|
| 171 |
+
|
| 172 |
+
>>> m = nn.RReLU(0.1, 0.3)
|
| 173 |
+
>>> input = torch.randn(2)
|
| 174 |
+
>>> output = m(input)
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
__constants__ = ["lower", "upper", "inplace"]
|
| 179 |
+
|
| 180 |
+
lower: float
|
| 181 |
+
upper: float
|
| 182 |
+
inplace: bool
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.lower = lower
|
| 189 |
+
self.upper = upper
|
| 190 |
+
self.inplace = inplace
|
| 191 |
+
|
| 192 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 193 |
+
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
|
| 194 |
+
|
| 195 |
+
def extra_repr(self):
|
| 196 |
+
inplace_str = ", inplace=True" if self.inplace else ""
|
| 197 |
+
return f"lower={self.lower}, upper={self.upper}{inplace_str}"
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Hardtanh(Module):
|
| 201 |
+
r"""Applies the HardTanh function element-wise.
|
| 202 |
+
|
| 203 |
+
HardTanh is defined as:
|
| 204 |
+
|
| 205 |
+
.. math::
|
| 206 |
+
\text{HardTanh}(x) = \begin{cases}
|
| 207 |
+
\text{max\_val} & \text{ if } x > \text{ max\_val } \\
|
| 208 |
+
\text{min\_val} & \text{ if } x < \text{ min\_val } \\
|
| 209 |
+
x & \text{ otherwise } \\
|
| 210 |
+
\end{cases}
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
min_val: minimum value of the linear region range. Default: -1
|
| 214 |
+
max_val: maximum value of the linear region range. Default: 1
|
| 215 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 216 |
+
|
| 217 |
+
Keyword arguments :attr:`min_value` and :attr:`max_value`
|
| 218 |
+
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
|
| 219 |
+
|
| 220 |
+
Shape:
|
| 221 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 222 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 223 |
+
|
| 224 |
+
.. image:: ../scripts/activation_images/Hardtanh.png
|
| 225 |
+
|
| 226 |
+
Examples::
|
| 227 |
+
|
| 228 |
+
>>> m = nn.Hardtanh(-2, 2)
|
| 229 |
+
>>> input = torch.randn(2)
|
| 230 |
+
>>> output = m(input)
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
__constants__ = ["min_val", "max_val", "inplace"]
|
| 234 |
+
|
| 235 |
+
min_val: float
|
| 236 |
+
max_val: float
|
| 237 |
+
inplace: bool
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
min_val: float = -1.0,
|
| 242 |
+
max_val: float = 1.0,
|
| 243 |
+
inplace: bool = False,
|
| 244 |
+
min_value: Optional[float] = None,
|
| 245 |
+
max_value: Optional[float] = None,
|
| 246 |
+
) -> None:
|
| 247 |
+
super().__init__()
|
| 248 |
+
if min_value is not None:
|
| 249 |
+
warnings.warn(
|
| 250 |
+
"keyword argument `min_value` is deprecated and rename to `min_val`",
|
| 251 |
+
FutureWarning,
|
| 252 |
+
stacklevel=2,
|
| 253 |
+
)
|
| 254 |
+
min_val = min_value
|
| 255 |
+
if max_value is not None:
|
| 256 |
+
warnings.warn(
|
| 257 |
+
"keyword argument `max_value` is deprecated and rename to `max_val`",
|
| 258 |
+
FutureWarning,
|
| 259 |
+
stacklevel=2,
|
| 260 |
+
)
|
| 261 |
+
max_val = max_value
|
| 262 |
+
|
| 263 |
+
self.min_val = min_val
|
| 264 |
+
self.max_val = max_val
|
| 265 |
+
self.inplace = inplace
|
| 266 |
+
assert self.max_val > self.min_val
|
| 267 |
+
|
| 268 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 269 |
+
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
|
| 270 |
+
|
| 271 |
+
def extra_repr(self) -> str:
|
| 272 |
+
inplace_str = ", inplace=True" if self.inplace else ""
|
| 273 |
+
return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}"
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class ReLU6(Hardtanh):
|
| 277 |
+
r"""Applies the ReLU6 function element-wise.
|
| 278 |
+
|
| 279 |
+
.. math::
|
| 280 |
+
\text{ReLU6}(x) = \min(\max(0,x), 6)
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 284 |
+
|
| 285 |
+
Shape:
|
| 286 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 287 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 288 |
+
|
| 289 |
+
.. image:: ../scripts/activation_images/ReLU6.png
|
| 290 |
+
|
| 291 |
+
Examples::
|
| 292 |
+
|
| 293 |
+
>>> m = nn.ReLU6()
|
| 294 |
+
>>> input = torch.randn(2)
|
| 295 |
+
>>> output = m(input)
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self, inplace: bool = False):
|
| 299 |
+
super().__init__(0.0, 6.0, inplace)
|
| 300 |
+
|
| 301 |
+
def extra_repr(self) -> str:
|
| 302 |
+
inplace_str = "inplace=True" if self.inplace else ""
|
| 303 |
+
return inplace_str
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class Sigmoid(Module):
|
| 307 |
+
r"""Applies the Sigmoid function element-wise.
|
| 308 |
+
|
| 309 |
+
.. math::
|
| 310 |
+
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
Shape:
|
| 314 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 315 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 316 |
+
|
| 317 |
+
.. image:: ../scripts/activation_images/Sigmoid.png
|
| 318 |
+
|
| 319 |
+
Examples::
|
| 320 |
+
|
| 321 |
+
>>> m = nn.Sigmoid()
|
| 322 |
+
>>> input = torch.randn(2)
|
| 323 |
+
>>> output = m(input)
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 327 |
+
return torch.sigmoid(input)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class Hardsigmoid(Module):
|
| 331 |
+
r"""Applies the Hardsigmoid function element-wise.
|
| 332 |
+
|
| 333 |
+
Hardsigmoid is defined as:
|
| 334 |
+
|
| 335 |
+
.. math::
|
| 336 |
+
\text{Hardsigmoid}(x) = \begin{cases}
|
| 337 |
+
0 & \text{if~} x \le -3, \\
|
| 338 |
+
1 & \text{if~} x \ge +3, \\
|
| 339 |
+
x / 6 + 1 / 2 & \text{otherwise}
|
| 340 |
+
\end{cases}
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 344 |
+
|
| 345 |
+
Shape:
|
| 346 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 347 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 348 |
+
|
| 349 |
+
.. image:: ../scripts/activation_images/Hardsigmoid.png
|
| 350 |
+
|
| 351 |
+
Examples::
|
| 352 |
+
|
| 353 |
+
>>> m = nn.Hardsigmoid()
|
| 354 |
+
>>> input = torch.randn(2)
|
| 355 |
+
>>> output = m(input)
|
| 356 |
+
"""
|
| 357 |
+
|
| 358 |
+
__constants__ = ["inplace"]
|
| 359 |
+
|
| 360 |
+
inplace: bool
|
| 361 |
+
|
| 362 |
+
def __init__(self, inplace: bool = False) -> None:
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.inplace = inplace
|
| 365 |
+
|
| 366 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 367 |
+
return F.hardsigmoid(input, self.inplace)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class Tanh(Module):
|
| 371 |
+
r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
|
| 372 |
+
|
| 373 |
+
Tanh is defined as:
|
| 374 |
+
|
| 375 |
+
.. math::
|
| 376 |
+
\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
|
| 377 |
+
|
| 378 |
+
Shape:
|
| 379 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 380 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 381 |
+
|
| 382 |
+
.. image:: ../scripts/activation_images/Tanh.png
|
| 383 |
+
|
| 384 |
+
Examples::
|
| 385 |
+
|
| 386 |
+
>>> m = nn.Tanh()
|
| 387 |
+
>>> input = torch.randn(2)
|
| 388 |
+
>>> output = m(input)
|
| 389 |
+
"""
|
| 390 |
+
|
| 391 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 392 |
+
return torch.tanh(input)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class SiLU(Module):
|
| 396 |
+
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
|
| 397 |
+
|
| 398 |
+
The SiLU function is also known as the swish function.
|
| 399 |
+
|
| 400 |
+
.. math::
|
| 401 |
+
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
|
| 402 |
+
|
| 403 |
+
.. note::
|
| 404 |
+
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
|
| 405 |
+
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
|
| 406 |
+
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
|
| 407 |
+
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
|
| 408 |
+
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
|
| 409 |
+
where the SiLU was experimented with later.
|
| 410 |
+
|
| 411 |
+
Shape:
|
| 412 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 413 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 414 |
+
|
| 415 |
+
.. image:: ../scripts/activation_images/SiLU.png
|
| 416 |
+
|
| 417 |
+
Examples::
|
| 418 |
+
|
| 419 |
+
>>> m = nn.SiLU()
|
| 420 |
+
>>> input = torch.randn(2)
|
| 421 |
+
>>> output = m(input)
|
| 422 |
+
"""
|
| 423 |
+
|
| 424 |
+
__constants__ = ["inplace"]
|
| 425 |
+
inplace: bool
|
| 426 |
+
|
| 427 |
+
def __init__(self, inplace: bool = False):
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.inplace = inplace
|
| 430 |
+
|
| 431 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 432 |
+
return F.silu(input, inplace=self.inplace)
|
| 433 |
+
|
| 434 |
+
def extra_repr(self) -> str:
|
| 435 |
+
inplace_str = "inplace=True" if self.inplace else ""
|
| 436 |
+
return inplace_str
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class Mish(Module):
|
| 440 |
+
r"""Applies the Mish function, element-wise.
|
| 441 |
+
|
| 442 |
+
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
| 443 |
+
|
| 444 |
+
.. math::
|
| 445 |
+
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
| 446 |
+
|
| 447 |
+
.. note::
|
| 448 |
+
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
|
| 449 |
+
|
| 450 |
+
Shape:
|
| 451 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 452 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 453 |
+
|
| 454 |
+
.. image:: ../scripts/activation_images/Mish.png
|
| 455 |
+
|
| 456 |
+
Examples::
|
| 457 |
+
|
| 458 |
+
>>> m = nn.Mish()
|
| 459 |
+
>>> input = torch.randn(2)
|
| 460 |
+
>>> output = m(input)
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
__constants__ = ["inplace"]
|
| 464 |
+
inplace: bool
|
| 465 |
+
|
| 466 |
+
def __init__(self, inplace: bool = False):
|
| 467 |
+
super().__init__()
|
| 468 |
+
self.inplace = inplace
|
| 469 |
+
|
| 470 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 471 |
+
return F.mish(input, inplace=self.inplace)
|
| 472 |
+
|
| 473 |
+
def extra_repr(self) -> str:
|
| 474 |
+
inplace_str = "inplace=True" if self.inplace else ""
|
| 475 |
+
return inplace_str
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class Hardswish(Module):
|
| 479 |
+
r"""Applies the Hardswish function, element-wise.
|
| 480 |
+
|
| 481 |
+
Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
|
| 482 |
+
|
| 483 |
+
Hardswish is defined as:
|
| 484 |
+
|
| 485 |
+
.. math::
|
| 486 |
+
\text{Hardswish}(x) = \begin{cases}
|
| 487 |
+
0 & \text{if~} x \le -3, \\
|
| 488 |
+
x & \text{if~} x \ge +3, \\
|
| 489 |
+
x \cdot (x + 3) /6 & \text{otherwise}
|
| 490 |
+
\end{cases}
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 494 |
+
|
| 495 |
+
Shape:
|
| 496 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 497 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 498 |
+
|
| 499 |
+
.. image:: ../scripts/activation_images/Hardswish.png
|
| 500 |
+
|
| 501 |
+
Examples::
|
| 502 |
+
|
| 503 |
+
>>> m = nn.Hardswish()
|
| 504 |
+
>>> input = torch.randn(2)
|
| 505 |
+
>>> output = m(input)
|
| 506 |
+
"""
|
| 507 |
+
|
| 508 |
+
__constants__ = ["inplace"]
|
| 509 |
+
|
| 510 |
+
inplace: bool
|
| 511 |
+
|
| 512 |
+
def __init__(self, inplace: bool = False) -> None:
|
| 513 |
+
super().__init__()
|
| 514 |
+
self.inplace = inplace
|
| 515 |
+
|
| 516 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 517 |
+
return F.hardswish(input, self.inplace)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class ELU(Module):
|
| 521 |
+
r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
|
| 522 |
+
|
| 523 |
+
Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
|
| 524 |
+
Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
|
| 525 |
+
|
| 526 |
+
ELU is defined as:
|
| 527 |
+
|
| 528 |
+
.. math::
|
| 529 |
+
\text{ELU}(x) = \begin{cases}
|
| 530 |
+
x, & \text{ if } x > 0\\
|
| 531 |
+
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
|
| 532 |
+
\end{cases}
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
| 536 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 537 |
+
|
| 538 |
+
Shape:
|
| 539 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 540 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 541 |
+
|
| 542 |
+
.. image:: ../scripts/activation_images/ELU.png
|
| 543 |
+
|
| 544 |
+
Examples::
|
| 545 |
+
|
| 546 |
+
>>> m = nn.ELU()
|
| 547 |
+
>>> input = torch.randn(2)
|
| 548 |
+
>>> output = m(input)
|
| 549 |
+
"""
|
| 550 |
+
|
| 551 |
+
__constants__ = ["alpha", "inplace"]
|
| 552 |
+
alpha: float
|
| 553 |
+
inplace: bool
|
| 554 |
+
|
| 555 |
+
def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None:
|
| 556 |
+
super().__init__()
|
| 557 |
+
self.alpha = alpha
|
| 558 |
+
self.inplace = inplace
|
| 559 |
+
|
| 560 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 561 |
+
return F.elu(input, self.alpha, self.inplace)
|
| 562 |
+
|
| 563 |
+
def extra_repr(self) -> str:
|
| 564 |
+
inplace_str = ", inplace=True" if self.inplace else ""
|
| 565 |
+
return f"alpha={self.alpha}{inplace_str}"
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class CELU(Module):
|
| 569 |
+
r"""Applies the CELU function element-wise.
|
| 570 |
+
|
| 571 |
+
.. math::
|
| 572 |
+
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
| 573 |
+
|
| 574 |
+
More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
| 578 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 579 |
+
|
| 580 |
+
Shape:
|
| 581 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 582 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 583 |
+
|
| 584 |
+
.. image:: ../scripts/activation_images/CELU.png
|
| 585 |
+
|
| 586 |
+
Examples::
|
| 587 |
+
|
| 588 |
+
>>> m = nn.CELU()
|
| 589 |
+
>>> input = torch.randn(2)
|
| 590 |
+
>>> output = m(input)
|
| 591 |
+
|
| 592 |
+
.. _`Continuously Differentiable Exponential Linear Units`:
|
| 593 |
+
https://arxiv.org/abs/1704.07483
|
| 594 |
+
"""
|
| 595 |
+
|
| 596 |
+
__constants__ = ["alpha", "inplace"]
|
| 597 |
+
alpha: float
|
| 598 |
+
inplace: bool
|
| 599 |
+
|
| 600 |
+
def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None:
|
| 601 |
+
super().__init__()
|
| 602 |
+
self.alpha = alpha
|
| 603 |
+
self.inplace = inplace
|
| 604 |
+
|
| 605 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 606 |
+
return F.celu(input, self.alpha, self.inplace)
|
| 607 |
+
|
| 608 |
+
def extra_repr(self) -> str:
|
| 609 |
+
inplace_str = ", inplace=True" if self.inplace else ""
|
| 610 |
+
return f"alpha={self.alpha}{inplace_str}"
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class SELU(Module):
|
| 614 |
+
r"""Applies the SELU function element-wise.
|
| 615 |
+
|
| 616 |
+
.. math::
|
| 617 |
+
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
|
| 618 |
+
|
| 619 |
+
with :math:`\alpha = 1.6732632423543772848170429916717` and
|
| 620 |
+
:math:`\text{scale} = 1.0507009873554804934193349852946`.
|
| 621 |
+
|
| 622 |
+
.. warning::
|
| 623 |
+
When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
|
| 624 |
+
``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
|
| 625 |
+
in order to get `Self-Normalizing Neural Networks`_.
|
| 626 |
+
See :func:`torch.nn.init.calculate_gain` for more information.
|
| 627 |
+
|
| 628 |
+
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
|
| 632 |
+
|
| 633 |
+
Shape:
|
| 634 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 635 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 636 |
+
|
| 637 |
+
.. image:: ../scripts/activation_images/SELU.png
|
| 638 |
+
|
| 639 |
+
Examples::
|
| 640 |
+
|
| 641 |
+
>>> m = nn.SELU()
|
| 642 |
+
>>> input = torch.randn(2)
|
| 643 |
+
>>> output = m(input)
|
| 644 |
+
|
| 645 |
+
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
| 646 |
+
"""
|
| 647 |
+
|
| 648 |
+
__constants__ = ["inplace"]
|
| 649 |
+
inplace: bool
|
| 650 |
+
|
| 651 |
+
def __init__(self, inplace: bool = False) -> None:
|
| 652 |
+
super().__init__()
|
| 653 |
+
self.inplace = inplace
|
| 654 |
+
|
| 655 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 656 |
+
return F.selu(input, self.inplace)
|
| 657 |
+
|
| 658 |
+
def extra_repr(self) -> str:
|
| 659 |
+
inplace_str = "inplace=True" if self.inplace else ""
|
| 660 |
+
return inplace_str
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
class GLU(Module):
|
| 664 |
+
r"""Applies the gated linear unit function.
|
| 665 |
+
|
| 666 |
+
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
| 667 |
+
of the input matrices and :math:`b` is the second half.
|
| 668 |
+
|
| 669 |
+
Args:
|
| 670 |
+
dim (int): the dimension on which to split the input. Default: -1
|
| 671 |
+
|
| 672 |
+
Shape:
|
| 673 |
+
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
| 674 |
+
dimensions
|
| 675 |
+
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
| 676 |
+
|
| 677 |
+
Examples::
|
| 678 |
+
|
| 679 |
+
>>> m = nn.GLU()
|
| 680 |
+
>>> input = torch.randn(4, 2)
|
| 681 |
+
>>> output = m(input)
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
__constants__ = ["dim"]
|
| 685 |
+
dim: int
|
| 686 |
+
|
| 687 |
+
def __init__(self, dim: int = -1) -> None:
|
| 688 |
+
super().__init__()
|
| 689 |
+
self.dim = dim
|
| 690 |
+
|
| 691 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 692 |
+
return F.glu(input, self.dim)
|
| 693 |
+
|
| 694 |
+
def extra_repr(self) -> str:
|
| 695 |
+
return f"dim={self.dim}"
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
class GELU(Module):
|
| 699 |
+
r"""Applies the Gaussian Error Linear Units function.
|
| 700 |
+
|
| 701 |
+
.. math:: \text{GELU}(x) = x * \Phi(x)
|
| 702 |
+
|
| 703 |
+
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
| 704 |
+
|
| 705 |
+
When the approximate argument is 'tanh', Gelu is estimated with:
|
| 706 |
+
|
| 707 |
+
.. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
|
| 708 |
+
|
| 709 |
+
Args:
|
| 710 |
+
approximate (str, optional): the gelu approximation algorithm to use:
|
| 711 |
+
``'none'`` | ``'tanh'``. Default: ``'none'``
|
| 712 |
+
|
| 713 |
+
Shape:
|
| 714 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 715 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 716 |
+
|
| 717 |
+
.. image:: ../scripts/activation_images/GELU.png
|
| 718 |
+
|
| 719 |
+
Examples::
|
| 720 |
+
|
| 721 |
+
>>> m = nn.GELU()
|
| 722 |
+
>>> input = torch.randn(2)
|
| 723 |
+
>>> output = m(input)
|
| 724 |
+
"""
|
| 725 |
+
|
| 726 |
+
__constants__ = ["approximate"]
|
| 727 |
+
approximate: str
|
| 728 |
+
|
| 729 |
+
def __init__(self, approximate: str = "none") -> None:
|
| 730 |
+
super().__init__()
|
| 731 |
+
self.approximate = approximate
|
| 732 |
+
|
| 733 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 734 |
+
return F.gelu(input, approximate=self.approximate)
|
| 735 |
+
|
| 736 |
+
def extra_repr(self) -> str:
|
| 737 |
+
return f"approximate={repr(self.approximate)}"
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
class Hardshrink(Module):
|
| 741 |
+
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
|
| 742 |
+
|
| 743 |
+
Hardshrink is defined as:
|
| 744 |
+
|
| 745 |
+
.. math::
|
| 746 |
+
\text{HardShrink}(x) =
|
| 747 |
+
\begin{cases}
|
| 748 |
+
x, & \text{ if } x > \lambda \\
|
| 749 |
+
x, & \text{ if } x < -\lambda \\
|
| 750 |
+
0, & \text{ otherwise }
|
| 751 |
+
\end{cases}
|
| 752 |
+
|
| 753 |
+
Args:
|
| 754 |
+
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
| 755 |
+
|
| 756 |
+
Shape:
|
| 757 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 758 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 759 |
+
|
| 760 |
+
.. image:: ../scripts/activation_images/Hardshrink.png
|
| 761 |
+
|
| 762 |
+
Examples::
|
| 763 |
+
|
| 764 |
+
>>> m = nn.Hardshrink()
|
| 765 |
+
>>> input = torch.randn(2)
|
| 766 |
+
>>> output = m(input)
|
| 767 |
+
"""
|
| 768 |
+
|
| 769 |
+
__constants__ = ["lambd"]
|
| 770 |
+
lambd: float
|
| 771 |
+
|
| 772 |
+
def __init__(self, lambd: float = 0.5) -> None:
|
| 773 |
+
super().__init__()
|
| 774 |
+
self.lambd = lambd
|
| 775 |
+
|
| 776 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 777 |
+
return F.hardshrink(input, self.lambd)
|
| 778 |
+
|
| 779 |
+
def extra_repr(self) -> str:
|
| 780 |
+
return f"{self.lambd}"
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
class LeakyReLU(Module):
|
| 784 |
+
r"""Applies the LeakyReLU function element-wise.
|
| 785 |
+
|
| 786 |
+
.. math::
|
| 787 |
+
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
or
|
| 791 |
+
|
| 792 |
+
.. math::
|
| 793 |
+
\text{LeakyReLU}(x) =
|
| 794 |
+
\begin{cases}
|
| 795 |
+
x, & \text{ if } x \geq 0 \\
|
| 796 |
+
\text{negative\_slope} \times x, & \text{ otherwise }
|
| 797 |
+
\end{cases}
|
| 798 |
+
|
| 799 |
+
Args:
|
| 800 |
+
negative_slope: Controls the angle of the negative slope (which is used for
|
| 801 |
+
negative input values). Default: 1e-2
|
| 802 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 803 |
+
|
| 804 |
+
Shape:
|
| 805 |
+
- Input: :math:`(*)` where `*` means, any number of additional
|
| 806 |
+
dimensions
|
| 807 |
+
- Output: :math:`(*)`, same shape as the input
|
| 808 |
+
|
| 809 |
+
.. image:: ../scripts/activation_images/LeakyReLU.png
|
| 810 |
+
|
| 811 |
+
Examples::
|
| 812 |
+
|
| 813 |
+
>>> m = nn.LeakyReLU(0.1)
|
| 814 |
+
>>> input = torch.randn(2)
|
| 815 |
+
>>> output = m(input)
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
__constants__ = ["inplace", "negative_slope"]
|
| 819 |
+
inplace: bool
|
| 820 |
+
negative_slope: float
|
| 821 |
+
|
| 822 |
+
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
|
| 823 |
+
super().__init__()
|
| 824 |
+
self.negative_slope = negative_slope
|
| 825 |
+
self.inplace = inplace
|
| 826 |
+
|
| 827 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 828 |
+
return F.leaky_relu(input, self.negative_slope, self.inplace)
|
| 829 |
+
|
| 830 |
+
def extra_repr(self) -> str:
|
| 831 |
+
inplace_str = ", inplace=True" if self.inplace else ""
|
| 832 |
+
return f"negative_slope={self.negative_slope}{inplace_str}"
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
class LogSigmoid(Module):
|
| 836 |
+
r"""Applies the Logsigmoid function element-wise.
|
| 837 |
+
|
| 838 |
+
.. math::
|
| 839 |
+
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
|
| 840 |
+
|
| 841 |
+
Shape:
|
| 842 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 843 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 844 |
+
|
| 845 |
+
.. image:: ../scripts/activation_images/LogSigmoid.png
|
| 846 |
+
|
| 847 |
+
Examples::
|
| 848 |
+
|
| 849 |
+
>>> m = nn.LogSigmoid()
|
| 850 |
+
>>> input = torch.randn(2)
|
| 851 |
+
>>> output = m(input)
|
| 852 |
+
"""
|
| 853 |
+
|
| 854 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 855 |
+
return F.logsigmoid(input)
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
class Softplus(Module):
|
| 859 |
+
r"""Applies the Softplus function element-wise.
|
| 860 |
+
|
| 861 |
+
.. math::
|
| 862 |
+
\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
|
| 863 |
+
|
| 864 |
+
SoftPlus is a smooth approximation to the ReLU function and can be used
|
| 865 |
+
to constrain the output of a machine to always be positive.
|
| 866 |
+
|
| 867 |
+
For numerical stability the implementation reverts to the linear function
|
| 868 |
+
when :math:`input \times \beta > threshold`.
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
|
| 872 |
+
threshold: values above this revert to a linear function. Default: 20
|
| 873 |
+
|
| 874 |
+
Shape:
|
| 875 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 876 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 877 |
+
|
| 878 |
+
.. image:: ../scripts/activation_images/Softplus.png
|
| 879 |
+
|
| 880 |
+
Examples::
|
| 881 |
+
|
| 882 |
+
>>> m = nn.Softplus()
|
| 883 |
+
>>> input = torch.randn(2)
|
| 884 |
+
>>> output = m(input)
|
| 885 |
+
"""
|
| 886 |
+
|
| 887 |
+
__constants__ = ["beta", "threshold"]
|
| 888 |
+
beta: float
|
| 889 |
+
threshold: float
|
| 890 |
+
|
| 891 |
+
def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None:
|
| 892 |
+
super().__init__()
|
| 893 |
+
self.beta = beta
|
| 894 |
+
self.threshold = threshold
|
| 895 |
+
|
| 896 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 897 |
+
return F.softplus(input, self.beta, self.threshold)
|
| 898 |
+
|
| 899 |
+
def extra_repr(self) -> str:
|
| 900 |
+
return f"beta={self.beta}, threshold={self.threshold}"
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
class Softshrink(Module):
|
| 904 |
+
r"""Applies the soft shrinkage function element-wise.
|
| 905 |
+
|
| 906 |
+
.. math::
|
| 907 |
+
\text{SoftShrinkage}(x) =
|
| 908 |
+
\begin{cases}
|
| 909 |
+
x - \lambda, & \text{ if } x > \lambda \\
|
| 910 |
+
x + \lambda, & \text{ if } x < -\lambda \\
|
| 911 |
+
0, & \text{ otherwise }
|
| 912 |
+
\end{cases}
|
| 913 |
+
|
| 914 |
+
Args:
|
| 915 |
+
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
|
| 916 |
+
|
| 917 |
+
Shape:
|
| 918 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 919 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 920 |
+
|
| 921 |
+
.. image:: ../scripts/activation_images/Softshrink.png
|
| 922 |
+
|
| 923 |
+
Examples::
|
| 924 |
+
|
| 925 |
+
>>> m = nn.Softshrink()
|
| 926 |
+
>>> input = torch.randn(2)
|
| 927 |
+
>>> output = m(input)
|
| 928 |
+
"""
|
| 929 |
+
|
| 930 |
+
__constants__ = ["lambd"]
|
| 931 |
+
lambd: float
|
| 932 |
+
|
| 933 |
+
def __init__(self, lambd: float = 0.5) -> None:
|
| 934 |
+
super().__init__()
|
| 935 |
+
self.lambd = lambd
|
| 936 |
+
|
| 937 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 938 |
+
return F.softshrink(input, self.lambd)
|
| 939 |
+
|
| 940 |
+
def extra_repr(self) -> str:
|
| 941 |
+
return str(self.lambd)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def _check_arg_device(x: Optional[torch.Tensor]) -> bool:
|
| 945 |
+
if x is not None:
|
| 946 |
+
return x.device.type in [
|
| 947 |
+
"cpu",
|
| 948 |
+
"cuda",
|
| 949 |
+
torch.utils.backend_registration._privateuse1_backend_name,
|
| 950 |
+
]
|
| 951 |
+
return True
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool:
|
| 955 |
+
if x is not None:
|
| 956 |
+
return x.requires_grad
|
| 957 |
+
return False
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
def _is_make_fx_tracing():
|
| 961 |
+
if not torch.jit.is_scripting():
|
| 962 |
+
torch_dispatch_mode_stack = (
|
| 963 |
+
torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
| 964 |
+
)
|
| 965 |
+
return any(
|
| 966 |
+
type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode
|
| 967 |
+
for x in torch_dispatch_mode_stack
|
| 968 |
+
)
|
| 969 |
+
else:
|
| 970 |
+
return False
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
class MultiheadAttention(Module):
|
| 974 |
+
r"""Allows the model to jointly attend to information from different representation subspaces.
|
| 975 |
+
|
| 976 |
+
Method described in the paper:
|
| 977 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
| 978 |
+
|
| 979 |
+
Multi-Head Attention is defined as:
|
| 980 |
+
|
| 981 |
+
.. math::
|
| 982 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
| 983 |
+
|
| 984 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
| 985 |
+
|
| 986 |
+
``nn.MultiHeadAttention`` will use the optimized implementations of
|
| 987 |
+
``scaled_dot_product_attention()`` when possible.
|
| 988 |
+
|
| 989 |
+
In addition to support for the new ``scaled_dot_product_attention()``
|
| 990 |
+
function, for speeding up Inference, MHA will use
|
| 991 |
+
fastpath inference with support for Nested Tensors, iff:
|
| 992 |
+
|
| 993 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
|
| 994 |
+
- inputs are batched (3D) with ``batch_first==True``
|
| 995 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
| 996 |
+
- training is disabled (using ``.eval()``)
|
| 997 |
+
- ``add_bias_kv`` is ``False``
|
| 998 |
+
- ``add_zero_attn`` is ``False``
|
| 999 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
| 1000 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
| 1001 |
+
nor ``attn_mask`` is passed
|
| 1002 |
+
- autocast is disabled
|
| 1003 |
+
|
| 1004 |
+
If the optimized inference fastpath implementation is in use, a
|
| 1005 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
| 1006 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
| 1007 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
| 1008 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
| 1009 |
+
that is padding can be expected.
|
| 1010 |
+
|
| 1011 |
+
Args:
|
| 1012 |
+
embed_dim: Total dimension of the model.
|
| 1013 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
| 1014 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
| 1015 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
| 1016 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
| 1017 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
| 1018 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
| 1019 |
+
Default: ``False``.
|
| 1020 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
| 1021 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
| 1022 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 1023 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 1024 |
+
|
| 1025 |
+
Examples::
|
| 1026 |
+
|
| 1027 |
+
>>> # xdoctest: +SKIP
|
| 1028 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
| 1029 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
| 1030 |
+
|
| 1031 |
+
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
| 1032 |
+
https://arxiv.org/abs/2205.14135
|
| 1033 |
+
|
| 1034 |
+
"""
|
| 1035 |
+
|
| 1036 |
+
__constants__ = ["batch_first"]
|
| 1037 |
+
bias_k: Optional[torch.Tensor]
|
| 1038 |
+
bias_v: Optional[torch.Tensor]
|
| 1039 |
+
|
| 1040 |
+
def __init__(
|
| 1041 |
+
self,
|
| 1042 |
+
embed_dim,
|
| 1043 |
+
num_heads,
|
| 1044 |
+
dropout=0.0,
|
| 1045 |
+
bias=True,
|
| 1046 |
+
add_bias_kv=False,
|
| 1047 |
+
add_zero_attn=False,
|
| 1048 |
+
kdim=None,
|
| 1049 |
+
vdim=None,
|
| 1050 |
+
batch_first=False,
|
| 1051 |
+
device=None,
|
| 1052 |
+
dtype=None,
|
| 1053 |
+
) -> None:
|
| 1054 |
+
if embed_dim <= 0 or num_heads <= 0:
|
| 1055 |
+
raise ValueError(
|
| 1056 |
+
f"embed_dim and num_heads must be greater than 0,"
|
| 1057 |
+
f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
|
| 1058 |
+
)
|
| 1059 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1060 |
+
super().__init__()
|
| 1061 |
+
self.embed_dim = embed_dim
|
| 1062 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 1063 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 1064 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 1065 |
+
|
| 1066 |
+
self.num_heads = num_heads
|
| 1067 |
+
self.dropout = dropout
|
| 1068 |
+
self.batch_first = batch_first
|
| 1069 |
+
self.head_dim = embed_dim // num_heads
|
| 1070 |
+
assert (
|
| 1071 |
+
self.head_dim * num_heads == self.embed_dim
|
| 1072 |
+
), "embed_dim must be divisible by num_heads"
|
| 1073 |
+
|
| 1074 |
+
if not self._qkv_same_embed_dim:
|
| 1075 |
+
self.q_proj_weight = Parameter(
|
| 1076 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
| 1077 |
+
)
|
| 1078 |
+
self.k_proj_weight = Parameter(
|
| 1079 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
| 1080 |
+
)
|
| 1081 |
+
self.v_proj_weight = Parameter(
|
| 1082 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
| 1083 |
+
)
|
| 1084 |
+
self.register_parameter("in_proj_weight", None)
|
| 1085 |
+
else:
|
| 1086 |
+
self.in_proj_weight = Parameter(
|
| 1087 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
| 1088 |
+
)
|
| 1089 |
+
self.register_parameter("q_proj_weight", None)
|
| 1090 |
+
self.register_parameter("k_proj_weight", None)
|
| 1091 |
+
self.register_parameter("v_proj_weight", None)
|
| 1092 |
+
|
| 1093 |
+
if bias:
|
| 1094 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
| 1095 |
+
else:
|
| 1096 |
+
self.register_parameter("in_proj_bias", None)
|
| 1097 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
| 1098 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
if add_bias_kv:
|
| 1102 |
+
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
| 1103 |
+
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
| 1104 |
+
else:
|
| 1105 |
+
self.bias_k = self.bias_v = None
|
| 1106 |
+
|
| 1107 |
+
self.add_zero_attn = add_zero_attn
|
| 1108 |
+
|
| 1109 |
+
self._reset_parameters()
|
| 1110 |
+
|
| 1111 |
+
def _reset_parameters(self):
|
| 1112 |
+
if self._qkv_same_embed_dim:
|
| 1113 |
+
xavier_uniform_(self.in_proj_weight)
|
| 1114 |
+
else:
|
| 1115 |
+
xavier_uniform_(self.q_proj_weight)
|
| 1116 |
+
xavier_uniform_(self.k_proj_weight)
|
| 1117 |
+
xavier_uniform_(self.v_proj_weight)
|
| 1118 |
+
|
| 1119 |
+
if self.in_proj_bias is not None:
|
| 1120 |
+
constant_(self.in_proj_bias, 0.0)
|
| 1121 |
+
constant_(self.out_proj.bias, 0.0)
|
| 1122 |
+
if self.bias_k is not None:
|
| 1123 |
+
xavier_normal_(self.bias_k)
|
| 1124 |
+
if self.bias_v is not None:
|
| 1125 |
+
xavier_normal_(self.bias_v)
|
| 1126 |
+
|
| 1127 |
+
def __setstate__(self, state):
|
| 1128 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
| 1129 |
+
if "_qkv_same_embed_dim" not in state:
|
| 1130 |
+
state["_qkv_same_embed_dim"] = True
|
| 1131 |
+
|
| 1132 |
+
super().__setstate__(state)
|
| 1133 |
+
|
| 1134 |
+
def forward(
|
| 1135 |
+
self,
|
| 1136 |
+
query: Tensor,
|
| 1137 |
+
key: Tensor,
|
| 1138 |
+
value: Tensor,
|
| 1139 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 1140 |
+
need_weights: bool = True,
|
| 1141 |
+
attn_mask: Optional[Tensor] = None,
|
| 1142 |
+
average_attn_weights: bool = True,
|
| 1143 |
+
is_causal: bool = False,
|
| 1144 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 1145 |
+
r"""Compute attention outputs using query, key, and value embeddings.
|
| 1146 |
+
|
| 1147 |
+
Supports optional parameters for padding, masks and attention weights.
|
| 1148 |
+
|
| 1149 |
+
Args:
|
| 1150 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
| 1151 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
| 1152 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
| 1153 |
+
Queries are compared against key-value pairs to produce the output.
|
| 1154 |
+
See "Attention Is All You Need" for more details.
|
| 1155 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
| 1156 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
| 1157 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
| 1158 |
+
See "Attention Is All You Need" for more details.
|
| 1159 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
| 1160 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
| 1161 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
| 1162 |
+
See "Attention Is All You Need" for more details.
|
| 1163 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
| 1164 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
| 1165 |
+
Binary and float masks are supported.
|
| 1166 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
| 1167 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
| 1168 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
| 1169 |
+
Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
|
| 1170 |
+
and achieve the best performance for MHA.
|
| 1171 |
+
Default: ``True``.
|
| 1172 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
| 1173 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
| 1174 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
| 1175 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
| 1176 |
+
Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
| 1177 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
| 1178 |
+
the attention weight.
|
| 1179 |
+
If both attn_mask and key_padding_mask are supplied, their types should match.
|
| 1180 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
| 1181 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
| 1182 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
| 1183 |
+
is_causal: If specified, applies a causal mask as attention mask.
|
| 1184 |
+
Default: ``False``.
|
| 1185 |
+
Warning:
|
| 1186 |
+
``is_causal`` provides a hint that ``attn_mask`` is the
|
| 1187 |
+
causal mask. Providing incorrect hints can result in
|
| 1188 |
+
incorrect execution, including forward and backward
|
| 1189 |
+
compatibility.
|
| 1190 |
+
|
| 1191 |
+
Outputs:
|
| 1192 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
| 1193 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
| 1194 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
| 1195 |
+
embedding dimension ``embed_dim``.
|
| 1196 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
| 1197 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
| 1198 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
| 1199 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
| 1200 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
| 1201 |
+
|
| 1202 |
+
.. note::
|
| 1203 |
+
`batch_first` argument is ignored for unbatched inputs.
|
| 1204 |
+
""" # noqa: B950
|
| 1205 |
+
why_not_fast_path = ""
|
| 1206 |
+
if (
|
| 1207 |
+
(attn_mask is not None and torch.is_floating_point(attn_mask))
|
| 1208 |
+
or (key_padding_mask is not None)
|
| 1209 |
+
and torch.is_floating_point(key_padding_mask)
|
| 1210 |
+
):
|
| 1211 |
+
why_not_fast_path = "floating-point masks are not supported for fast path."
|
| 1212 |
+
|
| 1213 |
+
is_batched = query.dim() == 3
|
| 1214 |
+
|
| 1215 |
+
key_padding_mask = F._canonical_mask(
|
| 1216 |
+
mask=key_padding_mask,
|
| 1217 |
+
mask_name="key_padding_mask",
|
| 1218 |
+
other_type=F._none_or_dtype(attn_mask),
|
| 1219 |
+
other_name="attn_mask",
|
| 1220 |
+
target_type=query.dtype,
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
attn_mask = F._canonical_mask(
|
| 1224 |
+
mask=attn_mask,
|
| 1225 |
+
mask_name="attn_mask",
|
| 1226 |
+
other_type=None,
|
| 1227 |
+
other_name="",
|
| 1228 |
+
target_type=query.dtype,
|
| 1229 |
+
check_other=False,
|
| 1230 |
+
)
|
| 1231 |
+
|
| 1232 |
+
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
| 1233 |
+
|
| 1234 |
+
if not is_fastpath_enabled:
|
| 1235 |
+
why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
| 1236 |
+
elif not is_batched:
|
| 1237 |
+
why_not_fast_path = (
|
| 1238 |
+
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
| 1239 |
+
)
|
| 1240 |
+
elif query is not key or key is not value:
|
| 1241 |
+
# When lifting this restriction, don't forget to either
|
| 1242 |
+
# enforce that the dtypes all match or test cases where
|
| 1243 |
+
# they don't!
|
| 1244 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
| 1245 |
+
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
| 1246 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
| 1247 |
+
elif self.in_proj_weight is None:
|
| 1248 |
+
why_not_fast_path = "in_proj_weight was None"
|
| 1249 |
+
elif query.dtype != self.in_proj_weight.dtype:
|
| 1250 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
| 1251 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
| 1252 |
+
elif self.training:
|
| 1253 |
+
why_not_fast_path = "training is enabled"
|
| 1254 |
+
elif (self.num_heads % 2) != 0:
|
| 1255 |
+
why_not_fast_path = "self.num_heads is not even"
|
| 1256 |
+
elif not self.batch_first:
|
| 1257 |
+
why_not_fast_path = "batch_first was not True"
|
| 1258 |
+
elif self.bias_k is not None:
|
| 1259 |
+
why_not_fast_path = "self.bias_k was not None"
|
| 1260 |
+
elif self.bias_v is not None:
|
| 1261 |
+
why_not_fast_path = "self.bias_v was not None"
|
| 1262 |
+
elif self.add_zero_attn:
|
| 1263 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
| 1264 |
+
elif not self._qkv_same_embed_dim:
|
| 1265 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
| 1266 |
+
elif query.is_nested and (
|
| 1267 |
+
key_padding_mask is not None or attn_mask is not None
|
| 1268 |
+
):
|
| 1269 |
+
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
|
| 1270 |
+
is not supported with NestedTensor input"
|
| 1271 |
+
elif torch.is_autocast_enabled():
|
| 1272 |
+
why_not_fast_path = "autocast is enabled"
|
| 1273 |
+
|
| 1274 |
+
if not why_not_fast_path:
|
| 1275 |
+
tensor_args = (
|
| 1276 |
+
query,
|
| 1277 |
+
key,
|
| 1278 |
+
value,
|
| 1279 |
+
self.in_proj_weight,
|
| 1280 |
+
self.in_proj_bias,
|
| 1281 |
+
self.out_proj.weight,
|
| 1282 |
+
self.out_proj.bias,
|
| 1283 |
+
)
|
| 1284 |
+
# We have to use list comprehensions below because TorchScript does not support
|
| 1285 |
+
# generator expressions.
|
| 1286 |
+
if torch.overrides.has_torch_function(tensor_args):
|
| 1287 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
| 1288 |
+
elif _is_make_fx_tracing():
|
| 1289 |
+
why_not_fast_path = "we are running make_fx tracing"
|
| 1290 |
+
elif not all(_check_arg_device(x) for x in tensor_args):
|
| 1291 |
+
why_not_fast_path = (
|
| 1292 |
+
"some Tensor argument's device is neither one of "
|
| 1293 |
+
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
|
| 1294 |
+
)
|
| 1295 |
+
elif torch.is_grad_enabled() and any(
|
| 1296 |
+
_arg_requires_grad(x) for x in tensor_args
|
| 1297 |
+
):
|
| 1298 |
+
why_not_fast_path = (
|
| 1299 |
+
"grad is enabled and at least one of query or the "
|
| 1300 |
+
"input/output projection weights or biases requires_grad"
|
| 1301 |
+
)
|
| 1302 |
+
if not why_not_fast_path:
|
| 1303 |
+
merged_mask, mask_type = self.merge_masks(
|
| 1304 |
+
attn_mask, key_padding_mask, query
|
| 1305 |
+
)
|
| 1306 |
+
|
| 1307 |
+
if self.in_proj_bias is not None and self.in_proj_weight is not None:
|
| 1308 |
+
return torch._native_multi_head_attention(
|
| 1309 |
+
query,
|
| 1310 |
+
key,
|
| 1311 |
+
value,
|
| 1312 |
+
self.embed_dim,
|
| 1313 |
+
self.num_heads,
|
| 1314 |
+
self.in_proj_weight,
|
| 1315 |
+
self.in_proj_bias,
|
| 1316 |
+
self.out_proj.weight,
|
| 1317 |
+
self.out_proj.bias,
|
| 1318 |
+
merged_mask,
|
| 1319 |
+
need_weights,
|
| 1320 |
+
average_attn_weights,
|
| 1321 |
+
mask_type,
|
| 1322 |
+
)
|
| 1323 |
+
|
| 1324 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
| 1325 |
+
assert not any_nested, (
|
| 1326 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
| 1327 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
| 1328 |
+
)
|
| 1329 |
+
|
| 1330 |
+
if self.batch_first and is_batched:
|
| 1331 |
+
# make sure that the transpose op does not affect the "is" property
|
| 1332 |
+
if key is value:
|
| 1333 |
+
if query is key:
|
| 1334 |
+
query = key = value = query.transpose(1, 0)
|
| 1335 |
+
else:
|
| 1336 |
+
query, key = (x.transpose(1, 0) for x in (query, key))
|
| 1337 |
+
value = key
|
| 1338 |
+
else:
|
| 1339 |
+
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
|
| 1340 |
+
|
| 1341 |
+
if not self._qkv_same_embed_dim:
|
| 1342 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
| 1343 |
+
query,
|
| 1344 |
+
key,
|
| 1345 |
+
value,
|
| 1346 |
+
self.embed_dim,
|
| 1347 |
+
self.num_heads,
|
| 1348 |
+
self.in_proj_weight,
|
| 1349 |
+
self.in_proj_bias,
|
| 1350 |
+
self.bias_k,
|
| 1351 |
+
self.bias_v,
|
| 1352 |
+
self.add_zero_attn,
|
| 1353 |
+
self.dropout,
|
| 1354 |
+
self.out_proj.weight,
|
| 1355 |
+
self.out_proj.bias,
|
| 1356 |
+
training=self.training,
|
| 1357 |
+
key_padding_mask=key_padding_mask,
|
| 1358 |
+
need_weights=need_weights,
|
| 1359 |
+
attn_mask=attn_mask,
|
| 1360 |
+
use_separate_proj_weight=True,
|
| 1361 |
+
q_proj_weight=self.q_proj_weight,
|
| 1362 |
+
k_proj_weight=self.k_proj_weight,
|
| 1363 |
+
v_proj_weight=self.v_proj_weight,
|
| 1364 |
+
average_attn_weights=average_attn_weights,
|
| 1365 |
+
is_causal=is_causal,
|
| 1366 |
+
)
|
| 1367 |
+
else:
|
| 1368 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
| 1369 |
+
query,
|
| 1370 |
+
key,
|
| 1371 |
+
value,
|
| 1372 |
+
self.embed_dim,
|
| 1373 |
+
self.num_heads,
|
| 1374 |
+
self.in_proj_weight,
|
| 1375 |
+
self.in_proj_bias,
|
| 1376 |
+
self.bias_k,
|
| 1377 |
+
self.bias_v,
|
| 1378 |
+
self.add_zero_attn,
|
| 1379 |
+
self.dropout,
|
| 1380 |
+
self.out_proj.weight,
|
| 1381 |
+
self.out_proj.bias,
|
| 1382 |
+
training=self.training,
|
| 1383 |
+
key_padding_mask=key_padding_mask,
|
| 1384 |
+
need_weights=need_weights,
|
| 1385 |
+
attn_mask=attn_mask,
|
| 1386 |
+
average_attn_weights=average_attn_weights,
|
| 1387 |
+
is_causal=is_causal,
|
| 1388 |
+
)
|
| 1389 |
+
if self.batch_first and is_batched:
|
| 1390 |
+
return attn_output.transpose(1, 0), attn_output_weights
|
| 1391 |
+
else:
|
| 1392 |
+
return attn_output, attn_output_weights
|
| 1393 |
+
|
| 1394 |
+
def merge_masks(
|
| 1395 |
+
self,
|
| 1396 |
+
attn_mask: Optional[Tensor],
|
| 1397 |
+
key_padding_mask: Optional[Tensor],
|
| 1398 |
+
query: Tensor,
|
| 1399 |
+
) -> Tuple[Optional[Tensor], Optional[int]]:
|
| 1400 |
+
r"""Determine mask type and combine masks if necessary.
|
| 1401 |
+
|
| 1402 |
+
If only one mask is provided, that mask
|
| 1403 |
+
and the corresponding mask type will be returned. If both masks are provided, they will be both
|
| 1404 |
+
expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
|
| 1405 |
+
and mask type 2 will be returned
|
| 1406 |
+
Args:
|
| 1407 |
+
attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
|
| 1408 |
+
key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
|
| 1409 |
+
query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
|
| 1410 |
+
Returns:
|
| 1411 |
+
merged_mask: merged mask
|
| 1412 |
+
mask_type: merged mask type (0, 1, or 2)
|
| 1413 |
+
"""
|
| 1414 |
+
mask_type: Optional[int] = None
|
| 1415 |
+
merged_mask: Optional[Tensor] = None
|
| 1416 |
+
|
| 1417 |
+
if key_padding_mask is not None:
|
| 1418 |
+
mask_type = 1
|
| 1419 |
+
merged_mask = key_padding_mask
|
| 1420 |
+
|
| 1421 |
+
if attn_mask is not None:
|
| 1422 |
+
# In this branch query can't be a nested tensor, so it has a shape
|
| 1423 |
+
batch_size, seq_len, _ = query.shape
|
| 1424 |
+
mask_type = 2
|
| 1425 |
+
|
| 1426 |
+
# Always expands attn_mask to 4D
|
| 1427 |
+
if attn_mask.dim() == 3:
|
| 1428 |
+
attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
|
| 1429 |
+
else: # attn_mask.dim() == 2:
|
| 1430 |
+
attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(
|
| 1431 |
+
batch_size, self.num_heads, -1, -1
|
| 1432 |
+
)
|
| 1433 |
+
merged_mask = attn_mask_expanded
|
| 1434 |
+
|
| 1435 |
+
if key_padding_mask is not None:
|
| 1436 |
+
key_padding_mask_expanded = key_padding_mask.view(
|
| 1437 |
+
batch_size, 1, 1, seq_len
|
| 1438 |
+
).expand(-1, self.num_heads, -1, -1)
|
| 1439 |
+
merged_mask = attn_mask_expanded + key_padding_mask_expanded
|
| 1440 |
+
|
| 1441 |
+
# no attn_mask and no key_padding_mask, returns None, None
|
| 1442 |
+
return merged_mask, mask_type
|
| 1443 |
+
|
| 1444 |
+
|
| 1445 |
+
class PReLU(Module):
|
| 1446 |
+
r"""Applies the element-wise PReLU function.
|
| 1447 |
+
|
| 1448 |
+
.. math::
|
| 1449 |
+
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
| 1450 |
+
|
| 1451 |
+
or
|
| 1452 |
+
|
| 1453 |
+
.. math::
|
| 1454 |
+
\text{PReLU}(x) =
|
| 1455 |
+
\begin{cases}
|
| 1456 |
+
x, & \text{ if } x \ge 0 \\
|
| 1457 |
+
ax, & \text{ otherwise }
|
| 1458 |
+
\end{cases}
|
| 1459 |
+
|
| 1460 |
+
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
|
| 1461 |
+
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
| 1462 |
+
a separate :math:`a` is used for each input channel.
|
| 1463 |
+
|
| 1464 |
+
|
| 1465 |
+
.. note::
|
| 1466 |
+
weight decay should not be used when learning :math:`a` for good performance.
|
| 1467 |
+
|
| 1468 |
+
.. note::
|
| 1469 |
+
Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
| 1470 |
+
no channel dim and the number of channels = 1.
|
| 1471 |
+
|
| 1472 |
+
Args:
|
| 1473 |
+
num_parameters (int): number of :math:`a` to learn.
|
| 1474 |
+
Although it takes an int as input, there is only two values are legitimate:
|
| 1475 |
+
1, or the number of channels at input. Default: 1
|
| 1476 |
+
init (float): the initial value of :math:`a`. Default: 0.25
|
| 1477 |
+
|
| 1478 |
+
Shape:
|
| 1479 |
+
- Input: :math:`( *)` where `*` means, any number of additional
|
| 1480 |
+
dimensions.
|
| 1481 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 1482 |
+
|
| 1483 |
+
Attributes:
|
| 1484 |
+
weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
|
| 1485 |
+
|
| 1486 |
+
.. image:: ../scripts/activation_images/PReLU.png
|
| 1487 |
+
|
| 1488 |
+
Examples::
|
| 1489 |
+
|
| 1490 |
+
>>> m = nn.PReLU()
|
| 1491 |
+
>>> input = torch.randn(2)
|
| 1492 |
+
>>> output = m(input)
|
| 1493 |
+
"""
|
| 1494 |
+
|
| 1495 |
+
__constants__ = ["num_parameters"]
|
| 1496 |
+
num_parameters: int
|
| 1497 |
+
|
| 1498 |
+
def __init__(
|
| 1499 |
+
self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None
|
| 1500 |
+
) -> None:
|
| 1501 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1502 |
+
self.num_parameters = num_parameters
|
| 1503 |
+
super().__init__()
|
| 1504 |
+
self.init = init
|
| 1505 |
+
self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
|
| 1506 |
+
self.reset_parameters()
|
| 1507 |
+
|
| 1508 |
+
def reset_parameters(self):
|
| 1509 |
+
torch.nn.init.constant_(self.weight, self.init)
|
| 1510 |
+
|
| 1511 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1512 |
+
return F.prelu(input, self.weight)
|
| 1513 |
+
|
| 1514 |
+
def extra_repr(self) -> str:
|
| 1515 |
+
return f"num_parameters={self.num_parameters}"
|
| 1516 |
+
|
| 1517 |
+
|
| 1518 |
+
class Softsign(Module):
|
| 1519 |
+
r"""Applies the element-wise Softsign function.
|
| 1520 |
+
|
| 1521 |
+
.. math::
|
| 1522 |
+
\text{SoftSign}(x) = \frac{x}{ 1 + |x|}
|
| 1523 |
+
|
| 1524 |
+
Shape:
|
| 1525 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 1526 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 1527 |
+
|
| 1528 |
+
.. image:: ../scripts/activation_images/Softsign.png
|
| 1529 |
+
|
| 1530 |
+
Examples::
|
| 1531 |
+
|
| 1532 |
+
>>> m = nn.Softsign()
|
| 1533 |
+
>>> input = torch.randn(2)
|
| 1534 |
+
>>> output = m(input)
|
| 1535 |
+
"""
|
| 1536 |
+
|
| 1537 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1538 |
+
return F.softsign(input)
|
| 1539 |
+
|
| 1540 |
+
|
| 1541 |
+
class Tanhshrink(Module):
|
| 1542 |
+
r"""Applies the element-wise Tanhshrink function.
|
| 1543 |
+
|
| 1544 |
+
.. math::
|
| 1545 |
+
\text{Tanhshrink}(x) = x - \tanh(x)
|
| 1546 |
+
|
| 1547 |
+
Shape:
|
| 1548 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 1549 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 1550 |
+
|
| 1551 |
+
.. image:: ../scripts/activation_images/Tanhshrink.png
|
| 1552 |
+
|
| 1553 |
+
Examples::
|
| 1554 |
+
|
| 1555 |
+
>>> m = nn.Tanhshrink()
|
| 1556 |
+
>>> input = torch.randn(2)
|
| 1557 |
+
>>> output = m(input)
|
| 1558 |
+
"""
|
| 1559 |
+
|
| 1560 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1561 |
+
return F.tanhshrink(input)
|
| 1562 |
+
|
| 1563 |
+
|
| 1564 |
+
class Softmin(Module):
|
| 1565 |
+
r"""Applies the Softmin function to an n-dimensional input Tensor.
|
| 1566 |
+
|
| 1567 |
+
Rescales them so that the elements of the n-dimensional output Tensor
|
| 1568 |
+
lie in the range `[0, 1]` and sum to 1.
|
| 1569 |
+
|
| 1570 |
+
Softmin is defined as:
|
| 1571 |
+
|
| 1572 |
+
.. math::
|
| 1573 |
+
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
| 1574 |
+
|
| 1575 |
+
Shape:
|
| 1576 |
+
- Input: :math:`(*)` where `*` means, any number of additional
|
| 1577 |
+
dimensions
|
| 1578 |
+
- Output: :math:`(*)`, same shape as the input
|
| 1579 |
+
|
| 1580 |
+
Args:
|
| 1581 |
+
dim (int): A dimension along which Softmin will be computed (so every slice
|
| 1582 |
+
along dim will sum to 1).
|
| 1583 |
+
|
| 1584 |
+
Returns:
|
| 1585 |
+
a Tensor of the same dimension and shape as the input, with
|
| 1586 |
+
values in the range [0, 1]
|
| 1587 |
+
|
| 1588 |
+
Examples::
|
| 1589 |
+
|
| 1590 |
+
>>> m = nn.Softmin(dim=1)
|
| 1591 |
+
>>> input = torch.randn(2, 3)
|
| 1592 |
+
>>> output = m(input)
|
| 1593 |
+
"""
|
| 1594 |
+
|
| 1595 |
+
__constants__ = ["dim"]
|
| 1596 |
+
dim: Optional[int]
|
| 1597 |
+
|
| 1598 |
+
def __init__(self, dim: Optional[int] = None) -> None:
|
| 1599 |
+
super().__init__()
|
| 1600 |
+
self.dim = dim
|
| 1601 |
+
|
| 1602 |
+
def __setstate__(self, state):
|
| 1603 |
+
super().__setstate__(state)
|
| 1604 |
+
if not hasattr(self, "dim"):
|
| 1605 |
+
self.dim = None
|
| 1606 |
+
|
| 1607 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1608 |
+
return F.softmin(input, self.dim, _stacklevel=5)
|
| 1609 |
+
|
| 1610 |
+
def extra_repr(self):
|
| 1611 |
+
return f"dim={self.dim}"
|
| 1612 |
+
|
| 1613 |
+
|
| 1614 |
+
class Softmax(Module):
|
| 1615 |
+
r"""Applies the Softmax function to an n-dimensional input Tensor.
|
| 1616 |
+
|
| 1617 |
+
Rescales them so that the elements of the n-dimensional output Tensor
|
| 1618 |
+
lie in the range [0,1] and sum to 1.
|
| 1619 |
+
|
| 1620 |
+
Softmax is defined as:
|
| 1621 |
+
|
| 1622 |
+
.. math::
|
| 1623 |
+
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
| 1624 |
+
|
| 1625 |
+
When the input Tensor is a sparse tensor then the unspecified
|
| 1626 |
+
values are treated as ``-inf``.
|
| 1627 |
+
|
| 1628 |
+
Shape:
|
| 1629 |
+
- Input: :math:`(*)` where `*` means, any number of additional
|
| 1630 |
+
dimensions
|
| 1631 |
+
- Output: :math:`(*)`, same shape as the input
|
| 1632 |
+
|
| 1633 |
+
Returns:
|
| 1634 |
+
a Tensor of the same dimension and shape as the input with
|
| 1635 |
+
values in the range [0, 1]
|
| 1636 |
+
|
| 1637 |
+
Args:
|
| 1638 |
+
dim (int): A dimension along which Softmax will be computed (so every slice
|
| 1639 |
+
along dim will sum to 1).
|
| 1640 |
+
|
| 1641 |
+
.. note::
|
| 1642 |
+
This module doesn't work directly with NLLLoss,
|
| 1643 |
+
which expects the Log to be computed between the Softmax and itself.
|
| 1644 |
+
Use `LogSoftmax` instead (it's faster and has better numerical properties).
|
| 1645 |
+
|
| 1646 |
+
Examples::
|
| 1647 |
+
|
| 1648 |
+
>>> m = nn.Softmax(dim=1)
|
| 1649 |
+
>>> input = torch.randn(2, 3)
|
| 1650 |
+
>>> output = m(input)
|
| 1651 |
+
|
| 1652 |
+
"""
|
| 1653 |
+
|
| 1654 |
+
__constants__ = ["dim"]
|
| 1655 |
+
dim: Optional[int]
|
| 1656 |
+
|
| 1657 |
+
def __init__(self, dim: Optional[int] = None) -> None:
|
| 1658 |
+
super().__init__()
|
| 1659 |
+
self.dim = dim
|
| 1660 |
+
|
| 1661 |
+
def __setstate__(self, state):
|
| 1662 |
+
super().__setstate__(state)
|
| 1663 |
+
if not hasattr(self, "dim"):
|
| 1664 |
+
self.dim = None
|
| 1665 |
+
|
| 1666 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1667 |
+
return F.softmax(input, self.dim, _stacklevel=5)
|
| 1668 |
+
|
| 1669 |
+
def extra_repr(self) -> str:
|
| 1670 |
+
return f"dim={self.dim}"
|
| 1671 |
+
|
| 1672 |
+
|
| 1673 |
+
class Softmax2d(Module):
|
| 1674 |
+
r"""Applies SoftMax over features to each spatial location.
|
| 1675 |
+
|
| 1676 |
+
When given an image of ``Channels x Height x Width``, it will
|
| 1677 |
+
apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
|
| 1678 |
+
|
| 1679 |
+
Shape:
|
| 1680 |
+
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
|
| 1681 |
+
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
|
| 1682 |
+
|
| 1683 |
+
Returns:
|
| 1684 |
+
a Tensor of the same dimension and shape as the input with
|
| 1685 |
+
values in the range [0, 1]
|
| 1686 |
+
|
| 1687 |
+
Examples::
|
| 1688 |
+
|
| 1689 |
+
>>> m = nn.Softmax2d()
|
| 1690 |
+
>>> # you softmax over the 2nd dimension
|
| 1691 |
+
>>> input = torch.randn(2, 3, 12, 13)
|
| 1692 |
+
>>> output = m(input)
|
| 1693 |
+
"""
|
| 1694 |
+
|
| 1695 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1696 |
+
if input.dim() not in (3, 4):
|
| 1697 |
+
raise ValueError(
|
| 1698 |
+
f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead"
|
| 1699 |
+
)
|
| 1700 |
+
return F.softmax(input, -3, _stacklevel=5)
|
| 1701 |
+
|
| 1702 |
+
|
| 1703 |
+
class LogSoftmax(Module):
|
| 1704 |
+
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
|
| 1705 |
+
|
| 1706 |
+
The LogSoftmax formulation can be simplified as:
|
| 1707 |
+
|
| 1708 |
+
.. math::
|
| 1709 |
+
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
|
| 1710 |
+
|
| 1711 |
+
Shape:
|
| 1712 |
+
- Input: :math:`(*)` where `*` means, any number of additional
|
| 1713 |
+
dimensions
|
| 1714 |
+
- Output: :math:`(*)`, same shape as the input
|
| 1715 |
+
|
| 1716 |
+
Args:
|
| 1717 |
+
dim (int): A dimension along which LogSoftmax will be computed.
|
| 1718 |
+
|
| 1719 |
+
Returns:
|
| 1720 |
+
a Tensor of the same dimension and shape as the input with
|
| 1721 |
+
values in the range [-inf, 0)
|
| 1722 |
+
|
| 1723 |
+
Examples::
|
| 1724 |
+
|
| 1725 |
+
>>> m = nn.LogSoftmax(dim=1)
|
| 1726 |
+
>>> input = torch.randn(2, 3)
|
| 1727 |
+
>>> output = m(input)
|
| 1728 |
+
"""
|
| 1729 |
+
|
| 1730 |
+
__constants__ = ["dim"]
|
| 1731 |
+
dim: Optional[int]
|
| 1732 |
+
|
| 1733 |
+
def __init__(self, dim: Optional[int] = None) -> None:
|
| 1734 |
+
super().__init__()
|
| 1735 |
+
self.dim = dim
|
| 1736 |
+
|
| 1737 |
+
def __setstate__(self, state):
|
| 1738 |
+
super().__setstate__(state)
|
| 1739 |
+
if not hasattr(self, "dim"):
|
| 1740 |
+
self.dim = None
|
| 1741 |
+
|
| 1742 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1743 |
+
return F.log_softmax(input, self.dim, _stacklevel=5)
|
| 1744 |
+
|
| 1745 |
+
def extra_repr(self):
|
| 1746 |
+
return f"dim={self.dim}"
|
.venv/Lib/site-packages/torch/nn/modules/adaptive.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
from typing import List, Sequence
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from .container import ModuleList, Sequential
|
| 11 |
+
from .linear import Linear
|
| 12 |
+
from .module import Module
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = ["AdaptiveLogSoftmaxWithLoss"]
|
| 16 |
+
|
| 17 |
+
_ASMoutput = namedtuple("_ASMoutput", ["output", "loss"])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AdaptiveLogSoftmaxWithLoss(Module):
|
| 21 |
+
"""Efficient softmax approximation.
|
| 22 |
+
|
| 23 |
+
As described in
|
| 24 |
+
`Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
|
| 25 |
+
Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou
|
| 26 |
+
<https://arxiv.org/abs/1609.04309>`__.
|
| 27 |
+
""" r"""
|
| 28 |
+
Adaptive softmax is an approximate strategy for training models with large
|
| 29 |
+
output spaces. It is most effective when the label distribution is highly
|
| 30 |
+
imbalanced, for example in natural language modelling, where the word
|
| 31 |
+
frequency distribution approximately follows the `Zipf's law`_.
|
| 32 |
+
|
| 33 |
+
Adaptive softmax partitions the labels into several clusters, according to
|
| 34 |
+
their frequency. These clusters may contain different number of targets
|
| 35 |
+
each.
|
| 36 |
+
Additionally, clusters containing less frequent labels assign lower
|
| 37 |
+
dimensional embeddings to those labels, which speeds up the computation.
|
| 38 |
+
For each minibatch, only clusters for which at least one target is
|
| 39 |
+
present are evaluated.
|
| 40 |
+
|
| 41 |
+
The idea is that the clusters which are accessed frequently
|
| 42 |
+
(like the first one, containing most frequent labels), should also be cheap
|
| 43 |
+
to compute -- that is, contain a small number of assigned labels.
|
| 44 |
+
|
| 45 |
+
We highly recommend taking a look at the original paper for more details.
|
| 46 |
+
|
| 47 |
+
* :attr:`cutoffs` should be an ordered Sequence of integers sorted
|
| 48 |
+
in the increasing order.
|
| 49 |
+
It controls number of clusters and the partitioning of targets into
|
| 50 |
+
clusters. For example setting ``cutoffs = [10, 100, 1000]``
|
| 51 |
+
means that first `10` targets will be assigned
|
| 52 |
+
to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
|
| 53 |
+
assigned to the first cluster, and targets `101, 102, ..., 1000` will be
|
| 54 |
+
assigned to the second cluster, while targets
|
| 55 |
+
`1001, 1002, ..., n_classes - 1` will be assigned
|
| 56 |
+
to the last, third cluster.
|
| 57 |
+
|
| 58 |
+
* :attr:`div_value` is used to compute the size of each additional cluster,
|
| 59 |
+
which is given as
|
| 60 |
+
:math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`,
|
| 61 |
+
where :math:`idx` is the cluster index (with clusters
|
| 62 |
+
for less frequent words having larger indices,
|
| 63 |
+
and indices starting from :math:`1`).
|
| 64 |
+
|
| 65 |
+
* :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
|
| 66 |
+
adaptive softmax. See paper for details. Set to False in the official
|
| 67 |
+
implementation.
|
| 68 |
+
|
| 69 |
+
.. warning::
|
| 70 |
+
Labels passed as inputs to this module should be sorted according to
|
| 71 |
+
their frequency. This means that the most frequent label should be
|
| 72 |
+
represented by the index `0`, and the least frequent
|
| 73 |
+
label should be represented by the index `n_classes - 1`.
|
| 74 |
+
|
| 75 |
+
.. note::
|
| 76 |
+
This module returns a ``NamedTuple`` with ``output``
|
| 77 |
+
and ``loss`` fields. See further documentation for details.
|
| 78 |
+
|
| 79 |
+
.. note::
|
| 80 |
+
To compute log-probabilities for all classes, the ``log_prob``
|
| 81 |
+
method can be used.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
in_features (int): Number of features in the input tensor
|
| 85 |
+
n_classes (int): Number of classes in the dataset
|
| 86 |
+
cutoffs (Sequence): Cutoffs used to assign targets to their buckets
|
| 87 |
+
div_value (float, optional): value used as an exponent to compute sizes
|
| 88 |
+
of the clusters. Default: 4.0
|
| 89 |
+
head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
|
| 90 |
+
adaptive softmax. Default: ``False``
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
``NamedTuple`` with ``output`` and ``loss`` fields:
|
| 94 |
+
* **output** is a Tensor of size ``N`` containing computed target
|
| 95 |
+
log probabilities for each example
|
| 96 |
+
* **loss** is a Scalar representing the computed negative
|
| 97 |
+
log likelihood loss
|
| 98 |
+
|
| 99 |
+
Shape:
|
| 100 |
+
- input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})`
|
| 101 |
+
- target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}`
|
| 102 |
+
- output1: :math:`(N)` or :math:`()`
|
| 103 |
+
- output2: ``Scalar``
|
| 104 |
+
|
| 105 |
+
.. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
in_features: int
|
| 109 |
+
n_classes: int
|
| 110 |
+
cutoffs: List[int]
|
| 111 |
+
div_value: float
|
| 112 |
+
head_bias: bool
|
| 113 |
+
head: Linear
|
| 114 |
+
tail: ModuleList
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
in_features: int,
|
| 119 |
+
n_classes: int,
|
| 120 |
+
cutoffs: Sequence[int],
|
| 121 |
+
div_value: float = 4.0,
|
| 122 |
+
head_bias: bool = False,
|
| 123 |
+
device=None,
|
| 124 |
+
dtype=None,
|
| 125 |
+
) -> None:
|
| 126 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
cutoffs = list(cutoffs)
|
| 130 |
+
|
| 131 |
+
if len(cutoffs) == 0:
|
| 132 |
+
raise ValueError("cutoffs should be a sequence of length larger than 0")
|
| 133 |
+
|
| 134 |
+
if (
|
| 135 |
+
(cutoffs != sorted(cutoffs))
|
| 136 |
+
or (min(cutoffs) <= 0)
|
| 137 |
+
or (max(cutoffs) > (n_classes - 1))
|
| 138 |
+
or (len(set(cutoffs)) != len(cutoffs))
|
| 139 |
+
or any(int(c) != c for c in cutoffs)
|
| 140 |
+
):
|
| 141 |
+
raise ValueError(
|
| 142 |
+
"cutoffs should be a sequence of unique, positive "
|
| 143 |
+
"integers sorted in an increasing order, where "
|
| 144 |
+
"each value is between 1 and n_classes-1"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
self.in_features = in_features
|
| 148 |
+
self.n_classes = n_classes
|
| 149 |
+
self.cutoffs = cutoffs + [n_classes]
|
| 150 |
+
self.div_value = div_value
|
| 151 |
+
self.head_bias = head_bias
|
| 152 |
+
|
| 153 |
+
self.shortlist_size = self.cutoffs[0]
|
| 154 |
+
self.n_clusters = len(self.cutoffs) - 1
|
| 155 |
+
self.head_size = self.shortlist_size + self.n_clusters
|
| 156 |
+
|
| 157 |
+
self.head = Linear(
|
| 158 |
+
self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs
|
| 159 |
+
)
|
| 160 |
+
self.tail = ModuleList()
|
| 161 |
+
|
| 162 |
+
for i in range(self.n_clusters):
|
| 163 |
+
hsz = int(self.in_features // (self.div_value ** (i + 1)))
|
| 164 |
+
osz = self.cutoffs[i + 1] - self.cutoffs[i]
|
| 165 |
+
|
| 166 |
+
projection = Sequential(
|
| 167 |
+
Linear(self.in_features, hsz, bias=False, **factory_kwargs),
|
| 168 |
+
Linear(hsz, osz, bias=False, **factory_kwargs),
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.tail.append(projection)
|
| 172 |
+
|
| 173 |
+
def reset_parameters(self) -> None:
|
| 174 |
+
self.head.reset_parameters()
|
| 175 |
+
for i2h, h2o in self.tail:
|
| 176 |
+
i2h.reset_parameters()
|
| 177 |
+
h2o.reset_parameters()
|
| 178 |
+
|
| 179 |
+
def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
|
| 180 |
+
targ_dim = target_.dim()
|
| 181 |
+
|
| 182 |
+
if targ_dim == 1:
|
| 183 |
+
if input_.size(0) != target_.size(0):
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
"Input and target should have the same size "
|
| 186 |
+
"in the batch dimension."
|
| 187 |
+
)
|
| 188 |
+
if input_.dim() != 2:
|
| 189 |
+
raise RuntimeError(
|
| 190 |
+
"1D target tensor expects 2D input tensors, "
|
| 191 |
+
"but found inputs with size",
|
| 192 |
+
input_.size(),
|
| 193 |
+
)
|
| 194 |
+
elif targ_dim == 0:
|
| 195 |
+
if input_.dim() != 1:
|
| 196 |
+
raise RuntimeError(
|
| 197 |
+
"0D target tensor expects 1D input tensors, "
|
| 198 |
+
"but found inputs with size",
|
| 199 |
+
input_.size(),
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
raise RuntimeError(
|
| 203 |
+
"0D or 1D target tensor expected, " "multi-target not supported"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
is_batched = targ_dim > 0
|
| 207 |
+
input = input_ if is_batched else input_.unsqueeze(0)
|
| 208 |
+
target = target_ if is_batched else target_.unsqueeze(0)
|
| 209 |
+
|
| 210 |
+
used_rows = 0
|
| 211 |
+
batch_size = target.size(0)
|
| 212 |
+
|
| 213 |
+
output = input.new_zeros(batch_size)
|
| 214 |
+
gather_inds = target.new_empty(batch_size)
|
| 215 |
+
|
| 216 |
+
cutoff_values = [0] + self.cutoffs
|
| 217 |
+
for i in range(len(cutoff_values) - 1):
|
| 218 |
+
low_idx = cutoff_values[i]
|
| 219 |
+
high_idx = cutoff_values[i + 1]
|
| 220 |
+
|
| 221 |
+
target_mask = (target >= low_idx) & (target < high_idx)
|
| 222 |
+
row_indices = target_mask.nonzero().squeeze()
|
| 223 |
+
|
| 224 |
+
if row_indices.numel() == 0:
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
if i == 0:
|
| 228 |
+
gather_inds.index_copy_(0, row_indices, target[target_mask])
|
| 229 |
+
|
| 230 |
+
else:
|
| 231 |
+
relative_target = target[target_mask] - low_idx
|
| 232 |
+
input_subset = input.index_select(0, row_indices)
|
| 233 |
+
|
| 234 |
+
cluster_output = self.tail[i - 1](input_subset)
|
| 235 |
+
cluster_index = self.shortlist_size + i - 1
|
| 236 |
+
|
| 237 |
+
gather_inds.index_fill_(0, row_indices, cluster_index)
|
| 238 |
+
cluster_logprob = F.log_softmax(cluster_output, dim=1)
|
| 239 |
+
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
|
| 240 |
+
output.index_copy_(0, row_indices, local_logprob.squeeze(1))
|
| 241 |
+
|
| 242 |
+
used_rows += row_indices.numel()
|
| 243 |
+
|
| 244 |
+
if used_rows != batch_size:
|
| 245 |
+
raise RuntimeError(
|
| 246 |
+
f"Target values should be in [0, {self.n_classes - 1}], "
|
| 247 |
+
f"but values in range [{target.min().item()}, {target.max().item()}] "
|
| 248 |
+
"were found. "
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
head_output = self.head(input)
|
| 252 |
+
head_logprob = F.log_softmax(head_output, dim=1)
|
| 253 |
+
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
|
| 254 |
+
loss = (-output).mean()
|
| 255 |
+
|
| 256 |
+
if not is_batched:
|
| 257 |
+
output = output.squeeze(0)
|
| 258 |
+
|
| 259 |
+
return _ASMoutput(output, loss)
|
| 260 |
+
|
| 261 |
+
def _get_full_log_prob(self, input, head_output):
|
| 262 |
+
"""Given input tensor, and output of ``self.head``, compute the log of the full distribution."""
|
| 263 |
+
out = input.new_empty((head_output.size(0), self.n_classes))
|
| 264 |
+
head_logprob = F.log_softmax(head_output, dim=1)
|
| 265 |
+
|
| 266 |
+
out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size]
|
| 267 |
+
|
| 268 |
+
for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
|
| 269 |
+
cluster_output = self.tail[i](input)
|
| 270 |
+
cluster_logprob = F.log_softmax(cluster_output, dim=1)
|
| 271 |
+
output_logprob = cluster_logprob + head_logprob[
|
| 272 |
+
:, self.shortlist_size + i
|
| 273 |
+
].unsqueeze(1)
|
| 274 |
+
|
| 275 |
+
out[:, start_idx:stop_idx] = output_logprob
|
| 276 |
+
|
| 277 |
+
return out
|
| 278 |
+
|
| 279 |
+
def log_prob(self, input: Tensor) -> Tensor:
|
| 280 |
+
r"""Compute log probabilities for all :math:`\texttt{n\_classes}`.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
input (Tensor): a minibatch of examples
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
log-probabilities of for each class :math:`c`
|
| 287 |
+
in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
|
| 288 |
+
parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
|
| 289 |
+
|
| 290 |
+
Shape:
|
| 291 |
+
- Input: :math:`(N, \texttt{in\_features})`
|
| 292 |
+
- Output: :math:`(N, \texttt{n\_classes})`
|
| 293 |
+
|
| 294 |
+
"""
|
| 295 |
+
head_output = self.head(input)
|
| 296 |
+
return self._get_full_log_prob(input, head_output)
|
| 297 |
+
|
| 298 |
+
def predict(self, input: Tensor) -> Tensor:
|
| 299 |
+
r"""Return the class with the highest probability for each example in the input minibatch.
|
| 300 |
+
|
| 301 |
+
This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
input (Tensor): a minibatch of examples
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
output (Tensor): a class with the highest probability for each example
|
| 308 |
+
|
| 309 |
+
Shape:
|
| 310 |
+
- Input: :math:`(N, \texttt{in\_features})`
|
| 311 |
+
- Output: :math:`(N)`
|
| 312 |
+
"""
|
| 313 |
+
head_output = self.head(input)
|
| 314 |
+
output = torch.argmax(head_output, dim=1)
|
| 315 |
+
not_in_shortlist = output >= self.shortlist_size
|
| 316 |
+
all_in_shortlist = not (not_in_shortlist.any())
|
| 317 |
+
|
| 318 |
+
if all_in_shortlist:
|
| 319 |
+
return output
|
| 320 |
+
|
| 321 |
+
elif not_in_shortlist.all():
|
| 322 |
+
log_prob = self._get_full_log_prob(input, head_output)
|
| 323 |
+
return torch.argmax(log_prob, dim=1)
|
| 324 |
+
|
| 325 |
+
else:
|
| 326 |
+
log_prob = self._get_full_log_prob(
|
| 327 |
+
input[not_in_shortlist], head_output[not_in_shortlist]
|
| 328 |
+
)
|
| 329 |
+
output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
|
| 330 |
+
return output
|
.venv/Lib/site-packages/torch/nn/modules/batchnorm.py
ADDED
|
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch.nn import functional as F, init
|
| 7 |
+
from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter
|
| 8 |
+
|
| 9 |
+
from ._functions import SyncBatchNorm as sync_batch_norm
|
| 10 |
+
from .lazy import LazyModuleMixin
|
| 11 |
+
from .module import Module
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"BatchNorm1d",
|
| 16 |
+
"LazyBatchNorm1d",
|
| 17 |
+
"BatchNorm2d",
|
| 18 |
+
"LazyBatchNorm2d",
|
| 19 |
+
"BatchNorm3d",
|
| 20 |
+
"LazyBatchNorm3d",
|
| 21 |
+
"SyncBatchNorm",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _NormBase(Module):
|
| 26 |
+
"""Common base of _InstanceNorm and _BatchNorm."""
|
| 27 |
+
|
| 28 |
+
_version = 2
|
| 29 |
+
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
|
| 30 |
+
num_features: int
|
| 31 |
+
eps: float
|
| 32 |
+
momentum: Optional[float]
|
| 33 |
+
affine: bool
|
| 34 |
+
track_running_stats: bool
|
| 35 |
+
# WARNING: weight and bias purposely not defined here.
|
| 36 |
+
# See https://github.com/pytorch/pytorch/issues/39670
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
num_features: int,
|
| 41 |
+
eps: float = 1e-5,
|
| 42 |
+
momentum: Optional[float] = 0.1,
|
| 43 |
+
affine: bool = True,
|
| 44 |
+
track_running_stats: bool = True,
|
| 45 |
+
device=None,
|
| 46 |
+
dtype=None,
|
| 47 |
+
) -> None:
|
| 48 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.num_features = num_features
|
| 51 |
+
self.eps = eps
|
| 52 |
+
self.momentum = momentum
|
| 53 |
+
self.affine = affine
|
| 54 |
+
self.track_running_stats = track_running_stats
|
| 55 |
+
if self.affine:
|
| 56 |
+
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
|
| 57 |
+
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
|
| 58 |
+
else:
|
| 59 |
+
self.register_parameter("weight", None)
|
| 60 |
+
self.register_parameter("bias", None)
|
| 61 |
+
if self.track_running_stats:
|
| 62 |
+
self.register_buffer(
|
| 63 |
+
"running_mean", torch.zeros(num_features, **factory_kwargs)
|
| 64 |
+
)
|
| 65 |
+
self.register_buffer(
|
| 66 |
+
"running_var", torch.ones(num_features, **factory_kwargs)
|
| 67 |
+
)
|
| 68 |
+
self.running_mean: Optional[Tensor]
|
| 69 |
+
self.running_var: Optional[Tensor]
|
| 70 |
+
self.register_buffer(
|
| 71 |
+
"num_batches_tracked",
|
| 72 |
+
torch.tensor(
|
| 73 |
+
0,
|
| 74 |
+
dtype=torch.long,
|
| 75 |
+
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
| 76 |
+
),
|
| 77 |
+
)
|
| 78 |
+
self.num_batches_tracked: Optional[Tensor]
|
| 79 |
+
else:
|
| 80 |
+
self.register_buffer("running_mean", None)
|
| 81 |
+
self.register_buffer("running_var", None)
|
| 82 |
+
self.register_buffer("num_batches_tracked", None)
|
| 83 |
+
self.reset_parameters()
|
| 84 |
+
|
| 85 |
+
def reset_running_stats(self) -> None:
|
| 86 |
+
if self.track_running_stats:
|
| 87 |
+
# running_mean/running_var/num_batches... are registered at runtime depending
|
| 88 |
+
# if self.track_running_stats is on
|
| 89 |
+
self.running_mean.zero_() # type: ignore[union-attr]
|
| 90 |
+
self.running_var.fill_(1) # type: ignore[union-attr]
|
| 91 |
+
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
|
| 92 |
+
|
| 93 |
+
def reset_parameters(self) -> None:
|
| 94 |
+
self.reset_running_stats()
|
| 95 |
+
if self.affine:
|
| 96 |
+
init.ones_(self.weight)
|
| 97 |
+
init.zeros_(self.bias)
|
| 98 |
+
|
| 99 |
+
def _check_input_dim(self, input):
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
def extra_repr(self):
|
| 103 |
+
return (
|
| 104 |
+
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
|
| 105 |
+
"track_running_stats={track_running_stats}".format(**self.__dict__)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def _load_from_state_dict(
|
| 109 |
+
self,
|
| 110 |
+
state_dict,
|
| 111 |
+
prefix,
|
| 112 |
+
local_metadata,
|
| 113 |
+
strict,
|
| 114 |
+
missing_keys,
|
| 115 |
+
unexpected_keys,
|
| 116 |
+
error_msgs,
|
| 117 |
+
):
|
| 118 |
+
version = local_metadata.get("version", None)
|
| 119 |
+
|
| 120 |
+
if (version is None or version < 2) and self.track_running_stats:
|
| 121 |
+
# at version 2: added num_batches_tracked buffer
|
| 122 |
+
# this should have a default value of 0
|
| 123 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 124 |
+
if num_batches_tracked_key not in state_dict:
|
| 125 |
+
state_dict[num_batches_tracked_key] = (
|
| 126 |
+
self.num_batches_tracked
|
| 127 |
+
if self.num_batches_tracked is not None
|
| 128 |
+
and self.num_batches_tracked.device != torch.device("meta")
|
| 129 |
+
else torch.tensor(0, dtype=torch.long)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
super()._load_from_state_dict(
|
| 133 |
+
state_dict,
|
| 134 |
+
prefix,
|
| 135 |
+
local_metadata,
|
| 136 |
+
strict,
|
| 137 |
+
missing_keys,
|
| 138 |
+
unexpected_keys,
|
| 139 |
+
error_msgs,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class _BatchNorm(_NormBase):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
num_features: int,
|
| 147 |
+
eps: float = 1e-5,
|
| 148 |
+
momentum: Optional[float] = 0.1,
|
| 149 |
+
affine: bool = True,
|
| 150 |
+
track_running_stats: bool = True,
|
| 151 |
+
device=None,
|
| 152 |
+
dtype=None,
|
| 153 |
+
) -> None:
|
| 154 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 155 |
+
super().__init__(
|
| 156 |
+
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 160 |
+
self._check_input_dim(input)
|
| 161 |
+
|
| 162 |
+
# exponential_average_factor is set to self.momentum
|
| 163 |
+
# (when it is available) only so that it gets updated
|
| 164 |
+
# in ONNX graph when this node is exported to ONNX.
|
| 165 |
+
if self.momentum is None:
|
| 166 |
+
exponential_average_factor = 0.0
|
| 167 |
+
else:
|
| 168 |
+
exponential_average_factor = self.momentum
|
| 169 |
+
|
| 170 |
+
if self.training and self.track_running_stats:
|
| 171 |
+
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
| 172 |
+
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
| 173 |
+
self.num_batches_tracked.add_(1) # type: ignore[has-type]
|
| 174 |
+
if self.momentum is None: # use cumulative moving average
|
| 175 |
+
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
| 176 |
+
else: # use exponential moving average
|
| 177 |
+
exponential_average_factor = self.momentum
|
| 178 |
+
|
| 179 |
+
r"""
|
| 180 |
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
| 181 |
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
| 182 |
+
"""
|
| 183 |
+
if self.training:
|
| 184 |
+
bn_training = True
|
| 185 |
+
else:
|
| 186 |
+
bn_training = (self.running_mean is None) and (self.running_var is None)
|
| 187 |
+
|
| 188 |
+
r"""
|
| 189 |
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
| 190 |
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
| 191 |
+
used for normalization (i.e. in eval mode when buffers are not None).
|
| 192 |
+
"""
|
| 193 |
+
return F.batch_norm(
|
| 194 |
+
input,
|
| 195 |
+
# If buffers are not to be tracked, ensure that they won't be updated
|
| 196 |
+
self.running_mean
|
| 197 |
+
if not self.training or self.track_running_stats
|
| 198 |
+
else None,
|
| 199 |
+
self.running_var if not self.training or self.track_running_stats else None,
|
| 200 |
+
self.weight,
|
| 201 |
+
self.bias,
|
| 202 |
+
bn_training,
|
| 203 |
+
exponential_average_factor,
|
| 204 |
+
self.eps,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class _LazyNormBase(LazyModuleMixin, _NormBase):
|
| 209 |
+
weight: UninitializedParameter # type: ignore[assignment]
|
| 210 |
+
bias: UninitializedParameter # type: ignore[assignment]
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
eps=1e-5,
|
| 215 |
+
momentum=0.1,
|
| 216 |
+
affine=True,
|
| 217 |
+
track_running_stats=True,
|
| 218 |
+
device=None,
|
| 219 |
+
dtype=None,
|
| 220 |
+
) -> None:
|
| 221 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 222 |
+
super().__init__(
|
| 223 |
+
# affine and track_running_stats are hardcoded to False to
|
| 224 |
+
# avoid creating tensors that will soon be overwritten.
|
| 225 |
+
0,
|
| 226 |
+
eps,
|
| 227 |
+
momentum,
|
| 228 |
+
False,
|
| 229 |
+
False,
|
| 230 |
+
**factory_kwargs,
|
| 231 |
+
)
|
| 232 |
+
self.affine = affine
|
| 233 |
+
self.track_running_stats = track_running_stats
|
| 234 |
+
if self.affine:
|
| 235 |
+
self.weight = UninitializedParameter(**factory_kwargs)
|
| 236 |
+
self.bias = UninitializedParameter(**factory_kwargs)
|
| 237 |
+
if self.track_running_stats:
|
| 238 |
+
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
| 239 |
+
self.running_var = UninitializedBuffer(**factory_kwargs)
|
| 240 |
+
self.num_batches_tracked = torch.tensor(
|
| 241 |
+
0,
|
| 242 |
+
dtype=torch.long,
|
| 243 |
+
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def reset_parameters(self) -> None:
|
| 247 |
+
if not self.has_uninitialized_params() and self.num_features != 0:
|
| 248 |
+
super().reset_parameters()
|
| 249 |
+
|
| 250 |
+
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
| 251 |
+
if self.has_uninitialized_params():
|
| 252 |
+
self.num_features = input.shape[1]
|
| 253 |
+
if self.affine:
|
| 254 |
+
assert isinstance(self.weight, UninitializedParameter)
|
| 255 |
+
assert isinstance(self.bias, UninitializedParameter)
|
| 256 |
+
self.weight.materialize((self.num_features,))
|
| 257 |
+
self.bias.materialize((self.num_features,))
|
| 258 |
+
if self.track_running_stats:
|
| 259 |
+
self.running_mean.materialize( # type:ignore[union-attr]
|
| 260 |
+
(self.num_features,)
|
| 261 |
+
)
|
| 262 |
+
self.running_var.materialize( # type:ignore[union-attr]
|
| 263 |
+
(self.num_features,)
|
| 264 |
+
)
|
| 265 |
+
self.reset_parameters()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class BatchNorm1d(_BatchNorm):
|
| 269 |
+
r"""Applies Batch Normalization over a 2D or 3D input.
|
| 270 |
+
|
| 271 |
+
Method described in the paper
|
| 272 |
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
| 273 |
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
| 274 |
+
|
| 275 |
+
.. math::
|
| 276 |
+
|
| 277 |
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 278 |
+
|
| 279 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 280 |
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
| 281 |
+
of size `C` (where `C` is the number of features or channels of the input). By default, the
|
| 282 |
+
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
|
| 283 |
+
At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
|
| 284 |
+
equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
|
| 285 |
+
moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to
|
| 286 |
+
``torch.var(input, unbiased=True)``.
|
| 287 |
+
|
| 288 |
+
Also by default, during training this layer keeps running estimates of its
|
| 289 |
+
computed mean and variance, which are then used for normalization during
|
| 290 |
+
evaluation. The running estimates are kept with a default :attr:`momentum`
|
| 291 |
+
of 0.1.
|
| 292 |
+
|
| 293 |
+
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
| 294 |
+
keep running estimates, and batch statistics are instead used during
|
| 295 |
+
evaluation time as well.
|
| 296 |
+
|
| 297 |
+
.. note::
|
| 298 |
+
This :attr:`momentum` argument is different from one used in optimizer
|
| 299 |
+
classes and the conventional notion of momentum. Mathematically, the
|
| 300 |
+
update rule for running statistics here is
|
| 301 |
+
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
| 302 |
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
| 303 |
+
new observed value.
|
| 304 |
+
|
| 305 |
+
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
| 306 |
+
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
num_features: number of features or channels :math:`C` of the input
|
| 310 |
+
eps: a value added to the denominator for numerical stability.
|
| 311 |
+
Default: 1e-5
|
| 312 |
+
momentum: the value used for the running_mean and running_var
|
| 313 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 314 |
+
(i.e. simple average). Default: 0.1
|
| 315 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 316 |
+
learnable affine parameters. Default: ``True``
|
| 317 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 318 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 319 |
+
this module does not track such statistics, and initializes statistics
|
| 320 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 321 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 322 |
+
in both training and eval modes. Default: ``True``
|
| 323 |
+
|
| 324 |
+
Shape:
|
| 325 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
|
| 326 |
+
:math:`C` is the number of features or channels, and :math:`L` is the sequence length
|
| 327 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
| 328 |
+
|
| 329 |
+
Examples::
|
| 330 |
+
|
| 331 |
+
>>> # With Learnable Parameters
|
| 332 |
+
>>> m = nn.BatchNorm1d(100)
|
| 333 |
+
>>> # Without Learnable Parameters
|
| 334 |
+
>>> m = nn.BatchNorm1d(100, affine=False)
|
| 335 |
+
>>> input = torch.randn(20, 100)
|
| 336 |
+
>>> output = m(input)
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
def _check_input_dim(self, input):
|
| 340 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 341 |
+
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
|
| 345 |
+
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
|
| 346 |
+
|
| 347 |
+
Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
|
| 348 |
+
from the ``input.size(1)``.
|
| 349 |
+
The attributes that will be lazily initialized are `weight`, `bias`,
|
| 350 |
+
`running_mean` and `running_var`.
|
| 351 |
+
|
| 352 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 353 |
+
on lazy modules and their limitations.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
eps: a value added to the denominator for numerical stability.
|
| 357 |
+
Default: 1e-5
|
| 358 |
+
momentum: the value used for the running_mean and running_var
|
| 359 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 360 |
+
(i.e. simple average). Default: 0.1
|
| 361 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 362 |
+
learnable affine parameters. Default: ``True``
|
| 363 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 364 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 365 |
+
this module does not track such statistics, and initializes statistics
|
| 366 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 367 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 368 |
+
in both training and eval modes. Default: ``True``
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
cls_to_become = BatchNorm1d # type: ignore[assignment]
|
| 372 |
+
|
| 373 |
+
def _check_input_dim(self, input):
|
| 374 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 375 |
+
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class BatchNorm2d(_BatchNorm):
|
| 379 |
+
r"""Applies Batch Normalization over a 4D input.
|
| 380 |
+
|
| 381 |
+
4D is a mini-batch of 2D inputs
|
| 382 |
+
with additional channel dimension. Method described in the paper
|
| 383 |
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
| 384 |
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
| 385 |
+
|
| 386 |
+
.. math::
|
| 387 |
+
|
| 388 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 389 |
+
|
| 390 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 391 |
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
| 392 |
+
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
| 393 |
+
to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
|
| 394 |
+
standard-deviation is calculated via the biased estimator, equivalent to
|
| 395 |
+
``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
|
| 396 |
+
standard-deviation is calculated via the unbiased estimator, equivalent to
|
| 397 |
+
``torch.var(input, unbiased=True)``.
|
| 398 |
+
|
| 399 |
+
Also by default, during training this layer keeps running estimates of its
|
| 400 |
+
computed mean and variance, which are then used for normalization during
|
| 401 |
+
evaluation. The running estimates are kept with a default :attr:`momentum`
|
| 402 |
+
of 0.1.
|
| 403 |
+
|
| 404 |
+
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
| 405 |
+
keep running estimates, and batch statistics are instead used during
|
| 406 |
+
evaluation time as well.
|
| 407 |
+
|
| 408 |
+
.. note::
|
| 409 |
+
This :attr:`momentum` argument is different from one used in optimizer
|
| 410 |
+
classes and the conventional notion of momentum. Mathematically, the
|
| 411 |
+
update rule for running statistics here is
|
| 412 |
+
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
| 413 |
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
| 414 |
+
new observed value.
|
| 415 |
+
|
| 416 |
+
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
| 417 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
num_features: :math:`C` from an expected input of size
|
| 421 |
+
:math:`(N, C, H, W)`
|
| 422 |
+
eps: a value added to the denominator for numerical stability.
|
| 423 |
+
Default: 1e-5
|
| 424 |
+
momentum: the value used for the running_mean and running_var
|
| 425 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 426 |
+
(i.e. simple average). Default: 0.1
|
| 427 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 428 |
+
learnable affine parameters. Default: ``True``
|
| 429 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 430 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 431 |
+
this module does not track such statistics, and initializes statistics
|
| 432 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 433 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 434 |
+
in both training and eval modes. Default: ``True``
|
| 435 |
+
|
| 436 |
+
Shape:
|
| 437 |
+
- Input: :math:`(N, C, H, W)`
|
| 438 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
| 439 |
+
|
| 440 |
+
Examples::
|
| 441 |
+
|
| 442 |
+
>>> # With Learnable Parameters
|
| 443 |
+
>>> m = nn.BatchNorm2d(100)
|
| 444 |
+
>>> # Without Learnable Parameters
|
| 445 |
+
>>> m = nn.BatchNorm2d(100, affine=False)
|
| 446 |
+
>>> input = torch.randn(20, 100, 35, 45)
|
| 447 |
+
>>> output = m(input)
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
def _check_input_dim(self, input):
|
| 451 |
+
if input.dim() != 4:
|
| 452 |
+
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
|
| 456 |
+
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
|
| 457 |
+
|
| 458 |
+
Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
|
| 459 |
+
from the ``input.size(1)``.
|
| 460 |
+
The attributes that will be lazily initialized are `weight`, `bias`,
|
| 461 |
+
`running_mean` and `running_var`.
|
| 462 |
+
|
| 463 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 464 |
+
on lazy modules and their limitations.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
eps: a value added to the denominator for numerical stability.
|
| 468 |
+
Default: 1e-5
|
| 469 |
+
momentum: the value used for the running_mean and running_var
|
| 470 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 471 |
+
(i.e. simple average). Default: 0.1
|
| 472 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 473 |
+
learnable affine parameters. Default: ``True``
|
| 474 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 475 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 476 |
+
this module does not track such statistics, and initializes statistics
|
| 477 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 478 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 479 |
+
in both training and eval modes. Default: ``True``
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
cls_to_become = BatchNorm2d # type: ignore[assignment]
|
| 483 |
+
|
| 484 |
+
def _check_input_dim(self, input):
|
| 485 |
+
if input.dim() != 4:
|
| 486 |
+
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
class BatchNorm3d(_BatchNorm):
|
| 490 |
+
r"""Applies Batch Normalization over a 5D input.
|
| 491 |
+
|
| 492 |
+
5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
|
| 493 |
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
| 494 |
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
| 495 |
+
|
| 496 |
+
.. math::
|
| 497 |
+
|
| 498 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 499 |
+
|
| 500 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 501 |
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
| 502 |
+
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
| 503 |
+
to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
|
| 504 |
+
standard-deviation is calculated via the biased estimator, equivalent to
|
| 505 |
+
``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
|
| 506 |
+
standard-deviation is calculated via the unbiased estimator, equivalent to
|
| 507 |
+
``torch.var(input, unbiased=True)``.
|
| 508 |
+
|
| 509 |
+
Also by default, during training this layer keeps running estimates of its
|
| 510 |
+
computed mean and variance, which are then used for normalization during
|
| 511 |
+
evaluation. The running estimates are kept with a default :attr:`momentum`
|
| 512 |
+
of 0.1.
|
| 513 |
+
|
| 514 |
+
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
| 515 |
+
keep running estimates, and batch statistics are instead used during
|
| 516 |
+
evaluation time as well.
|
| 517 |
+
|
| 518 |
+
.. note::
|
| 519 |
+
This :attr:`momentum` argument is different from one used in optimizer
|
| 520 |
+
classes and the conventional notion of momentum. Mathematically, the
|
| 521 |
+
update rule for running statistics here is
|
| 522 |
+
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
| 523 |
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
| 524 |
+
new observed value.
|
| 525 |
+
|
| 526 |
+
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
| 527 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
|
| 528 |
+
or Spatio-temporal Batch Normalization.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
num_features: :math:`C` from an expected input of size
|
| 532 |
+
:math:`(N, C, D, H, W)`
|
| 533 |
+
eps: a value added to the denominator for numerical stability.
|
| 534 |
+
Default: 1e-5
|
| 535 |
+
momentum: the value used for the running_mean and running_var
|
| 536 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 537 |
+
(i.e. simple average). Default: 0.1
|
| 538 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 539 |
+
learnable affine parameters. Default: ``True``
|
| 540 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 541 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 542 |
+
this module does not track such statistics, and initializes statistics
|
| 543 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 544 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 545 |
+
in both training and eval modes. Default: ``True``
|
| 546 |
+
|
| 547 |
+
Shape:
|
| 548 |
+
- Input: :math:`(N, C, D, H, W)`
|
| 549 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
| 550 |
+
|
| 551 |
+
Examples::
|
| 552 |
+
|
| 553 |
+
>>> # With Learnable Parameters
|
| 554 |
+
>>> m = nn.BatchNorm3d(100)
|
| 555 |
+
>>> # Without Learnable Parameters
|
| 556 |
+
>>> m = nn.BatchNorm3d(100, affine=False)
|
| 557 |
+
>>> input = torch.randn(20, 100, 35, 45, 10)
|
| 558 |
+
>>> output = m(input)
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
def _check_input_dim(self, input):
|
| 562 |
+
if input.dim() != 5:
|
| 563 |
+
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
|
| 567 |
+
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
|
| 568 |
+
|
| 569 |
+
Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
|
| 570 |
+
from the ``input.size(1)``.
|
| 571 |
+
The attributes that will be lazily initialized are `weight`, `bias`,
|
| 572 |
+
`running_mean` and `running_var`.
|
| 573 |
+
|
| 574 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 575 |
+
on lazy modules and their limitations.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
eps: a value added to the denominator for numerical stability.
|
| 579 |
+
Default: 1e-5
|
| 580 |
+
momentum: the value used for the running_mean and running_var
|
| 581 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 582 |
+
(i.e. simple average). Default: 0.1
|
| 583 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 584 |
+
learnable affine parameters. Default: ``True``
|
| 585 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 586 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 587 |
+
this module does not track such statistics, and initializes statistics
|
| 588 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 589 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 590 |
+
in both training and eval modes. Default: ``True``
|
| 591 |
+
"""
|
| 592 |
+
|
| 593 |
+
cls_to_become = BatchNorm3d # type: ignore[assignment]
|
| 594 |
+
|
| 595 |
+
def _check_input_dim(self, input):
|
| 596 |
+
if input.dim() != 5:
|
| 597 |
+
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class SyncBatchNorm(_BatchNorm):
|
| 601 |
+
r"""Applies Batch Normalization over a N-Dimensional input.
|
| 602 |
+
|
| 603 |
+
The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
|
| 604 |
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
| 605 |
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
| 606 |
+
|
| 607 |
+
.. math::
|
| 608 |
+
|
| 609 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 610 |
+
|
| 611 |
+
The mean and standard-deviation are calculated per-dimension over all
|
| 612 |
+
mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
|
| 613 |
+
are learnable parameter vectors of size `C` (where `C` is the input size).
|
| 614 |
+
By default, the elements of :math:`\gamma` are sampled from
|
| 615 |
+
:math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
| 616 |
+
The standard-deviation is calculated via the biased estimator, equivalent to
|
| 617 |
+
`torch.var(input, unbiased=False)`.
|
| 618 |
+
|
| 619 |
+
Also by default, during training this layer keeps running estimates of its
|
| 620 |
+
computed mean and variance, which are then used for normalization during
|
| 621 |
+
evaluation. The running estimates are kept with a default :attr:`momentum`
|
| 622 |
+
of 0.1.
|
| 623 |
+
|
| 624 |
+
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
| 625 |
+
keep running estimates, and batch statistics are instead used during
|
| 626 |
+
evaluation time as well.
|
| 627 |
+
|
| 628 |
+
.. note::
|
| 629 |
+
This :attr:`momentum` argument is different from one used in optimizer
|
| 630 |
+
classes and the conventional notion of momentum. Mathematically, the
|
| 631 |
+
update rule for running statistics here is
|
| 632 |
+
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
| 633 |
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
| 634 |
+
new observed value.
|
| 635 |
+
|
| 636 |
+
Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
|
| 637 |
+
statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
|
| 638 |
+
Normalization or Spatio-temporal Batch Normalization.
|
| 639 |
+
|
| 640 |
+
Currently :class:`SyncBatchNorm` only supports
|
| 641 |
+
:class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
|
| 642 |
+
:meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
|
| 643 |
+
:attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
|
| 644 |
+
Network with DDP.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
num_features: :math:`C` from an expected input of size
|
| 648 |
+
:math:`(N, C, +)`
|
| 649 |
+
eps: a value added to the denominator for numerical stability.
|
| 650 |
+
Default: ``1e-5``
|
| 651 |
+
momentum: the value used for the running_mean and running_var
|
| 652 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 653 |
+
(i.e. simple average). Default: 0.1
|
| 654 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 655 |
+
learnable affine parameters. Default: ``True``
|
| 656 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 657 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 658 |
+
this module does not track such statistics, and initializes statistics
|
| 659 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 660 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 661 |
+
in both training and eval modes. Default: ``True``
|
| 662 |
+
process_group: synchronization of stats happen within each process group
|
| 663 |
+
individually. Default behavior is synchronization across the whole
|
| 664 |
+
world
|
| 665 |
+
|
| 666 |
+
Shape:
|
| 667 |
+
- Input: :math:`(N, C, +)`
|
| 668 |
+
- Output: :math:`(N, C, +)` (same shape as input)
|
| 669 |
+
|
| 670 |
+
.. note::
|
| 671 |
+
Synchronization of batchnorm statistics occurs only while training, i.e.
|
| 672 |
+
synchronization is disabled when ``model.eval()`` is set or if
|
| 673 |
+
``self.training`` is otherwise ``False``.
|
| 674 |
+
|
| 675 |
+
Examples::
|
| 676 |
+
|
| 677 |
+
>>> # xdoctest: +SKIP
|
| 678 |
+
>>> # With Learnable Parameters
|
| 679 |
+
>>> m = nn.SyncBatchNorm(100)
|
| 680 |
+
>>> # creating process group (optional)
|
| 681 |
+
>>> # ranks is a list of int identifying rank ids.
|
| 682 |
+
>>> ranks = list(range(8))
|
| 683 |
+
>>> r1, r2 = ranks[:4], ranks[4:]
|
| 684 |
+
>>> # Note: every rank calls into new_group for every
|
| 685 |
+
>>> # process group created, even if that rank is not
|
| 686 |
+
>>> # part of the group.
|
| 687 |
+
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
|
| 688 |
+
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
|
| 689 |
+
>>> # Without Learnable Parameters
|
| 690 |
+
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
|
| 691 |
+
>>> input = torch.randn(20, 100, 35, 45, 10)
|
| 692 |
+
>>> output = m(input)
|
| 693 |
+
|
| 694 |
+
>>> # network is nn.BatchNorm layer
|
| 695 |
+
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
|
| 696 |
+
>>> # only single gpu per process is currently supported
|
| 697 |
+
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
|
| 698 |
+
>>> sync_bn_network,
|
| 699 |
+
>>> device_ids=[args.local_rank],
|
| 700 |
+
>>> output_device=args.local_rank)
|
| 701 |
+
"""
|
| 702 |
+
|
| 703 |
+
def __init__(
|
| 704 |
+
self,
|
| 705 |
+
num_features: int,
|
| 706 |
+
eps: float = 1e-5,
|
| 707 |
+
momentum: Optional[float] = 0.1,
|
| 708 |
+
affine: bool = True,
|
| 709 |
+
track_running_stats: bool = True,
|
| 710 |
+
process_group: Optional[Any] = None,
|
| 711 |
+
device=None,
|
| 712 |
+
dtype=None,
|
| 713 |
+
) -> None:
|
| 714 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 715 |
+
super().__init__(
|
| 716 |
+
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
| 717 |
+
)
|
| 718 |
+
self.process_group = process_group
|
| 719 |
+
|
| 720 |
+
def _check_input_dim(self, input):
|
| 721 |
+
if input.dim() < 2:
|
| 722 |
+
raise ValueError(f"expected at least 2D input (got {input.dim()}D input)")
|
| 723 |
+
|
| 724 |
+
def _check_non_zero_input_channels(self, input):
|
| 725 |
+
if input.size(1) == 0:
|
| 726 |
+
raise ValueError(
|
| 727 |
+
"SyncBatchNorm number of input channels should be non-zero"
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 731 |
+
self._check_input_dim(input)
|
| 732 |
+
self._check_non_zero_input_channels(input)
|
| 733 |
+
|
| 734 |
+
# exponential_average_factor is set to self.momentum
|
| 735 |
+
# (when it is available) only so that it gets updated
|
| 736 |
+
# in ONNX graph when this node is exported to ONNX.
|
| 737 |
+
if self.momentum is None:
|
| 738 |
+
exponential_average_factor = 0.0
|
| 739 |
+
else:
|
| 740 |
+
exponential_average_factor = self.momentum
|
| 741 |
+
|
| 742 |
+
if self.training and self.track_running_stats:
|
| 743 |
+
assert self.num_batches_tracked is not None
|
| 744 |
+
self.num_batches_tracked.add_(1)
|
| 745 |
+
if self.momentum is None: # use cumulative moving average
|
| 746 |
+
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
|
| 747 |
+
else: # use exponential moving average
|
| 748 |
+
exponential_average_factor = self.momentum
|
| 749 |
+
|
| 750 |
+
r"""
|
| 751 |
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
| 752 |
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
| 753 |
+
"""
|
| 754 |
+
if self.training:
|
| 755 |
+
bn_training = True
|
| 756 |
+
else:
|
| 757 |
+
bn_training = (self.running_mean is None) and (self.running_var is None)
|
| 758 |
+
|
| 759 |
+
r"""
|
| 760 |
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
| 761 |
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
| 762 |
+
used for normalization (i.e. in eval mode when buffers are not None).
|
| 763 |
+
"""
|
| 764 |
+
# If buffers are not to be tracked, ensure that they won't be updated
|
| 765 |
+
running_mean = (
|
| 766 |
+
self.running_mean if not self.training or self.track_running_stats else None
|
| 767 |
+
)
|
| 768 |
+
running_var = (
|
| 769 |
+
self.running_var if not self.training or self.track_running_stats else None
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
# Don't sync batchnorm stats in inference mode (model.eval()).
|
| 773 |
+
need_sync = (
|
| 774 |
+
bn_training
|
| 775 |
+
and self.training
|
| 776 |
+
and torch.distributed.is_available()
|
| 777 |
+
and torch.distributed.is_initialized()
|
| 778 |
+
)
|
| 779 |
+
if need_sync:
|
| 780 |
+
# currently only GPU/PrivateUse1 input is supported
|
| 781 |
+
if input.device.type not in [
|
| 782 |
+
"cuda",
|
| 783 |
+
torch._C._get_privateuse1_backend_name(),
|
| 784 |
+
]:
|
| 785 |
+
raise ValueError(
|
| 786 |
+
"SyncBatchNorm expected input tensor to be on GPU or "
|
| 787 |
+
f"{torch._C._get_privateuse1_backend_name()}"
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
process_group = torch.distributed.group.WORLD
|
| 791 |
+
if self.process_group:
|
| 792 |
+
process_group = self.process_group
|
| 793 |
+
world_size = torch.distributed.get_world_size(process_group)
|
| 794 |
+
need_sync = world_size > 1
|
| 795 |
+
|
| 796 |
+
# fallback to framework BN when synchronization is not necessary
|
| 797 |
+
if not need_sync:
|
| 798 |
+
return F.batch_norm(
|
| 799 |
+
input,
|
| 800 |
+
running_mean,
|
| 801 |
+
running_var,
|
| 802 |
+
self.weight,
|
| 803 |
+
self.bias,
|
| 804 |
+
bn_training,
|
| 805 |
+
exponential_average_factor,
|
| 806 |
+
self.eps,
|
| 807 |
+
)
|
| 808 |
+
else:
|
| 809 |
+
assert bn_training
|
| 810 |
+
return sync_batch_norm.apply(
|
| 811 |
+
input,
|
| 812 |
+
self.weight,
|
| 813 |
+
self.bias,
|
| 814 |
+
running_mean,
|
| 815 |
+
running_var,
|
| 816 |
+
self.eps,
|
| 817 |
+
exponential_average_factor,
|
| 818 |
+
process_group, # type: ignore[possibly-undefined]
|
| 819 |
+
world_size, # type: ignore[possibly-undefined]
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
@classmethod
|
| 823 |
+
def convert_sync_batchnorm(cls, module, process_group=None):
|
| 824 |
+
r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
|
| 828 |
+
process_group (optional): process group to scope synchronization,
|
| 829 |
+
default is the whole world
|
| 830 |
+
|
| 831 |
+
Returns:
|
| 832 |
+
The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
|
| 833 |
+
layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
|
| 834 |
+
a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
|
| 835 |
+
instead.
|
| 836 |
+
|
| 837 |
+
Example::
|
| 838 |
+
|
| 839 |
+
>>> # Network with nn.BatchNorm layer
|
| 840 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
| 841 |
+
>>> module = torch.nn.Sequential(
|
| 842 |
+
>>> torch.nn.Linear(20, 100),
|
| 843 |
+
>>> torch.nn.BatchNorm1d(100),
|
| 844 |
+
>>> ).cuda()
|
| 845 |
+
>>> # creating process group (optional)
|
| 846 |
+
>>> # ranks is a list of int identifying rank ids.
|
| 847 |
+
>>> ranks = list(range(8))
|
| 848 |
+
>>> r1, r2 = ranks[:4], ranks[4:]
|
| 849 |
+
>>> # Note: every rank calls into new_group for every
|
| 850 |
+
>>> # process group created, even if that rank is not
|
| 851 |
+
>>> # part of the group.
|
| 852 |
+
>>> # xdoctest: +SKIP("distributed")
|
| 853 |
+
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
|
| 854 |
+
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
|
| 855 |
+
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
|
| 856 |
+
|
| 857 |
+
"""
|
| 858 |
+
module_output = module
|
| 859 |
+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
| 860 |
+
module_output = torch.nn.SyncBatchNorm(
|
| 861 |
+
module.num_features,
|
| 862 |
+
module.eps,
|
| 863 |
+
module.momentum,
|
| 864 |
+
module.affine,
|
| 865 |
+
module.track_running_stats,
|
| 866 |
+
process_group,
|
| 867 |
+
)
|
| 868 |
+
if module.affine:
|
| 869 |
+
with torch.no_grad():
|
| 870 |
+
module_output.weight = module.weight
|
| 871 |
+
module_output.bias = module.bias
|
| 872 |
+
module_output.running_mean = module.running_mean
|
| 873 |
+
module_output.running_var = module.running_var
|
| 874 |
+
module_output.num_batches_tracked = module.num_batches_tracked
|
| 875 |
+
module_output.training = module.training
|
| 876 |
+
if hasattr(module, "qconfig"):
|
| 877 |
+
module_output.qconfig = module.qconfig
|
| 878 |
+
for name, child in module.named_children():
|
| 879 |
+
module_output.add_module(
|
| 880 |
+
name, cls.convert_sync_batchnorm(child, process_group)
|
| 881 |
+
)
|
| 882 |
+
del module
|
| 883 |
+
return module_output
|
.venv/Lib/site-packages/torch/nn/modules/channelshuffle.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
|
| 4 |
+
from .module import Module
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = ["ChannelShuffle"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ChannelShuffle(Module):
|
| 11 |
+
r"""Divides and rearranges the channels in a tensor.
|
| 12 |
+
|
| 13 |
+
This operation divides the channels in a tensor of shape :math:`(N, C, *)`
|
| 14 |
+
into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them,
|
| 15 |
+
while retaining the original tensor shape in the final output.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
groups (int): number of groups to divide channels in.
|
| 19 |
+
|
| 20 |
+
Examples::
|
| 21 |
+
|
| 22 |
+
>>> channel_shuffle = nn.ChannelShuffle(2)
|
| 23 |
+
>>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
|
| 24 |
+
>>> input
|
| 25 |
+
tensor([[[[ 1., 2.],
|
| 26 |
+
[ 3., 4.]],
|
| 27 |
+
[[ 5., 6.],
|
| 28 |
+
[ 7., 8.]],
|
| 29 |
+
[[ 9., 10.],
|
| 30 |
+
[11., 12.]],
|
| 31 |
+
[[13., 14.],
|
| 32 |
+
[15., 16.]]]])
|
| 33 |
+
>>> output = channel_shuffle(input)
|
| 34 |
+
>>> output
|
| 35 |
+
tensor([[[[ 1., 2.],
|
| 36 |
+
[ 3., 4.]],
|
| 37 |
+
[[ 9., 10.],
|
| 38 |
+
[11., 12.]],
|
| 39 |
+
[[ 5., 6.],
|
| 40 |
+
[ 7., 8.]],
|
| 41 |
+
[[13., 14.],
|
| 42 |
+
[15., 16.]]]])
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
__constants__ = ["groups"]
|
| 46 |
+
groups: int
|
| 47 |
+
|
| 48 |
+
def __init__(self, groups: int) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.groups = groups
|
| 51 |
+
|
| 52 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 53 |
+
return F.channel_shuffle(input, self.groups)
|
| 54 |
+
|
| 55 |
+
def extra_repr(self) -> str:
|
| 56 |
+
return f"groups={self.groups}"
|
.venv/Lib/site-packages/torch/nn/modules/container.py
ADDED
|
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import operator
|
| 4 |
+
from collections import abc as container_abcs, OrderedDict
|
| 5 |
+
from itertools import chain, islice
|
| 6 |
+
from typing import (
|
| 7 |
+
Any,
|
| 8 |
+
Dict,
|
| 9 |
+
Iterable,
|
| 10 |
+
Iterator,
|
| 11 |
+
Mapping,
|
| 12 |
+
Optional,
|
| 13 |
+
overload,
|
| 14 |
+
Tuple,
|
| 15 |
+
TypeVar,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
from typing_extensions import deprecated, Self
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch._jit_internal import _copy_to_script_wrapper
|
| 22 |
+
from torch.nn.parameter import Parameter
|
| 23 |
+
|
| 24 |
+
from .module import Module
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"Container",
|
| 29 |
+
"Sequential",
|
| 30 |
+
"ModuleList",
|
| 31 |
+
"ModuleDict",
|
| 32 |
+
"ParameterList",
|
| 33 |
+
"ParameterDict",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
T = TypeVar("T", bound=Module)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
|
| 40 |
+
def _addindent(s_, numSpaces):
|
| 41 |
+
s = s_.split("\n")
|
| 42 |
+
# don't do anything for single-line stuff
|
| 43 |
+
if len(s) == 1:
|
| 44 |
+
return s_
|
| 45 |
+
first = s.pop(0)
|
| 46 |
+
s = [(numSpaces * " ") + line for line in s]
|
| 47 |
+
s = "\n".join(s)
|
| 48 |
+
s = first + "\n" + s
|
| 49 |
+
return s
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@deprecated(
|
| 53 |
+
"`nn.Container` is deprecated. "
|
| 54 |
+
"All of it's functionality is now implemented in `nn.Module`. Subclass that instead.",
|
| 55 |
+
category=FutureWarning,
|
| 56 |
+
)
|
| 57 |
+
class Container(Module):
|
| 58 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
for key, value in kwargs.items():
|
| 61 |
+
self.add_module(key, value)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Sequential(Module):
|
| 65 |
+
r"""A sequential container.
|
| 66 |
+
|
| 67 |
+
Modules will be added to it in the order they are passed in the
|
| 68 |
+
constructor. Alternatively, an ``OrderedDict`` of modules can be
|
| 69 |
+
passed in. The ``forward()`` method of ``Sequential`` accepts any
|
| 70 |
+
input and forwards it to the first module it contains. It then
|
| 71 |
+
"chains" outputs to inputs sequentially for each subsequent module,
|
| 72 |
+
finally returning the output of the last module.
|
| 73 |
+
|
| 74 |
+
The value a ``Sequential`` provides over manually calling a sequence
|
| 75 |
+
of modules is that it allows treating the whole container as a
|
| 76 |
+
single module, such that performing a transformation on the
|
| 77 |
+
``Sequential`` applies to each of the modules it stores (which are
|
| 78 |
+
each a registered submodule of the ``Sequential``).
|
| 79 |
+
|
| 80 |
+
What's the difference between a ``Sequential`` and a
|
| 81 |
+
:class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
|
| 82 |
+
sounds like--a list for storing ``Module`` s! On the other hand,
|
| 83 |
+
the layers in a ``Sequential`` are connected in a cascading way.
|
| 84 |
+
|
| 85 |
+
Example::
|
| 86 |
+
|
| 87 |
+
# Using Sequential to create a small model. When `model` is run,
|
| 88 |
+
# input will first be passed to `Conv2d(1,20,5)`. The output of
|
| 89 |
+
# `Conv2d(1,20,5)` will be used as the input to the first
|
| 90 |
+
# `ReLU`; the output of the first `ReLU` will become the input
|
| 91 |
+
# for `Conv2d(20,64,5)`. Finally, the output of
|
| 92 |
+
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
|
| 93 |
+
model = nn.Sequential(
|
| 94 |
+
nn.Conv2d(1,20,5),
|
| 95 |
+
nn.ReLU(),
|
| 96 |
+
nn.Conv2d(20,64,5),
|
| 97 |
+
nn.ReLU()
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Using Sequential with OrderedDict. This is functionally the
|
| 101 |
+
# same as the above code
|
| 102 |
+
model = nn.Sequential(OrderedDict([
|
| 103 |
+
('conv1', nn.Conv2d(1,20,5)),
|
| 104 |
+
('relu1', nn.ReLU()),
|
| 105 |
+
('conv2', nn.Conv2d(20,64,5)),
|
| 106 |
+
('relu2', nn.ReLU())
|
| 107 |
+
]))
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
_modules: Dict[str, Module] # type: ignore[assignment]
|
| 111 |
+
|
| 112 |
+
@overload
|
| 113 |
+
def __init__(self, *args: Module) -> None:
|
| 114 |
+
...
|
| 115 |
+
|
| 116 |
+
@overload
|
| 117 |
+
def __init__(self, arg: "OrderedDict[str, Module]") -> None:
|
| 118 |
+
...
|
| 119 |
+
|
| 120 |
+
def __init__(self, *args):
|
| 121 |
+
super().__init__()
|
| 122 |
+
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
| 123 |
+
for key, module in args[0].items():
|
| 124 |
+
self.add_module(key, module)
|
| 125 |
+
else:
|
| 126 |
+
for idx, module in enumerate(args):
|
| 127 |
+
self.add_module(str(idx), module)
|
| 128 |
+
|
| 129 |
+
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
|
| 130 |
+
"""Get the idx-th item of the iterator."""
|
| 131 |
+
size = len(self)
|
| 132 |
+
idx = operator.index(idx)
|
| 133 |
+
if not -size <= idx < size:
|
| 134 |
+
raise IndexError(f"index {idx} is out of range")
|
| 135 |
+
idx %= size
|
| 136 |
+
return next(islice(iterator, idx, None))
|
| 137 |
+
|
| 138 |
+
@_copy_to_script_wrapper
|
| 139 |
+
def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]:
|
| 140 |
+
if isinstance(idx, slice):
|
| 141 |
+
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
|
| 142 |
+
else:
|
| 143 |
+
return self._get_item_by_idx(self._modules.values(), idx)
|
| 144 |
+
|
| 145 |
+
def __setitem__(self, idx: int, module: Module) -> None:
|
| 146 |
+
key: str = self._get_item_by_idx(self._modules.keys(), idx)
|
| 147 |
+
return setattr(self, key, module)
|
| 148 |
+
|
| 149 |
+
def __delitem__(self, idx: Union[slice, int]) -> None:
|
| 150 |
+
if isinstance(idx, slice):
|
| 151 |
+
for key in list(self._modules.keys())[idx]:
|
| 152 |
+
delattr(self, key)
|
| 153 |
+
else:
|
| 154 |
+
key = self._get_item_by_idx(self._modules.keys(), idx)
|
| 155 |
+
delattr(self, key)
|
| 156 |
+
# To preserve numbering
|
| 157 |
+
str_indices = [str(i) for i in range(len(self._modules))]
|
| 158 |
+
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
| 159 |
+
|
| 160 |
+
@_copy_to_script_wrapper
|
| 161 |
+
def __len__(self) -> int:
|
| 162 |
+
return len(self._modules)
|
| 163 |
+
|
| 164 |
+
def __add__(self, other) -> "Sequential":
|
| 165 |
+
if isinstance(other, Sequential):
|
| 166 |
+
ret = Sequential()
|
| 167 |
+
for layer in self:
|
| 168 |
+
ret.append(layer)
|
| 169 |
+
for layer in other:
|
| 170 |
+
ret.append(layer)
|
| 171 |
+
return ret
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
"add operator supports only objects "
|
| 175 |
+
f"of Sequential class, but {str(type(other))} is given."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def pop(self, key: Union[int, slice]) -> Module:
|
| 179 |
+
v = self[key]
|
| 180 |
+
del self[key]
|
| 181 |
+
return v
|
| 182 |
+
|
| 183 |
+
def __iadd__(self, other) -> Self:
|
| 184 |
+
if isinstance(other, Sequential):
|
| 185 |
+
offset = len(self)
|
| 186 |
+
for i, module in enumerate(other):
|
| 187 |
+
self.add_module(str(i + offset), module)
|
| 188 |
+
return self
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
"add operator supports only objects "
|
| 192 |
+
f"of Sequential class, but {str(type(other))} is given."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def __mul__(self, other: int) -> "Sequential":
|
| 196 |
+
if not isinstance(other, int):
|
| 197 |
+
raise TypeError(
|
| 198 |
+
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
|
| 199 |
+
)
|
| 200 |
+
elif other <= 0:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"Non-positive multiplication factor {other} for {type(self)}"
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
combined = Sequential()
|
| 206 |
+
offset = 0
|
| 207 |
+
for _ in range(other):
|
| 208 |
+
for module in self:
|
| 209 |
+
combined.add_module(str(offset), module)
|
| 210 |
+
offset += 1
|
| 211 |
+
return combined
|
| 212 |
+
|
| 213 |
+
def __rmul__(self, other: int) -> "Sequential":
|
| 214 |
+
return self.__mul__(other)
|
| 215 |
+
|
| 216 |
+
def __imul__(self, other: int) -> Self:
|
| 217 |
+
if not isinstance(other, int):
|
| 218 |
+
raise TypeError(
|
| 219 |
+
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
|
| 220 |
+
)
|
| 221 |
+
elif other <= 0:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"Non-positive multiplication factor {other} for {type(self)}"
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
len_original = len(self)
|
| 227 |
+
offset = len(self)
|
| 228 |
+
for _ in range(other - 1):
|
| 229 |
+
for i in range(len_original):
|
| 230 |
+
self.add_module(str(i + offset), self._modules[str(i)])
|
| 231 |
+
offset += len_original
|
| 232 |
+
return self
|
| 233 |
+
|
| 234 |
+
@_copy_to_script_wrapper
|
| 235 |
+
def __dir__(self):
|
| 236 |
+
keys = super().__dir__()
|
| 237 |
+
keys = [key for key in keys if not key.isdigit()]
|
| 238 |
+
return keys
|
| 239 |
+
|
| 240 |
+
@_copy_to_script_wrapper
|
| 241 |
+
def __iter__(self) -> Iterator[Module]:
|
| 242 |
+
return iter(self._modules.values())
|
| 243 |
+
|
| 244 |
+
# NB: We can't really type check this function as the type of input
|
| 245 |
+
# may change dynamically (as is tested in
|
| 246 |
+
# TestScript.test_sequential_intermediary_types). Cannot annotate
|
| 247 |
+
# with Any as TorchScript expects a more precise type
|
| 248 |
+
def forward(self, input):
|
| 249 |
+
for module in self:
|
| 250 |
+
input = module(input)
|
| 251 |
+
return input
|
| 252 |
+
|
| 253 |
+
def append(self, module: Module) -> "Sequential":
|
| 254 |
+
r"""Append a given module to the end.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
module (nn.Module): module to append
|
| 258 |
+
"""
|
| 259 |
+
self.add_module(str(len(self)), module)
|
| 260 |
+
return self
|
| 261 |
+
|
| 262 |
+
def insert(self, index: int, module: Module) -> "Sequential":
|
| 263 |
+
if not isinstance(module, Module):
|
| 264 |
+
raise AssertionError(f"module should be of type: {Module}")
|
| 265 |
+
n = len(self._modules)
|
| 266 |
+
if not (-n <= index <= n):
|
| 267 |
+
raise IndexError(f"Index out of range: {index}")
|
| 268 |
+
if index < 0:
|
| 269 |
+
index += n
|
| 270 |
+
for i in range(n, index, -1):
|
| 271 |
+
self._modules[str(i)] = self._modules[str(i - 1)]
|
| 272 |
+
self._modules[str(index)] = module
|
| 273 |
+
return self
|
| 274 |
+
|
| 275 |
+
def extend(self, sequential) -> "Sequential":
|
| 276 |
+
for layer in sequential:
|
| 277 |
+
self.append(layer)
|
| 278 |
+
return self
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class ModuleList(Module):
|
| 282 |
+
r"""Holds submodules in a list.
|
| 283 |
+
|
| 284 |
+
:class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
|
| 285 |
+
modules it contains are properly registered, and will be visible by all
|
| 286 |
+
:class:`~torch.nn.Module` methods.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
modules (iterable, optional): an iterable of modules to add
|
| 290 |
+
|
| 291 |
+
Example::
|
| 292 |
+
|
| 293 |
+
class MyModule(nn.Module):
|
| 294 |
+
def __init__(self) -> None:
|
| 295 |
+
super().__init__()
|
| 296 |
+
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
|
| 297 |
+
|
| 298 |
+
def forward(self, x):
|
| 299 |
+
# ModuleList can act as an iterable, or be indexed using ints
|
| 300 |
+
for i, l in enumerate(self.linears):
|
| 301 |
+
x = self.linears[i // 2](x) + l(x)
|
| 302 |
+
return x
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
_modules: Dict[str, Module] # type: ignore[assignment]
|
| 306 |
+
|
| 307 |
+
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
|
| 308 |
+
super().__init__()
|
| 309 |
+
if modules is not None:
|
| 310 |
+
self += modules
|
| 311 |
+
|
| 312 |
+
def _get_abs_string_index(self, idx):
|
| 313 |
+
"""Get the absolute index for the list of modules."""
|
| 314 |
+
idx = operator.index(idx)
|
| 315 |
+
if not (-len(self) <= idx < len(self)):
|
| 316 |
+
raise IndexError(f"index {idx} is out of range")
|
| 317 |
+
if idx < 0:
|
| 318 |
+
idx += len(self)
|
| 319 |
+
return str(idx)
|
| 320 |
+
|
| 321 |
+
@overload
|
| 322 |
+
def __getitem__(self, idx: slice) -> "ModuleList":
|
| 323 |
+
...
|
| 324 |
+
|
| 325 |
+
@overload
|
| 326 |
+
def __getitem__(self, idx: int) -> Module:
|
| 327 |
+
...
|
| 328 |
+
|
| 329 |
+
@_copy_to_script_wrapper
|
| 330 |
+
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
|
| 331 |
+
if isinstance(idx, slice):
|
| 332 |
+
return self.__class__(list(self._modules.values())[idx])
|
| 333 |
+
else:
|
| 334 |
+
return self._modules[self._get_abs_string_index(idx)]
|
| 335 |
+
|
| 336 |
+
def __setitem__(self, idx: int, module: Module) -> None:
|
| 337 |
+
idx = self._get_abs_string_index(idx)
|
| 338 |
+
return setattr(self, str(idx), module)
|
| 339 |
+
|
| 340 |
+
def __delitem__(self, idx: Union[int, slice]) -> None:
|
| 341 |
+
if isinstance(idx, slice):
|
| 342 |
+
for k in range(len(self._modules))[idx]:
|
| 343 |
+
delattr(self, str(k))
|
| 344 |
+
else:
|
| 345 |
+
delattr(self, self._get_abs_string_index(idx))
|
| 346 |
+
# To preserve numbering, self._modules is being reconstructed with modules after deletion
|
| 347 |
+
str_indices = [str(i) for i in range(len(self._modules))]
|
| 348 |
+
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
| 349 |
+
|
| 350 |
+
@_copy_to_script_wrapper
|
| 351 |
+
def __len__(self) -> int:
|
| 352 |
+
return len(self._modules)
|
| 353 |
+
|
| 354 |
+
@_copy_to_script_wrapper
|
| 355 |
+
def __iter__(self) -> Iterator[Module]:
|
| 356 |
+
return iter(self._modules.values())
|
| 357 |
+
|
| 358 |
+
def __iadd__(self, modules: Iterable[Module]) -> Self:
|
| 359 |
+
return self.extend(modules)
|
| 360 |
+
|
| 361 |
+
def __add__(self, other: Iterable[Module]) -> "ModuleList":
|
| 362 |
+
combined = ModuleList()
|
| 363 |
+
for i, module in enumerate(chain(self, other)):
|
| 364 |
+
combined.add_module(str(i), module)
|
| 365 |
+
return combined
|
| 366 |
+
|
| 367 |
+
def __repr__(self):
|
| 368 |
+
"""Return a custom repr for ModuleList that compresses repeated module representations."""
|
| 369 |
+
list_of_reprs = [repr(item) for item in self]
|
| 370 |
+
if len(list_of_reprs) == 0:
|
| 371 |
+
return self._get_name() + "()"
|
| 372 |
+
|
| 373 |
+
start_end_indices = [[0, 0]]
|
| 374 |
+
repeated_blocks = [list_of_reprs[0]]
|
| 375 |
+
for i, r in enumerate(list_of_reprs[1:], 1):
|
| 376 |
+
if r == repeated_blocks[-1]:
|
| 377 |
+
start_end_indices[-1][1] += 1
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
start_end_indices.append([i, i])
|
| 381 |
+
repeated_blocks.append(r)
|
| 382 |
+
|
| 383 |
+
lines = []
|
| 384 |
+
main_str = self._get_name() + "("
|
| 385 |
+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
|
| 386 |
+
local_repr = f"({start_id}): {b}" # default repr
|
| 387 |
+
|
| 388 |
+
if start_id != end_id:
|
| 389 |
+
n = end_id - start_id + 1
|
| 390 |
+
local_repr = f"({start_id}-{end_id}): {n} x {b}"
|
| 391 |
+
|
| 392 |
+
local_repr = _addindent(local_repr, 2)
|
| 393 |
+
lines.append(local_repr)
|
| 394 |
+
|
| 395 |
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
| 396 |
+
main_str += ")"
|
| 397 |
+
return main_str
|
| 398 |
+
|
| 399 |
+
@_copy_to_script_wrapper
|
| 400 |
+
def __dir__(self):
|
| 401 |
+
keys = super().__dir__()
|
| 402 |
+
keys = [key for key in keys if not key.isdigit()]
|
| 403 |
+
return keys
|
| 404 |
+
|
| 405 |
+
def insert(self, index: int, module: Module) -> None:
|
| 406 |
+
r"""Insert a given module before a given index in the list.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
index (int): index to insert.
|
| 410 |
+
module (nn.Module): module to insert
|
| 411 |
+
"""
|
| 412 |
+
for i in range(len(self._modules), index, -1):
|
| 413 |
+
self._modules[str(i)] = self._modules[str(i - 1)]
|
| 414 |
+
self._modules[str(index)] = module
|
| 415 |
+
|
| 416 |
+
def append(self, module: Module) -> "ModuleList":
|
| 417 |
+
r"""Append a given module to the end of the list.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
module (nn.Module): module to append
|
| 421 |
+
"""
|
| 422 |
+
self.add_module(str(len(self)), module)
|
| 423 |
+
return self
|
| 424 |
+
|
| 425 |
+
def pop(self, key: Union[int, slice]) -> Module:
|
| 426 |
+
v = self[key]
|
| 427 |
+
del self[key]
|
| 428 |
+
return v
|
| 429 |
+
|
| 430 |
+
def extend(self, modules: Iterable[Module]) -> Self:
|
| 431 |
+
r"""Append modules from a Python iterable to the end of the list.
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
modules (iterable): iterable of modules to append
|
| 435 |
+
"""
|
| 436 |
+
if not isinstance(modules, container_abcs.Iterable):
|
| 437 |
+
raise TypeError(
|
| 438 |
+
"ModuleList.extend should be called with an "
|
| 439 |
+
"iterable, but got " + type(modules).__name__
|
| 440 |
+
)
|
| 441 |
+
offset = len(self)
|
| 442 |
+
for i, module in enumerate(modules):
|
| 443 |
+
self.add_module(str(offset + i), module)
|
| 444 |
+
return self
|
| 445 |
+
|
| 446 |
+
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class ModuleDict(Module):
|
| 450 |
+
r"""Holds submodules in a dictionary.
|
| 451 |
+
|
| 452 |
+
:class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
|
| 453 |
+
but modules it contains are properly registered, and will be visible by all
|
| 454 |
+
:class:`~torch.nn.Module` methods.
|
| 455 |
+
|
| 456 |
+
:class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
|
| 457 |
+
|
| 458 |
+
* the order of insertion, and
|
| 459 |
+
|
| 460 |
+
* in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
|
| 461 |
+
``OrderedDict``, ``dict`` (started from Python 3.6) or another
|
| 462 |
+
:class:`~torch.nn.ModuleDict` (the argument to
|
| 463 |
+
:meth:`~torch.nn.ModuleDict.update`).
|
| 464 |
+
|
| 465 |
+
Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
|
| 466 |
+
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
|
| 467 |
+
preserve the order of the merged mapping.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
modules (iterable, optional): a mapping (dictionary) of (string: module)
|
| 471 |
+
or an iterable of key-value pairs of type (string, module)
|
| 472 |
+
|
| 473 |
+
Example::
|
| 474 |
+
|
| 475 |
+
class MyModule(nn.Module):
|
| 476 |
+
def __init__(self) -> None:
|
| 477 |
+
super().__init__()
|
| 478 |
+
self.choices = nn.ModuleDict({
|
| 479 |
+
'conv': nn.Conv2d(10, 10, 3),
|
| 480 |
+
'pool': nn.MaxPool2d(3)
|
| 481 |
+
})
|
| 482 |
+
self.activations = nn.ModuleDict([
|
| 483 |
+
['lrelu', nn.LeakyReLU()],
|
| 484 |
+
['prelu', nn.PReLU()]
|
| 485 |
+
])
|
| 486 |
+
|
| 487 |
+
def forward(self, x, choice, act):
|
| 488 |
+
x = self.choices[choice](x)
|
| 489 |
+
x = self.activations[act](x)
|
| 490 |
+
return x
|
| 491 |
+
"""
|
| 492 |
+
|
| 493 |
+
_modules: Dict[str, Module] # type: ignore[assignment]
|
| 494 |
+
|
| 495 |
+
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
|
| 496 |
+
super().__init__()
|
| 497 |
+
if modules is not None:
|
| 498 |
+
self.update(modules)
|
| 499 |
+
|
| 500 |
+
@_copy_to_script_wrapper
|
| 501 |
+
def __getitem__(self, key: str) -> Module:
|
| 502 |
+
return self._modules[key]
|
| 503 |
+
|
| 504 |
+
def __setitem__(self, key: str, module: Module) -> None:
|
| 505 |
+
self.add_module(key, module)
|
| 506 |
+
|
| 507 |
+
def __delitem__(self, key: str) -> None:
|
| 508 |
+
del self._modules[key]
|
| 509 |
+
|
| 510 |
+
@_copy_to_script_wrapper
|
| 511 |
+
def __len__(self) -> int:
|
| 512 |
+
return len(self._modules)
|
| 513 |
+
|
| 514 |
+
@_copy_to_script_wrapper
|
| 515 |
+
def __iter__(self) -> Iterator[str]:
|
| 516 |
+
return iter(self._modules)
|
| 517 |
+
|
| 518 |
+
@_copy_to_script_wrapper
|
| 519 |
+
def __contains__(self, key: str) -> bool:
|
| 520 |
+
return key in self._modules
|
| 521 |
+
|
| 522 |
+
def clear(self) -> None:
|
| 523 |
+
"""Remove all items from the ModuleDict."""
|
| 524 |
+
self._modules.clear()
|
| 525 |
+
|
| 526 |
+
def pop(self, key: str) -> Module:
|
| 527 |
+
r"""Remove key from the ModuleDict and return its module.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
key (str): key to pop from the ModuleDict
|
| 531 |
+
"""
|
| 532 |
+
v = self[key]
|
| 533 |
+
del self[key]
|
| 534 |
+
return v
|
| 535 |
+
|
| 536 |
+
@_copy_to_script_wrapper
|
| 537 |
+
def keys(self) -> Iterable[str]:
|
| 538 |
+
r"""Return an iterable of the ModuleDict keys."""
|
| 539 |
+
return self._modules.keys()
|
| 540 |
+
|
| 541 |
+
@_copy_to_script_wrapper
|
| 542 |
+
def items(self) -> Iterable[Tuple[str, Module]]:
|
| 543 |
+
r"""Return an iterable of the ModuleDict key/value pairs."""
|
| 544 |
+
return self._modules.items()
|
| 545 |
+
|
| 546 |
+
@_copy_to_script_wrapper
|
| 547 |
+
def values(self) -> Iterable[Module]:
|
| 548 |
+
r"""Return an iterable of the ModuleDict values."""
|
| 549 |
+
return self._modules.values()
|
| 550 |
+
|
| 551 |
+
def update(self, modules: Mapping[str, Module]) -> None:
|
| 552 |
+
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
|
| 553 |
+
|
| 554 |
+
.. note::
|
| 555 |
+
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
|
| 556 |
+
an iterable of key-value pairs, the order of new elements in it is preserved.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
|
| 560 |
+
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
|
| 561 |
+
"""
|
| 562 |
+
if not isinstance(modules, container_abcs.Iterable):
|
| 563 |
+
raise TypeError(
|
| 564 |
+
"ModuleDict.update should be called with an "
|
| 565 |
+
"iterable of key/value pairs, but got " + type(modules).__name__
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
|
| 569 |
+
for key, module in modules.items():
|
| 570 |
+
self[key] = module
|
| 571 |
+
else:
|
| 572 |
+
# modules here can be a list with two items
|
| 573 |
+
for j, m in enumerate(modules):
|
| 574 |
+
if not isinstance(m, container_abcs.Iterable):
|
| 575 |
+
raise TypeError(
|
| 576 |
+
"ModuleDict update sequence element "
|
| 577 |
+
"#" + str(j) + " should be Iterable; is" + type(m).__name__
|
| 578 |
+
)
|
| 579 |
+
if not len(m) == 2:
|
| 580 |
+
raise ValueError(
|
| 581 |
+
"ModuleDict update sequence element "
|
| 582 |
+
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
|
| 583 |
+
)
|
| 584 |
+
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
| 585 |
+
# that's too cumbersome to type correctly with overloads, so we add an ignore here
|
| 586 |
+
self[m[0]] = m[1] # type: ignore[assignment]
|
| 587 |
+
|
| 588 |
+
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class ParameterList(Module):
|
| 592 |
+
r"""Holds parameters in a list.
|
| 593 |
+
|
| 594 |
+
:class:`~torch.nn.ParameterList` can be used like a regular Python
|
| 595 |
+
list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
|
| 596 |
+
and will be visible by all :class:`~torch.nn.Module` methods.
|
| 597 |
+
|
| 598 |
+
Note that the constructor, assigning an element of the list, the
|
| 599 |
+
:meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend`
|
| 600 |
+
method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
parameters (iterable, optional): an iterable of elements to add to the list.
|
| 604 |
+
|
| 605 |
+
Example::
|
| 606 |
+
|
| 607 |
+
class MyModule(nn.Module):
|
| 608 |
+
def __init__(self) -> None:
|
| 609 |
+
super().__init__()
|
| 610 |
+
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
|
| 611 |
+
|
| 612 |
+
def forward(self, x):
|
| 613 |
+
# ParameterList can act as an iterable, or be indexed using ints
|
| 614 |
+
for i, p in enumerate(self.params):
|
| 615 |
+
x = self.params[i // 2].mm(x) + p.mm(x)
|
| 616 |
+
return x
|
| 617 |
+
"""
|
| 618 |
+
|
| 619 |
+
def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
|
| 620 |
+
super().__init__()
|
| 621 |
+
self._size = 0
|
| 622 |
+
if values is not None:
|
| 623 |
+
self += values
|
| 624 |
+
|
| 625 |
+
def _get_abs_string_index(self, idx):
|
| 626 |
+
"""Get the absolute index for the list of modules."""
|
| 627 |
+
idx = operator.index(idx)
|
| 628 |
+
if not (-len(self) <= idx < len(self)):
|
| 629 |
+
raise IndexError(f"index {idx} is out of range")
|
| 630 |
+
if idx < 0:
|
| 631 |
+
idx += len(self)
|
| 632 |
+
return str(idx)
|
| 633 |
+
|
| 634 |
+
@overload
|
| 635 |
+
def __getitem__(self, idx: int) -> Any:
|
| 636 |
+
...
|
| 637 |
+
|
| 638 |
+
@overload
|
| 639 |
+
def __getitem__(self: T, idx: slice) -> T:
|
| 640 |
+
...
|
| 641 |
+
|
| 642 |
+
def __getitem__(self, idx):
|
| 643 |
+
if isinstance(idx, slice):
|
| 644 |
+
start, stop, step = idx.indices(len(self))
|
| 645 |
+
out = self.__class__()
|
| 646 |
+
for i in range(start, stop, step):
|
| 647 |
+
out.append(self[i])
|
| 648 |
+
return out
|
| 649 |
+
else:
|
| 650 |
+
idx = self._get_abs_string_index(idx)
|
| 651 |
+
return getattr(self, str(idx))
|
| 652 |
+
|
| 653 |
+
def __setitem__(self, idx: int, param: Any) -> None:
|
| 654 |
+
# Note that all other function that add an entry to the list part of
|
| 655 |
+
# the ParameterList end up here. So this is the only place where we need
|
| 656 |
+
# to wrap things into Parameter if needed.
|
| 657 |
+
# Objects added via setattr() are not in the list part and thus won't
|
| 658 |
+
# call into this function.
|
| 659 |
+
idx = self._get_abs_string_index(idx)
|
| 660 |
+
if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
|
| 661 |
+
param = Parameter(param)
|
| 662 |
+
return setattr(self, str(idx), param)
|
| 663 |
+
|
| 664 |
+
def __len__(self) -> int:
|
| 665 |
+
return self._size
|
| 666 |
+
|
| 667 |
+
def __iter__(self) -> Iterator[Any]:
|
| 668 |
+
return iter(self[i] for i in range(len(self)))
|
| 669 |
+
|
| 670 |
+
def __iadd__(self, parameters: Iterable[Any]) -> Self:
|
| 671 |
+
return self.extend(parameters)
|
| 672 |
+
|
| 673 |
+
def __dir__(self):
|
| 674 |
+
keys = super().__dir__()
|
| 675 |
+
keys = [key for key in keys if not key.isdigit()]
|
| 676 |
+
return keys
|
| 677 |
+
|
| 678 |
+
def append(self, value: Any) -> "ParameterList":
|
| 679 |
+
"""Append a given value at the end of the list.
|
| 680 |
+
|
| 681 |
+
Args:
|
| 682 |
+
value (Any): value to append
|
| 683 |
+
"""
|
| 684 |
+
new_idx = len(self)
|
| 685 |
+
self._size += 1
|
| 686 |
+
self[new_idx] = value
|
| 687 |
+
return self
|
| 688 |
+
|
| 689 |
+
def extend(self, values: Iterable[Any]) -> Self:
|
| 690 |
+
"""Append values from a Python iterable to the end of the list.
|
| 691 |
+
|
| 692 |
+
Args:
|
| 693 |
+
values (iterable): iterable of values to append
|
| 694 |
+
"""
|
| 695 |
+
# Tensor is an iterable but we never want to unpack it here
|
| 696 |
+
if not isinstance(values, container_abcs.Iterable) or isinstance(
|
| 697 |
+
values, torch.Tensor
|
| 698 |
+
):
|
| 699 |
+
raise TypeError(
|
| 700 |
+
"ParameterList.extend should be called with an "
|
| 701 |
+
"iterable, but got " + type(values).__name__
|
| 702 |
+
)
|
| 703 |
+
for value in values:
|
| 704 |
+
self.append(value)
|
| 705 |
+
return self
|
| 706 |
+
|
| 707 |
+
def extra_repr(self) -> str:
|
| 708 |
+
child_lines = []
|
| 709 |
+
for k, p in enumerate(self):
|
| 710 |
+
if isinstance(p, torch.Tensor):
|
| 711 |
+
size_str = "x".join(str(size) for size in p.size())
|
| 712 |
+
if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
| 713 |
+
device_str = f" ({p.device})"
|
| 714 |
+
else:
|
| 715 |
+
device_str = ""
|
| 716 |
+
parastr = "{} containing: [{} of size {}{}]".format(
|
| 717 |
+
"Parameter" if isinstance(p, Parameter) else "Tensor",
|
| 718 |
+
p.dtype,
|
| 719 |
+
size_str,
|
| 720 |
+
device_str,
|
| 721 |
+
)
|
| 722 |
+
child_lines.append(" (" + str(k) + "): " + parastr)
|
| 723 |
+
else:
|
| 724 |
+
child_lines.append(
|
| 725 |
+
" (" + str(k) + "): Object of type: " + type(p).__name__
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
tmpstr = "\n".join(child_lines)
|
| 729 |
+
return tmpstr
|
| 730 |
+
|
| 731 |
+
def __call__(self, *args, **kwargs):
|
| 732 |
+
raise RuntimeError("ParameterList should not be called.")
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
class ParameterDict(Module):
|
| 736 |
+
r"""Holds parameters in a dictionary.
|
| 737 |
+
|
| 738 |
+
ParameterDict can be indexed like a regular Python dictionary, but Parameters it
|
| 739 |
+
contains are properly registered, and will be visible by all Module methods.
|
| 740 |
+
Other objects are treated as would be done by a regular Python dictionary
|
| 741 |
+
|
| 742 |
+
:class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
|
| 743 |
+
:meth:`~torch.nn.ParameterDict.update` with other unordered mapping
|
| 744 |
+
types (e.g., Python's plain ``dict``) does not preserve the order of the
|
| 745 |
+
merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
|
| 746 |
+
will preserve their ordering.
|
| 747 |
+
|
| 748 |
+
Note that the constructor, assigning an element of the dictionary and the
|
| 749 |
+
:meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
|
| 750 |
+
:class:`~torch.nn.Parameter`.
|
| 751 |
+
|
| 752 |
+
Args:
|
| 753 |
+
values (iterable, optional): a mapping (dictionary) of
|
| 754 |
+
(string : Any) or an iterable of key-value pairs
|
| 755 |
+
of type (string, Any)
|
| 756 |
+
|
| 757 |
+
Example::
|
| 758 |
+
|
| 759 |
+
class MyModule(nn.Module):
|
| 760 |
+
def __init__(self) -> None:
|
| 761 |
+
super().__init__()
|
| 762 |
+
self.params = nn.ParameterDict({
|
| 763 |
+
'left': nn.Parameter(torch.randn(5, 10)),
|
| 764 |
+
'right': nn.Parameter(torch.randn(5, 10))
|
| 765 |
+
})
|
| 766 |
+
|
| 767 |
+
def forward(self, x, choice):
|
| 768 |
+
x = self.params[choice].mm(x)
|
| 769 |
+
return x
|
| 770 |
+
"""
|
| 771 |
+
|
| 772 |
+
def __init__(self, parameters: Any = None) -> None:
|
| 773 |
+
super().__init__()
|
| 774 |
+
self._keys: Dict[str, None] = {}
|
| 775 |
+
if parameters is not None:
|
| 776 |
+
self.update(parameters)
|
| 777 |
+
|
| 778 |
+
def _key_to_attr(self, key: str) -> str:
|
| 779 |
+
if not isinstance(key, str):
|
| 780 |
+
raise TypeError(
|
| 781 |
+
"Index given to ParameterDict cannot be used as a key as it is "
|
| 782 |
+
f"not a string (type is '{type(key).__name__}'). Open an issue on "
|
| 783 |
+
"github if you need non-string keys."
|
| 784 |
+
)
|
| 785 |
+
else:
|
| 786 |
+
# Use the key as-is so that `.named_parameters()` returns the right thing
|
| 787 |
+
return key
|
| 788 |
+
|
| 789 |
+
def __getitem__(self, key: str) -> Any:
|
| 790 |
+
attr = self._key_to_attr(key)
|
| 791 |
+
return getattr(self, attr)
|
| 792 |
+
|
| 793 |
+
def __setitem__(self, key: str, value: Any) -> None:
|
| 794 |
+
# Note that all other function that add an entry to the dictionary part of
|
| 795 |
+
# the ParameterDict end up here. So this is the only place where we need
|
| 796 |
+
# to wrap things into Parameter if needed.
|
| 797 |
+
# Objects added via setattr() are not in the dictionary part and thus won't
|
| 798 |
+
# call into this function.
|
| 799 |
+
self._keys[key] = None
|
| 800 |
+
attr = self._key_to_attr(key)
|
| 801 |
+
if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
|
| 802 |
+
value = Parameter(value)
|
| 803 |
+
setattr(self, attr, value)
|
| 804 |
+
|
| 805 |
+
def __delitem__(self, key: str) -> None:
|
| 806 |
+
del self._keys[key]
|
| 807 |
+
attr = self._key_to_attr(key)
|
| 808 |
+
delattr(self, attr)
|
| 809 |
+
|
| 810 |
+
def __len__(self) -> int:
|
| 811 |
+
return len(self._keys)
|
| 812 |
+
|
| 813 |
+
def __iter__(self) -> Iterator[str]:
|
| 814 |
+
return iter(self._keys)
|
| 815 |
+
|
| 816 |
+
def __reversed__(self) -> Iterator[str]:
|
| 817 |
+
return reversed(list(self._keys))
|
| 818 |
+
|
| 819 |
+
def copy(self) -> "ParameterDict":
|
| 820 |
+
"""Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
|
| 821 |
+
# We have to use an OrderedDict because the ParameterDict constructor
|
| 822 |
+
# behaves differently on plain dict vs OrderedDict
|
| 823 |
+
return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
|
| 824 |
+
|
| 825 |
+
def __contains__(self, key: str) -> bool:
|
| 826 |
+
return key in self._keys
|
| 827 |
+
|
| 828 |
+
def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
|
| 829 |
+
"""Set the default for a key in the Parameterdict.
|
| 830 |
+
|
| 831 |
+
If key is in the ParameterDict, return its value.
|
| 832 |
+
If not, insert `key` with a parameter `default` and return `default`.
|
| 833 |
+
`default` defaults to `None`.
|
| 834 |
+
|
| 835 |
+
Args:
|
| 836 |
+
key (str): key to set default for
|
| 837 |
+
default (Any): the parameter set to the key
|
| 838 |
+
"""
|
| 839 |
+
if key not in self:
|
| 840 |
+
self[key] = default
|
| 841 |
+
return self[key]
|
| 842 |
+
|
| 843 |
+
def clear(self) -> None:
|
| 844 |
+
"""Remove all items from the ParameterDict."""
|
| 845 |
+
for k in self._keys.copy():
|
| 846 |
+
del self[k]
|
| 847 |
+
|
| 848 |
+
def pop(self, key: str) -> Any:
|
| 849 |
+
r"""Remove key from the ParameterDict and return its parameter.
|
| 850 |
+
|
| 851 |
+
Args:
|
| 852 |
+
key (str): key to pop from the ParameterDict
|
| 853 |
+
"""
|
| 854 |
+
v = self[key]
|
| 855 |
+
del self[key]
|
| 856 |
+
return v
|
| 857 |
+
|
| 858 |
+
def popitem(self) -> Tuple[str, Any]:
|
| 859 |
+
"""Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
|
| 860 |
+
k, _ = self._keys.popitem()
|
| 861 |
+
# We need the key in the _keys to be able to access/del
|
| 862 |
+
self._keys[k] = None
|
| 863 |
+
val = self[k]
|
| 864 |
+
del self[k]
|
| 865 |
+
return k, val
|
| 866 |
+
|
| 867 |
+
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
| 868 |
+
r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
key (str): key to get from the ParameterDict
|
| 872 |
+
default (Parameter, optional): value to return if key not present
|
| 873 |
+
"""
|
| 874 |
+
return self[key] if key in self else default
|
| 875 |
+
|
| 876 |
+
def fromkeys(
|
| 877 |
+
self, keys: Iterable[str], default: Optional[Any] = None
|
| 878 |
+
) -> "ParameterDict":
|
| 879 |
+
r"""Return a new ParameterDict with the keys provided.
|
| 880 |
+
|
| 881 |
+
Args:
|
| 882 |
+
keys (iterable, string): keys to make the new ParameterDict from
|
| 883 |
+
default (Parameter, optional): value to set for all keys
|
| 884 |
+
"""
|
| 885 |
+
return ParameterDict((k, default) for k in keys)
|
| 886 |
+
|
| 887 |
+
def keys(self) -> Iterable[str]:
|
| 888 |
+
r"""Return an iterable of the ParameterDict keys."""
|
| 889 |
+
return self._keys.keys()
|
| 890 |
+
|
| 891 |
+
def items(self) -> Iterable[Tuple[str, Any]]:
|
| 892 |
+
r"""Return an iterable of the ParameterDict key/value pairs."""
|
| 893 |
+
return ((k, self[k]) for k in self._keys)
|
| 894 |
+
|
| 895 |
+
def values(self) -> Iterable[Any]:
|
| 896 |
+
r"""Return an iterable of the ParameterDict values."""
|
| 897 |
+
return (self[k] for k in self._keys)
|
| 898 |
+
|
| 899 |
+
def update(self, parameters: Union[Mapping[str, Any], "ParameterDict"]) -> None:
|
| 900 |
+
r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
|
| 901 |
+
|
| 902 |
+
.. note::
|
| 903 |
+
If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
|
| 904 |
+
an iterable of key-value pairs, the order of new elements in it is preserved.
|
| 905 |
+
|
| 906 |
+
Args:
|
| 907 |
+
parameters (iterable): a mapping (dictionary) from string to
|
| 908 |
+
:class:`~torch.nn.Parameter`, or an iterable of
|
| 909 |
+
key-value pairs of type (string, :class:`~torch.nn.Parameter`)
|
| 910 |
+
"""
|
| 911 |
+
if not isinstance(parameters, container_abcs.Iterable):
|
| 912 |
+
raise TypeError(
|
| 913 |
+
"ParametersDict.update should be called with an "
|
| 914 |
+
"iterable of key/value pairs, but got " + type(parameters).__name__
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
if isinstance(parameters, (OrderedDict, ParameterDict)):
|
| 918 |
+
for key, parameter in parameters.items():
|
| 919 |
+
self[key] = parameter
|
| 920 |
+
elif isinstance(parameters, container_abcs.Mapping):
|
| 921 |
+
for key, parameter in sorted(parameters.items()):
|
| 922 |
+
self[key] = parameter
|
| 923 |
+
else:
|
| 924 |
+
for j, p in enumerate(parameters):
|
| 925 |
+
if not isinstance(p, container_abcs.Iterable):
|
| 926 |
+
raise TypeError(
|
| 927 |
+
"ParameterDict update sequence element "
|
| 928 |
+
"#" + str(j) + " should be Iterable; is" + type(p).__name__
|
| 929 |
+
)
|
| 930 |
+
if not len(p) == 2:
|
| 931 |
+
raise ValueError(
|
| 932 |
+
"ParameterDict update sequence element "
|
| 933 |
+
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
|
| 934 |
+
)
|
| 935 |
+
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
|
| 936 |
+
self[p[0]] = p[1] # type: ignore[assignment]
|
| 937 |
+
|
| 938 |
+
def extra_repr(self) -> str:
|
| 939 |
+
child_lines = []
|
| 940 |
+
for k, p in self.items():
|
| 941 |
+
if isinstance(p, torch.Tensor):
|
| 942 |
+
size_str = "x".join(str(size) for size in p.size())
|
| 943 |
+
if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
| 944 |
+
device_str = f" ({p.device})"
|
| 945 |
+
else:
|
| 946 |
+
device_str = ""
|
| 947 |
+
parastr = "{} containing: [{} of size {}{}]".format(
|
| 948 |
+
"Parameter" if isinstance(p, Parameter) else "Tensor",
|
| 949 |
+
torch.typename(p),
|
| 950 |
+
size_str,
|
| 951 |
+
device_str,
|
| 952 |
+
)
|
| 953 |
+
child_lines.append(" (" + str(k) + "): " + parastr)
|
| 954 |
+
else:
|
| 955 |
+
child_lines.append(
|
| 956 |
+
" (" + str(k) + "): Object of type: " + type(p).__name__
|
| 957 |
+
)
|
| 958 |
+
tmpstr = "\n".join(child_lines)
|
| 959 |
+
return tmpstr
|
| 960 |
+
|
| 961 |
+
def __call__(self, input):
|
| 962 |
+
raise RuntimeError("ParameterDict should not be called.")
|
| 963 |
+
|
| 964 |
+
def __or__(self, other: "ParameterDict") -> "ParameterDict":
|
| 965 |
+
copy = self.copy()
|
| 966 |
+
copy.update(other)
|
| 967 |
+
return copy
|
| 968 |
+
|
| 969 |
+
def __ror__(self, other: "ParameterDict") -> "ParameterDict":
|
| 970 |
+
copy = other.copy()
|
| 971 |
+
copy.update(self)
|
| 972 |
+
return copy
|
| 973 |
+
|
| 974 |
+
def __ior__(self, other: "ParameterDict") -> Self:
|
| 975 |
+
self.update(other)
|
| 976 |
+
return self
|
.venv/Lib/site-packages/torch/nn/modules/conv.py
ADDED
|
@@ -0,0 +1,1866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Optional, Tuple, Union
|
| 4 |
+
from typing_extensions import deprecated
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch._torch_docs import reproducibility_notes
|
| 9 |
+
from torch.nn import functional as F, init
|
| 10 |
+
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
|
| 11 |
+
from torch.nn.parameter import Parameter, UninitializedParameter
|
| 12 |
+
|
| 13 |
+
from .lazy import LazyModuleMixin
|
| 14 |
+
from .module import Module
|
| 15 |
+
from .utils import _pair, _reverse_repeat_tuple, _single, _triple
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"Conv1d",
|
| 20 |
+
"Conv2d",
|
| 21 |
+
"Conv3d",
|
| 22 |
+
"ConvTranspose1d",
|
| 23 |
+
"ConvTranspose2d",
|
| 24 |
+
"ConvTranspose3d",
|
| 25 |
+
"LazyConv1d",
|
| 26 |
+
"LazyConv2d",
|
| 27 |
+
"LazyConv3d",
|
| 28 |
+
"LazyConvTranspose1d",
|
| 29 |
+
"LazyConvTranspose2d",
|
| 30 |
+
"LazyConvTranspose3d",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
convolution_notes = {
|
| 34 |
+
"groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs.
|
| 35 |
+
:attr:`in_channels` and :attr:`out_channels` must both be divisible by
|
| 36 |
+
:attr:`groups`. For example,
|
| 37 |
+
|
| 38 |
+
* At groups=1, all inputs are convolved to all outputs.
|
| 39 |
+
* At groups=2, the operation becomes equivalent to having two conv
|
| 40 |
+
layers side by side, each seeing half the input channels
|
| 41 |
+
and producing half the output channels, and both subsequently
|
| 42 |
+
concatenated.
|
| 43 |
+
* At groups= :attr:`in_channels`, each input channel is convolved with
|
| 44 |
+
its own set of filters (of size
|
| 45 |
+
:math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""",
|
| 46 |
+
"depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`,
|
| 47 |
+
where `K` is a positive integer, this operation is also known as a "depthwise convolution".
|
| 48 |
+
|
| 49 |
+
In other words, for an input of size :math:`(N, C_{in}, L_{in})`,
|
| 50 |
+
a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments
|
| 51 |
+
:math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`.""",
|
| 52 |
+
} # noqa: B950
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class _ConvNd(Module):
|
| 56 |
+
__constants__ = [
|
| 57 |
+
"stride",
|
| 58 |
+
"padding",
|
| 59 |
+
"dilation",
|
| 60 |
+
"groups",
|
| 61 |
+
"padding_mode",
|
| 62 |
+
"output_padding",
|
| 63 |
+
"in_channels",
|
| 64 |
+
"out_channels",
|
| 65 |
+
"kernel_size",
|
| 66 |
+
]
|
| 67 |
+
__annotations__ = {"bias": Optional[torch.Tensor]}
|
| 68 |
+
|
| 69 |
+
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body]
|
| 70 |
+
...
|
| 71 |
+
|
| 72 |
+
in_channels: int
|
| 73 |
+
_reversed_padding_repeated_twice: List[int]
|
| 74 |
+
out_channels: int
|
| 75 |
+
kernel_size: Tuple[int, ...]
|
| 76 |
+
stride: Tuple[int, ...]
|
| 77 |
+
padding: Union[str, Tuple[int, ...]]
|
| 78 |
+
dilation: Tuple[int, ...]
|
| 79 |
+
transposed: bool
|
| 80 |
+
output_padding: Tuple[int, ...]
|
| 81 |
+
groups: int
|
| 82 |
+
padding_mode: str
|
| 83 |
+
weight: Tensor
|
| 84 |
+
bias: Optional[Tensor]
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
in_channels: int,
|
| 89 |
+
out_channels: int,
|
| 90 |
+
kernel_size: Tuple[int, ...],
|
| 91 |
+
stride: Tuple[int, ...],
|
| 92 |
+
padding: Tuple[int, ...],
|
| 93 |
+
dilation: Tuple[int, ...],
|
| 94 |
+
transposed: bool,
|
| 95 |
+
output_padding: Tuple[int, ...],
|
| 96 |
+
groups: int,
|
| 97 |
+
bias: bool,
|
| 98 |
+
padding_mode: str,
|
| 99 |
+
device=None,
|
| 100 |
+
dtype=None,
|
| 101 |
+
) -> None:
|
| 102 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 103 |
+
super().__init__()
|
| 104 |
+
if groups <= 0:
|
| 105 |
+
raise ValueError("groups must be a positive integer")
|
| 106 |
+
if in_channels % groups != 0:
|
| 107 |
+
raise ValueError("in_channels must be divisible by groups")
|
| 108 |
+
if out_channels % groups != 0:
|
| 109 |
+
raise ValueError("out_channels must be divisible by groups")
|
| 110 |
+
valid_padding_strings = {"same", "valid"}
|
| 111 |
+
if isinstance(padding, str):
|
| 112 |
+
if padding not in valid_padding_strings:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}"
|
| 115 |
+
)
|
| 116 |
+
if padding == "same" and any(s != 1 for s in stride):
|
| 117 |
+
raise ValueError(
|
| 118 |
+
"padding='same' is not supported for strided convolutions"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
|
| 122 |
+
if padding_mode not in valid_padding_modes:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'"
|
| 125 |
+
)
|
| 126 |
+
self.in_channels = in_channels
|
| 127 |
+
self.out_channels = out_channels
|
| 128 |
+
self.kernel_size = kernel_size
|
| 129 |
+
self.stride = stride
|
| 130 |
+
self.padding = padding
|
| 131 |
+
self.dilation = dilation
|
| 132 |
+
self.transposed = transposed
|
| 133 |
+
self.output_padding = output_padding
|
| 134 |
+
self.groups = groups
|
| 135 |
+
self.padding_mode = padding_mode
|
| 136 |
+
# `_reversed_padding_repeated_twice` is the padding to be passed to
|
| 137 |
+
# `F.pad` if needed (e.g., for non-zero padding types that are
|
| 138 |
+
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
|
| 139 |
+
# reverse order than the dimension.
|
| 140 |
+
if isinstance(self.padding, str):
|
| 141 |
+
self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
|
| 142 |
+
if padding == "same":
|
| 143 |
+
for d, k, i in zip(
|
| 144 |
+
dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)
|
| 145 |
+
):
|
| 146 |
+
total_padding = d * (k - 1)
|
| 147 |
+
left_pad = total_padding // 2
|
| 148 |
+
self._reversed_padding_repeated_twice[2 * i] = left_pad
|
| 149 |
+
self._reversed_padding_repeated_twice[2 * i + 1] = (
|
| 150 |
+
total_padding - left_pad
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(
|
| 154 |
+
self.padding, 2
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if transposed:
|
| 158 |
+
self.weight = Parameter(
|
| 159 |
+
torch.empty(
|
| 160 |
+
(in_channels, out_channels // groups, *kernel_size),
|
| 161 |
+
**factory_kwargs,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
self.weight = Parameter(
|
| 166 |
+
torch.empty(
|
| 167 |
+
(out_channels, in_channels // groups, *kernel_size),
|
| 168 |
+
**factory_kwargs,
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
if bias:
|
| 172 |
+
self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
|
| 173 |
+
else:
|
| 174 |
+
self.register_parameter("bias", None)
|
| 175 |
+
|
| 176 |
+
self.reset_parameters()
|
| 177 |
+
|
| 178 |
+
def reset_parameters(self) -> None:
|
| 179 |
+
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
|
| 180 |
+
# uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
|
| 181 |
+
# For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
|
| 182 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 183 |
+
if self.bias is not None:
|
| 184 |
+
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
| 185 |
+
if fan_in != 0:
|
| 186 |
+
bound = 1 / math.sqrt(fan_in)
|
| 187 |
+
init.uniform_(self.bias, -bound, bound)
|
| 188 |
+
|
| 189 |
+
def extra_repr(self):
|
| 190 |
+
s = (
|
| 191 |
+
"{in_channels}, {out_channels}, kernel_size={kernel_size}"
|
| 192 |
+
", stride={stride}"
|
| 193 |
+
)
|
| 194 |
+
if self.padding != (0,) * len(self.padding):
|
| 195 |
+
s += ", padding={padding}"
|
| 196 |
+
if self.dilation != (1,) * len(self.dilation):
|
| 197 |
+
s += ", dilation={dilation}"
|
| 198 |
+
if self.output_padding != (0,) * len(self.output_padding):
|
| 199 |
+
s += ", output_padding={output_padding}"
|
| 200 |
+
if self.groups != 1:
|
| 201 |
+
s += ", groups={groups}"
|
| 202 |
+
if self.bias is None:
|
| 203 |
+
s += ", bias=False"
|
| 204 |
+
if self.padding_mode != "zeros":
|
| 205 |
+
s += ", padding_mode={padding_mode}"
|
| 206 |
+
return s.format(**self.__dict__)
|
| 207 |
+
|
| 208 |
+
def __setstate__(self, state):
|
| 209 |
+
super().__setstate__(state)
|
| 210 |
+
if not hasattr(self, "padding_mode"):
|
| 211 |
+
self.padding_mode = "zeros"
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class Conv1d(_ConvNd):
|
| 215 |
+
__doc__ = (
|
| 216 |
+
r"""Applies a 1D convolution over an input signal composed of several input
|
| 217 |
+
planes.
|
| 218 |
+
|
| 219 |
+
In the simplest case, the output value of the layer with input size
|
| 220 |
+
:math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
|
| 221 |
+
precisely described as:
|
| 222 |
+
|
| 223 |
+
.. math::
|
| 224 |
+
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
|
| 225 |
+
\sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
|
| 226 |
+
\star \text{input}(N_i, k)
|
| 227 |
+
|
| 228 |
+
where :math:`\star` is the valid `cross-correlation`_ operator,
|
| 229 |
+
:math:`N` is a batch size, :math:`C` denotes a number of channels,
|
| 230 |
+
:math:`L` is a length of signal sequence.
|
| 231 |
+
"""
|
| 232 |
+
+ r"""
|
| 233 |
+
|
| 234 |
+
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
|
| 235 |
+
|
| 236 |
+
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
|
| 237 |
+
|
| 238 |
+
* :attr:`stride` controls the stride for the cross-correlation, a single
|
| 239 |
+
number or a one-element tuple.
|
| 240 |
+
|
| 241 |
+
* :attr:`padding` controls the amount of padding applied to the input. It
|
| 242 |
+
can be either a string {{'valid', 'same'}} or a tuple of ints giving the
|
| 243 |
+
amount of implicit padding applied on both sides.
|
| 244 |
+
"""
|
| 245 |
+
"""
|
| 246 |
+
* :attr:`dilation` controls the spacing between the kernel points; also
|
| 247 |
+
known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
|
| 248 |
+
has a nice visualization of what :attr:`dilation` does.
|
| 249 |
+
"""
|
| 250 |
+
r"""
|
| 251 |
+
{groups_note}
|
| 252 |
+
|
| 253 |
+
Note:
|
| 254 |
+
{depthwise_separable_note}
|
| 255 |
+
Note:
|
| 256 |
+
{cudnn_reproducibility_note}
|
| 257 |
+
|
| 258 |
+
Note:
|
| 259 |
+
``padding='valid'`` is the same as no padding. ``padding='same'`` pads
|
| 260 |
+
the input so the output has the shape as the input. However, this mode
|
| 261 |
+
doesn't support any stride values other than 1.
|
| 262 |
+
|
| 263 |
+
Note:
|
| 264 |
+
This module supports complex data types i.e. ``complex32, complex64, complex128``.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
in_channels (int): Number of channels in the input image
|
| 268 |
+
out_channels (int): Number of channels produced by the convolution
|
| 269 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 270 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 271 |
+
padding (int, tuple or str, optional): Padding added to both sides of
|
| 272 |
+
the input. Default: 0
|
| 273 |
+
dilation (int or tuple, optional): Spacing between kernel
|
| 274 |
+
elements. Default: 1
|
| 275 |
+
groups (int, optional): Number of blocked connections from input
|
| 276 |
+
channels to output channels. Default: 1
|
| 277 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the
|
| 278 |
+
output. Default: ``True``
|
| 279 |
+
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
|
| 280 |
+
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
| 281 |
+
|
| 282 |
+
""".format(
|
| 283 |
+
**reproducibility_notes, **convolution_notes
|
| 284 |
+
)
|
| 285 |
+
+ r"""
|
| 286 |
+
|
| 287 |
+
Shape:
|
| 288 |
+
- Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
|
| 289 |
+
- Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
|
| 290 |
+
|
| 291 |
+
.. math::
|
| 292 |
+
L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
|
| 293 |
+
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
|
| 294 |
+
|
| 295 |
+
Attributes:
|
| 296 |
+
weight (Tensor): the learnable weights of the module of shape
|
| 297 |
+
:math:`(\text{out\_channels},
|
| 298 |
+
\frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`.
|
| 299 |
+
The values of these weights are sampled from
|
| 300 |
+
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 301 |
+
:math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
|
| 302 |
+
bias (Tensor): the learnable bias of the module of shape
|
| 303 |
+
(out_channels). If :attr:`bias` is ``True``, then the values of these weights are
|
| 304 |
+
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 305 |
+
:math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
|
| 306 |
+
|
| 307 |
+
Examples::
|
| 308 |
+
|
| 309 |
+
>>> m = nn.Conv1d(16, 33, 3, stride=2)
|
| 310 |
+
>>> input = torch.randn(20, 16, 50)
|
| 311 |
+
>>> output = m(input)
|
| 312 |
+
|
| 313 |
+
.. _cross-correlation:
|
| 314 |
+
https://en.wikipedia.org/wiki/Cross-correlation
|
| 315 |
+
|
| 316 |
+
.. _link:
|
| 317 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 318 |
+
"""
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def __init__(
|
| 322 |
+
self,
|
| 323 |
+
in_channels: int,
|
| 324 |
+
out_channels: int,
|
| 325 |
+
kernel_size: _size_1_t,
|
| 326 |
+
stride: _size_1_t = 1,
|
| 327 |
+
padding: Union[str, _size_1_t] = 0,
|
| 328 |
+
dilation: _size_1_t = 1,
|
| 329 |
+
groups: int = 1,
|
| 330 |
+
bias: bool = True,
|
| 331 |
+
padding_mode: str = "zeros", # TODO: refine this type
|
| 332 |
+
device=None,
|
| 333 |
+
dtype=None,
|
| 334 |
+
) -> None:
|
| 335 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 336 |
+
# we create new variables below to make mypy happy since kernel_size has
|
| 337 |
+
# type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
|
| 338 |
+
kernel_size_ = _single(kernel_size)
|
| 339 |
+
stride_ = _single(stride)
|
| 340 |
+
padding_ = padding if isinstance(padding, str) else _single(padding)
|
| 341 |
+
dilation_ = _single(dilation)
|
| 342 |
+
super().__init__(
|
| 343 |
+
in_channels,
|
| 344 |
+
out_channels,
|
| 345 |
+
kernel_size_,
|
| 346 |
+
stride_,
|
| 347 |
+
padding_,
|
| 348 |
+
dilation_,
|
| 349 |
+
False,
|
| 350 |
+
_single(0),
|
| 351 |
+
groups,
|
| 352 |
+
bias,
|
| 353 |
+
padding_mode,
|
| 354 |
+
**factory_kwargs,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
| 358 |
+
if self.padding_mode != "zeros":
|
| 359 |
+
return F.conv1d(
|
| 360 |
+
F.pad(
|
| 361 |
+
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
|
| 362 |
+
),
|
| 363 |
+
weight,
|
| 364 |
+
bias,
|
| 365 |
+
self.stride,
|
| 366 |
+
_single(0),
|
| 367 |
+
self.dilation,
|
| 368 |
+
self.groups,
|
| 369 |
+
)
|
| 370 |
+
return F.conv1d(
|
| 371 |
+
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 375 |
+
return self._conv_forward(input, self.weight, self.bias)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class Conv2d(_ConvNd):
|
| 379 |
+
__doc__ = (
|
| 380 |
+
r"""Applies a 2D convolution over an input signal composed of several input
|
| 381 |
+
planes.
|
| 382 |
+
|
| 383 |
+
In the simplest case, the output value of the layer with input size
|
| 384 |
+
:math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
|
| 385 |
+
can be precisely described as:
|
| 386 |
+
|
| 387 |
+
.. math::
|
| 388 |
+
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
|
| 389 |
+
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
where :math:`\star` is the valid 2D `cross-correlation`_ operator,
|
| 393 |
+
:math:`N` is a batch size, :math:`C` denotes a number of channels,
|
| 394 |
+
:math:`H` is a height of input planes in pixels, and :math:`W` is
|
| 395 |
+
width in pixels.
|
| 396 |
+
"""
|
| 397 |
+
+ r"""
|
| 398 |
+
|
| 399 |
+
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
|
| 400 |
+
|
| 401 |
+
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
|
| 402 |
+
|
| 403 |
+
* :attr:`stride` controls the stride for the cross-correlation, a single
|
| 404 |
+
number or a tuple.
|
| 405 |
+
|
| 406 |
+
* :attr:`padding` controls the amount of padding applied to the input. It
|
| 407 |
+
can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the
|
| 408 |
+
amount of implicit padding applied on both sides.
|
| 409 |
+
"""
|
| 410 |
+
"""
|
| 411 |
+
* :attr:`dilation` controls the spacing between the kernel points; also
|
| 412 |
+
known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
|
| 413 |
+
has a nice visualization of what :attr:`dilation` does.
|
| 414 |
+
"""
|
| 415 |
+
r"""
|
| 416 |
+
|
| 417 |
+
{groups_note}
|
| 418 |
+
|
| 419 |
+
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
|
| 420 |
+
|
| 421 |
+
- a single ``int`` -- in which case the same value is used for the height and width dimension
|
| 422 |
+
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
| 423 |
+
and the second `int` for the width dimension
|
| 424 |
+
|
| 425 |
+
Note:
|
| 426 |
+
{depthwise_separable_note}
|
| 427 |
+
|
| 428 |
+
Note:
|
| 429 |
+
{cudnn_reproducibility_note}
|
| 430 |
+
|
| 431 |
+
Note:
|
| 432 |
+
``padding='valid'`` is the same as no padding. ``padding='same'`` pads
|
| 433 |
+
the input so the output has the shape as the input. However, this mode
|
| 434 |
+
doesn't support any stride values other than 1.
|
| 435 |
+
|
| 436 |
+
Note:
|
| 437 |
+
This module supports complex data types i.e. ``complex32, complex64, complex128``.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
in_channels (int): Number of channels in the input image
|
| 441 |
+
out_channels (int): Number of channels produced by the convolution
|
| 442 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 443 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 444 |
+
padding (int, tuple or str, optional): Padding added to all four sides of
|
| 445 |
+
the input. Default: 0
|
| 446 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 447 |
+
groups (int, optional): Number of blocked connections from input
|
| 448 |
+
channels to output channels. Default: 1
|
| 449 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the
|
| 450 |
+
output. Default: ``True``
|
| 451 |
+
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
|
| 452 |
+
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
| 453 |
+
""".format(
|
| 454 |
+
**reproducibility_notes, **convolution_notes
|
| 455 |
+
)
|
| 456 |
+
+ r"""
|
| 457 |
+
|
| 458 |
+
Shape:
|
| 459 |
+
- Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
|
| 460 |
+
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
|
| 461 |
+
|
| 462 |
+
.. math::
|
| 463 |
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
|
| 464 |
+
\times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
|
| 465 |
+
|
| 466 |
+
.. math::
|
| 467 |
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
|
| 468 |
+
\times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
|
| 469 |
+
|
| 470 |
+
Attributes:
|
| 471 |
+
weight (Tensor): the learnable weights of the module of shape
|
| 472 |
+
:math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
|
| 473 |
+
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
|
| 474 |
+
The values of these weights are sampled from
|
| 475 |
+
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 476 |
+
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
|
| 477 |
+
bias (Tensor): the learnable bias of the module of shape
|
| 478 |
+
(out_channels). If :attr:`bias` is ``True``,
|
| 479 |
+
then the values of these weights are
|
| 480 |
+
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 481 |
+
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
|
| 482 |
+
|
| 483 |
+
Examples:
|
| 484 |
+
|
| 485 |
+
>>> # With square kernels and equal stride
|
| 486 |
+
>>> m = nn.Conv2d(16, 33, 3, stride=2)
|
| 487 |
+
>>> # non-square kernels and unequal stride and with padding
|
| 488 |
+
>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
|
| 489 |
+
>>> # non-square kernels and unequal stride and with padding and dilation
|
| 490 |
+
>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
| 491 |
+
>>> input = torch.randn(20, 16, 50, 100)
|
| 492 |
+
>>> output = m(input)
|
| 493 |
+
|
| 494 |
+
.. _cross-correlation:
|
| 495 |
+
https://en.wikipedia.org/wiki/Cross-correlation
|
| 496 |
+
|
| 497 |
+
.. _link:
|
| 498 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 499 |
+
"""
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
def __init__(
|
| 503 |
+
self,
|
| 504 |
+
in_channels: int,
|
| 505 |
+
out_channels: int,
|
| 506 |
+
kernel_size: _size_2_t,
|
| 507 |
+
stride: _size_2_t = 1,
|
| 508 |
+
padding: Union[str, _size_2_t] = 0,
|
| 509 |
+
dilation: _size_2_t = 1,
|
| 510 |
+
groups: int = 1,
|
| 511 |
+
bias: bool = True,
|
| 512 |
+
padding_mode: str = "zeros", # TODO: refine this type
|
| 513 |
+
device=None,
|
| 514 |
+
dtype=None,
|
| 515 |
+
) -> None:
|
| 516 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 517 |
+
kernel_size_ = _pair(kernel_size)
|
| 518 |
+
stride_ = _pair(stride)
|
| 519 |
+
padding_ = padding if isinstance(padding, str) else _pair(padding)
|
| 520 |
+
dilation_ = _pair(dilation)
|
| 521 |
+
super().__init__(
|
| 522 |
+
in_channels,
|
| 523 |
+
out_channels,
|
| 524 |
+
kernel_size_,
|
| 525 |
+
stride_,
|
| 526 |
+
padding_,
|
| 527 |
+
dilation_,
|
| 528 |
+
False,
|
| 529 |
+
_pair(0),
|
| 530 |
+
groups,
|
| 531 |
+
bias,
|
| 532 |
+
padding_mode,
|
| 533 |
+
**factory_kwargs,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
| 537 |
+
if self.padding_mode != "zeros":
|
| 538 |
+
return F.conv2d(
|
| 539 |
+
F.pad(
|
| 540 |
+
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
|
| 541 |
+
),
|
| 542 |
+
weight,
|
| 543 |
+
bias,
|
| 544 |
+
self.stride,
|
| 545 |
+
_pair(0),
|
| 546 |
+
self.dilation,
|
| 547 |
+
self.groups,
|
| 548 |
+
)
|
| 549 |
+
return F.conv2d(
|
| 550 |
+
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 554 |
+
return self._conv_forward(input, self.weight, self.bias)
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class Conv3d(_ConvNd):
|
| 558 |
+
__doc__ = (
|
| 559 |
+
r"""Applies a 3D convolution over an input signal composed of several input
|
| 560 |
+
planes.
|
| 561 |
+
|
| 562 |
+
In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`
|
| 563 |
+
and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
|
| 564 |
+
|
| 565 |
+
.. math::
|
| 566 |
+
out(N_i, C_{out_j}) = bias(C_{out_j}) +
|
| 567 |
+
\sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k)
|
| 568 |
+
|
| 569 |
+
where :math:`\star` is the valid 3D `cross-correlation`_ operator
|
| 570 |
+
"""
|
| 571 |
+
+ r"""
|
| 572 |
+
|
| 573 |
+
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
|
| 574 |
+
|
| 575 |
+
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
|
| 576 |
+
|
| 577 |
+
* :attr:`stride` controls the stride for the cross-correlation.
|
| 578 |
+
|
| 579 |
+
* :attr:`padding` controls the amount of padding applied to the input. It
|
| 580 |
+
can be either a string {{'valid', 'same'}} or a tuple of ints giving the
|
| 581 |
+
amount of implicit padding applied on both sides.
|
| 582 |
+
"""
|
| 583 |
+
"""
|
| 584 |
+
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
|
| 585 |
+
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
| 586 |
+
"""
|
| 587 |
+
r"""
|
| 588 |
+
|
| 589 |
+
{groups_note}
|
| 590 |
+
|
| 591 |
+
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
|
| 592 |
+
|
| 593 |
+
- a single ``int`` -- in which case the same value is used for the depth, height and width dimension
|
| 594 |
+
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
| 595 |
+
the second `int` for the height dimension and the third `int` for the width dimension
|
| 596 |
+
|
| 597 |
+
Note:
|
| 598 |
+
{depthwise_separable_note}
|
| 599 |
+
|
| 600 |
+
Note:
|
| 601 |
+
{cudnn_reproducibility_note}
|
| 602 |
+
|
| 603 |
+
Note:
|
| 604 |
+
``padding='valid'`` is the same as no padding. ``padding='same'`` pads
|
| 605 |
+
the input so the output has the shape as the input. However, this mode
|
| 606 |
+
doesn't support any stride values other than 1.
|
| 607 |
+
|
| 608 |
+
Note:
|
| 609 |
+
This module supports complex data types i.e. ``complex32, complex64, complex128``.
|
| 610 |
+
|
| 611 |
+
Args:
|
| 612 |
+
in_channels (int): Number of channels in the input image
|
| 613 |
+
out_channels (int): Number of channels produced by the convolution
|
| 614 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 615 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 616 |
+
padding (int, tuple or str, optional): Padding added to all six sides of
|
| 617 |
+
the input. Default: 0
|
| 618 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 619 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 620 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
| 621 |
+
padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
| 622 |
+
""".format(
|
| 623 |
+
**reproducibility_notes, **convolution_notes
|
| 624 |
+
)
|
| 625 |
+
+ r"""
|
| 626 |
+
|
| 627 |
+
Shape:
|
| 628 |
+
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
|
| 629 |
+
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`,
|
| 630 |
+
where
|
| 631 |
+
|
| 632 |
+
.. math::
|
| 633 |
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
|
| 634 |
+
\times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
|
| 635 |
+
|
| 636 |
+
.. math::
|
| 637 |
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
|
| 638 |
+
\times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
|
| 639 |
+
|
| 640 |
+
.. math::
|
| 641 |
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2]
|
| 642 |
+
\times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
|
| 643 |
+
|
| 644 |
+
Attributes:
|
| 645 |
+
weight (Tensor): the learnable weights of the module of shape
|
| 646 |
+
:math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
|
| 647 |
+
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
|
| 648 |
+
The values of these weights are sampled from
|
| 649 |
+
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 650 |
+
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
|
| 651 |
+
bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
|
| 652 |
+
then the values of these weights are
|
| 653 |
+
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 654 |
+
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
|
| 655 |
+
|
| 656 |
+
Examples::
|
| 657 |
+
|
| 658 |
+
>>> # With square kernels and equal stride
|
| 659 |
+
>>> m = nn.Conv3d(16, 33, 3, stride=2)
|
| 660 |
+
>>> # non-square kernels and unequal stride and with padding
|
| 661 |
+
>>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
|
| 662 |
+
>>> input = torch.randn(20, 16, 10, 50, 100)
|
| 663 |
+
>>> output = m(input)
|
| 664 |
+
|
| 665 |
+
.. _cross-correlation:
|
| 666 |
+
https://en.wikipedia.org/wiki/Cross-correlation
|
| 667 |
+
|
| 668 |
+
.. _link:
|
| 669 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 670 |
+
"""
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
def __init__(
|
| 674 |
+
self,
|
| 675 |
+
in_channels: int,
|
| 676 |
+
out_channels: int,
|
| 677 |
+
kernel_size: _size_3_t,
|
| 678 |
+
stride: _size_3_t = 1,
|
| 679 |
+
padding: Union[str, _size_3_t] = 0,
|
| 680 |
+
dilation: _size_3_t = 1,
|
| 681 |
+
groups: int = 1,
|
| 682 |
+
bias: bool = True,
|
| 683 |
+
padding_mode: str = "zeros",
|
| 684 |
+
device=None,
|
| 685 |
+
dtype=None,
|
| 686 |
+
) -> None:
|
| 687 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 688 |
+
kernel_size_ = _triple(kernel_size)
|
| 689 |
+
stride_ = _triple(stride)
|
| 690 |
+
padding_ = padding if isinstance(padding, str) else _triple(padding)
|
| 691 |
+
dilation_ = _triple(dilation)
|
| 692 |
+
super().__init__(
|
| 693 |
+
in_channels,
|
| 694 |
+
out_channels,
|
| 695 |
+
kernel_size_,
|
| 696 |
+
stride_,
|
| 697 |
+
padding_,
|
| 698 |
+
dilation_,
|
| 699 |
+
False,
|
| 700 |
+
_triple(0),
|
| 701 |
+
groups,
|
| 702 |
+
bias,
|
| 703 |
+
padding_mode,
|
| 704 |
+
**factory_kwargs,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
| 708 |
+
if self.padding_mode != "zeros":
|
| 709 |
+
return F.conv3d(
|
| 710 |
+
F.pad(
|
| 711 |
+
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
|
| 712 |
+
),
|
| 713 |
+
weight,
|
| 714 |
+
bias,
|
| 715 |
+
self.stride,
|
| 716 |
+
_triple(0),
|
| 717 |
+
self.dilation,
|
| 718 |
+
self.groups,
|
| 719 |
+
)
|
| 720 |
+
return F.conv3d(
|
| 721 |
+
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 725 |
+
return self._conv_forward(input, self.weight, self.bias)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
class _ConvTransposeNd(_ConvNd):
|
| 729 |
+
def __init__(
|
| 730 |
+
self,
|
| 731 |
+
in_channels,
|
| 732 |
+
out_channels,
|
| 733 |
+
kernel_size,
|
| 734 |
+
stride,
|
| 735 |
+
padding,
|
| 736 |
+
dilation,
|
| 737 |
+
transposed,
|
| 738 |
+
output_padding,
|
| 739 |
+
groups,
|
| 740 |
+
bias,
|
| 741 |
+
padding_mode,
|
| 742 |
+
device=None,
|
| 743 |
+
dtype=None,
|
| 744 |
+
) -> None:
|
| 745 |
+
if padding_mode != "zeros":
|
| 746 |
+
raise ValueError(
|
| 747 |
+
f'Only "zeros" padding mode is supported for {self.__class__.__name__}'
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 751 |
+
super().__init__(
|
| 752 |
+
in_channels,
|
| 753 |
+
out_channels,
|
| 754 |
+
kernel_size,
|
| 755 |
+
stride,
|
| 756 |
+
padding,
|
| 757 |
+
dilation,
|
| 758 |
+
transposed,
|
| 759 |
+
output_padding,
|
| 760 |
+
groups,
|
| 761 |
+
bias,
|
| 762 |
+
padding_mode,
|
| 763 |
+
**factory_kwargs,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
# dilation being an optional parameter is for backwards
|
| 767 |
+
# compatibility
|
| 768 |
+
def _output_padding(
|
| 769 |
+
self,
|
| 770 |
+
input: Tensor,
|
| 771 |
+
output_size: Optional[List[int]],
|
| 772 |
+
stride: List[int],
|
| 773 |
+
padding: List[int],
|
| 774 |
+
kernel_size: List[int],
|
| 775 |
+
num_spatial_dims: int,
|
| 776 |
+
dilation: Optional[List[int]] = None,
|
| 777 |
+
) -> List[int]:
|
| 778 |
+
if output_size is None:
|
| 779 |
+
ret = _single(self.output_padding) # converting to list if was not already
|
| 780 |
+
else:
|
| 781 |
+
has_batch_dim = input.dim() == num_spatial_dims + 2
|
| 782 |
+
num_non_spatial_dims = 2 if has_batch_dim else 1
|
| 783 |
+
if len(output_size) == num_non_spatial_dims + num_spatial_dims:
|
| 784 |
+
output_size = output_size[num_non_spatial_dims:]
|
| 785 |
+
if len(output_size) != num_spatial_dims:
|
| 786 |
+
raise ValueError(
|
| 787 |
+
f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} "
|
| 788 |
+
f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
min_sizes = torch.jit.annotate(List[int], [])
|
| 792 |
+
max_sizes = torch.jit.annotate(List[int], [])
|
| 793 |
+
for d in range(num_spatial_dims):
|
| 794 |
+
dim_size = (
|
| 795 |
+
(input.size(d + num_non_spatial_dims) - 1) * stride[d]
|
| 796 |
+
- 2 * padding[d]
|
| 797 |
+
+ (dilation[d] if dilation is not None else 1)
|
| 798 |
+
* (kernel_size[d] - 1)
|
| 799 |
+
+ 1
|
| 800 |
+
)
|
| 801 |
+
min_sizes.append(dim_size)
|
| 802 |
+
max_sizes.append(min_sizes[d] + stride[d] - 1)
|
| 803 |
+
|
| 804 |
+
for i in range(len(output_size)):
|
| 805 |
+
size = output_size[i]
|
| 806 |
+
min_size = min_sizes[i]
|
| 807 |
+
max_size = max_sizes[i]
|
| 808 |
+
if size < min_size or size > max_size:
|
| 809 |
+
raise ValueError(
|
| 810 |
+
f"requested an output size of {output_size}, but valid sizes range "
|
| 811 |
+
f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
res = torch.jit.annotate(List[int], [])
|
| 815 |
+
for d in range(num_spatial_dims):
|
| 816 |
+
res.append(output_size[d] - min_sizes[d])
|
| 817 |
+
|
| 818 |
+
ret = res
|
| 819 |
+
return ret
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
class ConvTranspose1d(_ConvTransposeNd):
|
| 823 |
+
__doc__ = (
|
| 824 |
+
r"""Applies a 1D transposed convolution operator over an input image
|
| 825 |
+
composed of several input planes.
|
| 826 |
+
|
| 827 |
+
This module can be seen as the gradient of Conv1d with respect to its input.
|
| 828 |
+
It is also known as a fractionally-strided convolution or
|
| 829 |
+
a deconvolution (although it is not an actual deconvolution operation as it does
|
| 830 |
+
not compute a true inverse of convolution). For more information, see the visualizations
|
| 831 |
+
`here`_ and the `Deconvolutional Networks`_ paper.
|
| 832 |
+
|
| 833 |
+
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
|
| 834 |
+
|
| 835 |
+
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
|
| 836 |
+
|
| 837 |
+
* :attr:`stride` controls the stride for the cross-correlation.
|
| 838 |
+
|
| 839 |
+
* :attr:`padding` controls the amount of implicit zero padding on both
|
| 840 |
+
sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
|
| 841 |
+
below for details.
|
| 842 |
+
|
| 843 |
+
* :attr:`output_padding` controls the additional size added to one side
|
| 844 |
+
of the output shape. See note below for details.
|
| 845 |
+
"""
|
| 846 |
+
"""
|
| 847 |
+
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
|
| 848 |
+
It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
|
| 849 |
+
"""
|
| 850 |
+
r"""
|
| 851 |
+
{groups_note}
|
| 852 |
+
|
| 853 |
+
Note:
|
| 854 |
+
The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
|
| 855 |
+
amount of zero padding to both sizes of the input. This is set so that
|
| 856 |
+
when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d`
|
| 857 |
+
are initialized with same parameters, they are inverses of each other in
|
| 858 |
+
regard to the input and output shapes. However, when ``stride > 1``,
|
| 859 |
+
:class:`~torch.nn.Conv1d` maps multiple input shapes to the same output
|
| 860 |
+
shape. :attr:`output_padding` is provided to resolve this ambiguity by
|
| 861 |
+
effectively increasing the calculated output shape on one side. Note
|
| 862 |
+
that :attr:`output_padding` is only used to find output shape, but does
|
| 863 |
+
not actually add zero-padding to output.
|
| 864 |
+
|
| 865 |
+
Note:
|
| 866 |
+
In some circumstances when using the CUDA backend with CuDNN, this operator
|
| 867 |
+
may select a nondeterministic algorithm to increase performance. If this is
|
| 868 |
+
undesirable, you can try to make the operation deterministic (potentially at
|
| 869 |
+
a performance cost) by setting ``torch.backends.cudnn.deterministic =
|
| 870 |
+
True``.
|
| 871 |
+
Please see the notes on :doc:`/notes/randomness` for background.
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
Args:
|
| 875 |
+
in_channels (int): Number of channels in the input image
|
| 876 |
+
out_channels (int): Number of channels produced by the convolution
|
| 877 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 878 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 879 |
+
padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
|
| 880 |
+
will be added to both sides of the input. Default: 0
|
| 881 |
+
output_padding (int or tuple, optional): Additional size added to one side
|
| 882 |
+
of the output shape. Default: 0
|
| 883 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 884 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
| 885 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 886 |
+
""".format(
|
| 887 |
+
**reproducibility_notes, **convolution_notes
|
| 888 |
+
)
|
| 889 |
+
+ r"""
|
| 890 |
+
|
| 891 |
+
Shape:
|
| 892 |
+
- Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
|
| 893 |
+
- Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
|
| 894 |
+
|
| 895 |
+
.. math::
|
| 896 |
+
L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation}
|
| 897 |
+
\times (\text{kernel\_size} - 1) + \text{output\_padding} + 1
|
| 898 |
+
|
| 899 |
+
Attributes:
|
| 900 |
+
weight (Tensor): the learnable weights of the module of shape
|
| 901 |
+
:math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
|
| 902 |
+
:math:`\text{kernel\_size})`.
|
| 903 |
+
The values of these weights are sampled from
|
| 904 |
+
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 905 |
+
:math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
|
| 906 |
+
bias (Tensor): the learnable bias of the module of shape (out_channels).
|
| 907 |
+
If :attr:`bias` is ``True``, then the values of these weights are
|
| 908 |
+
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 909 |
+
:math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
|
| 910 |
+
|
| 911 |
+
.. _`here`:
|
| 912 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 913 |
+
|
| 914 |
+
.. _`Deconvolutional Networks`:
|
| 915 |
+
https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
|
| 916 |
+
"""
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
def __init__(
|
| 920 |
+
self,
|
| 921 |
+
in_channels: int,
|
| 922 |
+
out_channels: int,
|
| 923 |
+
kernel_size: _size_1_t,
|
| 924 |
+
stride: _size_1_t = 1,
|
| 925 |
+
padding: _size_1_t = 0,
|
| 926 |
+
output_padding: _size_1_t = 0,
|
| 927 |
+
groups: int = 1,
|
| 928 |
+
bias: bool = True,
|
| 929 |
+
dilation: _size_1_t = 1,
|
| 930 |
+
padding_mode: str = "zeros",
|
| 931 |
+
device=None,
|
| 932 |
+
dtype=None,
|
| 933 |
+
) -> None:
|
| 934 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 935 |
+
kernel_size = _single(kernel_size)
|
| 936 |
+
stride = _single(stride)
|
| 937 |
+
padding = _single(padding)
|
| 938 |
+
dilation = _single(dilation)
|
| 939 |
+
output_padding = _single(output_padding)
|
| 940 |
+
super().__init__(
|
| 941 |
+
in_channels,
|
| 942 |
+
out_channels,
|
| 943 |
+
kernel_size,
|
| 944 |
+
stride,
|
| 945 |
+
padding,
|
| 946 |
+
dilation,
|
| 947 |
+
True,
|
| 948 |
+
output_padding,
|
| 949 |
+
groups,
|
| 950 |
+
bias,
|
| 951 |
+
padding_mode,
|
| 952 |
+
**factory_kwargs,
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 956 |
+
if self.padding_mode != "zeros":
|
| 957 |
+
raise ValueError(
|
| 958 |
+
"Only `zeros` padding mode is supported for ConvTranspose1d"
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
assert isinstance(self.padding, tuple)
|
| 962 |
+
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
| 963 |
+
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
| 964 |
+
num_spatial_dims = 1
|
| 965 |
+
output_padding = self._output_padding(
|
| 966 |
+
input,
|
| 967 |
+
output_size,
|
| 968 |
+
self.stride, # type: ignore[arg-type]
|
| 969 |
+
self.padding, # type: ignore[arg-type]
|
| 970 |
+
self.kernel_size, # type: ignore[arg-type]
|
| 971 |
+
num_spatial_dims,
|
| 972 |
+
self.dilation, # type: ignore[arg-type]
|
| 973 |
+
)
|
| 974 |
+
return F.conv_transpose1d(
|
| 975 |
+
input,
|
| 976 |
+
self.weight,
|
| 977 |
+
self.bias,
|
| 978 |
+
self.stride,
|
| 979 |
+
self.padding,
|
| 980 |
+
output_padding,
|
| 981 |
+
self.groups,
|
| 982 |
+
self.dilation,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
class ConvTranspose2d(_ConvTransposeNd):
|
| 987 |
+
__doc__ = (
|
| 988 |
+
r"""Applies a 2D transposed convolution operator over an input image
|
| 989 |
+
composed of several input planes.
|
| 990 |
+
|
| 991 |
+
This module can be seen as the gradient of Conv2d with respect to its input.
|
| 992 |
+
It is also known as a fractionally-strided convolution or
|
| 993 |
+
a deconvolution (although it is not an actual deconvolution operation as it does
|
| 994 |
+
not compute a true inverse of convolution). For more information, see the visualizations
|
| 995 |
+
`here`_ and the `Deconvolutional Networks`_ paper.
|
| 996 |
+
|
| 997 |
+
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
|
| 998 |
+
|
| 999 |
+
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
|
| 1000 |
+
|
| 1001 |
+
* :attr:`stride` controls the stride for the cross-correlation.
|
| 1002 |
+
|
| 1003 |
+
* :attr:`padding` controls the amount of implicit zero padding on both
|
| 1004 |
+
sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
|
| 1005 |
+
below for details.
|
| 1006 |
+
|
| 1007 |
+
* :attr:`output_padding` controls the additional size added to one side
|
| 1008 |
+
of the output shape. See note below for details.
|
| 1009 |
+
"""
|
| 1010 |
+
"""
|
| 1011 |
+
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
|
| 1012 |
+
It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
|
| 1013 |
+
"""
|
| 1014 |
+
r"""
|
| 1015 |
+
{groups_note}
|
| 1016 |
+
|
| 1017 |
+
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
|
| 1018 |
+
can either be:
|
| 1019 |
+
|
| 1020 |
+
- a single ``int`` -- in which case the same value is used for the height and width dimensions
|
| 1021 |
+
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
| 1022 |
+
and the second `int` for the width dimension
|
| 1023 |
+
|
| 1024 |
+
Note:
|
| 1025 |
+
The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
|
| 1026 |
+
amount of zero padding to both sizes of the input. This is set so that
|
| 1027 |
+
when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d`
|
| 1028 |
+
are initialized with same parameters, they are inverses of each other in
|
| 1029 |
+
regard to the input and output shapes. However, when ``stride > 1``,
|
| 1030 |
+
:class:`~torch.nn.Conv2d` maps multiple input shapes to the same output
|
| 1031 |
+
shape. :attr:`output_padding` is provided to resolve this ambiguity by
|
| 1032 |
+
effectively increasing the calculated output shape on one side. Note
|
| 1033 |
+
that :attr:`output_padding` is only used to find output shape, but does
|
| 1034 |
+
not actually add zero-padding to output.
|
| 1035 |
+
|
| 1036 |
+
Note:
|
| 1037 |
+
{cudnn_reproducibility_note}
|
| 1038 |
+
|
| 1039 |
+
Args:
|
| 1040 |
+
in_channels (int): Number of channels in the input image
|
| 1041 |
+
out_channels (int): Number of channels produced by the convolution
|
| 1042 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 1043 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 1044 |
+
padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
|
| 1045 |
+
will be added to both sides of each dimension in the input. Default: 0
|
| 1046 |
+
output_padding (int or tuple, optional): Additional size added to one side
|
| 1047 |
+
of each dimension in the output shape. Default: 0
|
| 1048 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 1049 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
| 1050 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 1051 |
+
""".format(
|
| 1052 |
+
**reproducibility_notes, **convolution_notes
|
| 1053 |
+
)
|
| 1054 |
+
+ r"""
|
| 1055 |
+
|
| 1056 |
+
Shape:
|
| 1057 |
+
- Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
|
| 1058 |
+
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
|
| 1059 |
+
|
| 1060 |
+
.. math::
|
| 1061 |
+
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
|
| 1062 |
+
\times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
|
| 1063 |
+
.. math::
|
| 1064 |
+
W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
|
| 1065 |
+
\times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
|
| 1066 |
+
|
| 1067 |
+
Attributes:
|
| 1068 |
+
weight (Tensor): the learnable weights of the module of shape
|
| 1069 |
+
:math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
|
| 1070 |
+
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
|
| 1071 |
+
The values of these weights are sampled from
|
| 1072 |
+
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 1073 |
+
:math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
|
| 1074 |
+
bias (Tensor): the learnable bias of the module of shape (out_channels)
|
| 1075 |
+
If :attr:`bias` is ``True``, then the values of these weights are
|
| 1076 |
+
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 1077 |
+
:math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
|
| 1078 |
+
|
| 1079 |
+
Examples::
|
| 1080 |
+
|
| 1081 |
+
>>> # With square kernels and equal stride
|
| 1082 |
+
>>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
|
| 1083 |
+
>>> # non-square kernels and unequal stride and with padding
|
| 1084 |
+
>>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
|
| 1085 |
+
>>> input = torch.randn(20, 16, 50, 100)
|
| 1086 |
+
>>> output = m(input)
|
| 1087 |
+
>>> # exact output size can be also specified as an argument
|
| 1088 |
+
>>> input = torch.randn(1, 16, 12, 12)
|
| 1089 |
+
>>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
|
| 1090 |
+
>>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
|
| 1091 |
+
>>> h = downsample(input)
|
| 1092 |
+
>>> h.size()
|
| 1093 |
+
torch.Size([1, 16, 6, 6])
|
| 1094 |
+
>>> output = upsample(h, output_size=input.size())
|
| 1095 |
+
>>> output.size()
|
| 1096 |
+
torch.Size([1, 16, 12, 12])
|
| 1097 |
+
|
| 1098 |
+
.. _`here`:
|
| 1099 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 1100 |
+
|
| 1101 |
+
.. _`Deconvolutional Networks`:
|
| 1102 |
+
https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
|
| 1103 |
+
"""
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
def __init__(
|
| 1107 |
+
self,
|
| 1108 |
+
in_channels: int,
|
| 1109 |
+
out_channels: int,
|
| 1110 |
+
kernel_size: _size_2_t,
|
| 1111 |
+
stride: _size_2_t = 1,
|
| 1112 |
+
padding: _size_2_t = 0,
|
| 1113 |
+
output_padding: _size_2_t = 0,
|
| 1114 |
+
groups: int = 1,
|
| 1115 |
+
bias: bool = True,
|
| 1116 |
+
dilation: _size_2_t = 1,
|
| 1117 |
+
padding_mode: str = "zeros",
|
| 1118 |
+
device=None,
|
| 1119 |
+
dtype=None,
|
| 1120 |
+
) -> None:
|
| 1121 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1122 |
+
kernel_size = _pair(kernel_size)
|
| 1123 |
+
stride = _pair(stride)
|
| 1124 |
+
padding = _pair(padding)
|
| 1125 |
+
dilation = _pair(dilation)
|
| 1126 |
+
output_padding = _pair(output_padding)
|
| 1127 |
+
super().__init__(
|
| 1128 |
+
in_channels,
|
| 1129 |
+
out_channels,
|
| 1130 |
+
kernel_size,
|
| 1131 |
+
stride,
|
| 1132 |
+
padding,
|
| 1133 |
+
dilation,
|
| 1134 |
+
True,
|
| 1135 |
+
output_padding,
|
| 1136 |
+
groups,
|
| 1137 |
+
bias,
|
| 1138 |
+
padding_mode,
|
| 1139 |
+
**factory_kwargs,
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 1143 |
+
if self.padding_mode != "zeros":
|
| 1144 |
+
raise ValueError(
|
| 1145 |
+
"Only `zeros` padding mode is supported for ConvTranspose2d"
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
assert isinstance(self.padding, tuple)
|
| 1149 |
+
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
| 1150 |
+
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
| 1151 |
+
num_spatial_dims = 2
|
| 1152 |
+
output_padding = self._output_padding(
|
| 1153 |
+
input,
|
| 1154 |
+
output_size,
|
| 1155 |
+
self.stride, # type: ignore[arg-type]
|
| 1156 |
+
self.padding, # type: ignore[arg-type]
|
| 1157 |
+
self.kernel_size, # type: ignore[arg-type]
|
| 1158 |
+
num_spatial_dims,
|
| 1159 |
+
self.dilation, # type: ignore[arg-type]
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
return F.conv_transpose2d(
|
| 1163 |
+
input,
|
| 1164 |
+
self.weight,
|
| 1165 |
+
self.bias,
|
| 1166 |
+
self.stride,
|
| 1167 |
+
self.padding,
|
| 1168 |
+
output_padding,
|
| 1169 |
+
self.groups,
|
| 1170 |
+
self.dilation,
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
|
| 1174 |
+
class ConvTranspose3d(_ConvTransposeNd):
|
| 1175 |
+
__doc__ = (
|
| 1176 |
+
r"""Applies a 3D transposed convolution operator over an input image composed of several input
|
| 1177 |
+
planes.
|
| 1178 |
+
The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
|
| 1179 |
+
and sums over the outputs from all input feature planes.
|
| 1180 |
+
|
| 1181 |
+
This module can be seen as the gradient of Conv3d with respect to its input.
|
| 1182 |
+
It is also known as a fractionally-strided convolution or
|
| 1183 |
+
a deconvolution (although it is not an actual deconvolution operation as it does
|
| 1184 |
+
not compute a true inverse of convolution). For more information, see the visualizations
|
| 1185 |
+
`here`_ and the `Deconvolutional Networks`_ paper.
|
| 1186 |
+
|
| 1187 |
+
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
|
| 1188 |
+
|
| 1189 |
+
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
|
| 1190 |
+
|
| 1191 |
+
* :attr:`stride` controls the stride for the cross-correlation.
|
| 1192 |
+
|
| 1193 |
+
* :attr:`padding` controls the amount of implicit zero padding on both
|
| 1194 |
+
sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
|
| 1195 |
+
below for details.
|
| 1196 |
+
|
| 1197 |
+
* :attr:`output_padding` controls the additional size added to one side
|
| 1198 |
+
of the output shape. See note below for details.
|
| 1199 |
+
"""
|
| 1200 |
+
"""
|
| 1201 |
+
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
|
| 1202 |
+
It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
|
| 1203 |
+
"""
|
| 1204 |
+
r"""
|
| 1205 |
+
{groups_note}
|
| 1206 |
+
|
| 1207 |
+
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
|
| 1208 |
+
can either be:
|
| 1209 |
+
|
| 1210 |
+
- a single ``int`` -- in which case the same value is used for the depth, height and width dimensions
|
| 1211 |
+
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
| 1212 |
+
the second `int` for the height dimension and the third `int` for the width dimension
|
| 1213 |
+
|
| 1214 |
+
Note:
|
| 1215 |
+
The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
|
| 1216 |
+
amount of zero padding to both sizes of the input. This is set so that
|
| 1217 |
+
when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d`
|
| 1218 |
+
are initialized with same parameters, they are inverses of each other in
|
| 1219 |
+
regard to the input and output shapes. However, when ``stride > 1``,
|
| 1220 |
+
:class:`~torch.nn.Conv3d` maps multiple input shapes to the same output
|
| 1221 |
+
shape. :attr:`output_padding` is provided to resolve this ambiguity by
|
| 1222 |
+
effectively increasing the calculated output shape on one side. Note
|
| 1223 |
+
that :attr:`output_padding` is only used to find output shape, but does
|
| 1224 |
+
not actually add zero-padding to output.
|
| 1225 |
+
|
| 1226 |
+
Note:
|
| 1227 |
+
{cudnn_reproducibility_note}
|
| 1228 |
+
|
| 1229 |
+
Args:
|
| 1230 |
+
in_channels (int): Number of channels in the input image
|
| 1231 |
+
out_channels (int): Number of channels produced by the convolution
|
| 1232 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 1233 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 1234 |
+
padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
|
| 1235 |
+
will be added to both sides of each dimension in the input. Default: 0
|
| 1236 |
+
output_padding (int or tuple, optional): Additional size added to one side
|
| 1237 |
+
of each dimension in the output shape. Default: 0
|
| 1238 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 1239 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
| 1240 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 1241 |
+
""".format(
|
| 1242 |
+
**reproducibility_notes, **convolution_notes
|
| 1243 |
+
)
|
| 1244 |
+
+ r"""
|
| 1245 |
+
|
| 1246 |
+
Shape:
|
| 1247 |
+
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
|
| 1248 |
+
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or
|
| 1249 |
+
:math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where
|
| 1250 |
+
|
| 1251 |
+
.. math::
|
| 1252 |
+
D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
|
| 1253 |
+
\times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
|
| 1254 |
+
.. math::
|
| 1255 |
+
H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
|
| 1256 |
+
\times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
|
| 1257 |
+
.. math::
|
| 1258 |
+
W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2]
|
| 1259 |
+
\times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1
|
| 1260 |
+
|
| 1261 |
+
|
| 1262 |
+
Attributes:
|
| 1263 |
+
weight (Tensor): the learnable weights of the module of shape
|
| 1264 |
+
:math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
|
| 1265 |
+
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
|
| 1266 |
+
The values of these weights are sampled from
|
| 1267 |
+
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 1268 |
+
:math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
|
| 1269 |
+
bias (Tensor): the learnable bias of the module of shape (out_channels)
|
| 1270 |
+
If :attr:`bias` is ``True``, then the values of these weights are
|
| 1271 |
+
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
| 1272 |
+
:math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
|
| 1273 |
+
|
| 1274 |
+
Examples::
|
| 1275 |
+
|
| 1276 |
+
>>> # With square kernels and equal stride
|
| 1277 |
+
>>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
|
| 1278 |
+
>>> # non-square kernels and unequal stride and with padding
|
| 1279 |
+
>>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
|
| 1280 |
+
>>> input = torch.randn(20, 16, 10, 50, 100)
|
| 1281 |
+
>>> output = m(input)
|
| 1282 |
+
|
| 1283 |
+
.. _`here`:
|
| 1284 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 1285 |
+
|
| 1286 |
+
.. _`Deconvolutional Networks`:
|
| 1287 |
+
https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
|
| 1288 |
+
"""
|
| 1289 |
+
)
|
| 1290 |
+
|
| 1291 |
+
def __init__(
|
| 1292 |
+
self,
|
| 1293 |
+
in_channels: int,
|
| 1294 |
+
out_channels: int,
|
| 1295 |
+
kernel_size: _size_3_t,
|
| 1296 |
+
stride: _size_3_t = 1,
|
| 1297 |
+
padding: _size_3_t = 0,
|
| 1298 |
+
output_padding: _size_3_t = 0,
|
| 1299 |
+
groups: int = 1,
|
| 1300 |
+
bias: bool = True,
|
| 1301 |
+
dilation: _size_3_t = 1,
|
| 1302 |
+
padding_mode: str = "zeros",
|
| 1303 |
+
device=None,
|
| 1304 |
+
dtype=None,
|
| 1305 |
+
) -> None:
|
| 1306 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1307 |
+
kernel_size = _triple(kernel_size)
|
| 1308 |
+
stride = _triple(stride)
|
| 1309 |
+
padding = _triple(padding)
|
| 1310 |
+
dilation = _triple(dilation)
|
| 1311 |
+
output_padding = _triple(output_padding)
|
| 1312 |
+
super().__init__(
|
| 1313 |
+
in_channels,
|
| 1314 |
+
out_channels,
|
| 1315 |
+
kernel_size,
|
| 1316 |
+
stride,
|
| 1317 |
+
padding,
|
| 1318 |
+
dilation,
|
| 1319 |
+
True,
|
| 1320 |
+
output_padding,
|
| 1321 |
+
groups,
|
| 1322 |
+
bias,
|
| 1323 |
+
padding_mode,
|
| 1324 |
+
**factory_kwargs,
|
| 1325 |
+
)
|
| 1326 |
+
|
| 1327 |
+
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 1328 |
+
if self.padding_mode != "zeros":
|
| 1329 |
+
raise ValueError(
|
| 1330 |
+
"Only `zeros` padding mode is supported for ConvTranspose3d"
|
| 1331 |
+
)
|
| 1332 |
+
|
| 1333 |
+
assert isinstance(self.padding, tuple)
|
| 1334 |
+
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
| 1335 |
+
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
| 1336 |
+
num_spatial_dims = 3
|
| 1337 |
+
output_padding = self._output_padding(
|
| 1338 |
+
input,
|
| 1339 |
+
output_size,
|
| 1340 |
+
self.stride, # type: ignore[arg-type]
|
| 1341 |
+
self.padding, # type: ignore[arg-type]
|
| 1342 |
+
self.kernel_size, # type: ignore[arg-type]
|
| 1343 |
+
num_spatial_dims,
|
| 1344 |
+
self.dilation, # type: ignore[arg-type]
|
| 1345 |
+
)
|
| 1346 |
+
|
| 1347 |
+
return F.conv_transpose3d(
|
| 1348 |
+
input,
|
| 1349 |
+
self.weight,
|
| 1350 |
+
self.bias,
|
| 1351 |
+
self.stride,
|
| 1352 |
+
self.padding,
|
| 1353 |
+
output_padding,
|
| 1354 |
+
self.groups,
|
| 1355 |
+
self.dilation,
|
| 1356 |
+
)
|
| 1357 |
+
|
| 1358 |
+
|
| 1359 |
+
# TODO: Deprecate and remove the following alias `_ConvTransposeMixin`.
|
| 1360 |
+
#
|
| 1361 |
+
# `_ConvTransposeMixin` was a mixin that was removed. It is meant to be used
|
| 1362 |
+
# with `_ConvNd` to construct actual module classes that implements conv
|
| 1363 |
+
# transpose ops:
|
| 1364 |
+
#
|
| 1365 |
+
# class MyConvTranspose(_ConvNd, _ConvTransposeMixin):
|
| 1366 |
+
# ...
|
| 1367 |
+
#
|
| 1368 |
+
# In PyTorch, it has been replaced by `_ConvTransposeNd`, which is a proper
|
| 1369 |
+
# subclass of `_ConvNd`. However, some user code in the wild still (incorrectly)
|
| 1370 |
+
# use the internal class `_ConvTransposeMixin`. Hence, we provide this alias
|
| 1371 |
+
# for BC, because it is cheap and easy for us to do so, even though that
|
| 1372 |
+
# `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as
|
| 1373 |
+
# above would still work).
|
| 1374 |
+
class _ConvTransposeMixin(_ConvTransposeNd):
|
| 1375 |
+
@deprecated(
|
| 1376 |
+
"`_ConvTransposeMixin` is a deprecated internal class. "
|
| 1377 |
+
"Please consider using public APIs.",
|
| 1378 |
+
category=FutureWarning,
|
| 1379 |
+
)
|
| 1380 |
+
def __init__(self, *args, **kwargs):
|
| 1381 |
+
super().__init__(*args, **kwargs)
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
# TODO: Conv2dLocal
|
| 1385 |
+
# TODO: Conv2dMap
|
| 1386 |
+
# TODO: ConvTranspose2dMap
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
class _LazyConvXdMixin(LazyModuleMixin):
|
| 1390 |
+
groups: int
|
| 1391 |
+
transposed: bool
|
| 1392 |
+
in_channels: int
|
| 1393 |
+
out_channels: int
|
| 1394 |
+
kernel_size: Tuple[int, ...]
|
| 1395 |
+
weight: UninitializedParameter
|
| 1396 |
+
bias: UninitializedParameter
|
| 1397 |
+
|
| 1398 |
+
def reset_parameters(self) -> None:
|
| 1399 |
+
# has_uninitialized_params is defined in parent class and it is using a protocol on self
|
| 1400 |
+
if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc]
|
| 1401 |
+
# "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined
|
| 1402 |
+
# in super class. Turns out that it is defined in _ConvND which is inherited by any class
|
| 1403 |
+
# that also inherits _LazyConvXdMixin
|
| 1404 |
+
super().reset_parameters() # type: ignore[misc]
|
| 1405 |
+
|
| 1406 |
+
# Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin
|
| 1407 |
+
def initialize_parameters(self, input: Tensor, *args, **kwargs) -> None: # type: ignore[override]
|
| 1408 |
+
# defined by parent class but using a protocol
|
| 1409 |
+
if self.has_uninitialized_params(): # type: ignore[misc]
|
| 1410 |
+
self.in_channels = self._get_in_channels(input)
|
| 1411 |
+
if self.in_channels % self.groups != 0:
|
| 1412 |
+
raise ValueError("in_channels must be divisible by groups")
|
| 1413 |
+
assert isinstance(self.weight, UninitializedParameter)
|
| 1414 |
+
if self.transposed:
|
| 1415 |
+
self.weight.materialize(
|
| 1416 |
+
(
|
| 1417 |
+
self.in_channels,
|
| 1418 |
+
self.out_channels // self.groups,
|
| 1419 |
+
*self.kernel_size,
|
| 1420 |
+
)
|
| 1421 |
+
)
|
| 1422 |
+
else:
|
| 1423 |
+
self.weight.materialize(
|
| 1424 |
+
(
|
| 1425 |
+
self.out_channels,
|
| 1426 |
+
self.in_channels // self.groups,
|
| 1427 |
+
*self.kernel_size,
|
| 1428 |
+
)
|
| 1429 |
+
)
|
| 1430 |
+
if self.bias is not None:
|
| 1431 |
+
assert isinstance(self.bias, UninitializedParameter)
|
| 1432 |
+
self.bias.materialize((self.out_channels,))
|
| 1433 |
+
self.reset_parameters()
|
| 1434 |
+
|
| 1435 |
+
# Function to extract in_channels from first input.
|
| 1436 |
+
def _get_in_channels(self, input: Tensor) -> int:
|
| 1437 |
+
num_spatial_dims = self._get_num_spatial_dims()
|
| 1438 |
+
num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim
|
| 1439 |
+
num_dims_batch = num_dims_no_batch + 1
|
| 1440 |
+
if input.dim() not in (num_dims_no_batch, num_dims_batch):
|
| 1441 |
+
raise RuntimeError(
|
| 1442 |
+
f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input "
|
| 1443 |
+
f"to {self.__class__.__name__}, but "
|
| 1444 |
+
f"got input of size: {input.shape}"
|
| 1445 |
+
)
|
| 1446 |
+
return input.shape[1] if input.dim() == num_dims_batch else input.shape[0]
|
| 1447 |
+
|
| 1448 |
+
# Function to return the number of spatial dims expected for inputs to the module.
|
| 1449 |
+
# This is expected to be implemented by subclasses.
|
| 1450 |
+
def _get_num_spatial_dims(self) -> int:
|
| 1451 |
+
raise NotImplementedError
|
| 1452 |
+
|
| 1453 |
+
|
| 1454 |
+
# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
| 1455 |
+
class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
| 1456 |
+
r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument.
|
| 1457 |
+
|
| 1458 |
+
The ``in_channels`` argument of the :class:`Conv1d` is inferred from the ``input.size(1)``.
|
| 1459 |
+
The attributes that will be lazily initialized are `weight` and `bias`.
|
| 1460 |
+
|
| 1461 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 1462 |
+
on lazy modules and their limitations.
|
| 1463 |
+
|
| 1464 |
+
Args:
|
| 1465 |
+
out_channels (int): Number of channels produced by the convolution
|
| 1466 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 1467 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 1468 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
| 1469 |
+
the input. Default: 0
|
| 1470 |
+
dilation (int or tuple, optional): Spacing between kernel
|
| 1471 |
+
elements. Default: 1
|
| 1472 |
+
groups (int, optional): Number of blocked connections from input
|
| 1473 |
+
channels to output channels. Default: 1
|
| 1474 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the
|
| 1475 |
+
output. Default: ``True``
|
| 1476 |
+
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
|
| 1477 |
+
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
| 1478 |
+
|
| 1479 |
+
.. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
|
| 1480 |
+
"""
|
| 1481 |
+
|
| 1482 |
+
# super class define this variable as None. "type: ignore[..] is required
|
| 1483 |
+
# since we are redefining the variable.
|
| 1484 |
+
cls_to_become = Conv1d # type: ignore[assignment]
|
| 1485 |
+
|
| 1486 |
+
def __init__(
|
| 1487 |
+
self,
|
| 1488 |
+
out_channels: int,
|
| 1489 |
+
kernel_size: _size_1_t,
|
| 1490 |
+
stride: _size_1_t = 1,
|
| 1491 |
+
padding: _size_1_t = 0,
|
| 1492 |
+
dilation: _size_1_t = 1,
|
| 1493 |
+
groups: int = 1,
|
| 1494 |
+
bias: bool = True,
|
| 1495 |
+
padding_mode: str = "zeros",
|
| 1496 |
+
device=None,
|
| 1497 |
+
dtype=None,
|
| 1498 |
+
) -> None:
|
| 1499 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1500 |
+
super().__init__(
|
| 1501 |
+
0,
|
| 1502 |
+
0,
|
| 1503 |
+
kernel_size,
|
| 1504 |
+
stride,
|
| 1505 |
+
padding,
|
| 1506 |
+
dilation,
|
| 1507 |
+
groups,
|
| 1508 |
+
# bias is hardcoded to False to avoid creating tensor
|
| 1509 |
+
# that will soon be overwritten.
|
| 1510 |
+
False,
|
| 1511 |
+
padding_mode,
|
| 1512 |
+
**factory_kwargs,
|
| 1513 |
+
)
|
| 1514 |
+
self.weight = UninitializedParameter(**factory_kwargs)
|
| 1515 |
+
self.out_channels = out_channels
|
| 1516 |
+
if bias:
|
| 1517 |
+
self.bias = UninitializedParameter(**factory_kwargs)
|
| 1518 |
+
|
| 1519 |
+
def _get_num_spatial_dims(self) -> int:
|
| 1520 |
+
return 1
|
| 1521 |
+
|
| 1522 |
+
|
| 1523 |
+
# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
| 1524 |
+
class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
| 1525 |
+
r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument.
|
| 1526 |
+
|
| 1527 |
+
The ``in_channels`` argument of the :class:`Conv2d` that is inferred from the ``input.size(1)``.
|
| 1528 |
+
The attributes that will be lazily initialized are `weight` and `bias`.
|
| 1529 |
+
|
| 1530 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 1531 |
+
on lazy modules and their limitations.
|
| 1532 |
+
|
| 1533 |
+
Args:
|
| 1534 |
+
out_channels (int): Number of channels produced by the convolution
|
| 1535 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 1536 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 1537 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
| 1538 |
+
the input. Default: 0
|
| 1539 |
+
dilation (int or tuple, optional): Spacing between kernel
|
| 1540 |
+
elements. Default: 1
|
| 1541 |
+
groups (int, optional): Number of blocked connections from input
|
| 1542 |
+
channels to output channels. Default: 1
|
| 1543 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the
|
| 1544 |
+
output. Default: ``True``
|
| 1545 |
+
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
|
| 1546 |
+
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
| 1547 |
+
|
| 1548 |
+
.. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
|
| 1549 |
+
"""
|
| 1550 |
+
|
| 1551 |
+
# super class define this variable as None. "type: ignore[..] is required
|
| 1552 |
+
# since we are redefining the variable.
|
| 1553 |
+
cls_to_become = Conv2d # type: ignore[assignment]
|
| 1554 |
+
|
| 1555 |
+
def __init__(
|
| 1556 |
+
self,
|
| 1557 |
+
out_channels: int,
|
| 1558 |
+
kernel_size: _size_2_t,
|
| 1559 |
+
stride: _size_2_t = 1,
|
| 1560 |
+
padding: _size_2_t = 0,
|
| 1561 |
+
dilation: _size_2_t = 1,
|
| 1562 |
+
groups: int = 1,
|
| 1563 |
+
bias: bool = True,
|
| 1564 |
+
padding_mode: str = "zeros", # TODO: refine this type
|
| 1565 |
+
device=None,
|
| 1566 |
+
dtype=None,
|
| 1567 |
+
) -> None:
|
| 1568 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1569 |
+
super().__init__(
|
| 1570 |
+
0,
|
| 1571 |
+
0,
|
| 1572 |
+
kernel_size,
|
| 1573 |
+
stride,
|
| 1574 |
+
padding,
|
| 1575 |
+
dilation,
|
| 1576 |
+
groups,
|
| 1577 |
+
# bias is hardcoded to False to avoid creating tensor
|
| 1578 |
+
# that will soon be overwritten.
|
| 1579 |
+
False,
|
| 1580 |
+
padding_mode,
|
| 1581 |
+
**factory_kwargs,
|
| 1582 |
+
)
|
| 1583 |
+
self.weight = UninitializedParameter(**factory_kwargs)
|
| 1584 |
+
self.out_channels = out_channels
|
| 1585 |
+
if bias:
|
| 1586 |
+
self.bias = UninitializedParameter(**factory_kwargs)
|
| 1587 |
+
|
| 1588 |
+
def _get_num_spatial_dims(self) -> int:
|
| 1589 |
+
return 2
|
| 1590 |
+
|
| 1591 |
+
|
| 1592 |
+
# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
| 1593 |
+
class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
| 1594 |
+
r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument.
|
| 1595 |
+
|
| 1596 |
+
The ``in_channels`` argument of the :class:`Conv3d` that is inferred from
|
| 1597 |
+
the ``input.size(1)``.
|
| 1598 |
+
The attributes that will be lazily initialized are `weight` and `bias`.
|
| 1599 |
+
|
| 1600 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 1601 |
+
on lazy modules and their limitations.
|
| 1602 |
+
|
| 1603 |
+
Args:
|
| 1604 |
+
out_channels (int): Number of channels produced by the convolution
|
| 1605 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 1606 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 1607 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
| 1608 |
+
the input. Default: 0
|
| 1609 |
+
dilation (int or tuple, optional): Spacing between kernel
|
| 1610 |
+
elements. Default: 1
|
| 1611 |
+
groups (int, optional): Number of blocked connections from input
|
| 1612 |
+
channels to output channels. Default: 1
|
| 1613 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the
|
| 1614 |
+
output. Default: ``True``
|
| 1615 |
+
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
|
| 1616 |
+
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
| 1617 |
+
|
| 1618 |
+
.. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
|
| 1619 |
+
"""
|
| 1620 |
+
|
| 1621 |
+
# super class define this variable as None. "type: ignore[..] is required
|
| 1622 |
+
# since we are redefining the variable.
|
| 1623 |
+
cls_to_become = Conv3d # type: ignore[assignment]
|
| 1624 |
+
|
| 1625 |
+
def __init__(
|
| 1626 |
+
self,
|
| 1627 |
+
out_channels: int,
|
| 1628 |
+
kernel_size: _size_3_t,
|
| 1629 |
+
stride: _size_3_t = 1,
|
| 1630 |
+
padding: _size_3_t = 0,
|
| 1631 |
+
dilation: _size_3_t = 1,
|
| 1632 |
+
groups: int = 1,
|
| 1633 |
+
bias: bool = True,
|
| 1634 |
+
padding_mode: str = "zeros",
|
| 1635 |
+
device=None,
|
| 1636 |
+
dtype=None,
|
| 1637 |
+
) -> None:
|
| 1638 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1639 |
+
super().__init__(
|
| 1640 |
+
0,
|
| 1641 |
+
0,
|
| 1642 |
+
kernel_size,
|
| 1643 |
+
stride,
|
| 1644 |
+
padding,
|
| 1645 |
+
dilation,
|
| 1646 |
+
groups,
|
| 1647 |
+
# bias is hardcoded to False to avoid creating tensor
|
| 1648 |
+
# that will soon be overwritten.
|
| 1649 |
+
False,
|
| 1650 |
+
padding_mode,
|
| 1651 |
+
**factory_kwargs,
|
| 1652 |
+
)
|
| 1653 |
+
self.weight = UninitializedParameter(**factory_kwargs)
|
| 1654 |
+
self.out_channels = out_channels
|
| 1655 |
+
if bias:
|
| 1656 |
+
self.bias = UninitializedParameter(**factory_kwargs)
|
| 1657 |
+
|
| 1658 |
+
def _get_num_spatial_dims(self) -> int:
|
| 1659 |
+
return 3
|
| 1660 |
+
|
| 1661 |
+
|
| 1662 |
+
# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
| 1663 |
+
class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc]
|
| 1664 |
+
r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument.
|
| 1665 |
+
|
| 1666 |
+
The ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from
|
| 1667 |
+
the ``input.size(1)``.
|
| 1668 |
+
The attributes that will be lazily initialized are `weight` and `bias`.
|
| 1669 |
+
|
| 1670 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 1671 |
+
on lazy modules and their limitations.
|
| 1672 |
+
|
| 1673 |
+
Args:
|
| 1674 |
+
out_channels (int): Number of channels produced by the convolution
|
| 1675 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 1676 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 1677 |
+
padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
|
| 1678 |
+
will be added to both sides of the input. Default: 0
|
| 1679 |
+
output_padding (int or tuple, optional): Additional size added to one side
|
| 1680 |
+
of the output shape. Default: 0
|
| 1681 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 1682 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
| 1683 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 1684 |
+
|
| 1685 |
+
.. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
|
| 1686 |
+
"""
|
| 1687 |
+
|
| 1688 |
+
# super class define this variable as None. "type: ignore[..] is required
|
| 1689 |
+
# since we are redefining the variable.
|
| 1690 |
+
cls_to_become = ConvTranspose1d # type: ignore[assignment]
|
| 1691 |
+
|
| 1692 |
+
def __init__(
|
| 1693 |
+
self,
|
| 1694 |
+
out_channels: int,
|
| 1695 |
+
kernel_size: _size_1_t,
|
| 1696 |
+
stride: _size_1_t = 1,
|
| 1697 |
+
padding: _size_1_t = 0,
|
| 1698 |
+
output_padding: _size_1_t = 0,
|
| 1699 |
+
groups: int = 1,
|
| 1700 |
+
bias: bool = True,
|
| 1701 |
+
dilation: _size_1_t = 1,
|
| 1702 |
+
padding_mode: str = "zeros",
|
| 1703 |
+
device=None,
|
| 1704 |
+
dtype=None,
|
| 1705 |
+
) -> None:
|
| 1706 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1707 |
+
super().__init__(
|
| 1708 |
+
0,
|
| 1709 |
+
0,
|
| 1710 |
+
kernel_size,
|
| 1711 |
+
stride,
|
| 1712 |
+
padding,
|
| 1713 |
+
output_padding,
|
| 1714 |
+
groups,
|
| 1715 |
+
# bias is hardcoded to False to avoid creating tensor
|
| 1716 |
+
# that will soon be overwritten.
|
| 1717 |
+
False,
|
| 1718 |
+
dilation,
|
| 1719 |
+
padding_mode,
|
| 1720 |
+
**factory_kwargs,
|
| 1721 |
+
)
|
| 1722 |
+
self.weight = UninitializedParameter(**factory_kwargs)
|
| 1723 |
+
self.out_channels = out_channels
|
| 1724 |
+
if bias:
|
| 1725 |
+
self.bias = UninitializedParameter(**factory_kwargs)
|
| 1726 |
+
|
| 1727 |
+
def _get_num_spatial_dims(self) -> int:
|
| 1728 |
+
return 1
|
| 1729 |
+
|
| 1730 |
+
|
| 1731 |
+
# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
| 1732 |
+
class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc]
|
| 1733 |
+
r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument.
|
| 1734 |
+
|
| 1735 |
+
The ``in_channels`` argument of the :class:`ConvTranspose2d` is inferred from
|
| 1736 |
+
the ``input.size(1)``.
|
| 1737 |
+
The attributes that will be lazily initialized are `weight` and `bias`.
|
| 1738 |
+
|
| 1739 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 1740 |
+
on lazy modules and their limitations.
|
| 1741 |
+
|
| 1742 |
+
Args:
|
| 1743 |
+
out_channels (int): Number of channels produced by the convolution
|
| 1744 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 1745 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 1746 |
+
padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
|
| 1747 |
+
will be added to both sides of each dimension in the input. Default: 0
|
| 1748 |
+
output_padding (int or tuple, optional): Additional size added to one side
|
| 1749 |
+
of each dimension in the output shape. Default: 0
|
| 1750 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 1751 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
| 1752 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 1753 |
+
|
| 1754 |
+
.. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
|
| 1755 |
+
"""
|
| 1756 |
+
|
| 1757 |
+
# super class define this variable as None. "type: ignore[..] is required
|
| 1758 |
+
# since we are redefining the variable.
|
| 1759 |
+
cls_to_become = ConvTranspose2d # type: ignore[assignment]
|
| 1760 |
+
|
| 1761 |
+
def __init__(
|
| 1762 |
+
self,
|
| 1763 |
+
out_channels: int,
|
| 1764 |
+
kernel_size: _size_2_t,
|
| 1765 |
+
stride: _size_2_t = 1,
|
| 1766 |
+
padding: _size_2_t = 0,
|
| 1767 |
+
output_padding: _size_2_t = 0,
|
| 1768 |
+
groups: int = 1,
|
| 1769 |
+
bias: bool = True,
|
| 1770 |
+
dilation: int = 1,
|
| 1771 |
+
padding_mode: str = "zeros",
|
| 1772 |
+
device=None,
|
| 1773 |
+
dtype=None,
|
| 1774 |
+
) -> None:
|
| 1775 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1776 |
+
super().__init__(
|
| 1777 |
+
0,
|
| 1778 |
+
0,
|
| 1779 |
+
kernel_size,
|
| 1780 |
+
stride,
|
| 1781 |
+
padding,
|
| 1782 |
+
output_padding,
|
| 1783 |
+
groups,
|
| 1784 |
+
# bias is hardcoded to False to avoid creating tensor
|
| 1785 |
+
# that will soon be overwritten.
|
| 1786 |
+
False,
|
| 1787 |
+
dilation,
|
| 1788 |
+
padding_mode,
|
| 1789 |
+
**factory_kwargs,
|
| 1790 |
+
)
|
| 1791 |
+
self.weight = UninitializedParameter(**factory_kwargs)
|
| 1792 |
+
self.out_channels = out_channels
|
| 1793 |
+
if bias:
|
| 1794 |
+
self.bias = UninitializedParameter(**factory_kwargs)
|
| 1795 |
+
|
| 1796 |
+
def _get_num_spatial_dims(self) -> int:
|
| 1797 |
+
return 2
|
| 1798 |
+
|
| 1799 |
+
|
| 1800 |
+
# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
| 1801 |
+
class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc]
|
| 1802 |
+
r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument.
|
| 1803 |
+
|
| 1804 |
+
The ``in_channels`` argument of the :class:`ConvTranspose3d` is inferred from
|
| 1805 |
+
the ``input.size(1)``.
|
| 1806 |
+
The attributes that will be lazily initialized are `weight` and `bias`.
|
| 1807 |
+
|
| 1808 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 1809 |
+
on lazy modules and their limitations.
|
| 1810 |
+
|
| 1811 |
+
Args:
|
| 1812 |
+
out_channels (int): Number of channels produced by the convolution
|
| 1813 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 1814 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 1815 |
+
padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
|
| 1816 |
+
will be added to both sides of each dimension in the input. Default: 0
|
| 1817 |
+
output_padding (int or tuple, optional): Additional size added to one side
|
| 1818 |
+
of each dimension in the output shape. Default: 0
|
| 1819 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 1820 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
| 1821 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 1822 |
+
|
| 1823 |
+
.. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
|
| 1824 |
+
"""
|
| 1825 |
+
|
| 1826 |
+
# super class define this variable as None. "type: ignore[..] is required
|
| 1827 |
+
# since we are redefining the variable.
|
| 1828 |
+
cls_to_become = ConvTranspose3d # type: ignore[assignment]
|
| 1829 |
+
|
| 1830 |
+
def __init__(
|
| 1831 |
+
self,
|
| 1832 |
+
out_channels: int,
|
| 1833 |
+
kernel_size: _size_3_t,
|
| 1834 |
+
stride: _size_3_t = 1,
|
| 1835 |
+
padding: _size_3_t = 0,
|
| 1836 |
+
output_padding: _size_3_t = 0,
|
| 1837 |
+
groups: int = 1,
|
| 1838 |
+
bias: bool = True,
|
| 1839 |
+
dilation: _size_3_t = 1,
|
| 1840 |
+
padding_mode: str = "zeros",
|
| 1841 |
+
device=None,
|
| 1842 |
+
dtype=None,
|
| 1843 |
+
) -> None:
|
| 1844 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1845 |
+
super().__init__(
|
| 1846 |
+
0,
|
| 1847 |
+
0,
|
| 1848 |
+
kernel_size,
|
| 1849 |
+
stride,
|
| 1850 |
+
padding,
|
| 1851 |
+
output_padding,
|
| 1852 |
+
groups,
|
| 1853 |
+
# bias is hardcoded to False to avoid creating tensor
|
| 1854 |
+
# that will soon be overwritten.
|
| 1855 |
+
False,
|
| 1856 |
+
dilation,
|
| 1857 |
+
padding_mode,
|
| 1858 |
+
**factory_kwargs,
|
| 1859 |
+
)
|
| 1860 |
+
self.weight = UninitializedParameter(**factory_kwargs)
|
| 1861 |
+
self.out_channels = out_channels
|
| 1862 |
+
if bias:
|
| 1863 |
+
self.bias = UninitializedParameter(**factory_kwargs)
|
| 1864 |
+
|
| 1865 |
+
def _get_num_spatial_dims(self) -> int:
|
| 1866 |
+
return 3
|