Spaces:
Runtime error
Runtime error
up
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- apex/.github/ISSUE_TEMPLATE/bug_report.md +0 -23
- apex/.gitignore +0 -147
- apex/.gitmodules +0 -7
- apex/.nojekyll +0 -0
- apex/LICENSE +0 -11
- apex/README.md +0 -187
- apex/apex/RNN/README.md +0 -3
- apex/apex/RNN/RNNBackend.py +0 -365
- apex/apex/RNN/__init__.py +0 -3
- apex/apex/RNN/cells.py +0 -84
- apex/apex/RNN/models.py +0 -56
- apex/apex/__init__.py +0 -68
- apex/apex/_autocast_utils.py +0 -26
- apex/apex/amp/README.md +0 -72
- apex/apex/amp/__init__.py +0 -5
- apex/apex/amp/__version__.py +0 -2
- apex/apex/amp/_amp_state.py +0 -59
- apex/apex/amp/_initialize.py +0 -265
- apex/apex/amp/_process_optimizer.py +0 -489
- apex/apex/amp/amp.py +0 -183
- apex/apex/amp/compat.py +0 -46
- apex/apex/amp/frontend.py +0 -446
- apex/apex/amp/handle.py +0 -281
- apex/apex/amp/lists/__init__.py +0 -0
- apex/apex/amp/lists/functional_overrides.py +0 -80
- apex/apex/amp/lists/tensor_overrides.py +0 -63
- apex/apex/amp/lists/torch_overrides.py +0 -115
- apex/apex/amp/opt.py +0 -103
- apex/apex/amp/rnn_compat.py +0 -53
- apex/apex/amp/scaler.py +0 -217
- apex/apex/amp/utils.py +0 -210
- apex/apex/amp/wrap.py +0 -276
- apex/apex/contrib/__init__.py +0 -0
- apex/apex/contrib/bottleneck/__init__.py +0 -2
- apex/apex/contrib/bottleneck/bottleneck.py +0 -749
- apex/apex/contrib/bottleneck/halo_exchangers.py +0 -180
- apex/apex/contrib/bottleneck/test.py +0 -71
- apex/apex/contrib/clip_grad/__init__.py +0 -1
- apex/apex/contrib/clip_grad/clip_grad.py +0 -128
- apex/apex/contrib/conv_bias_relu/__init__.py +0 -2
- apex/apex/contrib/conv_bias_relu/conv_bias_relu.py +0 -104
- apex/apex/contrib/csrc/bottleneck/bottleneck.cpp +0 -0
- apex/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp +0 -2153
- apex/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp +0 -163
- apex/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp +0 -479
- apex/apex/contrib/csrc/cudnn_gbn/norm_sample.h +0 -153
- apex/apex/contrib/csrc/fmha/fmha_api.cpp +0 -365
- apex/apex/contrib/csrc/fmha/src/fmha.h +0 -163
- apex/apex/contrib/csrc/fmha/src/fmha/gemm.h +0 -314
- apex/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h +0 -456
apex/.github/ISSUE_TEMPLATE/bug_report.md
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
---
|
2 |
-
name: Bug report
|
3 |
-
about: Create a report to help us improve apex
|
4 |
-
title: ''
|
5 |
-
labels: bug
|
6 |
-
assignees: ''
|
7 |
-
|
8 |
-
---
|
9 |
-
|
10 |
-
**Describe the Bug**
|
11 |
-
|
12 |
-
**Minimal Steps/Code to Reproduce the Bug**
|
13 |
-
<!--
|
14 |
-
Please list the *minimal* steps or provide a code snippet for us to be able to reproduce the bug.
|
15 |
-
|
16 |
-
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
|
17 |
-
-->
|
18 |
-
|
19 |
-
**Expected Behavior**
|
20 |
-
<!-- A clear and concise description of what you expected to happen. -->
|
21 |
-
|
22 |
-
**Environment**
|
23 |
-
<!-- OS, version of Python, CUDA, PyTorch; collect these via `python -m torch.utils.collect_env` -->
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/.gitignore
DELETED
@@ -1,147 +0,0 @@
|
|
1 |
-
apex.egg-info
|
2 |
-
dist
|
3 |
-
build
|
4 |
-
docs/build
|
5 |
-
*~
|
6 |
-
__pycache__
|
7 |
-
.vscode
|
8 |
-
|
9 |
-
# Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
|
10 |
-
# Byte-compiled / optimized / DLL files
|
11 |
-
__pycache__/
|
12 |
-
*.py[cod]
|
13 |
-
*$py.class
|
14 |
-
|
15 |
-
# C extensions
|
16 |
-
*.so
|
17 |
-
|
18 |
-
# Distribution / packaging
|
19 |
-
.Python
|
20 |
-
build/
|
21 |
-
develop-eggs/
|
22 |
-
dist/
|
23 |
-
downloads/
|
24 |
-
eggs/
|
25 |
-
.eggs/
|
26 |
-
lib/
|
27 |
-
lib64/
|
28 |
-
parts/
|
29 |
-
sdist/
|
30 |
-
var/
|
31 |
-
wheels/
|
32 |
-
share/python-wheels/
|
33 |
-
*.egg-info/
|
34 |
-
.installed.cfg
|
35 |
-
*.egg
|
36 |
-
MANIFEST
|
37 |
-
|
38 |
-
# PyInstaller
|
39 |
-
# Usually these files are written by a python script from a template
|
40 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
41 |
-
*.manifest
|
42 |
-
*.spec
|
43 |
-
|
44 |
-
# Installer logs
|
45 |
-
pip-log.txt
|
46 |
-
pip-delete-this-directory.txt
|
47 |
-
|
48 |
-
# Unit test / coverage reports
|
49 |
-
htmlcov/
|
50 |
-
.tox/
|
51 |
-
.nox/
|
52 |
-
.coverage
|
53 |
-
.coverage.*
|
54 |
-
.cache
|
55 |
-
nosetests.xml
|
56 |
-
coverage.xml
|
57 |
-
*.cover
|
58 |
-
*.py,cover
|
59 |
-
.hypothesis/
|
60 |
-
.pytest_cache/
|
61 |
-
cover/
|
62 |
-
|
63 |
-
# Translations
|
64 |
-
*.mo
|
65 |
-
*.pot
|
66 |
-
|
67 |
-
# Django stuff:
|
68 |
-
*.log
|
69 |
-
local_settings.py
|
70 |
-
db.sqlite3
|
71 |
-
db.sqlite3-journal
|
72 |
-
|
73 |
-
# Flask stuff:
|
74 |
-
instance/
|
75 |
-
.webassets-cache
|
76 |
-
|
77 |
-
# Scrapy stuff:
|
78 |
-
.scrapy
|
79 |
-
|
80 |
-
# Sphinx documentation
|
81 |
-
docs/_build/
|
82 |
-
|
83 |
-
# PyBuilder
|
84 |
-
.pybuilder/
|
85 |
-
target/
|
86 |
-
|
87 |
-
# Jupyter Notebook
|
88 |
-
.ipynb_checkpoints
|
89 |
-
|
90 |
-
# IPython
|
91 |
-
profile_default/
|
92 |
-
ipython_config.py
|
93 |
-
|
94 |
-
# pyenv
|
95 |
-
# For a library or package, you might want to ignore these files since the code is
|
96 |
-
# intended to run in multiple environments; otherwise, check them in:
|
97 |
-
# .python-version
|
98 |
-
|
99 |
-
# pipenv
|
100 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
101 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
102 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
103 |
-
# install all needed dependencies.
|
104 |
-
#Pipfile.lock
|
105 |
-
|
106 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
107 |
-
__pypackages__/
|
108 |
-
|
109 |
-
# Celery stuff
|
110 |
-
celerybeat-schedule
|
111 |
-
celerybeat.pid
|
112 |
-
|
113 |
-
# SageMath parsed files
|
114 |
-
*.sage.py
|
115 |
-
|
116 |
-
# Environments
|
117 |
-
.env
|
118 |
-
.venv
|
119 |
-
env/
|
120 |
-
venv/
|
121 |
-
ENV/
|
122 |
-
env.bak/
|
123 |
-
venv.bak/
|
124 |
-
|
125 |
-
# Spyder project settings
|
126 |
-
.spyderproject
|
127 |
-
.spyproject
|
128 |
-
|
129 |
-
# Rope project settings
|
130 |
-
.ropeproject
|
131 |
-
|
132 |
-
# mkdocs documentation
|
133 |
-
/site
|
134 |
-
|
135 |
-
# mypy
|
136 |
-
.mypy_cache/
|
137 |
-
.dmypy.json
|
138 |
-
dmypy.json
|
139 |
-
|
140 |
-
# Pyre type checker
|
141 |
-
.pyre/
|
142 |
-
|
143 |
-
# pytype static type analyzer
|
144 |
-
.pytype/
|
145 |
-
|
146 |
-
# Cython debug symbols
|
147 |
-
cython_debug/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/.gitmodules
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
[submodule "apex/contrib/csrc/multihead_attn/cutlass"]
|
2 |
-
path = apex/contrib/csrc/multihead_attn/cutlass
|
3 |
-
url = https://github.com/NVIDIA/cutlass.git
|
4 |
-
branch = v1.2.0
|
5 |
-
[submodule "apex/contrib/csrc/cudnn-frontend"]
|
6 |
-
path = apex/contrib/csrc/cudnn-frontend
|
7 |
-
url = https://github.com/NVIDIA/cudnn-frontend.git
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/.nojekyll
DELETED
File without changes
|
apex/LICENSE
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
All rights reserved.
|
2 |
-
|
3 |
-
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
4 |
-
|
5 |
-
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
6 |
-
|
7 |
-
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
8 |
-
|
9 |
-
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
10 |
-
|
11 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/README.md
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
# Introduction
|
2 |
-
|
3 |
-
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
|
4 |
-
Some of the code here will be included in upstream Pytorch eventually.
|
5 |
-
The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
|
6 |
-
|
7 |
-
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
|
8 |
-
|
9 |
-
## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides
|
10 |
-
|
11 |
-
# Contents
|
12 |
-
|
13 |
-
## 1. Amp: Automatic Mixed Precision
|
14 |
-
|
15 |
-
**Deprecated. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)**
|
16 |
-
|
17 |
-
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
|
18 |
-
Users can easily experiment with different pure and mixed precision training modes by supplying
|
19 |
-
different flags to `amp.initialize`.
|
20 |
-
|
21 |
-
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
|
22 |
-
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
|
23 |
-
|
24 |
-
[API Documentation](https://nvidia.github.io/apex/amp.html)
|
25 |
-
|
26 |
-
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
|
27 |
-
|
28 |
-
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
|
29 |
-
|
30 |
-
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
|
31 |
-
|
32 |
-
## 2. Distributed Training
|
33 |
-
|
34 |
-
**`apex.parallel.DistributedDataParallel` is deprecated. Use [`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=distributeddataparallel#torch.nn.parallel.DistributedDataParallel)**
|
35 |
-
|
36 |
-
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
|
37 |
-
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
|
38 |
-
optimized for NVIDIA's NCCL communication library.
|
39 |
-
|
40 |
-
[API Documentation](https://nvidia.github.io/apex/parallel.html)
|
41 |
-
|
42 |
-
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
|
43 |
-
|
44 |
-
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
|
45 |
-
|
46 |
-
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
|
47 |
-
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
|
48 |
-
|
49 |
-
### Synchronized Batch Normalization
|
50 |
-
|
51 |
-
**Deprecated. Use [`torch.nn.SyncBatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html)**
|
52 |
-
|
53 |
-
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
|
54 |
-
support synchronized BN.
|
55 |
-
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
|
56 |
-
Synchronous BN has been used in cases where only a small
|
57 |
-
local minibatch can fit on each GPU.
|
58 |
-
Allreduced stats increase the effective batch size for the BN layer to the
|
59 |
-
global batch size across all processes (which, technically, is the correct
|
60 |
-
formulation).
|
61 |
-
Synchronous BN has been observed to improve converged accuracy in some of our research models.
|
62 |
-
|
63 |
-
### Checkpointing
|
64 |
-
|
65 |
-
To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
|
66 |
-
as well as `amp.load_state_dict()` to restore these attributes.
|
67 |
-
|
68 |
-
In order to get bitwise accuracy, we recommend the following workflow:
|
69 |
-
```python
|
70 |
-
# Initialization
|
71 |
-
opt_level = 'O1'
|
72 |
-
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
|
73 |
-
|
74 |
-
# Train your model
|
75 |
-
...
|
76 |
-
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
77 |
-
scaled_loss.backward()
|
78 |
-
...
|
79 |
-
|
80 |
-
# Save checkpoint
|
81 |
-
checkpoint = {
|
82 |
-
'model': model.state_dict(),
|
83 |
-
'optimizer': optimizer.state_dict(),
|
84 |
-
'amp': amp.state_dict()
|
85 |
-
}
|
86 |
-
torch.save(checkpoint, 'amp_checkpoint.pt')
|
87 |
-
...
|
88 |
-
|
89 |
-
# Restore
|
90 |
-
model = ...
|
91 |
-
optimizer = ...
|
92 |
-
checkpoint = torch.load('amp_checkpoint.pt')
|
93 |
-
|
94 |
-
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
|
95 |
-
model.load_state_dict(checkpoint['model'])
|
96 |
-
optimizer.load_state_dict(checkpoint['optimizer'])
|
97 |
-
amp.load_state_dict(checkpoint['amp'])
|
98 |
-
|
99 |
-
# Continue training
|
100 |
-
...
|
101 |
-
```
|
102 |
-
|
103 |
-
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
|
104 |
-
|
105 |
-
# Installation
|
106 |
-
Each [`apex.contrib`](./apex/contrib) module requires one or more install options other than `--cpp_ext` and `--cuda_ext`.
|
107 |
-
Note that contrib modules do not necessarily support stable PyTorch releases.
|
108 |
-
|
109 |
-
## Containers
|
110 |
-
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
|
111 |
-
The containers come with all the custom extensions available at the moment.
|
112 |
-
|
113 |
-
See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
|
114 |
-
- how to pull a container
|
115 |
-
- how to run a pulled container
|
116 |
-
- release notes
|
117 |
-
|
118 |
-
## From Source
|
119 |
-
|
120 |
-
To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
|
121 |
-
|
122 |
-
The latest stable release obtainable from https://pytorch.org should also work.
|
123 |
-
|
124 |
-
We recommend installing [`Ninja`](https://ninja-build.org/) to make compilation faster.
|
125 |
-
|
126 |
-
### Linux
|
127 |
-
For performance and full functionality, we recommend installing Apex with
|
128 |
-
CUDA and C++ extensions via
|
129 |
-
```bash
|
130 |
-
git clone https://github.com/NVIDIA/apex
|
131 |
-
cd apex
|
132 |
-
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
|
133 |
-
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
134 |
-
# otherwise
|
135 |
-
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
136 |
-
```
|
137 |
-
|
138 |
-
APEX also supports a Python-only build via
|
139 |
-
```bash
|
140 |
-
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
|
141 |
-
```
|
142 |
-
A Python-only build omits:
|
143 |
-
- Fused kernels required to use `apex.optimizers.FusedAdam`.
|
144 |
-
- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
|
145 |
-
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
|
146 |
-
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
|
147 |
-
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
|
148 |
-
|
149 |
-
|
150 |
-
### [Experimental] Windows
|
151 |
-
`pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" .` may work if you were able to build Pytorch from source
|
152 |
-
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
|
153 |
-
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
|
154 |
-
|
155 |
-
|
156 |
-
## Custom C++/CUDA Extensions and Install Options
|
157 |
-
|
158 |
-
If a requirement of a module is not met, then it will not be built.
|
159 |
-
|
160 |
-
| Module Name | Install Option | Misc |
|
161 |
-
|---------------|------------------|--------|
|
162 |
-
| `apex_C` | `--cpp_ext` | |
|
163 |
-
| `amp_C` | `--cuda_ext` | |
|
164 |
-
| `syncbn` | `--cuda_ext` | |
|
165 |
-
| `fused_layer_norm_cuda` | `--cuda_ext` | [`apex.normalization`](./apex/normalization) |
|
166 |
-
| `mlp_cuda` | `--cuda_ext` | |
|
167 |
-
| `scaled_upper_triang_masked_softmax_cuda` | `--cuda_ext` | |
|
168 |
-
| `generic_scaled_masked_softmax_cuda` | `--cuda_ext` | |
|
169 |
-
| `scaled_masked_softmax_cuda` | `--cuda_ext` | |
|
170 |
-
| `fused_weight_gradient_mlp_cuda` | `--cuda_ext` | Requires CUDA>=11 |
|
171 |
-
| `permutation_search_cuda` | `--permutation_search` | [`apex.contrib.sparsity`](./apex/contrib/sparsity) |
|
172 |
-
| `bnp` | `--bnp` | [`apex.contrib.groupbn`](./apex/contrib/groupbn) |
|
173 |
-
| `xentropy` | `--xentropy` | [`apex.contrib.xentropy`](./apex/contrib/xentropy) |
|
174 |
-
| `focal_loss_cuda` | `--focal_loss` | [`apex.contrib.focal_loss`](./apex/contrib/focal_loss) |
|
175 |
-
| `fused_index_mul_2d` | `--index_mul_2d` | [`apex.contrib.index_mul_2d`](./apex/contrib/index_mul_2d) |
|
176 |
-
| `fused_adam_cuda` | `--deprecated_fused_adam` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) |
|
177 |
-
| `fused_lamb_cuda` | `--deprecated_fused_lamb` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) |
|
178 |
-
| `fast_layer_norm` | `--fast_layer_norm` | [`apex.contrib.layer_norm`](./apex/contrib/layer_norm). different from `fused_layer_norm` |
|
179 |
-
| `fmhalib` | `--fmha` | [`apex.contrib.fmha`](./apex/contrib/fmha) |
|
180 |
-
| `fast_multihead_attn` | `--fast_multihead_attn` | [`apex.contrib.multihead_attn`](./apex/contrib/multihead_attn) |
|
181 |
-
| `transducer_joint_cuda` | `--transducer` | [`apex.contrib.transducer`](./apex/contrib/transducer) |
|
182 |
-
| `transducer_loss_cuda` | `--transducer` | [`apex.contrib.transducer`](./apex/contrib/transducer) |
|
183 |
-
| `cudnn_gbn_lib` | `--cudnn_gbn` | Requires cuDNN>=8.5, [`apex.contrib.cudnn_gbn`](./apex/contrib/cudnn_gbn) |
|
184 |
-
| `peer_memory_cuda` | `--peer_memory` | [`apex.contrib.peer_memory`](./apex/contrib/peer_memory) |
|
185 |
-
| `nccl_p2p_cuda` | `--nccl_p2p` | Requires NCCL >= 2.10, [`apex.contrib.nccl_p2p`](./apex/contrib/nccl_p2p) |
|
186 |
-
| `fast_bottleneck` | `--fast_bottleneck` | Requires `peer_memory_cuda` and `nccl_p2p_cuda`, [`apex.contrib.bottleneck`](./apex/contrib/bottleneck) |
|
187 |
-
| `fused_conv_bias_relu` | `--fused_conv_bias_relu` | Requires cuDNN>=8.4, [`apex.contrib.conv_bias_relu`](./apex/contrib/conv_bias_relu) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/RNN/README.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
**This module will be removed by the end of February 2023**
|
2 |
-
|
3 |
-
Under construction...
|
|
|
|
|
|
|
|
apex/apex/RNN/RNNBackend.py
DELETED
@@ -1,365 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from torch.autograd import Variable
|
4 |
-
|
5 |
-
import torch.nn.functional as F
|
6 |
-
|
7 |
-
import math
|
8 |
-
|
9 |
-
|
10 |
-
def is_iterable(maybe_iterable):
|
11 |
-
return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple)
|
12 |
-
|
13 |
-
|
14 |
-
def flatten_list(tens_list):
|
15 |
-
"""
|
16 |
-
flatten_list
|
17 |
-
"""
|
18 |
-
if not is_iterable(tens_list):
|
19 |
-
return tens_list
|
20 |
-
|
21 |
-
return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() )
|
22 |
-
|
23 |
-
|
24 |
-
#These modules always assumes batch_first
|
25 |
-
class bidirectionalRNN(nn.Module):
|
26 |
-
"""
|
27 |
-
bidirectionalRNN
|
28 |
-
"""
|
29 |
-
def __init__(self, inputRNN, num_layers=1, dropout = 0):
|
30 |
-
super(bidirectionalRNN, self).__init__()
|
31 |
-
self.dropout = dropout
|
32 |
-
self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout)
|
33 |
-
self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout)
|
34 |
-
self.rnns = nn.ModuleList([self.fwd, self.bckwrd])
|
35 |
-
|
36 |
-
#collect hidden option will return all hidden/cell states from entire RNN
|
37 |
-
def forward(self, input, collect_hidden=False):
|
38 |
-
"""
|
39 |
-
forward()
|
40 |
-
"""
|
41 |
-
seq_len = input.size(0)
|
42 |
-
bsz = input.size(1)
|
43 |
-
|
44 |
-
fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden))
|
45 |
-
bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden))
|
46 |
-
|
47 |
-
output = torch.cat( [fwd_out, bckwrd_out], -1 )
|
48 |
-
hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) )
|
49 |
-
|
50 |
-
return output, hiddens
|
51 |
-
|
52 |
-
def reset_parameters(self):
|
53 |
-
"""
|
54 |
-
reset_parameters()
|
55 |
-
"""
|
56 |
-
for rnn in self.rnns:
|
57 |
-
rnn.reset_parameters()
|
58 |
-
|
59 |
-
def init_hidden(self, bsz):
|
60 |
-
"""
|
61 |
-
init_hidden()
|
62 |
-
"""
|
63 |
-
for rnn in self.rnns:
|
64 |
-
rnn.init_hidden(bsz)
|
65 |
-
|
66 |
-
def detach_hidden(self):
|
67 |
-
"""
|
68 |
-
detach_hidden()
|
69 |
-
"""
|
70 |
-
for rnn in self.rnns:
|
71 |
-
rnn.detachHidden()
|
72 |
-
|
73 |
-
def reset_hidden(self, bsz):
|
74 |
-
"""
|
75 |
-
reset_hidden()
|
76 |
-
"""
|
77 |
-
for rnn in self.rnns:
|
78 |
-
rnn.reset_hidden(bsz)
|
79 |
-
|
80 |
-
def init_inference(self, bsz):
|
81 |
-
"""
|
82 |
-
init_inference()
|
83 |
-
"""
|
84 |
-
for rnn in self.rnns:
|
85 |
-
rnn.init_inference(bsz)
|
86 |
-
|
87 |
-
|
88 |
-
#assumes hidden_state[0] of inputRNN is output hidden state
|
89 |
-
#constructor either takes an RNNCell or list of RNN layers
|
90 |
-
class stackedRNN(nn.Module):
|
91 |
-
"""
|
92 |
-
stackedRNN
|
93 |
-
"""
|
94 |
-
def __init__(self, inputRNN, num_layers=1, dropout=0):
|
95 |
-
super(stackedRNN, self).__init__()
|
96 |
-
|
97 |
-
self.dropout = dropout
|
98 |
-
|
99 |
-
if isinstance(inputRNN, RNNCell):
|
100 |
-
self.rnns = [inputRNN]
|
101 |
-
for i in range(num_layers-1):
|
102 |
-
self.rnns.append(inputRNN.new_like(inputRNN.output_size))
|
103 |
-
elif isinstance(inputRNN, list):
|
104 |
-
assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers"
|
105 |
-
self.rnns=inputRNN
|
106 |
-
else:
|
107 |
-
raise RuntimeError()
|
108 |
-
|
109 |
-
self.nLayers = len(self.rnns)
|
110 |
-
|
111 |
-
self.rnns = nn.ModuleList(self.rnns)
|
112 |
-
|
113 |
-
|
114 |
-
'''
|
115 |
-
Returns output as hidden_state[0] Tensor([sequence steps][batch size][features])
|
116 |
-
If collect hidden will also return Tuple(
|
117 |
-
[n_hidden_states][sequence steps] Tensor([layer][batch size][features])
|
118 |
-
)
|
119 |
-
If not collect hidden will also return Tuple(
|
120 |
-
[n_hidden_states] Tensor([layer][batch size][features])
|
121 |
-
'''
|
122 |
-
def forward(self, input, collect_hidden=False, reverse=False):
|
123 |
-
"""
|
124 |
-
forward()
|
125 |
-
"""
|
126 |
-
seq_len = input.size(0)
|
127 |
-
bsz = input.size(1)
|
128 |
-
inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)
|
129 |
-
|
130 |
-
hidden_states = [[] for i in range(self.nLayers)]
|
131 |
-
outputs = []
|
132 |
-
|
133 |
-
for seq in inp_iter:
|
134 |
-
for layer in range(self.nLayers):
|
135 |
-
|
136 |
-
if layer == 0:
|
137 |
-
prev_out = input[seq]
|
138 |
-
|
139 |
-
outs = self.rnns[layer](prev_out)
|
140 |
-
|
141 |
-
if collect_hidden:
|
142 |
-
hidden_states[layer].append(outs)
|
143 |
-
elif seq == seq_len-1:
|
144 |
-
hidden_states[layer].append(outs)
|
145 |
-
|
146 |
-
prev_out = outs[0]
|
147 |
-
|
148 |
-
outputs.append(prev_out)
|
149 |
-
|
150 |
-
if reverse:
|
151 |
-
outputs = list(reversed(outputs))
|
152 |
-
'''
|
153 |
-
At this point outputs is in format:
|
154 |
-
list( [seq_length] x Tensor([bsz][features]) )
|
155 |
-
need to convert it to:
|
156 |
-
list( Tensor([seq_length][bsz][features]) )
|
157 |
-
'''
|
158 |
-
output = flatten_list(outputs)
|
159 |
-
|
160 |
-
'''
|
161 |
-
hidden_states at this point is in format:
|
162 |
-
list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )
|
163 |
-
need to convert it to:
|
164 |
-
For not collect hidden:
|
165 |
-
list( [hidden_states] x Tensor([layer][bsz][features]) )
|
166 |
-
For collect hidden:
|
167 |
-
list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
|
168 |
-
'''
|
169 |
-
if not collect_hidden:
|
170 |
-
seq_len = 1
|
171 |
-
n_hid = self.rnns[0].n_hidden_states
|
172 |
-
new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]
|
173 |
-
|
174 |
-
|
175 |
-
for i in range(n_hid):
|
176 |
-
for j in range(seq_len):
|
177 |
-
for k in range(self.nLayers):
|
178 |
-
new_hidden[i][j][k] = hidden_states[k][j][i]
|
179 |
-
|
180 |
-
hidden_states = new_hidden
|
181 |
-
#Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )
|
182 |
-
#Reverse seq_length if reverse
|
183 |
-
if reverse:
|
184 |
-
hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)
|
185 |
-
|
186 |
-
#flatten layer dimension into tensor
|
187 |
-
hiddens = list( list(
|
188 |
-
flatten_list(seq) for seq in hidden )
|
189 |
-
for hidden in hidden_states )
|
190 |
-
|
191 |
-
#Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
|
192 |
-
#Remove seq_length dimension if not collect_hidden
|
193 |
-
if not collect_hidden:
|
194 |
-
hidden_states = list( entry[0] for entry in hidden_states)
|
195 |
-
return output, hidden_states
|
196 |
-
|
197 |
-
def reset_parameters(self):
|
198 |
-
"""
|
199 |
-
reset_parameters()
|
200 |
-
"""
|
201 |
-
for rnn in self.rnns:
|
202 |
-
rnn.reset_parameters()
|
203 |
-
|
204 |
-
def init_hidden(self, bsz):
|
205 |
-
"""
|
206 |
-
init_hidden()
|
207 |
-
"""
|
208 |
-
for rnn in self.rnns:
|
209 |
-
rnn.init_hidden(bsz)
|
210 |
-
|
211 |
-
def detach_hidden(self):
|
212 |
-
"""
|
213 |
-
detach_hidden()
|
214 |
-
"""
|
215 |
-
for rnn in self.rnns:
|
216 |
-
rnn.detach_hidden()
|
217 |
-
|
218 |
-
def reset_hidden(self, bsz):
|
219 |
-
"""
|
220 |
-
reset_hidden()
|
221 |
-
"""
|
222 |
-
for rnn in self.rnns:
|
223 |
-
rnn.reset_hidden(bsz)
|
224 |
-
|
225 |
-
def init_inference(self, bsz):
|
226 |
-
"""
|
227 |
-
init_inference()
|
228 |
-
"""
|
229 |
-
for rnn in self.rnns:
|
230 |
-
rnn.init_inference(bsz)
|
231 |
-
|
232 |
-
class RNNCell(nn.Module):
|
233 |
-
"""
|
234 |
-
RNNCell
|
235 |
-
gate_multiplier is related to the architecture you're working with
|
236 |
-
For LSTM-like it will be 4 and GRU-like will be 3.
|
237 |
-
Always assumes input is NOT batch_first.
|
238 |
-
Output size that's not hidden size will use output projection
|
239 |
-
Hidden_states is number of hidden states that are needed for cell
|
240 |
-
if one will go directly to cell as tensor, if more will go as list
|
241 |
-
"""
|
242 |
-
def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None):
|
243 |
-
super(RNNCell, self).__init__()
|
244 |
-
|
245 |
-
self.gate_multiplier = gate_multiplier
|
246 |
-
self.input_size = input_size
|
247 |
-
self.hidden_size = hidden_size
|
248 |
-
self.cell = cell
|
249 |
-
self.bias = bias
|
250 |
-
self.output_size = output_size
|
251 |
-
if output_size is None:
|
252 |
-
self.output_size = hidden_size
|
253 |
-
|
254 |
-
self.gate_size = gate_multiplier * self.hidden_size
|
255 |
-
self.n_hidden_states = n_hidden_states
|
256 |
-
|
257 |
-
self.w_ih = nn.Parameter(torch.empty(self.gate_size, self.input_size))
|
258 |
-
self.w_hh = nn.Parameter(torch.empty(self.gate_size, self.output_size))
|
259 |
-
|
260 |
-
#Check if there's recurrent projection
|
261 |
-
if(self.output_size != self.hidden_size):
|
262 |
-
self.w_ho = nn.Parameter(torch.empty(self.output_size, self.hidden_size))
|
263 |
-
|
264 |
-
self.b_ih = self.b_hh = None
|
265 |
-
if self.bias:
|
266 |
-
self.b_ih = nn.Parameter(torch.empty(self.gate_size))
|
267 |
-
self.b_hh = nn.Parameter(torch.empty(self.gate_size))
|
268 |
-
|
269 |
-
#hidden states for forward
|
270 |
-
self.hidden = [ None for states in range(self.n_hidden_states)]
|
271 |
-
|
272 |
-
self.reset_parameters()
|
273 |
-
|
274 |
-
def new_like(self, new_input_size=None):
|
275 |
-
"""
|
276 |
-
new_like()
|
277 |
-
"""
|
278 |
-
if new_input_size is None:
|
279 |
-
new_input_size = self.input_size
|
280 |
-
|
281 |
-
return type(self)(self.gate_multiplier,
|
282 |
-
new_input_size,
|
283 |
-
self.hidden_size,
|
284 |
-
self.cell,
|
285 |
-
self.n_hidden_states,
|
286 |
-
self.bias,
|
287 |
-
self.output_size)
|
288 |
-
|
289 |
-
|
290 |
-
#Use xavier where we can (weights), otherwise use uniform (bias)
|
291 |
-
def reset_parameters(self, gain=1):
|
292 |
-
"""
|
293 |
-
reset_parameters()
|
294 |
-
"""
|
295 |
-
stdev = 1.0 / math.sqrt(self.hidden_size)
|
296 |
-
for param in self.parameters():
|
297 |
-
param.data.uniform_(-stdev, stdev)
|
298 |
-
'''
|
299 |
-
Xavier reset:
|
300 |
-
def reset_parameters(self, gain=1):
|
301 |
-
stdv = 1.0 / math.sqrt(self.gate_size)
|
302 |
-
|
303 |
-
for param in self.parameters():
|
304 |
-
if (param.dim() > 1):
|
305 |
-
torch.nn.init.xavier_normal(param, gain)
|
306 |
-
else:
|
307 |
-
param.data.uniform_(-stdv, stdv)
|
308 |
-
'''
|
309 |
-
def init_hidden(self, bsz):
|
310 |
-
"""
|
311 |
-
init_hidden()
|
312 |
-
"""
|
313 |
-
for param in self.parameters():
|
314 |
-
if param is not None:
|
315 |
-
a_param = param
|
316 |
-
break
|
317 |
-
|
318 |
-
for i, _ in enumerate(self.hidden):
|
319 |
-
if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz):
|
320 |
-
|
321 |
-
if i==0:
|
322 |
-
hidden_size = self.output_size
|
323 |
-
else:
|
324 |
-
hidden_size = self.hidden_size
|
325 |
-
|
326 |
-
tens = a_param.data.new(bsz, hidden_size).zero_()
|
327 |
-
self.hidden[i] = Variable(tens, requires_grad=False)
|
328 |
-
|
329 |
-
|
330 |
-
def reset_hidden(self, bsz):
|
331 |
-
"""
|
332 |
-
reset_hidden()
|
333 |
-
"""
|
334 |
-
for i, _ in enumerate(self.hidden):
|
335 |
-
self.hidden[i] = None
|
336 |
-
self.init_hidden(bsz)
|
337 |
-
|
338 |
-
def detach_hidden(self):
|
339 |
-
"""
|
340 |
-
detach_hidden()
|
341 |
-
"""
|
342 |
-
for i, _ in enumerate(self.hidden):
|
343 |
-
if self.hidden[i] is None:
|
344 |
-
raise RuntimeError("Must initialize hidden state before you can detach it")
|
345 |
-
for i, _ in enumerate(self.hidden):
|
346 |
-
self.hidden[i] = self.hidden[i].detach()
|
347 |
-
|
348 |
-
def forward(self, input):
|
349 |
-
"""
|
350 |
-
forward()
|
351 |
-
if not inited or bsz has changed this will create hidden states
|
352 |
-
"""
|
353 |
-
self.init_hidden(input.size()[0])
|
354 |
-
|
355 |
-
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
|
356 |
-
self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh)
|
357 |
-
if(self.n_hidden_states > 1):
|
358 |
-
self.hidden = list(self.hidden)
|
359 |
-
else:
|
360 |
-
self.hidden=[self.hidden]
|
361 |
-
|
362 |
-
if self.output_size != self.hidden_size:
|
363 |
-
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
|
364 |
-
|
365 |
-
return tuple(self.hidden)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/RNN/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .models import LSTM, GRU, ReLU, Tanh, mLSTM
|
2 |
-
|
3 |
-
__all__ = ['models']
|
|
|
|
|
|
|
|
apex/apex/RNN/cells.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
from .RNNBackend import RNNCell
|
6 |
-
|
7 |
-
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
|
8 |
-
|
9 |
-
import math
|
10 |
-
|
11 |
-
|
12 |
-
class mLSTMRNNCell(RNNCell):
|
13 |
-
"""
|
14 |
-
mLSTMRNNCell
|
15 |
-
"""
|
16 |
-
|
17 |
-
def __init__(self, input_size, hidden_size, bias = False, output_size = None):
|
18 |
-
gate_multiplier = 4
|
19 |
-
super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)
|
20 |
-
|
21 |
-
self.w_mih = nn.Parameter(torch.empty(self.output_size, self.input_size))
|
22 |
-
self.w_mhh = nn.Parameter(torch.empty(self.output_size, self.output_size))
|
23 |
-
|
24 |
-
self.reset_parameters()
|
25 |
-
|
26 |
-
def forward(self, input):
|
27 |
-
"""
|
28 |
-
mLSTMRNNCell.forward()
|
29 |
-
"""
|
30 |
-
#if not inited or bsz has changed this will create hidden states
|
31 |
-
self.init_hidden(input.size()[0])
|
32 |
-
|
33 |
-
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
|
34 |
-
|
35 |
-
self.hidden = list(
|
36 |
-
self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh,
|
37 |
-
b_ih=self.b_ih, b_hh=self.b_hh)
|
38 |
-
)
|
39 |
-
|
40 |
-
if self.output_size != self.hidden_size:
|
41 |
-
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
|
42 |
-
return tuple(self.hidden)
|
43 |
-
|
44 |
-
|
45 |
-
def new_like(self, new_input_size=None):
|
46 |
-
if new_input_size is None:
|
47 |
-
new_input_size = self.input_size
|
48 |
-
|
49 |
-
return type(self)(
|
50 |
-
new_input_size,
|
51 |
-
self.hidden_size,
|
52 |
-
self.bias,
|
53 |
-
self.output_size)
|
54 |
-
|
55 |
-
def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None):
|
56 |
-
"""
|
57 |
-
mLSTMCell
|
58 |
-
"""
|
59 |
-
|
60 |
-
if input.is_cuda:
|
61 |
-
igates = F.linear(input, w_ih)
|
62 |
-
m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
|
63 |
-
hgates = F.linear(m, w_hh)
|
64 |
-
|
65 |
-
state = fusedBackend.LSTMFused.apply
|
66 |
-
return state(igates, hgates, hidden[1], b_ih, b_hh)
|
67 |
-
|
68 |
-
hx, cx = hidden
|
69 |
-
|
70 |
-
m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
|
71 |
-
gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh)
|
72 |
-
|
73 |
-
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
74 |
-
|
75 |
-
ingate = F.sigmoid(ingate)
|
76 |
-
forgetgate = F.sigmoid(forgetgate)
|
77 |
-
cellgate = F.tanh(cellgate)
|
78 |
-
outgate = F.sigmoid(outgate)
|
79 |
-
|
80 |
-
cy = (forgetgate * cx) + (ingate * cellgate)
|
81 |
-
hy = outgate * F.tanh(cy)
|
82 |
-
|
83 |
-
return hy, cy
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/RNN/models.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell
|
4 |
-
|
5 |
-
from apex import deprecated_warning
|
6 |
-
from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell
|
7 |
-
from .cells import mLSTMRNNCell, mLSTMCell
|
8 |
-
|
9 |
-
def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):
|
10 |
-
"""
|
11 |
-
:class:`toRNNBackend`
|
12 |
-
"""
|
13 |
-
|
14 |
-
deprecated_warning("`apex.RNN` is deprecated and will be removed by the end of February 2023.")
|
15 |
-
if bidirectional:
|
16 |
-
return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)
|
17 |
-
else:
|
18 |
-
return stackedRNN(inputRNN, num_layers, dropout = dropout)
|
19 |
-
|
20 |
-
|
21 |
-
def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
22 |
-
"""
|
23 |
-
:class:`LSTM`
|
24 |
-
"""
|
25 |
-
inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size)
|
26 |
-
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
27 |
-
|
28 |
-
def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
29 |
-
"""
|
30 |
-
:class:`GRU`
|
31 |
-
"""
|
32 |
-
inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size)
|
33 |
-
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
34 |
-
|
35 |
-
def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
36 |
-
"""
|
37 |
-
:class:`ReLU`
|
38 |
-
"""
|
39 |
-
inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size)
|
40 |
-
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
41 |
-
|
42 |
-
def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
43 |
-
"""
|
44 |
-
:class:`Tanh`
|
45 |
-
"""
|
46 |
-
inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)
|
47 |
-
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
48 |
-
|
49 |
-
def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
50 |
-
"""
|
51 |
-
:class:`mLSTM`
|
52 |
-
"""
|
53 |
-
inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size)
|
54 |
-
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/__init__.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import warnings
|
3 |
-
|
4 |
-
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
|
5 |
-
import torch
|
6 |
-
|
7 |
-
|
8 |
-
__all__ = ["amp", "fp16_utils", "optimizers", "normalization", "transformer"]
|
9 |
-
|
10 |
-
|
11 |
-
if torch.distributed.is_available():
|
12 |
-
from . import parallel
|
13 |
-
__all__.append("parallel")
|
14 |
-
|
15 |
-
from . import amp
|
16 |
-
from . import fp16_utils
|
17 |
-
|
18 |
-
# For optimizers and normalization there is no Python fallback.
|
19 |
-
# Absence of cuda backend is a hard error.
|
20 |
-
# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
|
21 |
-
# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
|
22 |
-
# so they expect those backends to be available, but for some reason they actually aren't
|
23 |
-
# available (for example because they built improperly in a way that isn't revealed until
|
24 |
-
# load time) the error message is timely and visible.
|
25 |
-
from . import optimizers
|
26 |
-
from . import normalization
|
27 |
-
from . import transformer
|
28 |
-
|
29 |
-
|
30 |
-
# Logging utilities for apex.transformer module
|
31 |
-
class RankInfoFormatter(logging.Formatter):
|
32 |
-
|
33 |
-
def format(self, record):
|
34 |
-
from apex.transformer.parallel_state import get_rank_info
|
35 |
-
record.rank_info = get_rank_info()
|
36 |
-
return super().format(record)
|
37 |
-
|
38 |
-
|
39 |
-
_library_root_logger = logging.getLogger(__name__)
|
40 |
-
handler = logging.StreamHandler()
|
41 |
-
handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
|
42 |
-
_library_root_logger.addHandler(handler)
|
43 |
-
_library_root_logger.propagate = False
|
44 |
-
|
45 |
-
|
46 |
-
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
|
47 |
-
cudnn_available = torch.backends.cudnn.is_available()
|
48 |
-
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
|
49 |
-
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
|
50 |
-
warnings.warn(
|
51 |
-
f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
|
52 |
-
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
|
53 |
-
)
|
54 |
-
return False
|
55 |
-
return True
|
56 |
-
|
57 |
-
|
58 |
-
class DeprecatedFeatureWarning(FutureWarning):
|
59 |
-
pass
|
60 |
-
|
61 |
-
|
62 |
-
def deprecated_warning(msg: str) -> None:
|
63 |
-
if (
|
64 |
-
not torch.distributed.is_available
|
65 |
-
or not torch.distributed.is_initialized()
|
66 |
-
or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)
|
67 |
-
):
|
68 |
-
warnings.warn(msg, DeprecatedFeatureWarning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/_autocast_utils.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
from typing import Optional, Sequence
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
__all__ = ["_cast_if_autocast_enabled"]
|
7 |
-
|
8 |
-
|
9 |
-
def _get_autocast_dtypes() -> Sequence[torch.dtype]:
|
10 |
-
if torch.cuda.is_bf16_supported():
|
11 |
-
return [torch.half, torch.bfloat16]
|
12 |
-
return [torch.half]
|
13 |
-
|
14 |
-
|
15 |
-
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
16 |
-
if not torch.is_autocast_enabled():
|
17 |
-
return torch.float or dtype
|
18 |
-
else:
|
19 |
-
return torch.get_autocast_gpu_dtype()
|
20 |
-
|
21 |
-
|
22 |
-
def _cast_if_autocast_enabled(*args):
|
23 |
-
if not torch.is_autocast_enabled():
|
24 |
-
return args
|
25 |
-
else:
|
26 |
-
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/README.md
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
# amp: Automatic Mixed Precision
|
2 |
-
|
3 |
-
## Annotating User Functions
|
4 |
-
|
5 |
-
Nearly all PyTorch user code needs nothing more than the two steps
|
6 |
-
above to use amp. After all, custom layers are built out of simpler
|
7 |
-
PyTorch components, and amp already can see those.
|
8 |
-
|
9 |
-
However, any custom C++ or CUDA code is outside of amp's (default)
|
10 |
-
view of things. For example, suppose I implemented a new recurrent
|
11 |
-
cell called a "forgetful recurrent unit" that calls directly into a
|
12 |
-
CUDA backend:
|
13 |
-
|
14 |
-
```python
|
15 |
-
from backend import FRUBackend
|
16 |
-
|
17 |
-
def fru(input, hidden, weight, bias):
|
18 |
-
# call to CUDA code
|
19 |
-
FRUBackend(input, hidden, weight, bias)
|
20 |
-
```
|
21 |
-
|
22 |
-
In this case, it is possible to get a runtime type mismatch. For
|
23 |
-
example, you might have `input` in fp16, and `weight` in fp32, and amp
|
24 |
-
doesn't have the visibility to insert an appropriate cast.
|
25 |
-
|
26 |
-
amp exposes two ways to handle "invisible" backend code: function
|
27 |
-
annotations and explicit registration.
|
28 |
-
|
29 |
-
#### Function annotation
|
30 |
-
|
31 |
-
The first way to handle backend code is a set of function annotations:
|
32 |
-
|
33 |
-
- `@amp.half_function`
|
34 |
-
- `@amp.float_function`
|
35 |
-
- `@amp.promote_function`
|
36 |
-
|
37 |
-
These correspond to:
|
38 |
-
|
39 |
-
- Cast all arguments to fp16
|
40 |
-
- Cast all argumnets fo fp32
|
41 |
-
- If there are any type mismatches, cast everything to the widest type
|
42 |
-
|
43 |
-
In our example, we believe that the FRU unit is fp16-safe and will get
|
44 |
-
performance gains from casting its arguments to fp16, so we write:
|
45 |
-
|
46 |
-
```python
|
47 |
-
@amp.half_function
|
48 |
-
def fru(input, hidden, weight, bias):
|
49 |
-
#...
|
50 |
-
```
|
51 |
-
|
52 |
-
#### Explicit registration
|
53 |
-
|
54 |
-
The other way to handle backend code is with explicit function
|
55 |
-
registration:
|
56 |
-
|
57 |
-
- `amp.register_half_function(module, function_name)`
|
58 |
-
- `amp.register_float_function(module, function_name)`
|
59 |
-
- `amp.register_promote_function(module, function_name)`
|
60 |
-
|
61 |
-
When using this API, `module` is the containing class or module for
|
62 |
-
the function, and `function_name` is the _string_ name of the
|
63 |
-
function. Note that the function must be registered before the call to
|
64 |
-
`amp.initalize()`.
|
65 |
-
|
66 |
-
For our FRU unit, we can register the backend function directly:
|
67 |
-
|
68 |
-
```python
|
69 |
-
import backend
|
70 |
-
|
71 |
-
amp.register_half_function(backend, 'FRUBackend')
|
72 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
from .amp import init, half_function, float_function, promote_function,\
|
2 |
-
register_half_function, register_float_function, register_promote_function
|
3 |
-
from .handle import scale_loss, disable_casts
|
4 |
-
from .frontend import initialize, state_dict, load_state_dict
|
5 |
-
from ._amp_state import master_params, _amp_state
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/__version__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
VERSION = (0, 1, 0)
|
2 |
-
__version__ = '.'.join(map(str, VERSION))
|
|
|
|
|
|
apex/apex/amp/_amp_state.py
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
# This is a "header object" that allows different amp modules to communicate.
|
2 |
-
# I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
|
3 |
-
# But apparently it's ok:
|
4 |
-
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
|
5 |
-
import torch
|
6 |
-
|
7 |
-
|
8 |
-
class AmpState(object):
|
9 |
-
def __init__(self):
|
10 |
-
self.hard_override=False
|
11 |
-
self.allow_incoming_model_not_fp32 = False
|
12 |
-
self.verbosity=1
|
13 |
-
|
14 |
-
|
15 |
-
# Attribute stash. Could also just stash things as global module attributes.
|
16 |
-
_amp_state = AmpState()
|
17 |
-
|
18 |
-
|
19 |
-
def warn_or_err(msg):
|
20 |
-
if _amp_state.hard_override:
|
21 |
-
print("Warning: " + msg)
|
22 |
-
else:
|
23 |
-
raise RuntimeError(msg)
|
24 |
-
# I'm not sure if allowing hard_override is a good idea.
|
25 |
-
# + " If you're sure you know what you're doing, supply " +
|
26 |
-
# "hard_override=True to amp.initialize.")
|
27 |
-
|
28 |
-
|
29 |
-
def maybe_print(msg, rank0=False):
|
30 |
-
distributed = torch.distributed.is_available() and \
|
31 |
-
torch.distributed.is_initialized() and \
|
32 |
-
torch.distributed.get_world_size() > 1
|
33 |
-
if _amp_state.verbosity > 0:
|
34 |
-
if rank0:
|
35 |
-
if distributed:
|
36 |
-
if torch.distributed.get_rank() == 0:
|
37 |
-
print(msg)
|
38 |
-
else:
|
39 |
-
print(msg)
|
40 |
-
else:
|
41 |
-
print(msg)
|
42 |
-
|
43 |
-
|
44 |
-
# def iter_params(param_groups):
|
45 |
-
# for group in param_groups:
|
46 |
-
# for p in group['params']:
|
47 |
-
# yield p
|
48 |
-
|
49 |
-
|
50 |
-
def master_params(optimizer):
|
51 |
-
"""
|
52 |
-
Generator expression that iterates over the params owned by ``optimizer``.
|
53 |
-
|
54 |
-
Args:
|
55 |
-
optimizer: An optimizer previously returned from ``amp.initialize``.
|
56 |
-
"""
|
57 |
-
for group in optimizer.param_groups:
|
58 |
-
for p in group['params']:
|
59 |
-
yield p
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/_initialize.py
DELETED
@@ -1,265 +0,0 @@
|
|
1 |
-
import collections.abc as container_abcs
|
2 |
-
from types import MethodType
|
3 |
-
import functools
|
4 |
-
import sys
|
5 |
-
import warnings
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from ._amp_state import _amp_state, warn_or_err
|
11 |
-
from .handle import disable_casts
|
12 |
-
from .scaler import LossScaler
|
13 |
-
from ._process_optimizer import _process_optimizer
|
14 |
-
from apex.fp16_utils import convert_network
|
15 |
-
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
|
16 |
-
from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
|
17 |
-
|
18 |
-
if torch.distributed.is_available():
|
19 |
-
from ..parallel import DistributedDataParallel as apex_DDP
|
20 |
-
from ..parallel.LARC import LARC
|
21 |
-
|
22 |
-
|
23 |
-
def to_type(dtype, t):
|
24 |
-
if isinstance(t, torch.Tensor):
|
25 |
-
if not t.is_cuda:
|
26 |
-
# This should not be a hard error, since it may be legitimate.
|
27 |
-
warnings.warn("An input tensor was not cuda.")
|
28 |
-
# GANs require this.
|
29 |
-
# if t.requires_grad:
|
30 |
-
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
|
31 |
-
# "its gradients will not be properly allreduced by DDP.")
|
32 |
-
if t.is_floating_point():
|
33 |
-
return t.to(dtype)
|
34 |
-
return t
|
35 |
-
else:
|
36 |
-
# Trust the user's custom batch type, that's all I can do here.
|
37 |
-
return t.to(dtype)
|
38 |
-
|
39 |
-
|
40 |
-
# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
|
41 |
-
def applier(value, fn):
|
42 |
-
if isinstance(value, torch.Tensor):
|
43 |
-
return fn(value)
|
44 |
-
elif isinstance(value, str):
|
45 |
-
return value
|
46 |
-
elif isinstance(value, np.ndarray):
|
47 |
-
return value
|
48 |
-
elif hasattr(value, "to"): # Allow handling of custom batch classes
|
49 |
-
return fn(value)
|
50 |
-
elif isinstance(value, container_abcs.Mapping):
|
51 |
-
return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
|
52 |
-
elif isinstance(value, container_abcs.Iterable):
|
53 |
-
return type(value)(applier(v, fn) for v in value)
|
54 |
-
else:
|
55 |
-
# Do I want this to fire off even if someone chooses to pass something ordinary like
|
56 |
-
# an int or float? May be more annoying than it's worth.
|
57 |
-
# print("Warning: unrecognized type in applier. If your input data is a custom class, "
|
58 |
-
# "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
|
59 |
-
# "Amp will check for your custom to() and invoke it to cast the batch's "
|
60 |
-
# "floating-point Tensors to the appropriate type. "
|
61 |
-
# "Also, if your data is a custom class, it is your responsibility to ensure that "
|
62 |
-
# "any Tensors you want to be cuda are already cuda."
|
63 |
-
return value
|
64 |
-
|
65 |
-
|
66 |
-
def check_models(models):
|
67 |
-
for model in models:
|
68 |
-
parallel_type = None
|
69 |
-
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
70 |
-
parallel_type = "torch.nn.parallel.DistributedDataParallel"
|
71 |
-
if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP):
|
72 |
-
parallel_type = "apex.parallel.DistributedDataParallel"
|
73 |
-
if isinstance(model, torch.nn.parallel.DataParallel):
|
74 |
-
parallel_type = "torch.nn.parallel.DataParallel"
|
75 |
-
if parallel_type is not None:
|
76 |
-
raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) +
|
77 |
-
"Parallel wrappers should only be applied to the model(s) AFTER \n"
|
78 |
-
"the model(s) have been returned from amp.initialize.")
|
79 |
-
|
80 |
-
|
81 |
-
def check_params_fp32(models):
|
82 |
-
for model in models:
|
83 |
-
for name, param in model.named_parameters():
|
84 |
-
if param.is_floating_point():
|
85 |
-
if 'Half' in param.type():
|
86 |
-
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
|
87 |
-
"When using amp.initialize, you do not need to call .half() on your model\n"
|
88 |
-
"before passing it, no matter what optimization level you choose.".format(
|
89 |
-
name, param.type()))
|
90 |
-
elif not param.is_cuda:
|
91 |
-
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
|
92 |
-
"When using amp.initialize, you need to provide a model with parameters\n"
|
93 |
-
"located on a CUDA device before passing it no matter what optimization level\n"
|
94 |
-
"you chose. Use model.to('cuda') to use the default device.".format(
|
95 |
-
name, param.type()))
|
96 |
-
|
97 |
-
# Backward compatibility for PyTorch 0.4
|
98 |
-
if hasattr(model, 'named_buffers'):
|
99 |
-
buf_iter = model.named_buffers()
|
100 |
-
else:
|
101 |
-
buf_iter = model._buffers
|
102 |
-
for obj in buf_iter:
|
103 |
-
if type(obj)==tuple:
|
104 |
-
name, buf = obj
|
105 |
-
else:
|
106 |
-
name, buf = obj, buf_iter[obj]
|
107 |
-
if buf.is_floating_point():
|
108 |
-
if 'Half' in buf.type():
|
109 |
-
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
|
110 |
-
"When using amp.initialize, you do not need to call .half() on your model\n"
|
111 |
-
"before passing it, no matter what optimization level you choose.".format(
|
112 |
-
name, buf.type()))
|
113 |
-
elif not buf.is_cuda:
|
114 |
-
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
|
115 |
-
"When using amp.initialize, you need to provide a model with buffers\n"
|
116 |
-
"located on a CUDA device before passing it no matter what optimization level\n"
|
117 |
-
"you chose. Use model.to('cuda') to use the default device.".format(
|
118 |
-
name, buf.type()))
|
119 |
-
|
120 |
-
|
121 |
-
def check_optimizers(optimizers):
|
122 |
-
for optim in optimizers:
|
123 |
-
bad_optim_type = None
|
124 |
-
if isinstance(optim, FP16_Optimizer_general):
|
125 |
-
bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
|
126 |
-
if isinstance(optim, FP16_Optimizer_for_fused):
|
127 |
-
bad_optim_type = "apex.optimizers.FP16_Optimizer"
|
128 |
-
if bad_optim_type is not None:
|
129 |
-
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
|
130 |
-
"The optimizer(s) passed to amp.initialize() must be bare \n"
|
131 |
-
"instances of either ordinary Pytorch optimizers, or Apex fused \n"
|
132 |
-
"optimizers.\n")
|
133 |
-
|
134 |
-
|
135 |
-
class O2StateDictHook(object):
|
136 |
-
def __init__(self, fn):
|
137 |
-
self.fn = fn
|
138 |
-
|
139 |
-
def __call__(self, module, state_dict, prefix, local_metadata):
|
140 |
-
for key in state_dict:
|
141 |
-
param = state_dict[key]
|
142 |
-
if 'Half' in param.type():
|
143 |
-
param = param.to(torch.float32)
|
144 |
-
state_dict[key] = param
|
145 |
-
|
146 |
-
|
147 |
-
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
|
148 |
-
from .amp import init as amp_init
|
149 |
-
|
150 |
-
optimizers_was_list = False
|
151 |
-
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
|
152 |
-
optimizers = [optimizers]
|
153 |
-
elif optimizers is None:
|
154 |
-
optimizers = []
|
155 |
-
elif isinstance(optimizers, list):
|
156 |
-
optimizers_was_list = True
|
157 |
-
check_optimizers(optimizers)
|
158 |
-
else:
|
159 |
-
check_optimizers([optimizers])
|
160 |
-
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
|
161 |
-
|
162 |
-
if isinstance(models, torch.nn.Module):
|
163 |
-
models_was_list = False
|
164 |
-
models = [models]
|
165 |
-
elif isinstance(models, list):
|
166 |
-
models_was_list = True
|
167 |
-
else:
|
168 |
-
raise TypeError("models must be either a single model or a list of models.")
|
169 |
-
|
170 |
-
check_models(models)
|
171 |
-
|
172 |
-
if not _amp_state.allow_incoming_model_not_fp32:
|
173 |
-
check_params_fp32(models)
|
174 |
-
|
175 |
-
# In the future, when FP16_Optimizer can be deprecated and master weights can
|
176 |
-
# become an attribute, remember to stash master weights before casting the model.
|
177 |
-
|
178 |
-
if properties.cast_model_type:
|
179 |
-
if properties.keep_batchnorm_fp32:
|
180 |
-
for model in models:
|
181 |
-
convert_network(model, properties.cast_model_type)
|
182 |
-
else:
|
183 |
-
for model in models:
|
184 |
-
model.to(properties.cast_model_type)
|
185 |
-
|
186 |
-
input_caster = functools.partial(to_type, properties.cast_model_type)
|
187 |
-
if cast_model_outputs is not None:
|
188 |
-
output_caster = functools.partial(to_type, cast_model_outputs)
|
189 |
-
else:
|
190 |
-
output_caster = functools.partial(to_type, torch.float32)
|
191 |
-
|
192 |
-
for model in models:
|
193 |
-
# Patch the forward method to cast incoming data to the correct type, and
|
194 |
-
# outgoing data to float32, so "the user never needs to call .half()."
|
195 |
-
# I like writing things explicitly more than decorators.
|
196 |
-
def patch_forward(old_fwd):
|
197 |
-
def new_fwd(*args, **kwargs):
|
198 |
-
output = old_fwd(*applier(args, input_caster),
|
199 |
-
**applier(kwargs, input_caster))
|
200 |
-
return applier(output, output_caster)
|
201 |
-
return new_fwd
|
202 |
-
|
203 |
-
model.forward = patch_forward(model.forward)
|
204 |
-
|
205 |
-
# State dict trick to recast any preexisting per-param state tensors
|
206 |
-
for optimizer in optimizers:
|
207 |
-
optimizer.load_state_dict(optimizer.state_dict())
|
208 |
-
|
209 |
-
# patch model.state_dict() to return float32 params
|
210 |
-
for model in models:
|
211 |
-
for module in model.modules():
|
212 |
-
module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))
|
213 |
-
|
214 |
-
elif cast_model_outputs is not None:
|
215 |
-
output_caster = functools.partial(to_type, cast_model_outputs)
|
216 |
-
|
217 |
-
for model in models:
|
218 |
-
def patch_forward(old_fwd):
|
219 |
-
def new_fwd(*args, **kwargs):
|
220 |
-
output = old_fwd(*args, **kwargs)
|
221 |
-
return applier(output, output_caster)
|
222 |
-
return new_fwd
|
223 |
-
|
224 |
-
model.forward = patch_forward(model.forward)
|
225 |
-
|
226 |
-
for i, optimizer in enumerate(optimizers):
|
227 |
-
optimizers[i] = _process_optimizer(optimizer, properties)
|
228 |
-
|
229 |
-
_amp_state.loss_scalers = []
|
230 |
-
for _ in range(num_losses):
|
231 |
-
_amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
|
232 |
-
min_loss_scale=_amp_state.min_loss_scale,
|
233 |
-
max_loss_scale=_amp_state.max_loss_scale))
|
234 |
-
|
235 |
-
if properties.patch_torch_functions:
|
236 |
-
# handle is unused here. It's accessible later through a global value anyway.
|
237 |
-
handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
|
238 |
-
for optimizer in optimizers:
|
239 |
-
# Disable Amp casting for the optimizer step, because it should only be
|
240 |
-
# applied to FP32 master params anyway.
|
241 |
-
def patch_step(old_step):
|
242 |
-
def new_step(self, *args, **kwargs):
|
243 |
-
with disable_casts():
|
244 |
-
output = old_step(*args, **kwargs)
|
245 |
-
return output
|
246 |
-
return new_step
|
247 |
-
|
248 |
-
optimizer.step = MethodType(patch_step(optimizer.step), optimizer)
|
249 |
-
|
250 |
-
if optimizers_was_list:
|
251 |
-
if models_was_list:
|
252 |
-
return models, optimizers
|
253 |
-
else:
|
254 |
-
return models[0], optimizers
|
255 |
-
else:
|
256 |
-
if models_was_list:
|
257 |
-
if len(optimizers) == 0:
|
258 |
-
return models
|
259 |
-
else:
|
260 |
-
return models, optimizers[0]
|
261 |
-
else:
|
262 |
-
if len(optimizers) == 0:
|
263 |
-
return models[0]
|
264 |
-
else:
|
265 |
-
return models[0], optimizers[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/_process_optimizer.py
DELETED
@@ -1,489 +0,0 @@
|
|
1 |
-
import types
|
2 |
-
from ..fp16_utils import master_params_to_model_params
|
3 |
-
from ..multi_tensor_apply import multi_tensor_applier
|
4 |
-
from ._amp_state import maybe_print
|
5 |
-
import torch
|
6 |
-
from ..optimizers import FusedSGD
|
7 |
-
|
8 |
-
|
9 |
-
class AmpOptimizerState(object):
|
10 |
-
def __init__(self):
|
11 |
-
pass
|
12 |
-
|
13 |
-
|
14 |
-
def _master_params_to_model_params(self):
|
15 |
-
stash = self._amp_stash
|
16 |
-
if multi_tensor_applier.available:
|
17 |
-
if len(stash.all_fp16_params) > 0:
|
18 |
-
multi_tensor_applier(
|
19 |
-
stash.multi_tensor_scale,
|
20 |
-
stash.dummy_overflow_buf,
|
21 |
-
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
|
22 |
-
1.0)
|
23 |
-
else:
|
24 |
-
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
|
25 |
-
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
|
26 |
-
|
27 |
-
|
28 |
-
def lazy_init_with_master_weights(self):
|
29 |
-
stash = self._amp_stash
|
30 |
-
stash.fp16_groups = []
|
31 |
-
stash.fp32_from_fp16_groups = []
|
32 |
-
stash.fp32_from_fp32_groups = []
|
33 |
-
for i, param_group in enumerate(self.param_groups):
|
34 |
-
# maybe_print("FP16_Optimizer processing param group {}:".format(i))
|
35 |
-
fp16_params_this_group = []
|
36 |
-
fp32_params_this_group = []
|
37 |
-
fp32_from_fp16_params_this_group = []
|
38 |
-
for i, param in enumerate(param_group['params']):
|
39 |
-
if param.requires_grad:
|
40 |
-
if param.type() == 'torch.cuda.HalfTensor':
|
41 |
-
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
|
42 |
-
# .format(param.size()))
|
43 |
-
fp16_params_this_group.append(param)
|
44 |
-
master_param = param.detach().clone().float()
|
45 |
-
master_param.requires_grad = True
|
46 |
-
param_group['params'][i] = master_param
|
47 |
-
fp32_from_fp16_params_this_group.append(master_param)
|
48 |
-
# Reset existing state dict key to the new master param.
|
49 |
-
# We still need to recast per-param state tensors, if any, to FP32.
|
50 |
-
if param in self.state:
|
51 |
-
self.state[master_param] = self.state.pop(param)
|
52 |
-
elif param.type() == 'torch.cuda.FloatTensor':
|
53 |
-
# maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
|
54 |
-
# .format(param.size()))
|
55 |
-
fp32_params_this_group.append(param)
|
56 |
-
param_group['params'][i] = param
|
57 |
-
else:
|
58 |
-
raise TypeError("Optimizer's parameters must be either "
|
59 |
-
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
60 |
-
"Received {}".format(param.type()))
|
61 |
-
|
62 |
-
stash.fp16_groups.append(fp16_params_this_group)
|
63 |
-
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
|
64 |
-
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
|
65 |
-
|
66 |
-
stash.all_fp16_params = []
|
67 |
-
for group in stash.fp16_groups:
|
68 |
-
stash.all_fp16_params += group
|
69 |
-
|
70 |
-
stash.all_fp32_from_fp16_params = []
|
71 |
-
for group in stash.fp32_from_fp16_groups:
|
72 |
-
stash.all_fp32_from_fp16_params += group
|
73 |
-
|
74 |
-
stash.all_fp32_from_fp32_params = []
|
75 |
-
for group in stash.fp32_from_fp32_groups:
|
76 |
-
stash.all_fp32_from_fp32_params += group
|
77 |
-
|
78 |
-
# all_fp16_grad_stash is only needed for fused optimizers.
|
79 |
-
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
|
80 |
-
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
|
81 |
-
stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
|
82 |
-
|
83 |
-
for param in stash.all_fp32_from_fp16_params:
|
84 |
-
param.grad = None
|
85 |
-
|
86 |
-
for param in stash.all_fp32_from_fp32_params:
|
87 |
-
param.grad = None
|
88 |
-
|
89 |
-
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
|
90 |
-
self.load_state_dict(self.state_dict())
|
91 |
-
|
92 |
-
|
93 |
-
def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
|
94 |
-
grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
|
95 |
-
|
96 |
-
# not much to do if scale == 1.0 and static scaling
|
97 |
-
if scaler.loss_scale() == 1.0 and not scaler.dynamic:
|
98 |
-
# Clear the stash.
|
99 |
-
for i in range(len(stashed_grads)):
|
100 |
-
stashed_grads[i] = None
|
101 |
-
return
|
102 |
-
|
103 |
-
if scale_override is not None:
|
104 |
-
grads_have_scale, stashed_have_scale, out_scale = scale_override
|
105 |
-
|
106 |
-
# This is a lot of python overhead...
|
107 |
-
grads_needing_unscale = []
|
108 |
-
grads_needing_unscale_with_stash = []
|
109 |
-
stashed = []
|
110 |
-
for param, stashed_grad in zip(params, stashed_grads):
|
111 |
-
if param.grad is None and stashed_grad is not None:
|
112 |
-
param.grad = stashed_grad
|
113 |
-
elif param.grad is not None and stashed_grad is None:
|
114 |
-
grads_needing_unscale.append(param.grad)
|
115 |
-
elif param.grad is not None and stashed_grad is not None:
|
116 |
-
grads_needing_unscale_with_stash.append(param.grad)
|
117 |
-
stashed.append(stashed_grad)
|
118 |
-
else: # param.grad is None and stashed_grad is None
|
119 |
-
continue
|
120 |
-
|
121 |
-
# unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
|
122 |
-
if len(grads_needing_unscale) > 0:
|
123 |
-
scaler.unscale(
|
124 |
-
grads_needing_unscale,
|
125 |
-
grads_needing_unscale,
|
126 |
-
None, # unused_scale, currently present to avoid API breakage elsewhere
|
127 |
-
models_are_masters=True,
|
128 |
-
scale_override=grads_have_scale/out_scale)
|
129 |
-
|
130 |
-
if len(grads_needing_unscale_with_stash) > 0:
|
131 |
-
scaler.unscale_with_stashed(
|
132 |
-
grads_needing_unscale_with_stash,
|
133 |
-
stashed,
|
134 |
-
grads_needing_unscale_with_stash,
|
135 |
-
scale_override=(grads_have_scale, stashed_have_scale, out_scale))
|
136 |
-
|
137 |
-
# Clear the stash.
|
138 |
-
for i in range(len(stashed_grads)):
|
139 |
-
stashed_grads[i] = None
|
140 |
-
|
141 |
-
|
142 |
-
def prepare_backward_with_master_weights(self):
|
143 |
-
stash = self._amp_stash
|
144 |
-
|
145 |
-
self._amp_lazy_init()
|
146 |
-
|
147 |
-
for i, param in enumerate(stash.all_fp16_params):
|
148 |
-
# Set up to leverage grad copy elision.
|
149 |
-
# This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
|
150 |
-
param.grad = None
|
151 |
-
|
152 |
-
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
|
153 |
-
# stash.all_fp32_from_fp16_grad_stash[i] = param.grad
|
154 |
-
|
155 |
-
for i, param in enumerate(stash.all_fp32_from_fp32_params):
|
156 |
-
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
|
157 |
-
# Set up to leverage grad copy elision:
|
158 |
-
param.grad = None
|
159 |
-
|
160 |
-
|
161 |
-
def post_backward_with_master_weights(self, scaler):
|
162 |
-
stash = self._amp_stash
|
163 |
-
|
164 |
-
self._amp_lazy_init()
|
165 |
-
|
166 |
-
# This is a lot of python overhead...
|
167 |
-
fp16_grads_needing_unscale = []
|
168 |
-
new_fp32_grads = []
|
169 |
-
fp16_grads_needing_unscale_with_stash = []
|
170 |
-
preexisting_fp32_grads = []
|
171 |
-
for fp16_param, fp32_param in zip(stash.all_fp16_params,
|
172 |
-
stash.all_fp32_from_fp16_params):
|
173 |
-
if fp16_param.grad is None and fp32_param.grad is not None:
|
174 |
-
continue
|
175 |
-
elif fp16_param.grad is not None and fp32_param.grad is None:
|
176 |
-
fp32_param.grad = torch.empty_like(fp32_param)
|
177 |
-
fp16_grads_needing_unscale.append(fp16_param.grad)
|
178 |
-
new_fp32_grads.append(fp32_param.grad)
|
179 |
-
elif fp16_param.grad is not None and fp32_param.grad is not None:
|
180 |
-
fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
|
181 |
-
preexisting_fp32_grads.append(fp32_param.grad)
|
182 |
-
else: # fp16_param.grad is None and fp32_param.grad is None:
|
183 |
-
continue
|
184 |
-
|
185 |
-
if len(fp16_grads_needing_unscale) > 0:
|
186 |
-
scaler.unscale(
|
187 |
-
fp16_grads_needing_unscale,
|
188 |
-
new_fp32_grads,
|
189 |
-
scaler.loss_scale(),
|
190 |
-
models_are_masters=False)
|
191 |
-
|
192 |
-
if len(fp16_grads_needing_unscale_with_stash) > 0:
|
193 |
-
scaler.unscale_with_stashed(
|
194 |
-
fp16_grads_needing_unscale_with_stash,
|
195 |
-
preexisting_fp32_grads,
|
196 |
-
preexisting_fp32_grads)
|
197 |
-
|
198 |
-
# fp32 params can be treated as they would be in the "no_master_weights" case.
|
199 |
-
post_backward_models_are_masters(
|
200 |
-
scaler,
|
201 |
-
stash.all_fp32_from_fp32_params,
|
202 |
-
stash.all_fp32_from_fp32_grad_stash)
|
203 |
-
|
204 |
-
|
205 |
-
def lazy_init_no_master_weights(self):
|
206 |
-
stash = self._amp_stash
|
207 |
-
stash.all_fp16_params = []
|
208 |
-
stash.all_fp32_params = []
|
209 |
-
for i, param_group in enumerate(self.param_groups):
|
210 |
-
for i, param in enumerate(param_group['params']):
|
211 |
-
if param.type() == 'torch.cuda.HalfTensor':
|
212 |
-
stash.all_fp16_params.append(param)
|
213 |
-
elif param.type() == 'torch.cuda.FloatTensor':
|
214 |
-
stash.all_fp32_params.append(param)
|
215 |
-
else:
|
216 |
-
raise TypeError("Optimizer's parameters must be either "
|
217 |
-
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
218 |
-
"Received {}".format(param.type()))
|
219 |
-
|
220 |
-
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
|
221 |
-
stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
|
222 |
-
|
223 |
-
|
224 |
-
def prepare_backward_no_master_weights(self):
|
225 |
-
stash = self._amp_stash
|
226 |
-
|
227 |
-
self._amp_lazy_init()
|
228 |
-
|
229 |
-
for i, param in enumerate(stash.all_fp16_params):
|
230 |
-
stash.all_fp16_grad_stash[i] = param.grad
|
231 |
-
# Set up to leverage grad copy elision:
|
232 |
-
param.grad = None
|
233 |
-
|
234 |
-
for i, param in enumerate(stash.all_fp32_params):
|
235 |
-
stash.all_fp32_grad_stash[i] = param.grad
|
236 |
-
# Set up to leverage grad copy elision:
|
237 |
-
param.grad = None
|
238 |
-
|
239 |
-
|
240 |
-
def post_backward_no_master_weights(self, scaler):
|
241 |
-
stash = self._amp_stash
|
242 |
-
|
243 |
-
self._amp_lazy_init()
|
244 |
-
|
245 |
-
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
|
246 |
-
(stash.all_fp32_params, stash.all_fp32_grad_stash))
|
247 |
-
|
248 |
-
for params, stashed_grads in split_types:
|
249 |
-
post_backward_models_are_masters(scaler, params, stashed_grads)
|
250 |
-
|
251 |
-
|
252 |
-
#####################################################################################
|
253 |
-
# FusedSGD versions
|
254 |
-
#####################################################################################
|
255 |
-
|
256 |
-
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
|
257 |
-
# outside the kernel, so we must accumulate directly into the model grads.
|
258 |
-
def prepare_backward_with_master_weights_FusedSGD(self):
|
259 |
-
if self.materialize_master_grads:
|
260 |
-
prepare_backward_with_master_weights(self)
|
261 |
-
else:
|
262 |
-
stash = self._amp_stash
|
263 |
-
|
264 |
-
self._amp_lazy_init()
|
265 |
-
|
266 |
-
for i, param in enumerate(stash.all_fp16_params):
|
267 |
-
stash.all_fp16_grad_stash[i] = param.grad
|
268 |
-
# Set up to leverage grad copy elision:
|
269 |
-
param.grad = None
|
270 |
-
|
271 |
-
for i, param in enumerate(stash.all_fp32_from_fp32_params):
|
272 |
-
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
|
273 |
-
# Set up to leverage grad copy elision:
|
274 |
-
param.grad = None
|
275 |
-
|
276 |
-
|
277 |
-
def post_backward_with_master_weights_FusedSGD(self, scaler):
|
278 |
-
if self.materialize_master_grads:
|
279 |
-
post_backward_with_master_weights(self, scaler)
|
280 |
-
else:
|
281 |
-
stash = self._amp_stash
|
282 |
-
|
283 |
-
self._amp_lazy_init()
|
284 |
-
|
285 |
-
grads_have_scale = scaler.loss_scale()
|
286 |
-
stashed_have_scale = self.most_recent_scale
|
287 |
-
out_scale = grads_have_scale
|
288 |
-
if self.scale_set_by_backward:
|
289 |
-
out_scale = min(grads_have_scale, self.most_recent_scale)
|
290 |
-
|
291 |
-
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
|
292 |
-
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
|
293 |
-
|
294 |
-
|
295 |
-
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
|
296 |
-
# stashed_grads are scaled by self.most_recent_scale.
|
297 |
-
for params, stashed_grads in split_types:
|
298 |
-
post_backward_models_are_masters(scaler, params, stashed_grads,
|
299 |
-
(grads_have_scale, stashed_have_scale, out_scale))
|
300 |
-
|
301 |
-
self.most_recent_scale = out_scale
|
302 |
-
self.scale_set_by_backward = True
|
303 |
-
|
304 |
-
|
305 |
-
def prepare_backward_no_master_weights_FusedSGD(self):
|
306 |
-
prepare_backward_no_master_weights(self)
|
307 |
-
|
308 |
-
|
309 |
-
def post_backward_no_master_weights_FusedSGD(self, scaler):
|
310 |
-
post_backward_no_master_weights(self, scaler)
|
311 |
-
|
312 |
-
|
313 |
-
def _amp_lazy_init(self):
|
314 |
-
stash = self._amp_stash
|
315 |
-
|
316 |
-
if not stash.lazy_init_called:
|
317 |
-
self._lazy_init_maybe_master_weights()
|
318 |
-
stash.lazy_init_called = True
|
319 |
-
|
320 |
-
|
321 |
-
def _process_optimizer(optimizer, properties):
|
322 |
-
if hasattr(optimizer, "_amp_stash"):
|
323 |
-
raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
|
324 |
-
else:
|
325 |
-
optimizer._amp_stash = AmpOptimizerState()
|
326 |
-
|
327 |
-
optimizer._amp_stash.lazy_init_called = False
|
328 |
-
optimizer._amp_stash.already_patched = False
|
329 |
-
optimizer._amp_stash.params_have_scaled_gradients = False
|
330 |
-
|
331 |
-
for name in ("_lazy_init_maybe_master_weights",
|
332 |
-
"_master_params_to_model_params",
|
333 |
-
"_prepare_amp_backward",
|
334 |
-
"_post_amp_backward",
|
335 |
-
"_amp_lazy_init"):
|
336 |
-
if hasattr(optimizer, name):
|
337 |
-
raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
|
338 |
-
|
339 |
-
# TODO: Centralize exposure and import error checking for the C backend.
|
340 |
-
if multi_tensor_applier.available:
|
341 |
-
import amp_C
|
342 |
-
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
|
343 |
-
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
|
344 |
-
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
|
345 |
-
|
346 |
-
if properties.master_weights:
|
347 |
-
optimizer._lazy_init_maybe_master_weights = types.MethodType(
|
348 |
-
lazy_init_with_master_weights, optimizer)
|
349 |
-
|
350 |
-
optimizer._master_params_to_model_params = types.MethodType(
|
351 |
-
_master_params_to_model_params, optimizer)
|
352 |
-
|
353 |
-
old_step = optimizer.step
|
354 |
-
def new_step(self, closure=None):
|
355 |
-
if closure is not None:
|
356 |
-
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
|
357 |
-
retval = old_step()
|
358 |
-
if not isinstance(self, FusedSGD):
|
359 |
-
self._master_params_to_model_params()
|
360 |
-
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
|
361 |
-
for param in self._amp_stash.all_fp32_from_fp16_params:
|
362 |
-
param.grad = None
|
363 |
-
return retval
|
364 |
-
optimizer.step = types.MethodType(new_step, optimizer)
|
365 |
-
|
366 |
-
old_zero_grad = optimizer.zero_grad
|
367 |
-
def new_zero_grad(self):
|
368 |
-
stash = self._amp_stash
|
369 |
-
self._amp_lazy_init()
|
370 |
-
# Zero the model grads.
|
371 |
-
for param in stash.all_fp16_params:
|
372 |
-
if param.grad is not None:
|
373 |
-
param.grad.detach_()
|
374 |
-
param.grad.zero_()
|
375 |
-
for param in stash.all_fp32_from_fp32_params:
|
376 |
-
if param.grad is not None:
|
377 |
-
param.grad.detach_()
|
378 |
-
param.grad.zero_()
|
379 |
-
# Clear the master grads that are independent of model grads
|
380 |
-
for param in self._amp_stash.all_fp32_from_fp16_params:
|
381 |
-
param.grad = None
|
382 |
-
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
|
383 |
-
|
384 |
-
if isinstance(optimizer, FusedSGD):
|
385 |
-
optimizer._prepare_amp_backward = types.MethodType(
|
386 |
-
prepare_backward_with_master_weights_FusedSGD, optimizer)
|
387 |
-
optimizer._post_amp_backward = types.MethodType(
|
388 |
-
post_backward_with_master_weights_FusedSGD, optimizer)
|
389 |
-
else:
|
390 |
-
optimizer._prepare_amp_backward = types.MethodType(
|
391 |
-
prepare_backward_with_master_weights, optimizer)
|
392 |
-
optimizer._post_amp_backward = types.MethodType(
|
393 |
-
post_backward_with_master_weights, optimizer)
|
394 |
-
else:
|
395 |
-
optimizer._lazy_init_maybe_master_weights = types.MethodType(
|
396 |
-
lazy_init_no_master_weights, optimizer)
|
397 |
-
|
398 |
-
if isinstance(optimizer, FusedSGD):
|
399 |
-
optimizer._prepare_amp_backward = types.MethodType(
|
400 |
-
prepare_backward_no_master_weights_FusedSGD, optimizer)
|
401 |
-
optimizer._post_amp_backward = types.MethodType(
|
402 |
-
post_backward_no_master_weights_FusedSGD, optimizer)
|
403 |
-
else:
|
404 |
-
optimizer._prepare_amp_backward = types.MethodType(
|
405 |
-
prepare_backward_no_master_weights, optimizer)
|
406 |
-
optimizer._post_amp_backward = types.MethodType(
|
407 |
-
post_backward_no_master_weights, optimizer)
|
408 |
-
|
409 |
-
optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer)
|
410 |
-
|
411 |
-
old_add_param_group = optimizer.add_param_group
|
412 |
-
|
413 |
-
def new_add_param_group(self, new_group):
|
414 |
-
stash = self._amp_stash
|
415 |
-
|
416 |
-
if not stash.lazy_init_called:
|
417 |
-
self._lazy_init_maybe_master_weights()
|
418 |
-
stash.lazy_init_called = True
|
419 |
-
|
420 |
-
assert isinstance(new_group, dict), "param group must be a dict"
|
421 |
-
|
422 |
-
new_params = new_group['params']
|
423 |
-
if isinstance(new_params, torch.Tensor):
|
424 |
-
new_group['params'] = [new_params]
|
425 |
-
elif isinstance(new_params, set):
|
426 |
-
raise TypeError('optimizer parameters need to be organized in ordered collections, but '
|
427 |
-
'the ordering of tensors in sets will change between runs. Please use a list instead.')
|
428 |
-
else:
|
429 |
-
new_group['params'] = list(new_params)
|
430 |
-
|
431 |
-
if properties.master_weights:
|
432 |
-
# Mutate new_group in-place to use FP32 master params
|
433 |
-
fp16_params_this_group = []
|
434 |
-
fp32_params_this_group = []
|
435 |
-
fp32_from_fp16_params_this_group = []
|
436 |
-
for i, param in enumerate(new_group['params']):
|
437 |
-
if param.requires_grad:
|
438 |
-
if param.type() == 'torch.cuda.HalfTensor':
|
439 |
-
fp16_params_this_group.append(param)
|
440 |
-
master_param = param.detach().clone().float()
|
441 |
-
master_param.requires_grad = True
|
442 |
-
new_group['params'][i] = master_param
|
443 |
-
fp32_from_fp16_params_this_group.append(master_param)
|
444 |
-
elif param.type() == 'torch.cuda.FloatTensor':
|
445 |
-
fp32_params_this_group.append(param)
|
446 |
-
new_group['params'][i] = param
|
447 |
-
else:
|
448 |
-
raise TypeError("Optimizer's parameters must be either "
|
449 |
-
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
450 |
-
"Received {}".format(param.type()))
|
451 |
-
|
452 |
-
stash.fp16_groups.append(fp16_params_this_group)
|
453 |
-
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
|
454 |
-
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
|
455 |
-
|
456 |
-
stash.all_fp16_params += fp16_params_this_group
|
457 |
-
stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group
|
458 |
-
stash.all_fp32_from_fp32_params += fp32_params_this_group
|
459 |
-
|
460 |
-
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
|
461 |
-
stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]
|
462 |
-
|
463 |
-
# It should be ok to let params be added with existing .grad attributes.
|
464 |
-
# for param in fp16_params_this_group:
|
465 |
-
# param.grad = None
|
466 |
-
|
467 |
-
# for param in fp32_from_fp16_params_this_group:
|
468 |
-
# param.grad = None
|
469 |
-
|
470 |
-
# for param in stash.fp32_params_this_group:
|
471 |
-
# param.grad = None
|
472 |
-
else:
|
473 |
-
for param in new_group['params']:
|
474 |
-
if param.type() == 'torch.cuda.HalfTensor':
|
475 |
-
stash.all_fp16_params.append(param)
|
476 |
-
stash.all_fp16_grad_stash.append(None)
|
477 |
-
elif param.type() == 'torch.cuda.FloatTensor':
|
478 |
-
stash.all_fp32_params.append(param)
|
479 |
-
stash.all_fp32_grad_stash.append(None)
|
480 |
-
else:
|
481 |
-
raise TypeError("Optimizer's parameters must be either "
|
482 |
-
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
483 |
-
"Received {}".format(param.type()))
|
484 |
-
|
485 |
-
old_add_param_group(new_group)
|
486 |
-
|
487 |
-
optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer)
|
488 |
-
|
489 |
-
return optimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/amp.py
DELETED
@@ -1,183 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
import itertools
|
3 |
-
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from . import compat, rnn_compat, utils, wrap
|
7 |
-
from .handle import AmpHandle, NoOpHandle
|
8 |
-
from .lists import functional_overrides, torch_overrides, tensor_overrides
|
9 |
-
from ._amp_state import _amp_state
|
10 |
-
from .frontend import *
|
11 |
-
|
12 |
-
|
13 |
-
_DECORATOR_HANDLE = None
|
14 |
-
_USER_CAST_REGISTRY = set()
|
15 |
-
_USER_PROMOTE_REGISTRY = set()
|
16 |
-
|
17 |
-
|
18 |
-
def _decorator_helper(orig_fn, cast_fn, wrap_fn):
|
19 |
-
def wrapper(*args, **kwargs):
|
20 |
-
handle = _DECORATOR_HANDLE
|
21 |
-
if handle is None or not handle.is_active():
|
22 |
-
return orig_fn(*args, **kwargs)
|
23 |
-
inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
|
24 |
-
handle.verbose)
|
25 |
-
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
|
26 |
-
return wrapper
|
27 |
-
|
28 |
-
|
29 |
-
# Decorator form
|
30 |
-
def half_function(fn):
|
31 |
-
from apex import deprecated_warning
|
32 |
-
deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
|
33 |
-
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
|
34 |
-
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
|
35 |
-
|
36 |
-
|
37 |
-
def float_function(fn):
|
38 |
-
from apex import deprecated_warning
|
39 |
-
deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
|
40 |
-
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
|
41 |
-
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
|
42 |
-
|
43 |
-
|
44 |
-
def promote_function(fn):
|
45 |
-
from apex import deprecated_warning
|
46 |
-
deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
|
47 |
-
wrap_fn = functools.partial(wrap.make_promote_wrapper)
|
48 |
-
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
|
49 |
-
|
50 |
-
|
51 |
-
# Registry form
|
52 |
-
def register_half_function(module, name):
|
53 |
-
if not hasattr(module, name):
|
54 |
-
raise ValueError('No function named {} in module {}.'.format(
|
55 |
-
name, module))
|
56 |
-
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
|
57 |
-
|
58 |
-
|
59 |
-
def register_float_function(module, name):
|
60 |
-
if not hasattr(module, name):
|
61 |
-
raise ValueError('No function named {} in module {}.'.format(
|
62 |
-
name, module))
|
63 |
-
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
|
64 |
-
|
65 |
-
|
66 |
-
def register_promote_function(module, name):
|
67 |
-
if not hasattr(module, name):
|
68 |
-
raise ValueError('No function named {} in module {}.'.format(
|
69 |
-
name, module))
|
70 |
-
_USER_PROMOTE_REGISTRY.add((module, name))
|
71 |
-
|
72 |
-
|
73 |
-
# Top-level function to insert _all_ the hooks.
|
74 |
-
def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False):
|
75 |
-
global _DECORATOR_HANDLE
|
76 |
-
|
77 |
-
if not enabled:
|
78 |
-
handle = NoOpHandle()
|
79 |
-
_DECORATOR_HANDLE = handle
|
80 |
-
return handle
|
81 |
-
|
82 |
-
handle = AmpHandle(loss_scale, enable_caching, verbose)
|
83 |
-
|
84 |
-
# 0) Force-{fp16, fp32} for user-annotated functions
|
85 |
-
for mod, fn, cast_fn in _USER_CAST_REGISTRY:
|
86 |
-
try_caching = (cast_fn == utils.maybe_half)
|
87 |
-
wrap.cached_cast(mod, fn, cast_fn, handle,
|
88 |
-
try_caching, verbose)
|
89 |
-
_USER_CAST_REGISTRY.clear()
|
90 |
-
|
91 |
-
# 0.5) Force-promote for user-annotated functions
|
92 |
-
for mod, fn in _USER_PROMOTE_REGISTRY:
|
93 |
-
wrap.promote(mod, fn, handle, verbose)
|
94 |
-
_USER_PROMOTE_REGISTRY.clear()
|
95 |
-
|
96 |
-
# 1) Force-{fp16, fp32} on white- / black-list functions
|
97 |
-
override_modules = [functional_overrides,
|
98 |
-
torch_overrides,
|
99 |
-
tensor_overrides]
|
100 |
-
cast_table = [('FP16_FUNCS', utils.maybe_half),
|
101 |
-
('FP32_FUNCS', utils.maybe_float)]
|
102 |
-
for module, (list_name, cast_fn) in itertools.product(override_modules,
|
103 |
-
cast_table):
|
104 |
-
for fn in getattr(module, list_name):
|
105 |
-
try_caching = (cast_fn == utils.maybe_half)
|
106 |
-
wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
|
107 |
-
try_caching, verbose)
|
108 |
-
|
109 |
-
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
|
110 |
-
# methods on FloatTensor, since they're distinct types.
|
111 |
-
if compat.tensor_is_float_tensor():
|
112 |
-
for fn in tensor_overrides.FP16_FUNCS:
|
113 |
-
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
|
114 |
-
handle, try_caching=True, verbose=verbose)
|
115 |
-
for fn in tensor_overrides.FP32_FUNCS:
|
116 |
-
wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,
|
117 |
-
handle, try_caching=False, verbose=verbose)
|
118 |
-
|
119 |
-
# 2) Enable type-promotion on multi-arg functions and methods.
|
120 |
-
# NB: special handling for sequence fns (e.g. `torch.cat`).
|
121 |
-
promote_modules = [torch_overrides, tensor_overrides]
|
122 |
-
promote_table = [('CASTS', wrap.promote),
|
123 |
-
('SEQUENCE_CASTS', wrap.sequence_promote)]
|
124 |
-
for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
|
125 |
-
promote_table):
|
126 |
-
for fn in getattr(promote_mod, list_name):
|
127 |
-
promote_fn(promote_mod.MODULE, fn, handle, verbose)
|
128 |
-
|
129 |
-
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
|
130 |
-
if compat.tensor_is_float_tensor():
|
131 |
-
for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
|
132 |
-
torch.cuda.HalfTensor],
|
133 |
-
promote_table):
|
134 |
-
for fn in getattr(tensor_overrides, list_name):
|
135 |
-
promote_fn(cls, fn, handle, verbose)
|
136 |
-
|
137 |
-
# 3) For any in-place version of a blacklist function, error if any input is fp16.
|
138 |
-
# NB: this is overly conservative.
|
139 |
-
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
|
140 |
-
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
|
141 |
-
|
142 |
-
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
|
143 |
-
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
|
144 |
-
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
|
145 |
-
if compat.tensor_is_float_tensor():
|
146 |
-
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)
|
147 |
-
|
148 |
-
# 4) For other in-place methods, match the type of self tensor
|
149 |
-
for fn in utils.as_inplace(itertools.chain(
|
150 |
-
tensor_overrides.FP16_FUNCS,
|
151 |
-
tensor_overrides.CASTS)):
|
152 |
-
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
|
153 |
-
if compat.tensor_is_float_tensor():
|
154 |
-
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
|
155 |
-
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
|
156 |
-
|
157 |
-
# 5) RNNs + RNN cells are whitelisted specially
|
158 |
-
if rnn_compat.has_old_rnns():
|
159 |
-
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
|
160 |
-
if not rnn_compat.has_old_rnns():
|
161 |
-
# Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.
|
162 |
-
torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
|
163 |
-
# Wrap all the rnns
|
164 |
-
for x in rnn_compat.RNN_NAMES:
|
165 |
-
wrap.new_rnn_cast(x.upper(), handle, verbose)
|
166 |
-
|
167 |
-
# Wrap all the RNN cells
|
168 |
-
rnn_compat.whitelist_rnn_cells(handle, verbose)
|
169 |
-
|
170 |
-
# 6) Place error+print message on banned functions.
|
171 |
-
# Or, if allow_banned, then cast to FP32.
|
172 |
-
for fn, err_msg in functional_overrides.BANNED_FUNCS:
|
173 |
-
if allow_banned:
|
174 |
-
wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,
|
175 |
-
handle, try_caching=True, verbose=verbose)
|
176 |
-
else:
|
177 |
-
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
|
178 |
-
|
179 |
-
_DECORATOR_HANDLE = handle
|
180 |
-
|
181 |
-
_amp_state.handle = handle
|
182 |
-
|
183 |
-
return handle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/compat.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
# True for post-0.4, when Variables/Tensors merged.
|
4 |
-
def variable_is_tensor():
|
5 |
-
v = torch.autograd.Variable()
|
6 |
-
return isinstance(v, torch.Tensor)
|
7 |
-
|
8 |
-
def tensor_is_variable():
|
9 |
-
x = torch.Tensor()
|
10 |
-
return type(x) == torch.autograd.Variable
|
11 |
-
|
12 |
-
# False for post-0.4
|
13 |
-
def tensor_is_float_tensor():
|
14 |
-
x = torch.Tensor()
|
15 |
-
return type(x) == torch.FloatTensor
|
16 |
-
|
17 |
-
# Akin to `torch.is_tensor`, but returns True for Variable
|
18 |
-
# objects in pre-0.4.
|
19 |
-
def is_tensor_like(x):
|
20 |
-
return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)
|
21 |
-
|
22 |
-
# Wraps `torch.is_floating_point` if present, otherwise checks
|
23 |
-
# the suffix of `x.type()`.
|
24 |
-
def is_floating_point(x):
|
25 |
-
if hasattr(torch, 'is_floating_point'):
|
26 |
-
return torch.is_floating_point(x)
|
27 |
-
try:
|
28 |
-
torch_type = x.type()
|
29 |
-
return torch_type.endswith('FloatTensor') or \
|
30 |
-
torch_type.endswith('HalfTensor') or \
|
31 |
-
torch_type.endswith('DoubleTensor')
|
32 |
-
except AttributeError:
|
33 |
-
return False
|
34 |
-
|
35 |
-
def scalar_python_val(x):
|
36 |
-
if hasattr(x, 'item'):
|
37 |
-
return x.item()
|
38 |
-
else:
|
39 |
-
if isinstance(x, torch.autograd.Variable):
|
40 |
-
return x.data[0]
|
41 |
-
else:
|
42 |
-
return x[0]
|
43 |
-
|
44 |
-
# Accounts for the possibility that some ops may be removed from a namespace.
|
45 |
-
def filter_attrs(module, attrs):
|
46 |
-
return list(attrname for attrname in attrs if hasattr(module, attrname))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/frontend.py
DELETED
@@ -1,446 +0,0 @@
|
|
1 |
-
from collections import OrderedDict
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from ._initialize import _initialize
|
6 |
-
from ._amp_state import _amp_state, warn_or_err, maybe_print
|
7 |
-
|
8 |
-
|
9 |
-
class Properties(object):
|
10 |
-
"""
|
11 |
-
This class has two purposes: to establish a set of default properties,
|
12 |
-
and to route setting of these attributes through __setattr__ so that (in theory)
|
13 |
-
they can be checked for consistency with other existing args.
|
14 |
-
"""
|
15 |
-
def __init__(self):
|
16 |
-
self.options = {
|
17 |
-
"enabled" : False,
|
18 |
-
"opt_level" : None,
|
19 |
-
"cast_model_type" : None,
|
20 |
-
"patch_torch_functions" : False,
|
21 |
-
"keep_batchnorm_fp32" : None,
|
22 |
-
"master_weights" : None,
|
23 |
-
"loss_scale" : 1.0,
|
24 |
-
# Reserved for future functionality
|
25 |
-
# "fused_optimizer" : False,
|
26 |
-
# "enable_ddp_interop" : False,
|
27 |
-
}
|
28 |
-
|
29 |
-
"""
|
30 |
-
This function allows updating several options at a time without routing through
|
31 |
-
__setattr__ checks, to avoid "you can't get there from here" scenarios.
|
32 |
-
Currently not intended to be exposed; users are expected to select an opt_level
|
33 |
-
and apply consistent modifications.
|
34 |
-
"""
|
35 |
-
def _update_options_dict(self, new_options):
|
36 |
-
for k, v in new_options:
|
37 |
-
if k in self.options:
|
38 |
-
self.options[k] = v
|
39 |
-
else:
|
40 |
-
raise ValueError("Tried to set unexpected option {}".format(k))
|
41 |
-
"""
|
42 |
-
The members of "options" are not direct attributes of self, so access attempts
|
43 |
-
will roll down to __getattr__. This borrows from the logic in torch.nn.Module.
|
44 |
-
"""
|
45 |
-
def __getattr__(self, name):
|
46 |
-
if "options" in self.__dict__:
|
47 |
-
options = self.__dict__["options"]
|
48 |
-
if name in options:
|
49 |
-
return options[name]
|
50 |
-
raise AttributeError("'{}' object has no attribute '{}'".format(
|
51 |
-
type(self).__name__, name))
|
52 |
-
|
53 |
-
def __setattr__(self, name, value):
|
54 |
-
if "options" in self.__dict__:
|
55 |
-
if name in self.options:
|
56 |
-
# print("setting {} {}".format(name, value))
|
57 |
-
if name == "cast_model_type":
|
58 |
-
if self.opt_level == "O1" and value is not None:
|
59 |
-
if value is not False:
|
60 |
-
if value is not torch.float32:
|
61 |
-
warn_or_err("O1 inserts casts around Torch functions rather than "
|
62 |
-
"model weights, so with O1, the model weights themselves "
|
63 |
-
"should remain FP32. If you wish to cast the model to a "
|
64 |
-
"different type, use opt_level='O2' or 'O3'. " +
|
65 |
-
"cast_model_type was {}".format(value))
|
66 |
-
self.options[name] = value
|
67 |
-
elif name == "patch_torch_functions":
|
68 |
-
if self.opt_level != "O1" and value:
|
69 |
-
warn_or_err("Currently, patch_torch_functions=True should only be set by "
|
70 |
-
"selecting opt_level='O1'.")
|
71 |
-
self.options[name] = value
|
72 |
-
elif name == "keep_batchnorm_fp32":
|
73 |
-
if self.opt_level == "O1" and value is not None:
|
74 |
-
warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
|
75 |
-
"to run in FP32, so keep_batchnorm_fp32 should be None." +
|
76 |
-
" keep_batchnorm_fp32 was {}".format(value))
|
77 |
-
if value == "False":
|
78 |
-
self.options[name] = False
|
79 |
-
elif value == "True":
|
80 |
-
self.options[name] = True
|
81 |
-
else:
|
82 |
-
assert (value is True or value is False or value is None),\
|
83 |
-
"keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
|
84 |
-
"or None, found keep_batchnorm_fp32={}".format(value)
|
85 |
-
self.options[name] = value
|
86 |
-
elif name == "master_weights":
|
87 |
-
if self.opt_level == "O1" and value is not None:
|
88 |
-
warn_or_err("It doesn't make sense to use master_weights with O1. "
|
89 |
-
"With O1, your model weights themselves should be FP32.")
|
90 |
-
self.options[name] = value
|
91 |
-
elif name == "loss_scale":
|
92 |
-
if value == "dynamic":
|
93 |
-
self.options[name] = value
|
94 |
-
else:
|
95 |
-
self.options[name] = float(value)
|
96 |
-
else:
|
97 |
-
self.options[name] = value
|
98 |
-
else:
|
99 |
-
super(Properties, self).__setattr__(name, value)
|
100 |
-
|
101 |
-
|
102 |
-
""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """
|
103 |
-
|
104 |
-
class O3:
|
105 |
-
brief = "O3: Pure FP16 training."
|
106 |
-
more = "Calls .half() on your model, converting the entire model to FP16.\n"\
|
107 |
-
"A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
|
108 |
-
"so you don't need to change your data pipeline.\n"\
|
109 |
-
"This mode is useful for establishing a performance ceiling.\n"\
|
110 |
-
"It's also possible training may 'just work' in this mode.\n"\
|
111 |
-
"If not, try other optimization levels."
|
112 |
-
|
113 |
-
def __call__(self, properties):
|
114 |
-
properties.enabled = True
|
115 |
-
properties.opt_level = "O3"
|
116 |
-
properties.cast_model_type = torch.float16
|
117 |
-
properties.patch_torch_functions = False
|
118 |
-
properties.keep_batchnorm_fp32 = False
|
119 |
-
properties.master_weights = False
|
120 |
-
properties.loss_scale = 1.0
|
121 |
-
# properties.fused_optimizer = False
|
122 |
-
# properties.enable_ddp_interop = False
|
123 |
-
return properties # modified in place so this isn't really necessary
|
124 |
-
|
125 |
-
|
126 |
-
class O2:
|
127 |
-
brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n"
|
128 |
-
more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
|
129 |
-
"to FP16. Batchnorms are retained in FP32 for additional stability.\n"\
|
130 |
-
"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
|
131 |
-
"your data pipeline.\n"\
|
132 |
-
"O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
|
133 |
-
"these master weights, then copy the master weights into the FP16 model weights.\n"\
|
134 |
-
"Master weights can also improve convergence and stability."
|
135 |
-
|
136 |
-
def __call__(self, properties):
|
137 |
-
properties.enabled = True
|
138 |
-
properties.opt_level = "O2"
|
139 |
-
properties.cast_model_type = torch.float16
|
140 |
-
properties.patch_torch_functions = False
|
141 |
-
properties.keep_batchnorm_fp32 = True
|
142 |
-
properties.master_weights = True
|
143 |
-
properties.loss_scale = "dynamic"
|
144 |
-
# properties.fused_optimizer = False
|
145 |
-
# properties.enable_ddp_interop = False
|
146 |
-
return properties # modified in place so this isn't really necessary
|
147 |
-
|
148 |
-
|
149 |
-
class O1:
|
150 |
-
brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n"
|
151 |
-
more = "The type of your model's weights is not altered. However, internally,\n"\
|
152 |
-
"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
|
153 |
-
"while operations that might benefit from the additional stability of FP32 are patched\n"\
|
154 |
-
"to cast their inputs to fp32.\n"\
|
155 |
-
"O1 is the safest way to try mixed precision training, and is recommended when\n"\
|
156 |
-
"trying mixed precision training for the first time."
|
157 |
-
|
158 |
-
def __call__(self, properties):
|
159 |
-
properties.enabled = True
|
160 |
-
properties.opt_level = "O1"
|
161 |
-
properties.cast_model_type = None
|
162 |
-
properties.patch_torch_functions = True
|
163 |
-
properties.keep_batchnorm_fp32 = None
|
164 |
-
properties.master_weights = None
|
165 |
-
properties.loss_scale = "dynamic"
|
166 |
-
# properties.fused_optimizer = False
|
167 |
-
# properties.enable_ddp_interop = False
|
168 |
-
return properties # modified in place so this isn't really necessary
|
169 |
-
|
170 |
-
|
171 |
-
class O0:
|
172 |
-
brief = "O0: Pure FP32 training.\n"
|
173 |
-
more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
|
174 |
-
"types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
|
175 |
-
"FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
|
176 |
-
|
177 |
-
def __call__(self, properties):
|
178 |
-
properties.enabled = True
|
179 |
-
properties.opt_level = "O0"
|
180 |
-
properties.cast_model_type = torch.float32
|
181 |
-
properties.patch_torch_functions = False
|
182 |
-
properties.keep_batchnorm_fp32 = None
|
183 |
-
properties.master_weights = False
|
184 |
-
properties.loss_scale = 1.0
|
185 |
-
# properties.fused_optimizer = False
|
186 |
-
# properties.enable_ddp_interop = False
|
187 |
-
return properties # modified in place so this isn't really necessary
|
188 |
-
|
189 |
-
|
190 |
-
opt_levels = {"O3": O3(),
|
191 |
-
"O2": O2(),
|
192 |
-
"O1": O1(),
|
193 |
-
"O0": O0()}
|
194 |
-
|
195 |
-
|
196 |
-
# allow user to directly pass Properties struct as well?
|
197 |
-
def initialize(
|
198 |
-
models,
|
199 |
-
optimizers=None,
|
200 |
-
enabled=True,
|
201 |
-
opt_level="O1",
|
202 |
-
cast_model_type=None,
|
203 |
-
patch_torch_functions=None,
|
204 |
-
keep_batchnorm_fp32=None,
|
205 |
-
master_weights=None,
|
206 |
-
loss_scale=None,
|
207 |
-
cast_model_outputs=None,
|
208 |
-
num_losses=1,
|
209 |
-
verbosity=1,
|
210 |
-
min_loss_scale=None,
|
211 |
-
max_loss_scale=2.**24
|
212 |
-
):
|
213 |
-
"""
|
214 |
-
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
|
215 |
-
chosen ``opt_level`` and overridden properties, if any.
|
216 |
-
|
217 |
-
``amp.initialize`` should be called **after** you have finished
|
218 |
-
constructing your model(s) and
|
219 |
-
optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
|
220 |
-
See `Distributed training`_ in the Imagenet example.
|
221 |
-
|
222 |
-
Currently, ``amp.initialize`` should only be called **once**,
|
223 |
-
although it can process an arbitrary number of
|
224 |
-
models and optimizers (see the corresponding `Advanced Amp Usage topic`_).
|
225 |
-
If you think your use case requires ``amp.initialize`` to be called more than once,
|
226 |
-
`let us know`_.
|
227 |
-
|
228 |
-
Any property keyword argument that is not ``None`` will be interpreted as a manual override.
|
229 |
-
|
230 |
-
To prevent having to rewrite anything else in your script, name the returned models/optimizers
|
231 |
-
to replace the passed models/optimizers, as in the code sample below.
|
232 |
-
|
233 |
-
Args:
|
234 |
-
models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
|
235 |
-
optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast.
|
236 |
-
REQUIRED for training, optional for inference.
|
237 |
-
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
|
238 |
-
should run as if Amp were not present.
|
239 |
-
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
|
240 |
-
"O0", "O1", "O2", and "O3", explained in detail above.
|
241 |
-
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
|
242 |
-
above.
|
243 |
-
patch_torch_functions (bool, optional, default=None): Optional property override.
|
244 |
-
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
|
245 |
-
passed as a string, must be the string "True" or "False".
|
246 |
-
master_weights (bool, optional, default=None): Optional property override.
|
247 |
-
loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
|
248 |
-
must be a string representing a number, e.g., "128.0", or the string "dynamic".
|
249 |
-
cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs
|
250 |
-
of your model(s) are always cast to a particular type regardless of ``opt_level``.
|
251 |
-
num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
|
252 |
-
passes you plan to use. When used in conjunction with the ``loss_id`` argument to
|
253 |
-
``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
|
254 |
-
which can improve stability. See "Multiple models/optimizers/losses"
|
255 |
-
under `Advanced Amp Usage`_ for examples. If ``num_losses`` is left to 1, Amp will still
|
256 |
-
support multiple losses/backward passes, but use a single global loss scale
|
257 |
-
for all of them.
|
258 |
-
verbosity (int, default=1): Set to 0 to suppress Amp-related output.
|
259 |
-
min_loss_scale (float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic
|
260 |
-
loss scaling. The default value of None means that no floor is imposed.
|
261 |
-
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
|
262 |
-
max_loss_scale (float, default=2.**24): Sets a ceiling for the loss scale values that can be chosen by
|
263 |
-
dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored.
|
264 |
-
|
265 |
-
Returns:
|
266 |
-
Model(s) and optimizer(s) modified according to the ``opt_level``.
|
267 |
-
If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
|
268 |
-
also be a list.
|
269 |
-
|
270 |
-
Permissible invocations::
|
271 |
-
|
272 |
-
model, optim = amp.initialize(model, optim,...)
|
273 |
-
model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
|
274 |
-
[model1, model2], optim = amp.initialize([model1, model2], optim,...)
|
275 |
-
[model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)
|
276 |
-
|
277 |
-
# This is not an exhaustive list of the cross product of options that are possible,
|
278 |
-
# just a set of examples.
|
279 |
-
model, optim = amp.initialize(model, optim, opt_level="O0")
|
280 |
-
model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0")
|
281 |
-
|
282 |
-
model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default
|
283 |
-
model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0")
|
284 |
-
|
285 |
-
model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default
|
286 |
-
model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0")
|
287 |
-
model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False")
|
288 |
-
|
289 |
-
model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default
|
290 |
-
model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0")
|
291 |
-
model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False")
|
292 |
-
|
293 |
-
The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.
|
294 |
-
|
295 |
-
.. _`Distributed training`:
|
296 |
-
https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training
|
297 |
-
|
298 |
-
.. _`Imagenet example`:
|
299 |
-
https://github.com/NVIDIA/apex/tree/master/examples/imagenet
|
300 |
-
|
301 |
-
.. _`Advanced Amp Usage`:
|
302 |
-
https://nvidia.github.io/apex/advanced.html
|
303 |
-
|
304 |
-
.. _`Advanced Amp Usage topic`:
|
305 |
-
https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses
|
306 |
-
|
307 |
-
.. _`let us know`:
|
308 |
-
https://github.com/NVIDIA/apex/issues
|
309 |
-
"""
|
310 |
-
from apex import deprecated_warning
|
311 |
-
deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
|
312 |
-
_amp_state.opt_properties = Properties()
|
313 |
-
_amp_state.verbosity = verbosity
|
314 |
-
|
315 |
-
if not enabled:
|
316 |
-
if optimizers is None:
|
317 |
-
return models
|
318 |
-
else:
|
319 |
-
return models, optimizers
|
320 |
-
|
321 |
-
if not torch.backends.cudnn.enabled:
|
322 |
-
raise RuntimeError(
|
323 |
-
"Amp requires torch.backends.cudnn.enabled = True")
|
324 |
-
|
325 |
-
if opt_level not in opt_levels:
|
326 |
-
raise RuntimeError(
|
327 |
-
"Unexpected optimization level {}. ".format(opt_level) +
|
328 |
-
"Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
|
329 |
-
"not the number zero.")
|
330 |
-
else:
|
331 |
-
_amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
|
332 |
-
maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
|
333 |
-
maybe_print("Defaults for this optimization level are:", True)
|
334 |
-
for k, v in _amp_state.opt_properties.options.items():
|
335 |
-
maybe_print("{:22} : {}".format(k, v), True)
|
336 |
-
|
337 |
-
_amp_state.min_loss_scale = min_loss_scale
|
338 |
-
_amp_state.max_loss_scale = max_loss_scale
|
339 |
-
|
340 |
-
maybe_print("Processing user overrides (additional kwargs that are not None)...", True)
|
341 |
-
# I chose to have the keyword arguments listed directly in the argument list,
|
342 |
-
# instead of **kwargs, so I can't use kwargs.items() here.
|
343 |
-
if enabled is not None:
|
344 |
-
_amp_state.opt_properties.enabled = enabled
|
345 |
-
if opt_level is not None:
|
346 |
-
_amp_state.opt_properties.opt_level = opt_level
|
347 |
-
if cast_model_type is not None:
|
348 |
-
_amp_state.opt_properties.cast_model_type = cast_model_type
|
349 |
-
if patch_torch_functions is not None:
|
350 |
-
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions
|
351 |
-
if keep_batchnorm_fp32 is not None:
|
352 |
-
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
|
353 |
-
if master_weights is not None:
|
354 |
-
_amp_state.opt_properties.master_weights = master_weights
|
355 |
-
if loss_scale is not None:
|
356 |
-
_amp_state.opt_properties.loss_scale = loss_scale
|
357 |
-
|
358 |
-
maybe_print("After processing overrides, optimization options are:", True)
|
359 |
-
for k, v in _amp_state.opt_properties.options.items():
|
360 |
-
maybe_print("{:22} : {}".format(k, v), True)
|
361 |
-
|
362 |
-
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
|
363 |
-
|
364 |
-
|
365 |
-
def state_dict(destination=None):
|
366 |
-
if destination is None:
|
367 |
-
destination = OrderedDict()
|
368 |
-
|
369 |
-
for idx, loss_scaler in enumerate(_amp_state.loss_scalers):
|
370 |
-
destination['loss_scaler%d' % idx] = {
|
371 |
-
'loss_scale': loss_scaler.loss_scale(),
|
372 |
-
'unskipped': loss_scaler._unskipped,
|
373 |
-
}
|
374 |
-
return destination
|
375 |
-
|
376 |
-
|
377 |
-
def load_state_dict(state_dict):
|
378 |
-
# Check if state_dict containes the same number of loss_scalers as current setup
|
379 |
-
if len(state_dict) != len(_amp_state.loss_scalers):
|
380 |
-
print('Warning: state_dict contains {} entries, while {} loss_scalers are used'.format(
|
381 |
-
len(state_dict), len(_amp_state.loss_scalers)))
|
382 |
-
|
383 |
-
state_dict = state_dict.copy()
|
384 |
-
|
385 |
-
nb_loss_scalers = len(_amp_state.loss_scalers)
|
386 |
-
unexpected_keys = []
|
387 |
-
# Initialize idx outside, since unexpected_keys will increase it if enumerate is used
|
388 |
-
idx = 0
|
389 |
-
for key in state_dict:
|
390 |
-
if 'loss_scaler' not in key:
|
391 |
-
unexpected_keys.append(key)
|
392 |
-
else:
|
393 |
-
if idx > (nb_loss_scalers - 1):
|
394 |
-
print('Skipping loss_scaler[{}], since num_losses was set to {}'.format(
|
395 |
-
idx, nb_loss_scalers))
|
396 |
-
break
|
397 |
-
_amp_state.loss_scalers[idx]._loss_scale = state_dict[key]['loss_scale']
|
398 |
-
_amp_state.loss_scalers[idx]._unskipped = state_dict[key]['unskipped']
|
399 |
-
idx += 1
|
400 |
-
|
401 |
-
if len(unexpected_keys) > 0:
|
402 |
-
raise RuntimeError(
|
403 |
-
'Error(s) in loading state_dict. Unexpected key(s) in state_dict: {}. '.format(
|
404 |
-
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
405 |
-
|
406 |
-
|
407 |
-
# TODO: is this necessary/useful?
|
408 |
-
# def check_option_consistency(enabled=True,
|
409 |
-
# opt_level=None,
|
410 |
-
# cast_model_type=None,
|
411 |
-
# patch_torch_functions=None,
|
412 |
-
# keep_batchnorm_fp32=None,
|
413 |
-
# master_weights=None,
|
414 |
-
# loss_scale=None,
|
415 |
-
# enable_ddp_interop=None,
|
416 |
-
# hard_override=False):
|
417 |
-
# """
|
418 |
-
# Utility function that enables users to quickly check if the option combination they intend
|
419 |
-
# to use is permitted. ``check_option_consistency`` does not require models or optimizers
|
420 |
-
# to be constructed, and can be called at any point in the script. ``check_option_consistency``
|
421 |
-
# is totally self-contained; it does not set any amp global state or affect anything outside
|
422 |
-
# of itself.
|
423 |
-
# """
|
424 |
-
#
|
425 |
-
# if not enabled:
|
426 |
-
# return
|
427 |
-
#
|
428 |
-
# if opt_level not in opt_levels:
|
429 |
-
# raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
|
430 |
-
# else:
|
431 |
-
# opt_properties = opt_levels[opt_level](Properties())
|
432 |
-
# print("Selected optimization level {}", opt_levels[opt_level].brief)
|
433 |
-
# print("Defaults for this optimization level are:")
|
434 |
-
# for k, v in opt_properties.options:
|
435 |
-
# print("{:22} : {}".format(k, v))
|
436 |
-
#
|
437 |
-
# print("Processing user overrides (additional kwargs that are not None)...")
|
438 |
-
# for k, v in kwargs:
|
439 |
-
# if k not in _amp_state.opt_properties.options:
|
440 |
-
# raise RuntimeError("Unexpected kwarg {}".format(k))
|
441 |
-
# if v is not None:
|
442 |
-
# setattr(opt_properties, k, v)
|
443 |
-
#
|
444 |
-
# print("After processing overrides, optimization options are:")
|
445 |
-
# for k, v in opt_properties.options:
|
446 |
-
# print("{:22} : {}".format(k, v))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/handle.py
DELETED
@@ -1,281 +0,0 @@
|
|
1 |
-
import contextlib
|
2 |
-
import warnings
|
3 |
-
import sys
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from . import utils
|
7 |
-
from .opt import OptimWrapper
|
8 |
-
from .scaler import LossScaler
|
9 |
-
from ._amp_state import _amp_state, master_params, maybe_print
|
10 |
-
|
11 |
-
if torch.distributed.is_available():
|
12 |
-
from ..parallel.LARC import LARC
|
13 |
-
|
14 |
-
|
15 |
-
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
|
16 |
-
@contextlib.contextmanager
|
17 |
-
def scale_loss(loss,
|
18 |
-
optimizers,
|
19 |
-
loss_id=0,
|
20 |
-
model=None,
|
21 |
-
delay_unscale=False,
|
22 |
-
delay_overflow_check=False):
|
23 |
-
"""
|
24 |
-
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
|
25 |
-
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
|
26 |
-
|
27 |
-
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
28 |
-
scaled_loss.backward()
|
29 |
-
|
30 |
-
On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs
|
31 |
-
and unscaled, so that ``optimizer.step()`` can be called.
|
32 |
-
|
33 |
-
.. note::
|
34 |
-
If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
|
35 |
-
can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
|
36 |
-
any FP16 gradients are copied to FP32 master gradients before being unscaled.
|
37 |
-
``optimizer.step()`` will then apply the unscaled master gradients to the master params.
|
38 |
-
|
39 |
-
.. warning::
|
40 |
-
If Amp is using explicit FP32 master params, only the FP32 master gradients will be
|
41 |
-
unscaled. The direct ``.grad`` attributes of any FP16
|
42 |
-
model params will remain scaled after context manager exit.
|
43 |
-
This subtlety affects gradient clipping. See "Gradient clipping" under
|
44 |
-
`Advanced Amp Usage`_ for best practices.
|
45 |
-
|
46 |
-
Args:
|
47 |
-
loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context
|
48 |
-
manager yields is simply ``loss.float()*loss_scale``, so in principle
|
49 |
-
``loss`` could have more than one element, as long as you call
|
50 |
-
``backward()`` on ``scaled_loss`` appropriately within the context manager body.
|
51 |
-
optimizers: All optimizer(s) for which the current backward pass is creating gradients.
|
52 |
-
Must be an optimizer or list of optimizers returned from an earlier call
|
53 |
-
to ``amp.initialize``. For example use with multiple optimizers, see
|
54 |
-
"Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
|
55 |
-
loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument
|
56 |
-
to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id``
|
57 |
-
must be an integer between 0 and ``num_losses`` that tells Amp which loss is
|
58 |
-
being used for the current backward pass. See "Multiple models/optimizers/losses"
|
59 |
-
under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp
|
60 |
-
will use the default global loss scaler for this backward pass.
|
61 |
-
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
|
62 |
-
optimizations.
|
63 |
-
delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and
|
64 |
-
the default value of ``False`` is strongly recommended.
|
65 |
-
If ``True``, Amp will not unscale the gradients or perform model->master
|
66 |
-
gradient copies on context manager exit.
|
67 |
-
``delay_unscale=True`` is a minor ninja performance optimization and can result
|
68 |
-
in weird gotchas (especially with multiple models/optimizers/losses),
|
69 |
-
so only use it if you know what you're doing.
|
70 |
-
"Gradient accumulation across iterations" under `Advanced Amp Usage`_
|
71 |
-
illustrates a situation where this CAN (but does not need to) be used.
|
72 |
-
|
73 |
-
.. warning::
|
74 |
-
If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
|
75 |
-
called yet after context manager exit, and must wait for another, later backward context
|
76 |
-
manager invocation with ``delay_unscale`` left to False.
|
77 |
-
|
78 |
-
.. _`Advanced Amp Usage`:
|
79 |
-
https://nvidia.github.io/apex/advanced.html
|
80 |
-
"""
|
81 |
-
if not hasattr(_amp_state, "opt_properties"):
|
82 |
-
raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. "
|
83 |
-
"model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called "
|
84 |
-
"before `with amp.scale_loss`.")
|
85 |
-
|
86 |
-
if not _amp_state.opt_properties.enabled:
|
87 |
-
yield loss
|
88 |
-
return
|
89 |
-
|
90 |
-
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
|
91 |
-
optimizers = [optimizers]
|
92 |
-
|
93 |
-
loss_scaler = _amp_state.loss_scalers[loss_id]
|
94 |
-
loss_scale = loss_scaler.loss_scale()
|
95 |
-
|
96 |
-
if ((not _amp_state.opt_properties.master_weights)
|
97 |
-
and (not loss_scaler.dynamic)
|
98 |
-
and loss_scale == 1.0):
|
99 |
-
yield loss.float()
|
100 |
-
# Needing to drop the cache here as well is an ugly gotcha.
|
101 |
-
# But for now I think it's necessary to short-circuit.
|
102 |
-
# Probably ok to skip this if not delay_unscale
|
103 |
-
if _amp_state.opt_properties.patch_torch_functions:
|
104 |
-
_amp_state.handle._clear_cache()
|
105 |
-
return
|
106 |
-
|
107 |
-
if not delay_unscale:
|
108 |
-
if isinstance(optimizers, list):
|
109 |
-
for optimizer in optimizers:
|
110 |
-
if not optimizer._amp_stash.params_have_scaled_gradients:
|
111 |
-
optimizer._prepare_amp_backward()
|
112 |
-
|
113 |
-
yield (loss.float())*loss_scale
|
114 |
-
|
115 |
-
if delay_unscale:
|
116 |
-
for optimizer in optimizers:
|
117 |
-
optimizer._amp_stash.params_have_scaled_gradients = True
|
118 |
-
else:
|
119 |
-
# FusedSGD may take care of unscaling as part of their step() methods.
|
120 |
-
# if not isinstance(optimizers, FP16_Optimizer_for_fused):
|
121 |
-
loss_scaler.clear_overflow_state()
|
122 |
-
for optimizer in optimizers:
|
123 |
-
optimizer._post_amp_backward(loss_scaler)
|
124 |
-
optimizer._amp_stash.params_have_scaled_gradients = False
|
125 |
-
# For future fused optimizers that enable sync-free dynamic loss scaling,
|
126 |
-
# should_skip will always be False.
|
127 |
-
should_skip = False if delay_overflow_check else loss_scaler.update_scale()
|
128 |
-
if should_skip:
|
129 |
-
for optimizer in optimizers:
|
130 |
-
if not optimizer._amp_stash.already_patched:
|
131 |
-
# Close on loss_scaler and loss_id as well, to be safe. Probably not
|
132 |
-
# necessary because amp.scale_loss is already creating a temporary scope.
|
133 |
-
def patch_step(opt, loss_scaler, loss_id):
|
134 |
-
opt_step = opt.step
|
135 |
-
def skip_step(closure=None):
|
136 |
-
if closure is not None:
|
137 |
-
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
|
138 |
-
maybe_print(("Gradient overflow. Skipping step, loss scaler " +
|
139 |
-
"{} reducing loss scale to {}").format(loss_id,
|
140 |
-
loss_scaler.loss_scale()))
|
141 |
-
# TODO: I don't like the special casing for different optimizer implementations.
|
142 |
-
# Maybe skip should delegate to a method owned by the optimizers themselves.
|
143 |
-
if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
|
144 |
-
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
|
145 |
-
for param in opt._amp_stash.all_fp32_from_fp16_params:
|
146 |
-
param.grad = None
|
147 |
-
if hasattr(opt, "most_recent_scale"):
|
148 |
-
opt.most_recent_scale = 1.0
|
149 |
-
opt.scale_set_by_backward = False
|
150 |
-
opt.step = opt_step
|
151 |
-
opt._amp_stash.already_patched = False
|
152 |
-
return skip_step
|
153 |
-
optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
|
154 |
-
optimizer._amp_stash.already_patched = True
|
155 |
-
|
156 |
-
# Probably ok to skip this if not delay_unscale
|
157 |
-
if _amp_state.opt_properties.patch_torch_functions:
|
158 |
-
_amp_state.handle._clear_cache()
|
159 |
-
|
160 |
-
|
161 |
-
# Free function version of AmpHandle.disable_casts, another step on the
|
162 |
-
# path to removing the concept of "AmpHandle"
|
163 |
-
@contextlib.contextmanager
|
164 |
-
def disable_casts():
|
165 |
-
_amp_state.handle._is_active = False
|
166 |
-
yield
|
167 |
-
_amp_state.handle._is_active = True
|
168 |
-
|
169 |
-
|
170 |
-
class AmpHandle(object):
|
171 |
-
def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
|
172 |
-
self._enable_caching = enable_caching
|
173 |
-
self._verbose = verbose
|
174 |
-
self._cache = dict()
|
175 |
-
self._default_scaler = LossScaler(loss_scale)
|
176 |
-
self._is_active = True
|
177 |
-
self._all_wrappers = []
|
178 |
-
|
179 |
-
def is_active(self):
|
180 |
-
return self._is_active
|
181 |
-
|
182 |
-
@contextlib.contextmanager
|
183 |
-
def _disable_casts(self):
|
184 |
-
self._is_active = False
|
185 |
-
yield
|
186 |
-
self._is_active = True
|
187 |
-
|
188 |
-
def wrap_optimizer(self, optimizer, num_loss=1):
|
189 |
-
self._default_scaler = None
|
190 |
-
return OptimWrapper(optimizer, self, num_loss)
|
191 |
-
|
192 |
-
@contextlib.contextmanager
|
193 |
-
def scale_loss(self, loss, optimizer):
|
194 |
-
raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, "
|
195 |
-
"documented here: https://nvidia.github.io/apex/amp.html. Transition guide: "
|
196 |
-
"https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users")
|
197 |
-
|
198 |
-
if not self.is_active():
|
199 |
-
yield loss
|
200 |
-
return
|
201 |
-
|
202 |
-
if self._default_scaler is None:
|
203 |
-
raise RuntimeError(
|
204 |
-
'After calling `handle.wrap_optimizer()`, you must explicitly ' +
|
205 |
-
'use `optimizer.scale_loss(loss)`.')
|
206 |
-
|
207 |
-
# TODO: this code block is duplicated here and `opt.py`. Unify.
|
208 |
-
loss_scale = self._default_scaler.loss_scale()
|
209 |
-
yield loss * loss_scale
|
210 |
-
|
211 |
-
self._default_scaler.clear_overflow_state()
|
212 |
-
self._default_scaler.unscale(
|
213 |
-
master_params(optimizer),
|
214 |
-
master_params(optimizer),
|
215 |
-
loss_scale)
|
216 |
-
should_skip = self._default_scaler.update_scale()
|
217 |
-
if should_skip:
|
218 |
-
optimizer_step = optimizer.step
|
219 |
-
def skip_step():
|
220 |
-
maybe_print('Gradient overflow, skipping update')
|
221 |
-
optimizer.step = optimizer_step
|
222 |
-
optimizer.step = skip_step
|
223 |
-
|
224 |
-
self._clear_cache()
|
225 |
-
|
226 |
-
def _clear_cache(self):
|
227 |
-
self._cache.clear()
|
228 |
-
|
229 |
-
# Experimental support for saving / restoring uncasted versions of functions
|
230 |
-
def _save_func(self, mod, fn, func):
|
231 |
-
self._all_wrappers.append((mod, fn, func))
|
232 |
-
|
233 |
-
def _deactivate(self):
|
234 |
-
for mod, fn, func in self._all_wrappers:
|
235 |
-
utils.set_func(mod, fn, func)
|
236 |
-
self._all_wrappers = []
|
237 |
-
|
238 |
-
@property
|
239 |
-
def has_cache(self):
|
240 |
-
return self._enable_caching
|
241 |
-
|
242 |
-
@property
|
243 |
-
def cache(self):
|
244 |
-
return self._cache
|
245 |
-
|
246 |
-
def remove_cache(self, param):
|
247 |
-
if self.has_cache and param in self.cache:
|
248 |
-
del self.cache[param]
|
249 |
-
|
250 |
-
@property
|
251 |
-
def verbose(self):
|
252 |
-
return self._verbose
|
253 |
-
|
254 |
-
class NoOpHandle(object):
|
255 |
-
def is_active(self):
|
256 |
-
return False
|
257 |
-
|
258 |
-
@contextlib.contextmanager
|
259 |
-
def _disable_casts(self):
|
260 |
-
yield
|
261 |
-
|
262 |
-
def wrap_optimizer(self, optimizer, num_loss=1):
|
263 |
-
return OptimWrapper(optimizer, self, num_loss)
|
264 |
-
|
265 |
-
@contextlib.contextmanager
|
266 |
-
def scale_loss(self, loss, optimizer):
|
267 |
-
yield loss
|
268 |
-
|
269 |
-
@property
|
270 |
-
def has_cache(self):
|
271 |
-
return False
|
272 |
-
|
273 |
-
@property
|
274 |
-
def verbose(self):
|
275 |
-
return False
|
276 |
-
|
277 |
-
def _clear_cache(self):
|
278 |
-
pass
|
279 |
-
|
280 |
-
def _deactivate(self):
|
281 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/lists/__init__.py
DELETED
File without changes
|
apex/apex/amp/lists/functional_overrides.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
|
2 |
-
# TODO: think about the following two. They do weird things.
|
3 |
-
# - torch.nn.utils.clip_grad (but it should always be fp32 anyway)
|
4 |
-
# - torch.nn.utils.weight_norm
|
5 |
-
|
6 |
-
# Notes:
|
7 |
-
# F.instance_norm uses batch_norm internally. Which correctly handles
|
8 |
-
# fp16 in/out with fp32 weights. So we shouldn't do anything for
|
9 |
-
# either of these.
|
10 |
-
# F.normalize calls `input.norm()` internally, so it's redundant, but
|
11 |
-
# kept here in case impl. changes.
|
12 |
-
# F.cosine_similarity is same: calls `x.norm()` internally.
|
13 |
-
|
14 |
-
import torch.nn.functional
|
15 |
-
|
16 |
-
MODULE = torch.nn.functional
|
17 |
-
|
18 |
-
FP16_FUNCS = [
|
19 |
-
'conv1d',
|
20 |
-
'conv2d',
|
21 |
-
'conv3d',
|
22 |
-
'conv_transpose1d',
|
23 |
-
'conv_transpose2d',
|
24 |
-
'conv_transpose3d',
|
25 |
-
'conv_tbc', # Undocumented / maybe new?
|
26 |
-
'linear',
|
27 |
-
]
|
28 |
-
|
29 |
-
FP32_FUNCS = [
|
30 |
-
|
31 |
-
# Interpolation/Upsampling TODO: Remove for 1.2
|
32 |
-
'interpolate',
|
33 |
-
'grid_sample',
|
34 |
-
|
35 |
-
# Pointwise
|
36 |
-
'softplus',
|
37 |
-
'softmin',
|
38 |
-
'log_softmax',
|
39 |
-
'softmax',
|
40 |
-
'gelu',
|
41 |
-
|
42 |
-
# Normalization
|
43 |
-
'layer_norm',
|
44 |
-
'group_norm',
|
45 |
-
'local_response_norm',
|
46 |
-
'normalize',
|
47 |
-
'cosine_similarity',
|
48 |
-
|
49 |
-
# Loss functions
|
50 |
-
# TODO: which of these can be fp16?
|
51 |
-
'poisson_nll_loss',
|
52 |
-
'cosine_embedding_loss',
|
53 |
-
'cross_entropy',
|
54 |
-
'hinge_embedding_loss',
|
55 |
-
'kl_div',
|
56 |
-
'l1_loss',
|
57 |
-
'mse_loss',
|
58 |
-
'margin_ranking_loss',
|
59 |
-
'multilabel_margin_loss',
|
60 |
-
'multilabel_soft_margin_loss',
|
61 |
-
'multi_margin_loss',
|
62 |
-
'nll_loss',
|
63 |
-
'binary_cross_entropy_with_logits',
|
64 |
-
'smooth_l1_loss',
|
65 |
-
'soft_margin_loss',
|
66 |
-
'triplet_margin_loss',
|
67 |
-
'ctc_loss'
|
68 |
-
]
|
69 |
-
|
70 |
-
BANNED_FUNCS = [
|
71 |
-
('binary_cross_entropy',
|
72 |
-
("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
|
73 |
-
"It requires that the output of the previous function be already a FloatTensor. \n\n"
|
74 |
-
"Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
|
75 |
-
" torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
|
76 |
-
"that is compatible with amp.\nAnother option is to add\n"
|
77 |
-
" amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
|
78 |
-
"If you _really_ know what you are doing, you can disable this warning by passing "
|
79 |
-
"allow_banned=True to `amp.init()`."))
|
80 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/lists/tensor_overrides.py
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
from .. import compat
|
2 |
-
from . import torch_overrides
|
3 |
-
|
4 |
-
import importlib
|
5 |
-
|
6 |
-
import torch
|
7 |
-
|
8 |
-
# if compat.variable_is_tensor() and not compat.tensor_is_variable():
|
9 |
-
MODULE = torch.Tensor
|
10 |
-
# else:
|
11 |
-
# MODULE = torch.autograd.Variable
|
12 |
-
|
13 |
-
|
14 |
-
FP16_FUNCS = compat.filter_attrs(MODULE, [
|
15 |
-
'__matmul__',
|
16 |
-
])
|
17 |
-
|
18 |
-
FP32_FUNCS = compat.filter_attrs(MODULE, [
|
19 |
-
'__ipow__',
|
20 |
-
'__pow__',
|
21 |
-
'__rpow__',
|
22 |
-
|
23 |
-
# Cast to fp32 before transfer to CPU
|
24 |
-
'cpu',
|
25 |
-
])
|
26 |
-
|
27 |
-
CASTS = compat.filter_attrs(MODULE, [
|
28 |
-
'__add__',
|
29 |
-
'__div__',
|
30 |
-
'__eq__',
|
31 |
-
'__ge__',
|
32 |
-
'__gt__',
|
33 |
-
'__iadd__',
|
34 |
-
'__idiv__',
|
35 |
-
'__imul__',
|
36 |
-
'__isub__',
|
37 |
-
'__itruediv__',
|
38 |
-
'__le__',
|
39 |
-
'__lt__',
|
40 |
-
'__mul__',
|
41 |
-
'__ne__',
|
42 |
-
'__radd__',
|
43 |
-
'__rdiv__',
|
44 |
-
'__rmul__',
|
45 |
-
'__rsub__',
|
46 |
-
'__rtruediv__',
|
47 |
-
'__sub__',
|
48 |
-
'__truediv__',
|
49 |
-
])
|
50 |
-
|
51 |
-
# None of these, but here to make code cleaner.
|
52 |
-
SEQUENCE_CASTS = []
|
53 |
-
|
54 |
-
# We need to grab all the methods from torch_overrides and add them to
|
55 |
-
# the Tensor lists as well, as almost all methods are duplicated
|
56 |
-
# between `torch` and `torch.Tensor` (and check with `hasattr`,
|
57 |
-
# because a few random ones aren't defined on Tensor)
|
58 |
-
_self_mod = importlib.import_module(__name__)
|
59 |
-
for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
|
60 |
-
lst = getattr(_self_mod, attrname)
|
61 |
-
for fn in getattr(torch_overrides, attrname):
|
62 |
-
if hasattr(MODULE, fn):
|
63 |
-
lst.append(fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/lists/torch_overrides.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
from .. import utils
|
4 |
-
|
5 |
-
MODULE = torch
|
6 |
-
|
7 |
-
FP16_FUNCS = [
|
8 |
-
# Low level functions wrapped by torch.nn layers.
|
9 |
-
# The wrapper layers contain the weights which are then passed in as a parameter
|
10 |
-
# to these functions.
|
11 |
-
'conv1d',
|
12 |
-
'conv2d',
|
13 |
-
'conv3d',
|
14 |
-
'conv_transpose1d',
|
15 |
-
'conv_transpose2d',
|
16 |
-
'conv_transpose3d',
|
17 |
-
'conv_tbc',
|
18 |
-
'prelu',
|
19 |
-
|
20 |
-
# BLAS
|
21 |
-
'addmm',
|
22 |
-
'addmv',
|
23 |
-
'addr',
|
24 |
-
'matmul',
|
25 |
-
'mm',
|
26 |
-
'mv',
|
27 |
-
]
|
28 |
-
|
29 |
-
FP32_FUNCS = [
|
30 |
-
# Pointwise
|
31 |
-
'acos',
|
32 |
-
'asin',
|
33 |
-
'cosh',
|
34 |
-
'erfinv',
|
35 |
-
'exp',
|
36 |
-
'expm1',
|
37 |
-
'log',
|
38 |
-
'log10',
|
39 |
-
'log2',
|
40 |
-
'reciprocal',
|
41 |
-
'rsqrt',
|
42 |
-
'sinh',
|
43 |
-
'tan',
|
44 |
-
|
45 |
-
# Other math
|
46 |
-
'pow',
|
47 |
-
|
48 |
-
# Reduction
|
49 |
-
'cumprod',
|
50 |
-
'cumsum',
|
51 |
-
'dist',
|
52 |
-
# 'mean',
|
53 |
-
'norm',
|
54 |
-
'prod',
|
55 |
-
'std',
|
56 |
-
'sum',
|
57 |
-
'var',
|
58 |
-
|
59 |
-
# Misc
|
60 |
-
'renorm'
|
61 |
-
]
|
62 |
-
|
63 |
-
version_strings = torch.__version__.split('.')
|
64 |
-
version_major = version_strings[0]
|
65 |
-
version_minor = version_strings[1]
|
66 |
-
version_num = float(version_major + "." + version_minor)
|
67 |
-
# Before torch 1.1, mean must be blacklisted.
|
68 |
-
if version_num < 1.1:
|
69 |
-
FP32_FUNCS.append('mean')
|
70 |
-
|
71 |
-
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
|
72 |
-
# check the CUDA version -- if at least 9.1, then put the bmm
|
73 |
-
# functions on the fp16 list. Otherwise, put them on the fp32 list.
|
74 |
-
_bmms = ['addbmm',
|
75 |
-
'baddbmm',
|
76 |
-
'bmm']
|
77 |
-
|
78 |
-
if utils.is_cuda_enabled():
|
79 |
-
# workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802
|
80 |
-
if utils.get_cuda_version() >= (9, 1, 0):
|
81 |
-
FP16_FUNCS.extend(_bmms)
|
82 |
-
else:
|
83 |
-
FP32_FUNCS.extend(_bmms)
|
84 |
-
|
85 |
-
# Multi-tensor fns that may need type promotion
|
86 |
-
CASTS = [
|
87 |
-
# Multi-tensor math
|
88 |
-
'addcdiv',
|
89 |
-
'addcmul',
|
90 |
-
'atan2',
|
91 |
-
'cross',
|
92 |
-
'bilinear',
|
93 |
-
'dot',
|
94 |
-
|
95 |
-
# Element-wise _or_ tensor-wise math
|
96 |
-
'add',
|
97 |
-
'div',
|
98 |
-
'mul',
|
99 |
-
|
100 |
-
# Comparison
|
101 |
-
'eq',
|
102 |
-
'equal',
|
103 |
-
'ge',
|
104 |
-
'gt',
|
105 |
-
'le',
|
106 |
-
'lt',
|
107 |
-
'ne'
|
108 |
-
]
|
109 |
-
|
110 |
-
# Functions that take sequence arguments. We need to inspect the whole
|
111 |
-
# sequence and cast to the widest type.
|
112 |
-
SEQUENCE_CASTS = [
|
113 |
-
'cat',
|
114 |
-
'stack'
|
115 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/opt.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
import contextlib
|
2 |
-
import warnings
|
3 |
-
|
4 |
-
from .scaler import LossScaler, master_params
|
5 |
-
from ._amp_state import maybe_print
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
class OptimWrapper(object):
|
10 |
-
def __init__(self, optimizer, amp_handle, num_loss):
|
11 |
-
self._optimizer = optimizer
|
12 |
-
self._amp_handle = amp_handle
|
13 |
-
self._num_loss = num_loss
|
14 |
-
self._loss_idx = 0
|
15 |
-
self._skip_next = [False] * num_loss
|
16 |
-
self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)]
|
17 |
-
|
18 |
-
@contextlib.contextmanager
|
19 |
-
def scale_loss(self, loss):
|
20 |
-
if not self._amp_handle.is_active():
|
21 |
-
yield loss
|
22 |
-
return
|
23 |
-
|
24 |
-
# When there are multiple losses per-optimizer, we need
|
25 |
-
# to save out current grad accumulation, since we won't be
|
26 |
-
# able to unscale this particulare loss once the grads are
|
27 |
-
# all mixed together.
|
28 |
-
cached_grads = []
|
29 |
-
if self._loss_idx > 0:
|
30 |
-
for p in master_params(self._optimizer):
|
31 |
-
if p.grad is not None:
|
32 |
-
cached_grads.append(p.grad.data.detach().clone())
|
33 |
-
else:
|
34 |
-
cached_grads.append(None)
|
35 |
-
self._optimizer.zero_grad()
|
36 |
-
|
37 |
-
loss_scale = self._cur_loss_scaler().loss_scale()
|
38 |
-
yield loss * loss_scale
|
39 |
-
|
40 |
-
self._cur_loss_scaler().clear_overflow_state()
|
41 |
-
self._cur_loss_scaler().unscale(
|
42 |
-
master_params(self._optimizer),
|
43 |
-
master_params(self._optimizer),
|
44 |
-
loss_scale)
|
45 |
-
self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()
|
46 |
-
self._loss_idx += 1
|
47 |
-
|
48 |
-
if len(cached_grads) > 0:
|
49 |
-
for p, cached_grad in zip(master_params(self._optimizer),
|
50 |
-
cached_grads):
|
51 |
-
if cached_grad is not None:
|
52 |
-
p.grad.data.add_(cached_grad)
|
53 |
-
cached_grads = []
|
54 |
-
|
55 |
-
def _cur_loss_scaler(self):
|
56 |
-
assert 0 <= self._loss_idx < self._num_loss
|
57 |
-
return self._loss_scaler[self._loss_idx]
|
58 |
-
|
59 |
-
def step(self, closure=None):
|
60 |
-
if not self._amp_handle.is_active():
|
61 |
-
return self._optimizer.step(closure=closure)
|
62 |
-
|
63 |
-
self._loss_idx = 0
|
64 |
-
|
65 |
-
for group in self._optimizer.param_groups:
|
66 |
-
for p in group['params']:
|
67 |
-
self._amp_handle.remove_cache(p)
|
68 |
-
|
69 |
-
if closure is not None:
|
70 |
-
raise NotImplementedError(
|
71 |
-
'The `closure` argument is unsupported by the amp ' +
|
72 |
-
'optimizer wrapper.')
|
73 |
-
if any(self._skip_next):
|
74 |
-
maybe_print('Gradient overflow, skipping update')
|
75 |
-
self._skip_next = [False] * self._num_loss
|
76 |
-
else:
|
77 |
-
return self._optimizer.step(closure=closure)
|
78 |
-
|
79 |
-
# Forward any attribute lookups
|
80 |
-
def __getattr__(self, attr):
|
81 |
-
return getattr(self._optimizer, attr)
|
82 |
-
|
83 |
-
# Forward all torch.optim.Optimizer methods
|
84 |
-
def __getstate__(self):
|
85 |
-
return self._optimizer.__getstate__()
|
86 |
-
|
87 |
-
def __setstate__(self):
|
88 |
-
return self._optimizer.__setstate__()
|
89 |
-
|
90 |
-
def __repr__(self):
|
91 |
-
return self._optimizer.__repr__()
|
92 |
-
|
93 |
-
def state_dict(self):
|
94 |
-
return self._optimizer.state_dict()
|
95 |
-
|
96 |
-
def load_state_dict(self, state_dict):
|
97 |
-
return self._optimizer.load_state_dict(state_dict)
|
98 |
-
|
99 |
-
def zero_grad(self):
|
100 |
-
return self._optimizer.zero_grad()
|
101 |
-
|
102 |
-
def add_param_group(self, param_group):
|
103 |
-
return self._optimizer.add_param_group(param_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/rnn_compat.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
from . import utils, wrap
|
2 |
-
|
3 |
-
import torch
|
4 |
-
_VF = torch._C._VariableFunctions
|
5 |
-
RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
|
6 |
-
|
7 |
-
def _gen_VF_wrapper(name):
|
8 |
-
def wrapper(*args, **kwargs):
|
9 |
-
return getattr(_VF, name)(*args, **kwargs)
|
10 |
-
return wrapper
|
11 |
-
|
12 |
-
# Some python magic to generate an object that has the rnn cell functions
|
13 |
-
# defined on it, all of which call into corresponding _VF version.
|
14 |
-
# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF"
|
15 |
-
# imported at module scope within torch.nn.modules.rnn). This should
|
16 |
-
# not affect third-party importers of _VF.py.
|
17 |
-
class VariableFunctionsShim(object):
|
18 |
-
def __init__(self):
|
19 |
-
for name in RNN_NAMES:
|
20 |
-
for suffix in ['', '_cell']:
|
21 |
-
fn_name = name + suffix
|
22 |
-
setattr(self, fn_name, _gen_VF_wrapper(fn_name))
|
23 |
-
|
24 |
-
def has_old_rnns():
|
25 |
-
try:
|
26 |
-
torch.nn.backends.thnn.backend.LSTMCell
|
27 |
-
return True
|
28 |
-
except:
|
29 |
-
return False
|
30 |
-
|
31 |
-
def whitelist_rnn_cells(handle, verbose):
|
32 |
-
# Different module + function names in old/new RNN cases
|
33 |
-
if has_old_rnns():
|
34 |
-
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
|
35 |
-
mod = torch.nn.backends.thnn.backend
|
36 |
-
else:
|
37 |
-
fn_names = [x + '_cell' for x in RNN_NAMES]
|
38 |
-
mod = torch.nn.modules.rnn._VF
|
39 |
-
assert isinstance(mod, VariableFunctionsShim)
|
40 |
-
|
41 |
-
# Insert casts on cell functions
|
42 |
-
for fn in fn_names:
|
43 |
-
wrap.cached_cast(mod, fn, utils.maybe_half, handle,
|
44 |
-
try_caching=True, verbose=verbose)
|
45 |
-
|
46 |
-
if has_old_rnns():
|
47 |
-
# Special handling of `backward` for fused gru / lstm:
|
48 |
-
# The `backward` method calls Tensor.sum() (blacklist) internally,
|
49 |
-
# and then the resulting grad_input has the wrong type.
|
50 |
-
# TODO: where else is this a problem?
|
51 |
-
for rnn_type in ['GRUFused', 'LSTMFused']:
|
52 |
-
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
|
53 |
-
wrap.disable_casts(mod, 'backward', handle)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/scaler.py
DELETED
@@ -1,217 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from ..multi_tensor_apply import multi_tensor_applier
|
3 |
-
from ._amp_state import _amp_state, master_params, maybe_print
|
4 |
-
from itertools import product
|
5 |
-
|
6 |
-
def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
|
7 |
-
# Exception handling for 18.04 compatibility
|
8 |
-
if check_overflow:
|
9 |
-
cpu_sum = float(model_grad.float().sum())
|
10 |
-
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
11 |
-
return True
|
12 |
-
|
13 |
-
if master_grad is not model_grad: # copy_ probably internally short-circuits this
|
14 |
-
master_grad.copy_(model_grad)
|
15 |
-
if scale != 1.0:
|
16 |
-
master_grad.mul_(scale)
|
17 |
-
return False
|
18 |
-
|
19 |
-
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
|
20 |
-
# Exception handling for 18.04 compatibility
|
21 |
-
if check_overflow:
|
22 |
-
cpu_sum = float(model_grad.float().sum())
|
23 |
-
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
24 |
-
return True
|
25 |
-
|
26 |
-
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
|
27 |
-
# master_grad.copy_(model_grad)
|
28 |
-
assert stashed_grad.dtype == master_grad.dtype
|
29 |
-
converted_model_grad = model_grad.data.to(master_grad.dtype)
|
30 |
-
master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
|
31 |
-
return False
|
32 |
-
|
33 |
-
class LossScaler(object):
|
34 |
-
warned_no_fused_kernel = False
|
35 |
-
warned_unscaling_non_fp32_grad = False
|
36 |
-
has_fused_kernel = False
|
37 |
-
|
38 |
-
def __init__(self,
|
39 |
-
loss_scale,
|
40 |
-
init_scale=2.**16,
|
41 |
-
scale_factor=2.,
|
42 |
-
scale_window=2000,
|
43 |
-
min_loss_scale=None,
|
44 |
-
max_loss_scale=2.**24):
|
45 |
-
if loss_scale == "dynamic":
|
46 |
-
self.dynamic = True
|
47 |
-
self._loss_scale = min(max_loss_scale, init_scale)
|
48 |
-
else:
|
49 |
-
self.dynamic = False
|
50 |
-
self._loss_scale = loss_scale
|
51 |
-
self._max_loss_scale = max_loss_scale
|
52 |
-
self._min_loss_scale = min_loss_scale
|
53 |
-
self._scale_seq_len = scale_window
|
54 |
-
self._unskipped = 0
|
55 |
-
self._has_overflow = False
|
56 |
-
self._overflow_buf = torch.cuda.IntTensor([0])
|
57 |
-
if multi_tensor_applier.available:
|
58 |
-
import amp_C
|
59 |
-
LossScaler.has_fused_kernel = multi_tensor_applier.available
|
60 |
-
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
|
61 |
-
LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
|
62 |
-
else:
|
63 |
-
if not LossScaler.warned_no_fused_kernel:
|
64 |
-
maybe_print(
|
65 |
-
"Warning: multi_tensor_applier fused unscale kernel is unavailable, "
|
66 |
-
"possibly because apex was installed without --cuda_ext --cpp_ext. "
|
67 |
-
"Using Python fallback. Original ImportError was: " +
|
68 |
-
repr(multi_tensor_applier.import_err),
|
69 |
-
True)
|
70 |
-
LossScaler.has_fused_kernel = False
|
71 |
-
LossScaler.warned_no_fused_kernel = True
|
72 |
-
|
73 |
-
def loss_scale(self):
|
74 |
-
return self._loss_scale
|
75 |
-
|
76 |
-
def unscale_python(self, model_grads, master_grads, scale):
|
77 |
-
for model, master in zip(model_grads, master_grads):
|
78 |
-
if model is not None:
|
79 |
-
if not LossScaler.warned_unscaling_non_fp32_grad:
|
80 |
-
if master.dtype != torch.float32:
|
81 |
-
maybe_print(
|
82 |
-
"Attempting to unscale a grad with type {} ".format(master.type()) +
|
83 |
-
"Unscaling non-fp32 grads may indicate an error. "
|
84 |
-
"When using Amp, you don't need to call .half() on your model.")
|
85 |
-
LossScaler.warned_unscaling_non_fp32_grad = True
|
86 |
-
self._has_overflow = scale_check_overflow_python(model,
|
87 |
-
master,
|
88 |
-
1./scale,
|
89 |
-
self.dynamic)
|
90 |
-
if self._has_overflow and self.dynamic:
|
91 |
-
break
|
92 |
-
|
93 |
-
# unused_scale keeps some of the old API alive for hopefully a short time.
|
94 |
-
def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):
|
95 |
-
if self._has_overflow:
|
96 |
-
return
|
97 |
-
|
98 |
-
scale = self._loss_scale
|
99 |
-
if scale_override is not None:
|
100 |
-
scale = scale_override
|
101 |
-
|
102 |
-
if scale == 1.0 and models_are_masters and not self.dynamic:
|
103 |
-
return
|
104 |
-
|
105 |
-
if LossScaler.has_fused_kernel:
|
106 |
-
# if (not LossScaler.warned_unscaling_non_fp32_grad
|
107 |
-
# and master_grads[0].dtype == torch.float16):
|
108 |
-
# print("Warning: unscaling grads that are not FP32. "
|
109 |
-
# "Unscaling non-fp32 grads may indicate an error. "
|
110 |
-
# "When using Amp, you don't need to call .half() on your model.")
|
111 |
-
# # Setting this to True unconditionally allows the possibility of an escape
|
112 |
-
# # if never-before-seen non-fp32 grads are created in some later iteration.
|
113 |
-
# LossScaler.warned_unscaling_non_fp32_grad = True
|
114 |
-
multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
|
115 |
-
self._overflow_buf,
|
116 |
-
[model_grads, master_grads],
|
117 |
-
1./scale)
|
118 |
-
else:
|
119 |
-
self.unscale_python(model_grads, master_grads, scale)
|
120 |
-
|
121 |
-
# Defer to update_scale
|
122 |
-
# If the fused kernel is available, we only need one D2H memcopy and sync.
|
123 |
-
# if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
|
124 |
-
# self._has_overflow = self._overflow_buf.item()
|
125 |
-
|
126 |
-
def unscale_with_stashed_python(self,
|
127 |
-
model_grads,
|
128 |
-
stashed_master_grads,
|
129 |
-
master_grads,
|
130 |
-
a,
|
131 |
-
b):
|
132 |
-
for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
|
133 |
-
if model is None and stashed is None:
|
134 |
-
continue
|
135 |
-
else:
|
136 |
-
if not LossScaler.warned_unscaling_non_fp32_grad:
|
137 |
-
if master.dtype != torch.float32:
|
138 |
-
maybe_print(
|
139 |
-
"Attempting to unscale a grad with type {} ".format(master.type()) +
|
140 |
-
"Unscaling non-fp32 grads may indicate an error. "
|
141 |
-
"When using Amp, you don't need to call .half() on your model.")
|
142 |
-
LossScaler.warned_unscaling_non_fp32_grad = True
|
143 |
-
self._has_overflow = axpby_check_overflow_python(model,
|
144 |
-
stashed,
|
145 |
-
master,
|
146 |
-
a,
|
147 |
-
b,
|
148 |
-
self.dynamic)
|
149 |
-
if self._has_overflow and self.dynamic:
|
150 |
-
break
|
151 |
-
|
152 |
-
def unscale_with_stashed(self,
|
153 |
-
model_grads,
|
154 |
-
stashed_master_grads,
|
155 |
-
master_grads,
|
156 |
-
scale_override=None):
|
157 |
-
if self._has_overflow:
|
158 |
-
return
|
159 |
-
|
160 |
-
grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
|
161 |
-
if scale_override is not None:
|
162 |
-
grads_have_scale, stashed_have_scale, out_scale = scale_override
|
163 |
-
|
164 |
-
if LossScaler.has_fused_kernel:
|
165 |
-
if (not LossScaler.warned_unscaling_non_fp32_grad
|
166 |
-
and master_grads[0].dtype == torch.float16):
|
167 |
-
print("Warning: unscaling grads that are not FP32. "
|
168 |
-
"Unscaling non-fp32 grads may indicate an error. "
|
169 |
-
"When using Amp, you don't need to call .half() on your model.")
|
170 |
-
# Setting this to True unconditionally allows the possibility of an escape
|
171 |
-
# if never-before-seen non-fp32 grads are created in some later iteration.
|
172 |
-
LossScaler.warned_unscaling_non_fp32_grad = True
|
173 |
-
multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
|
174 |
-
self._overflow_buf,
|
175 |
-
[model_grads, stashed_master_grads, master_grads],
|
176 |
-
out_scale/grads_have_scale, # 1./scale,
|
177 |
-
out_scale/stashed_have_scale, # 1.0,
|
178 |
-
0) # check only arg 0, aka the incoming model grads, for infs
|
179 |
-
else:
|
180 |
-
self.unscale_with_stashed_python(model_grads,
|
181 |
-
stashed_master_grads,
|
182 |
-
master_grads,
|
183 |
-
out_scale/grads_have_scale,
|
184 |
-
out_scale/stashed_have_scale)
|
185 |
-
|
186 |
-
# Defer to update_scale
|
187 |
-
# If the fused kernel is available, we only need one D2H memcopy and sync.
|
188 |
-
# if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
|
189 |
-
# self._has_overflow = self._overflow_buf.item()
|
190 |
-
|
191 |
-
def clear_overflow_state(self):
|
192 |
-
self._has_overflow = False
|
193 |
-
if self.has_fused_kernel:
|
194 |
-
self._overflow_buf.zero_()
|
195 |
-
|
196 |
-
# Separate so unscale() can be called more that once before updating.
|
197 |
-
def update_scale(self):
|
198 |
-
# If the fused kernel is available, we only need one D2H memcopy and sync.
|
199 |
-
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
|
200 |
-
self._has_overflow = self._overflow_buf.item()
|
201 |
-
|
202 |
-
if self._has_overflow and self.dynamic:
|
203 |
-
should_skip = True
|
204 |
-
if(self._min_loss_scale):
|
205 |
-
self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
|
206 |
-
else:
|
207 |
-
self._loss_scale = self._loss_scale/2.
|
208 |
-
self._unskipped = 0
|
209 |
-
else:
|
210 |
-
should_skip = False
|
211 |
-
self._unskipped += 1
|
212 |
-
|
213 |
-
if self._unskipped == self._scale_seq_len and self.dynamic:
|
214 |
-
self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
|
215 |
-
self._unskipped = 0
|
216 |
-
|
217 |
-
return should_skip
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/utils.py
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
from . import compat
|
2 |
-
|
3 |
-
import functools
|
4 |
-
import itertools
|
5 |
-
|
6 |
-
import torch
|
7 |
-
|
8 |
-
def is_cuda_enabled():
|
9 |
-
return torch.version.cuda is not None
|
10 |
-
|
11 |
-
def get_cuda_version():
|
12 |
-
return tuple(int(x) for x in torch.version.cuda.split('.'))
|
13 |
-
|
14 |
-
def is_fp_tensor(x):
|
15 |
-
if is_nested(x):
|
16 |
-
# Fast-fail version of all(is_fp_tensor)
|
17 |
-
for y in x:
|
18 |
-
if not is_fp_tensor(y):
|
19 |
-
return False
|
20 |
-
return True
|
21 |
-
return compat.is_tensor_like(x) and compat.is_floating_point(x)
|
22 |
-
|
23 |
-
def is_nested(x):
|
24 |
-
return isinstance(x, tuple) or isinstance(x, list)
|
25 |
-
|
26 |
-
def should_cache(x):
|
27 |
-
if is_nested(x):
|
28 |
-
# Fast-fail version of all(should_cache)
|
29 |
-
for y in x:
|
30 |
-
if not should_cache(y):
|
31 |
-
return False
|
32 |
-
return True
|
33 |
-
return isinstance(x, torch.nn.parameter.Parameter) and \
|
34 |
-
type_string(x) == 'FloatTensor'
|
35 |
-
|
36 |
-
def collect_fp_tensor_types(args, kwargs):
|
37 |
-
def collect_types(x, types):
|
38 |
-
if is_nested(x):
|
39 |
-
for y in x:
|
40 |
-
collect_types(y, types)
|
41 |
-
else:
|
42 |
-
types.add(type_string(x))
|
43 |
-
|
44 |
-
all_args = itertools.chain(args, kwargs.values())
|
45 |
-
types = set()
|
46 |
-
for x in all_args:
|
47 |
-
if is_fp_tensor(x):
|
48 |
-
collect_types(x, types)
|
49 |
-
return types
|
50 |
-
|
51 |
-
def type_string(x):
|
52 |
-
return x.type().split('.')[-1]
|
53 |
-
|
54 |
-
def maybe_half(x, name='', verbose=False):
|
55 |
-
if is_nested(x):
|
56 |
-
return type(x)([maybe_half(y) for y in x])
|
57 |
-
|
58 |
-
if not x.is_cuda or type_string(x) == 'HalfTensor':
|
59 |
-
return x
|
60 |
-
else:
|
61 |
-
if verbose:
|
62 |
-
print('Float->Half ({})'.format(name))
|
63 |
-
return x.half()
|
64 |
-
|
65 |
-
def maybe_float(x, name='', verbose=False):
|
66 |
-
if is_nested(x):
|
67 |
-
return type(x)([maybe_float(y) for y in x])
|
68 |
-
|
69 |
-
if not x.is_cuda or type_string(x) == 'FloatTensor':
|
70 |
-
return x
|
71 |
-
else:
|
72 |
-
if verbose:
|
73 |
-
print('Half->Float ({})'.format(name))
|
74 |
-
return x.float()
|
75 |
-
|
76 |
-
# NB: returneds casted `args`, mutates `kwargs` in-place
|
77 |
-
def casted_args(cast_fn, args, kwargs):
|
78 |
-
new_args = []
|
79 |
-
for x in args:
|
80 |
-
if is_fp_tensor(x):
|
81 |
-
new_args.append(cast_fn(x))
|
82 |
-
else:
|
83 |
-
new_args.append(x)
|
84 |
-
for k in kwargs:
|
85 |
-
val = kwargs[k]
|
86 |
-
if is_fp_tensor(val):
|
87 |
-
kwargs[k] = cast_fn(val)
|
88 |
-
return new_args
|
89 |
-
|
90 |
-
def cached_cast(cast_fn, x, cache):
|
91 |
-
if is_nested(x):
|
92 |
-
return type(x)([cached_cast(y) for y in x])
|
93 |
-
if x in cache:
|
94 |
-
cached_x = cache[x]
|
95 |
-
if x.requires_grad and cached_x.requires_grad:
|
96 |
-
# Make sure x is actually cached_x's autograd parent.
|
97 |
-
if cached_x.grad_fn.next_functions[1][0].variable is not x:
|
98 |
-
raise RuntimeError("x and cache[x] both require grad, but x is not "
|
99 |
-
"cache[x]'s parent. This is likely an error.")
|
100 |
-
# During eval, it's possible to end up caching casted weights with
|
101 |
-
# requires_grad=False. On the next training iter, if cached_x is found
|
102 |
-
# and reused from the cache, it will not actually have x as its parent.
|
103 |
-
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
|
104 |
-
# if x.requires_grad and cached_x.requires_grad do not match.
|
105 |
-
#
|
106 |
-
# During eval (i.e. running under with torch.no_grad()) the invalidation
|
107 |
-
# check would cause the cached value to be dropped every time, because
|
108 |
-
# cached_x would always be created with requires_grad=False, while x would
|
109 |
-
# still have requires_grad=True. This would render the cache effectively
|
110 |
-
# useless during eval. Therefore, if we are running under the no_grad()
|
111 |
-
# context manager (torch.is_grad_enabled=False) we elide the invalidation
|
112 |
-
# check, and use the cached value even though its requires_grad flag doesn't
|
113 |
-
# match. During eval, we don't care that there's no autograd-graph
|
114 |
-
# connection between x and cached_x.
|
115 |
-
if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
|
116 |
-
del cache[x]
|
117 |
-
else:
|
118 |
-
return cached_x
|
119 |
-
|
120 |
-
casted_x = cast_fn(x)
|
121 |
-
cache[x] = casted_x
|
122 |
-
return casted_x
|
123 |
-
|
124 |
-
def verbosify(cast_fn, fn_name, verbose):
|
125 |
-
if verbose:
|
126 |
-
return functools.partial(cast_fn, name=fn_name, verbose=verbose)
|
127 |
-
else:
|
128 |
-
return cast_fn
|
129 |
-
|
130 |
-
def as_inplace(fns):
|
131 |
-
for x in fns:
|
132 |
-
yield x + '_'
|
133 |
-
|
134 |
-
def has_func(mod, fn):
|
135 |
-
if isinstance(mod, dict):
|
136 |
-
return fn in mod
|
137 |
-
else:
|
138 |
-
return hasattr(mod, fn)
|
139 |
-
|
140 |
-
def get_func(mod, fn):
|
141 |
-
if isinstance(mod, dict):
|
142 |
-
return mod[fn]
|
143 |
-
else:
|
144 |
-
return getattr(mod, fn)
|
145 |
-
|
146 |
-
def set_func(mod, fn, new_fn):
|
147 |
-
if isinstance(mod, dict):
|
148 |
-
mod[fn] = new_fn
|
149 |
-
else:
|
150 |
-
setattr(mod, fn, new_fn)
|
151 |
-
|
152 |
-
def set_func_save(handle, mod, fn, new_fn):
|
153 |
-
cur_fn = get_func(mod, fn)
|
154 |
-
handle._save_func(mod, fn, cur_fn)
|
155 |
-
set_func(mod, fn, new_fn)
|
156 |
-
|
157 |
-
# A couple problems get solved here:
|
158 |
-
# - The flat_weight buffer is disconnected from autograd graph,
|
159 |
-
# so the fp16 weights need to be derived from the input weights
|
160 |
-
# to this forward call, not the flat buffer.
|
161 |
-
# - The ordering of weights in the flat buffer is...idiosyncratic.
|
162 |
-
# First problem is solved with combination of set_ (to set up
|
163 |
-
# correct storage) and copy_ (so the fp16 weight derives from the
|
164 |
-
# fp32 one in autograd.
|
165 |
-
# Second is solved by doing ptr arithmetic on the fp32 weights
|
166 |
-
# to derive the correct offset.
|
167 |
-
#
|
168 |
-
# TODO: maybe this should actually use
|
169 |
-
# `torch._cudnn_rnn_flatten_weight`? But then I need to call
|
170 |
-
# on first iter and cache the right offsets. Ugh.
|
171 |
-
def synthesize_flattened_rnn_weights(fp32_weights,
|
172 |
-
fp16_flat_tensor,
|
173 |
-
rnn_fn='',
|
174 |
-
verbose=False):
|
175 |
-
fp16_weights = []
|
176 |
-
fp32_base_ptr = fp32_weights[0][0].data_ptr()
|
177 |
-
for layer_weights in fp32_weights:
|
178 |
-
fp16_layer_weights = []
|
179 |
-
for w_fp32 in layer_weights:
|
180 |
-
w_fp16 = w_fp32.new().half()
|
181 |
-
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
|
182 |
-
w_fp16.set_(fp16_flat_tensor.storage(),
|
183 |
-
offset,
|
184 |
-
w_fp32.shape)
|
185 |
-
w_fp16.copy_(w_fp32)
|
186 |
-
if verbose:
|
187 |
-
print('Float->Half ({})'.format(rnn_fn))
|
188 |
-
fp16_layer_weights.append(w_fp16)
|
189 |
-
fp16_weights.append(fp16_layer_weights)
|
190 |
-
return fp16_weights
|
191 |
-
|
192 |
-
# Roughly same as above, just the `fp32_weights` aren't nested.
|
193 |
-
# Code kept separate for readability.
|
194 |
-
def new_synthesize_flattened_rnn_weights(fp32_weights,
|
195 |
-
fp16_flat_tensor,
|
196 |
-
rnn_fn='',
|
197 |
-
verbose=False):
|
198 |
-
fp16_weights = []
|
199 |
-
fp32_base_ptr = fp32_weights[0].data_ptr()
|
200 |
-
for w_fp32 in fp32_weights:
|
201 |
-
w_fp16 = w_fp32.new().half()
|
202 |
-
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
|
203 |
-
w_fp16.set_(fp16_flat_tensor.storage(),
|
204 |
-
offset,
|
205 |
-
w_fp32.shape)
|
206 |
-
w_fp16.copy_(w_fp32)
|
207 |
-
if verbose:
|
208 |
-
print('Float->Half ({})'.format(rnn_fn))
|
209 |
-
fp16_weights.append(w_fp16)
|
210 |
-
return fp16_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/amp/wrap.py
DELETED
@@ -1,276 +0,0 @@
|
|
1 |
-
from . import compat
|
2 |
-
from . import utils
|
3 |
-
from ._amp_state import _amp_state
|
4 |
-
from . import rnn_compat
|
5 |
-
|
6 |
-
import functools
|
7 |
-
|
8 |
-
import torch
|
9 |
-
|
10 |
-
def make_cast_wrapper(orig_fn, cast_fn, handle,
|
11 |
-
try_caching=False):
|
12 |
-
@functools.wraps(orig_fn)
|
13 |
-
def wrapper(*args, **kwargs):
|
14 |
-
if not handle.is_active():
|
15 |
-
return orig_fn(*args, **kwargs)
|
16 |
-
|
17 |
-
if try_caching and handle.has_cache:
|
18 |
-
args = list(args)
|
19 |
-
for i in range(len(args)):
|
20 |
-
if utils.should_cache(args[i]):
|
21 |
-
args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
|
22 |
-
for k in kwargs:
|
23 |
-
if utils.should_cache(kwargs[k]):
|
24 |
-
kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
|
25 |
-
new_args = utils.casted_args(cast_fn,
|
26 |
-
args,
|
27 |
-
kwargs)
|
28 |
-
return orig_fn(*new_args, **kwargs)
|
29 |
-
return wrapper
|
30 |
-
|
31 |
-
def cached_cast(mod, fn, cast_fn, handle,
|
32 |
-
try_caching=False, verbose=False):
|
33 |
-
if not utils.has_func(mod, fn):
|
34 |
-
return
|
35 |
-
|
36 |
-
orig_fn = utils.get_func(mod, fn)
|
37 |
-
cast_fn = utils.verbosify(cast_fn, fn, verbose)
|
38 |
-
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
|
39 |
-
utils.set_func_save(handle, mod, fn, wrapper)
|
40 |
-
|
41 |
-
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
|
42 |
-
# Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
|
43 |
-
# is on the new API and I am free to get rid of handle, I can clean this up.
|
44 |
-
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
|
45 |
-
@functools.wraps(orig_fn)
|
46 |
-
def wrapper(*args, **kwargs):
|
47 |
-
if not _amp_state.handle.is_active():
|
48 |
-
return orig_fn(*args, **kwargs)
|
49 |
-
|
50 |
-
types = utils.collect_fp_tensor_types(args, kwargs)
|
51 |
-
|
52 |
-
if len(types) <= 1:
|
53 |
-
return orig_fn(*args, **kwargs)
|
54 |
-
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
|
55 |
-
new_args = utils.casted_args(cast_fn,
|
56 |
-
args,
|
57 |
-
kwargs)
|
58 |
-
return orig_fn(*new_args, **kwargs)
|
59 |
-
else:
|
60 |
-
raise NotImplementedError('Do not know how to handle ' +
|
61 |
-
'these types to promote: {}'
|
62 |
-
.format(types))
|
63 |
-
return wrapper
|
64 |
-
|
65 |
-
def promote(mod, fn, handle, verbose=False):
|
66 |
-
orig_fn = utils.get_func(mod, fn)
|
67 |
-
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
|
68 |
-
wrapper = make_promote_wrapper(orig_fn, maybe_float)
|
69 |
-
utils.set_func_save(handle, mod, fn, wrapper)
|
70 |
-
|
71 |
-
def sequence_promote(mod, fn, handle, verbose=False):
|
72 |
-
orig_fn = utils.get_func(mod, fn)
|
73 |
-
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
|
74 |
-
@functools.wraps(orig_fn)
|
75 |
-
def wrapper(seq, *args, **kwargs):
|
76 |
-
if not _amp_state.handle.is_active():
|
77 |
-
return orig_fn(seq, *args, **kwargs)
|
78 |
-
|
79 |
-
types = set([utils.type_string(x) for x in seq])
|
80 |
-
if len(types) <= 1:
|
81 |
-
return orig_fn(seq, *args, **kwargs)
|
82 |
-
elif types == set(['HalfTensor', 'FloatTensor']):
|
83 |
-
cast_seq = utils.casted_args(maybe_float,
|
84 |
-
seq, {})
|
85 |
-
return orig_fn(cast_seq, *args, **kwargs)
|
86 |
-
else:
|
87 |
-
# TODO: other mixed-type cases aren't due to amp.
|
88 |
-
# Just pass through?
|
89 |
-
return orig_fn(seq, *args, **kwargs)
|
90 |
-
utils.set_func_save(handle, mod, fn, wrapper)
|
91 |
-
|
92 |
-
def promote_match_arg0(mod, fn, handle, verbose=False):
|
93 |
-
if not utils.has_func(mod, fn):
|
94 |
-
return
|
95 |
-
|
96 |
-
orig_fn = utils.get_func(mod, fn)
|
97 |
-
@functools.wraps(orig_fn)
|
98 |
-
def wrapper(arg0, *args, **kwargs):
|
99 |
-
assert compat.is_tensor_like(arg0)
|
100 |
-
if not _amp_state.handle.is_active():
|
101 |
-
return orig_fn(arg0, *args, **kwargs)
|
102 |
-
|
103 |
-
if utils.type_string(arg0) == 'HalfTensor':
|
104 |
-
cast_fn = utils.maybe_half
|
105 |
-
elif utils.type_string(arg0) == 'FloatTensor':
|
106 |
-
cast_fn = utils.maybe_float
|
107 |
-
else:
|
108 |
-
return orig_fn(arg0, *args, **kwargs)
|
109 |
-
cast_fn = utils.verbosify(cast_fn, fn, verbose)
|
110 |
-
new_args = utils.casted_args(cast_fn, args, kwargs)
|
111 |
-
return orig_fn(arg0, *new_args, **kwargs)
|
112 |
-
utils.set_func_save(handle, mod, fn, wrapper)
|
113 |
-
|
114 |
-
def err_if_any_half(mod, fn, handle, custom_err_msg=None):
|
115 |
-
if not utils.has_func(mod, fn):
|
116 |
-
return
|
117 |
-
|
118 |
-
orig_fn = utils.get_func(mod, fn)
|
119 |
-
@functools.wraps(orig_fn)
|
120 |
-
def wrapper(*args, **kwargs):
|
121 |
-
types = utils.collect_fp_tensor_types(args, kwargs)
|
122 |
-
if 'HalfTensor' in types:
|
123 |
-
if custom_err_msg:
|
124 |
-
raise NotImplementedError(custom_err_msg)
|
125 |
-
else:
|
126 |
-
raise NotImplementedError('Cannot call in-place function ' +
|
127 |
-
'{} with fp16 arguments.'.format(fn))
|
128 |
-
else:
|
129 |
-
return orig_fn(*args, **kwargs)
|
130 |
-
utils.set_func_save(handle, mod, fn, wrapper)
|
131 |
-
|
132 |
-
def err_if_arg0_half(mod, fn, handle, verbose=False):
|
133 |
-
if not utils.has_func(mod, fn):
|
134 |
-
return
|
135 |
-
|
136 |
-
orig_fn = utils.get_func(mod, fn)
|
137 |
-
@functools.wraps(orig_fn)
|
138 |
-
def wrapper(arg0, *args, **kwargs):
|
139 |
-
assert compat.is_tensor_like(arg0)
|
140 |
-
if utils.type_string(arg0) == 'HalfTensor':
|
141 |
-
raise NotImplementedError('Cannot call in-place method ' +
|
142 |
-
'{} on fp16 Tensors.'.format(fn))
|
143 |
-
else:
|
144 |
-
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
|
145 |
-
new_args = utils.casted_args(cast_fn, args, kwargs)
|
146 |
-
return orig_fn(arg0, *new_args, **kwargs)
|
147 |
-
utils.set_func_save(handle, mod, fn, wrapper)
|
148 |
-
|
149 |
-
# Current RNN approach:
|
150 |
-
# - Wrap top-level `RNN` function in thnn backend
|
151 |
-
# - Will call into either CudnnRNN or AutogradRNN
|
152 |
-
# - Each of these are factory functions that return a per-iter
|
153 |
-
# `forward` function
|
154 |
-
# - We interpose on the factory function to:
|
155 |
-
# 1) Interpose on the actual forward function and put in casts
|
156 |
-
# 2) Insert an fp16 `flat_weight` if necessary
|
157 |
-
def rnn_cast(backend, fn, handle, verbose=False):
|
158 |
-
orig_rnn = utils.get_func(backend, fn)
|
159 |
-
@functools.wraps(orig_rnn)
|
160 |
-
def rnn_wrapper(*args, **kwargs):
|
161 |
-
flat_weight = kwargs.get('flat_weight')
|
162 |
-
if flat_weight is not None:
|
163 |
-
# We replace `flat_weight` with an uninitialized fp16
|
164 |
-
# Tensor. The "actual" weight tensors (provided in `forward`),
|
165 |
-
# will then be set up as ptrs into the buffer and have the
|
166 |
-
# corresponding fp32 values copied in.
|
167 |
-
# We need to call `copy` on the "actual" weights so that the
|
168 |
-
# autograd graph correctly backprops from the wgrads computed
|
169 |
-
# inside cuDNN (on fp16 weights) into the fp32 weights.
|
170 |
-
assert utils.type_string(flat_weight) == 'FloatTensor'
|
171 |
-
if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
|
172 |
-
# Pre-0.4. A little slower, since it zeros out memory.
|
173 |
-
flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
|
174 |
-
else:
|
175 |
-
flat_weight_fp16 = torch.empty_like(flat_weight,
|
176 |
-
dtype=torch.float16)
|
177 |
-
kwargs['flat_weight'] = flat_weight_fp16
|
178 |
-
else:
|
179 |
-
flat_weight_fp16 = None
|
180 |
-
|
181 |
-
forward = orig_rnn(*args, **kwargs)
|
182 |
-
@functools.wraps(forward)
|
183 |
-
def fwd_wrapper(*fargs, **fkwargs):
|
184 |
-
assert len(fargs) == 3 or len(fargs) == 4
|
185 |
-
inputs, weights, hiddens = fargs[:3]
|
186 |
-
assert utils.is_fp_tensor(inputs)
|
187 |
-
assert isinstance(weights, list)
|
188 |
-
cast_fn = utils.verbosify(utils.maybe_half,
|
189 |
-
fn,
|
190 |
-
verbose)
|
191 |
-
new_args = []
|
192 |
-
|
193 |
-
# 0) Inputs
|
194 |
-
new_args.append(cast_fn(inputs))
|
195 |
-
|
196 |
-
# 1) Weights
|
197 |
-
if flat_weight_fp16 is not None:
|
198 |
-
fp16_weights = utils.synthesize_flattened_rnn_weights(
|
199 |
-
weights, flat_weight_fp16, fn, verbose)
|
200 |
-
else:
|
201 |
-
fp16_weights = [[cast_fn(w) for w in layer]
|
202 |
-
for layer in weights]
|
203 |
-
new_args.append(fp16_weights)
|
204 |
-
|
205 |
-
# 2) Inputs: either a tuple (for LSTM) or single tensor
|
206 |
-
if isinstance(hiddens, tuple):
|
207 |
-
new_args.append(tuple(cast_fn(x) for x in hiddens))
|
208 |
-
elif utils.is_fp_tensor(hiddens):
|
209 |
-
new_args.append(cast_fn(hiddens))
|
210 |
-
else:
|
211 |
-
# Hiddens can, in principle, be `None` -- pass through
|
212 |
-
new_args.append(hiddens)
|
213 |
-
|
214 |
-
# 3) Batch sizes (0.4 or later only)
|
215 |
-
if len(fargs) == 4:
|
216 |
-
new_args.append(fargs[3])
|
217 |
-
|
218 |
-
return forward(*new_args, **fkwargs)
|
219 |
-
return fwd_wrapper
|
220 |
-
utils.set_func_save(handle, backend, fn, rnn_wrapper)
|
221 |
-
|
222 |
-
def new_rnn_cast(fn, handle, verbose=False):
|
223 |
-
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
|
224 |
-
# For rnn backend calls that route through _rnn_impls, we must patch the ref
|
225 |
-
# that _rnn_impls stashed. For rnn backend calls that directly invoke
|
226 |
-
# _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
|
227 |
-
# which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
|
228 |
-
if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
|
229 |
-
mod = torch.nn.modules.rnn._rnn_impls
|
230 |
-
else:
|
231 |
-
mod = torch.nn.modules.rnn._VF
|
232 |
-
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
|
233 |
-
fn = fn.lower()
|
234 |
-
orig_fn = utils.get_func(mod, fn)
|
235 |
-
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
|
236 |
-
@functools.wraps(orig_fn)
|
237 |
-
def wrapper(*args, **kwargs):
|
238 |
-
# Exact call signature from modules/rnn.py
|
239 |
-
assert len(args) == 9
|
240 |
-
assert len(kwargs) == 0
|
241 |
-
|
242 |
-
if not _amp_state.handle.is_active():
|
243 |
-
return orig_fn(*args, **kwargs)
|
244 |
-
|
245 |
-
if isinstance(args[6], bool):
|
246 |
-
params_idx = 2 # Not PackedSequence case
|
247 |
-
else:
|
248 |
-
params_idx = 3 # PackedSequence case
|
249 |
-
|
250 |
-
new_args = []
|
251 |
-
for i, arg in enumerate(args):
|
252 |
-
if i == params_idx:
|
253 |
-
num_params = sum([x.numel() for x in arg])
|
254 |
-
fp16_weight_buf = args[0].new_empty((num_params,),
|
255 |
-
dtype=torch.half)
|
256 |
-
casted_weights = utils.new_synthesize_flattened_rnn_weights(
|
257 |
-
arg, fp16_weight_buf, fn, verbose)
|
258 |
-
new_args.append(casted_weights)
|
259 |
-
elif utils.is_fp_tensor(arg):
|
260 |
-
new_args.append(cast_fn(arg))
|
261 |
-
else:
|
262 |
-
new_args.append(arg)
|
263 |
-
|
264 |
-
return orig_fn(*new_args)
|
265 |
-
utils.set_func_save(handle, mod, fn, wrapper)
|
266 |
-
|
267 |
-
def disable_casts(mod, fn, handle):
|
268 |
-
if not utils.has_func(mod, fn):
|
269 |
-
return
|
270 |
-
|
271 |
-
orig_fn = utils.get_func(mod, fn)
|
272 |
-
@functools.wraps(orig_fn)
|
273 |
-
def wrapper(*args, **kwargs):
|
274 |
-
with handle._disable_casts():
|
275 |
-
return orig_fn(*args, **kwargs)
|
276 |
-
utils.set_func_save(handle, mod, fn, wrapper)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/__init__.py
DELETED
File without changes
|
apex/apex/contrib/bottleneck/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .bottleneck import Bottleneck, SpatialBottleneck
|
2 |
-
from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
|
|
|
|
|
|
apex/apex/contrib/bottleneck/bottleneck.py
DELETED
@@ -1,749 +0,0 @@
|
|
1 |
-
import functools as func
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.distributed as dist
|
5 |
-
from torch import nn
|
6 |
-
|
7 |
-
from apex import check_cudnn_version_and_warn
|
8 |
-
import fast_bottleneck
|
9 |
-
import nccl_p2p_cuda as inc
|
10 |
-
|
11 |
-
|
12 |
-
assert check_cudnn_version_and_warn(__name__, 8400)
|
13 |
-
|
14 |
-
|
15 |
-
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
16 |
-
weight_tensor_nchw = tensor
|
17 |
-
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
|
18 |
-
|
19 |
-
def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):
|
20 |
-
scale = weight * running_var.rsqrt()
|
21 |
-
bias = bias - running_mean * scale
|
22 |
-
w_scale.copy_(scale)
|
23 |
-
w_bias.copy_(bias)
|
24 |
-
|
25 |
-
def compute_scale_bias_method(nhwc, args):
|
26 |
-
for arg in args:
|
27 |
-
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
|
28 |
-
compute_scale_bias_one(nhwc, *arg)
|
29 |
-
|
30 |
-
class FrozenBatchNorm2d(torch.jit.ScriptModule):
|
31 |
-
"""
|
32 |
-
BatchNorm2d where the batch statistics and the affine parameters are fixed
|
33 |
-
"""
|
34 |
-
def __init__(self, n):
|
35 |
-
super(FrozenBatchNorm2d, self).__init__()
|
36 |
-
self.register_buffer("weight", torch.ones(n))
|
37 |
-
self.register_buffer("bias", torch.zeros(n))
|
38 |
-
self.register_buffer("running_mean", torch.zeros(n))
|
39 |
-
self.register_buffer("running_var", torch.ones(n))
|
40 |
-
|
41 |
-
@torch.jit.script_method
|
42 |
-
def get_scale_bias(self, nhwc):
|
43 |
-
# type: (bool) -> List[torch.Tensor]
|
44 |
-
scale = self.weight * self.running_var.rsqrt()
|
45 |
-
bias = self.bias - self.running_mean * scale
|
46 |
-
if nhwc:
|
47 |
-
scale = scale.reshape(1, 1, 1, -1)
|
48 |
-
bias = bias.reshape(1, 1, 1, -1)
|
49 |
-
else:
|
50 |
-
scale = scale.reshape(1, -1, 1, 1)
|
51 |
-
bias = bias.reshape(1, -1, 1, 1)
|
52 |
-
return scale, bias
|
53 |
-
|
54 |
-
@torch.jit.script_method
|
55 |
-
def forward(self, x):
|
56 |
-
scale, bias = self.get_scale_bias(False)
|
57 |
-
return x * scale + bias
|
58 |
-
|
59 |
-
@torch.jit.script
|
60 |
-
def drelu_dscale1(grad_o, output, scale1):
|
61 |
-
relu_mask = (output>0)
|
62 |
-
dx_relu = relu_mask * grad_o
|
63 |
-
g1 = dx_relu * scale1
|
64 |
-
return g1, dx_relu
|
65 |
-
|
66 |
-
@torch.jit.script
|
67 |
-
def drelu_dscale2(grad_o, output, scale1, scale2):
|
68 |
-
relu_mask = (output>0)
|
69 |
-
dx_relu = relu_mask * grad_o
|
70 |
-
g1 = dx_relu * scale1
|
71 |
-
g2 = dx_relu * scale2
|
72 |
-
return g1, g2
|
73 |
-
|
74 |
-
class BottleneckFunction(torch.autograd.Function):
|
75 |
-
@staticmethod
|
76 |
-
def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv):
|
77 |
-
# TODO: clean up order of tensors
|
78 |
-
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
|
79 |
-
ctx.downsample = len(conv) > 3
|
80 |
-
if ctx.downsample:
|
81 |
-
args.append(conv[3])
|
82 |
-
args.append(scale[3])
|
83 |
-
args.append(bias[3])
|
84 |
-
|
85 |
-
# weight buffers are always in nhwc while shape can be nhwc or channels_last
|
86 |
-
# here we pass in flag and let c++ handle it
|
87 |
-
# alternatively, we can put all sizes into a fixed format and pass it in
|
88 |
-
outputs = fast_bottleneck.forward(nhwc, stride_1x1, args)
|
89 |
-
ctx.save_for_backward(*(args+outputs))
|
90 |
-
# save relu outputs for drelu
|
91 |
-
ctx.nhwc = nhwc
|
92 |
-
ctx.stride_1x1 = stride_1x1
|
93 |
-
return outputs[2]
|
94 |
-
|
95 |
-
# backward relu is not exposed, MUL with mask used now
|
96 |
-
# only support dgrad
|
97 |
-
@staticmethod
|
98 |
-
def backward(ctx, grad_o):
|
99 |
-
outputs = ctx.saved_tensors[-3:]
|
100 |
-
|
101 |
-
if ctx.downsample:
|
102 |
-
grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
|
103 |
-
else:
|
104 |
-
grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])
|
105 |
-
|
106 |
-
# create input vector for backward
|
107 |
-
t_list = [*ctx.saved_tensors[0:10]]
|
108 |
-
t_list.append(grad_conv3)
|
109 |
-
t_list.append(grad_conv4)
|
110 |
-
|
111 |
-
# outputs used for wgrad and generating drelu mask
|
112 |
-
t_list.append(outputs[0])
|
113 |
-
t_list.append(outputs[1])
|
114 |
-
|
115 |
-
# in case there is downsample
|
116 |
-
if ctx.downsample:
|
117 |
-
t_list.append(ctx.saved_tensors[10])
|
118 |
-
|
119 |
-
grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list)
|
120 |
-
|
121 |
-
return (None, None, None, None, *grads)
|
122 |
-
|
123 |
-
bottleneck_function = BottleneckFunction.apply
|
124 |
-
|
125 |
-
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
126 |
-
"""3x3 convolution with padding"""
|
127 |
-
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
128 |
-
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
129 |
-
|
130 |
-
def conv1x1(in_planes, out_planes, stride=1):
|
131 |
-
"""1x1 convolution"""
|
132 |
-
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
133 |
-
|
134 |
-
class Bottleneck(torch.nn.Module):
|
135 |
-
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
136 |
-
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
137 |
-
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
138 |
-
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
139 |
-
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
140 |
-
# here we put it at 1x1
|
141 |
-
|
142 |
-
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
|
143 |
-
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False):
|
144 |
-
super(Bottleneck, self).__init__()
|
145 |
-
if groups != 1:
|
146 |
-
raise RuntimeError('Only support groups == 1')
|
147 |
-
if dilation != 1:
|
148 |
-
raise RuntimeError('Only support dilation == 1')
|
149 |
-
if norm_func == None:
|
150 |
-
norm_func = FrozenBatchNorm2d
|
151 |
-
else:
|
152 |
-
raise RuntimeError('Only support frozen BN now.')
|
153 |
-
|
154 |
-
if stride != 1 or in_channels != out_channels:
|
155 |
-
self.downsample = nn.Sequential(
|
156 |
-
conv1x1(in_channels, out_channels, stride),
|
157 |
-
norm_func(out_channels),
|
158 |
-
)
|
159 |
-
else:
|
160 |
-
self.downsample = None
|
161 |
-
|
162 |
-
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
163 |
-
self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
|
164 |
-
self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
|
165 |
-
self.conv3 = conv1x1(bottleneck_channels, out_channels)
|
166 |
-
self.relu = nn.ReLU(inplace=True)
|
167 |
-
self.stride = stride
|
168 |
-
|
169 |
-
self.bn1 = norm_func(bottleneck_channels)
|
170 |
-
self.bn2 = norm_func(bottleneck_channels)
|
171 |
-
self.bn3 = norm_func(out_channels)
|
172 |
-
self.w_scale = None
|
173 |
-
|
174 |
-
self.use_cudnn = use_cudnn
|
175 |
-
|
176 |
-
# setup conv weights
|
177 |
-
self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
|
178 |
-
if self.downsample is not None:
|
179 |
-
self.w_conv.append(self.downsample[0].weight)
|
180 |
-
|
181 |
-
# init weight in nchw format before possible transpose
|
182 |
-
for w in self.w_conv:
|
183 |
-
kaiming_uniform_(w, a=1)
|
184 |
-
|
185 |
-
# TODO: prevent unsupported case usage
|
186 |
-
# support cases
|
187 |
-
# native cudnn
|
188 |
-
# normal yes no
|
189 |
-
# channel_last yes yes
|
190 |
-
# explicit_nhwc no yes
|
191 |
-
self.explicit_nhwc = explicit_nhwc
|
192 |
-
if self.explicit_nhwc:
|
193 |
-
for p in self.parameters():
|
194 |
-
with torch.no_grad():
|
195 |
-
p.data = p.data.permute(0,2,3,1).contiguous()
|
196 |
-
|
197 |
-
return
|
198 |
-
|
199 |
-
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
|
200 |
-
# This method must be called before cuda graphing.
|
201 |
-
# The callable it returns can be called anytime.
|
202 |
-
# Calling this method will prevent these from being computed every forward call.
|
203 |
-
def get_scale_bias_callable(self):
|
204 |
-
self.w_scale, self.w_bias, args = [], [], []
|
205 |
-
batch_norms = [self.bn1, self.bn2, self.bn3]
|
206 |
-
if self.downsample is not None:
|
207 |
-
batch_norms.append(self.downsample[1])
|
208 |
-
for bn in batch_norms:
|
209 |
-
s = torch.empty_like(bn.weight)
|
210 |
-
b = torch.empty_like(s)
|
211 |
-
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
|
212 |
-
if self.explicit_nhwc:
|
213 |
-
self.w_scale.append( s.reshape(1, 1, 1, -1) )
|
214 |
-
self.w_bias.append( b.reshape(1, 1, 1, -1) )
|
215 |
-
else:
|
216 |
-
self.w_scale.append( s.reshape(1, -1, 1, 1) )
|
217 |
-
self.w_bias.append( b.reshape(1, -1, 1, 1) )
|
218 |
-
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
|
219 |
-
|
220 |
-
def forward(self, x):
|
221 |
-
if self.use_cudnn:
|
222 |
-
if self.w_scale is None:
|
223 |
-
# calculate scale/bias from registered buffers
|
224 |
-
# TODO: make this better
|
225 |
-
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
|
226 |
-
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
|
227 |
-
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
|
228 |
-
w_scale = [s1, s2, s3]
|
229 |
-
w_bias = [b1, b2, b3]
|
230 |
-
if self.downsample is not None:
|
231 |
-
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
|
232 |
-
w_scale.append(s4)
|
233 |
-
w_bias.append(b4)
|
234 |
-
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
|
235 |
-
else:
|
236 |
-
out = bottleneck_function(self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv)
|
237 |
-
return out
|
238 |
-
|
239 |
-
if self.explicit_nhwc:
|
240 |
-
raise RuntimeError('explicit nhwc with native ops is not supported.')
|
241 |
-
|
242 |
-
# fallback to native ops
|
243 |
-
identity = x
|
244 |
-
|
245 |
-
out = self.conv1(x)
|
246 |
-
out = self.bn1(out)
|
247 |
-
out = self.relu(out)
|
248 |
-
|
249 |
-
out = self.conv2(out)
|
250 |
-
out = self.bn2(out)
|
251 |
-
out = self.relu(out)
|
252 |
-
|
253 |
-
out = self.conv3(out)
|
254 |
-
out = self.bn3(out)
|
255 |
-
|
256 |
-
if self.downsample is not None:
|
257 |
-
identity = self.downsample(x)
|
258 |
-
|
259 |
-
out += identity
|
260 |
-
out = self.relu(out)
|
261 |
-
|
262 |
-
return out
|
263 |
-
|
264 |
-
|
265 |
-
class SpatialBottleneckFunction(torch.autograd.Function):
|
266 |
-
@staticmethod
|
267 |
-
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
|
268 |
-
if spatial_group_size > 1:
|
269 |
-
stream1 = spatial_halo_exchanger.stream1
|
270 |
-
stream2 = spatial_halo_exchanger.stream2
|
271 |
-
stream3 = spatial_halo_exchanger.stream3
|
272 |
-
|
273 |
-
# TODO: clean up order of tensors
|
274 |
-
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
|
275 |
-
ctx.downsample = len(conv) > 3
|
276 |
-
if ctx.downsample:
|
277 |
-
args.append(conv[3])
|
278 |
-
args.append(scale[3])
|
279 |
-
args.append(bias[3])
|
280 |
-
|
281 |
-
# weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last
|
282 |
-
# here we pass in flag and let c++ handle it
|
283 |
-
# alternatively, we can put all sizes into a fixed format and pass it in
|
284 |
-
outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)
|
285 |
-
fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)
|
286 |
-
|
287 |
-
if spatial_group_size > 1:
|
288 |
-
out1 = outputs[0]
|
289 |
-
if explicit_nhwc:
|
290 |
-
N,Hs,W,C = list(out1.shape)
|
291 |
-
memory_format = torch.contiguous_format
|
292 |
-
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
|
293 |
-
else:
|
294 |
-
N,C,Hs,W = list(out1.shape)
|
295 |
-
memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
|
296 |
-
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
|
297 |
-
stream1.wait_stream(torch.cuda.current_stream())
|
298 |
-
if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
|
299 |
-
with torch.cuda.stream(stream1):
|
300 |
-
if explicit_nhwc:
|
301 |
-
top_out1_halo = out1_pad[:,:1,:,:]
|
302 |
-
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
|
303 |
-
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
|
304 |
-
else:
|
305 |
-
top_out1_halo = out1_pad[:,:,:1,:]
|
306 |
-
btm_out1_halo = out1_pad[:,:,Hs+1:Hs+2,:]
|
307 |
-
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
|
308 |
-
if spatial_method == 1:
|
309 |
-
# overlap mid convolution with halo transfer
|
310 |
-
if spatial_group_rank < spatial_group_size-1:
|
311 |
-
stream2.wait_stream(stream1)
|
312 |
-
with torch.cuda.stream(stream2):
|
313 |
-
if explicit_nhwc:
|
314 |
-
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
|
315 |
-
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
|
316 |
-
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
|
317 |
-
else:
|
318 |
-
btm_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
|
319 |
-
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
|
320 |
-
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
|
321 |
-
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
|
322 |
-
if spatial_group_rank > 0:
|
323 |
-
with torch.cuda.stream(stream1):
|
324 |
-
if explicit_nhwc:
|
325 |
-
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
|
326 |
-
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
|
327 |
-
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
|
328 |
-
else:
|
329 |
-
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
|
330 |
-
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
|
331 |
-
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
|
332 |
-
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
|
333 |
-
if use_delay_kernel: inc.add_delay(10)
|
334 |
-
elif spatial_method != 2 and spatial_method != 3:
|
335 |
-
assert(False), "spatial_method must be 1, 2 or 3"
|
336 |
-
|
337 |
-
if spatial_group_size <= 1:
|
338 |
-
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
|
339 |
-
elif spatial_method == 1:
|
340 |
-
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
|
341 |
-
with torch.cuda.stream(stream3):
|
342 |
-
if explicit_nhwc:
|
343 |
-
out1_pad[:,1:Hs+1,:,:].copy_(out1)
|
344 |
-
else:
|
345 |
-
out1_pad[:,:,1:Hs+1,:].copy_(out1)
|
346 |
-
elif spatial_method == 2:
|
347 |
-
# wait for halo transfer to finish before doing a full convolution of padded x
|
348 |
-
if explicit_nhwc:
|
349 |
-
out1_pad[:,1:Hs+1,:,:].copy_(out1)
|
350 |
-
else:
|
351 |
-
out1_pad[:,:,1:Hs+1,:].copy_(out1)
|
352 |
-
torch.cuda.current_stream().wait_stream(stream1)
|
353 |
-
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
|
354 |
-
elif spatial_method == 3:
|
355 |
-
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
|
356 |
-
with torch.cuda.stream(stream3):
|
357 |
-
if explicit_nhwc:
|
358 |
-
out1_pad[:,1:Hs+1,:,:].copy_(out1)
|
359 |
-
else:
|
360 |
-
out1_pad[:,:,1:Hs+1,:].copy_(out1)
|
361 |
-
|
362 |
-
# compute halo cells for outputs[1] (out2)
|
363 |
-
if spatial_group_size > 1:
|
364 |
-
out2 = outputs[1]
|
365 |
-
if explicit_nhwc:
|
366 |
-
top_out2_halo = out2[:,:1,:,:]
|
367 |
-
btm_out2_halo = out2[:,Hs-1:,:,:]
|
368 |
-
else:
|
369 |
-
top_out2_halo = out2[:,:,:1,:]
|
370 |
-
btm_out2_halo = out2[:,:,Hs-1:,:]
|
371 |
-
if spatial_method == 1:
|
372 |
-
if spatial_group_rank > 0:
|
373 |
-
torch.cuda.current_stream().wait_stream(stream1)
|
374 |
-
top_out2_halo.copy_(top_out2)
|
375 |
-
if spatial_group_rank < spatial_group_size-1:
|
376 |
-
torch.cuda.current_stream().wait_stream(stream2)
|
377 |
-
btm_out2_halo.copy_(btm_out2)
|
378 |
-
elif spatial_method == 3:
|
379 |
-
# Note
|
380 |
-
# out2 halo correction cannot overlap with anything since it has
|
381 |
-
# to wait for out2_mask to finish, but itself has to finish before
|
382 |
-
# the first kernel of _forward_rest can launch.
|
383 |
-
# At least we can overlap the two halo correction kernels.
|
384 |
-
if spatial_group_rank < spatial_group_size-1:
|
385 |
-
stream2.wait_stream(stream1) # wait for halo transfers to finish
|
386 |
-
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
|
387 |
-
with torch.cuda.stream(stream2):
|
388 |
-
w1by3 = args[2][:,2:3,:,:].clone()
|
389 |
-
btm_out1_halo = btm_out1_halo.clone()
|
390 |
-
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
|
391 |
-
btm_out2_halo.copy_(btm_out2)
|
392 |
-
if spatial_group_rank > 0:
|
393 |
-
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
|
394 |
-
with torch.cuda.stream(stream1):
|
395 |
-
w1by3 = args[2][:,:1,:,:].clone()
|
396 |
-
top_out1_halo = top_out1_halo.clone()
|
397 |
-
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
|
398 |
-
top_out2_halo.copy_(top_out2)
|
399 |
-
if spatial_group_rank < spatial_group_size-1:
|
400 |
-
torch.cuda.current_stream().wait_stream(stream2)
|
401 |
-
if spatial_group_rank > 0:
|
402 |
-
torch.cuda.current_stream().wait_stream(stream1)
|
403 |
-
|
404 |
-
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
|
405 |
-
# save halos for backward pass
|
406 |
-
if spatial_group_size > 1:
|
407 |
-
if spatial_method != 2:
|
408 |
-
# make sure copy of mid-section of out1 into out1_pad is done before exiting
|
409 |
-
torch.cuda.current_stream().wait_stream(stream3)
|
410 |
-
ctx.save_for_backward(*(args+outputs+[out1_pad,]))
|
411 |
-
else:
|
412 |
-
ctx.save_for_backward(*(args+outputs))
|
413 |
-
# save relu outputs for drelu
|
414 |
-
ctx.explicit_nhwc = explicit_nhwc
|
415 |
-
ctx.stride_1x1 = stride_1x1
|
416 |
-
ctx.spatial_group_size = spatial_group_size
|
417 |
-
if spatial_group_size > 1:
|
418 |
-
ctx.spatial_group_rank = spatial_group_rank
|
419 |
-
ctx.spatial_halo_exchanger = spatial_halo_exchanger
|
420 |
-
ctx.spatial_method = spatial_method
|
421 |
-
ctx.use_delay_kernel = use_delay_kernel
|
422 |
-
ctx.thresholdTop = thresholdTop
|
423 |
-
ctx.thresholdBottom = thresholdBottom
|
424 |
-
ctx.stream1 = stream1
|
425 |
-
ctx.stream2 = stream2
|
426 |
-
ctx.stream3 = stream3
|
427 |
-
return outputs[2]
|
428 |
-
|
429 |
-
# backward relu is not exposed, MUL with mask used now
|
430 |
-
# only support dgrad
|
431 |
-
@staticmethod
|
432 |
-
def backward(ctx, grad_o):
|
433 |
-
if ctx.spatial_group_size > 1:
|
434 |
-
out1_pad = ctx.saved_tensors[-1]
|
435 |
-
outputs = ctx.saved_tensors[-4:-1]
|
436 |
-
else:
|
437 |
-
outputs = ctx.saved_tensors[-3:]
|
438 |
-
|
439 |
-
if ctx.downsample:
|
440 |
-
grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
|
441 |
-
else:
|
442 |
-
grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])
|
443 |
-
|
444 |
-
# create input vector for backward
|
445 |
-
t_list = [*ctx.saved_tensors[0:10]]
|
446 |
-
t_list.append(grad_conv3)
|
447 |
-
t_list.append(grad_conv4)
|
448 |
-
|
449 |
-
# outputs used for wgrad and generating drelu mask
|
450 |
-
t_list.append(outputs[0])
|
451 |
-
t_list.append(outputs[1])
|
452 |
-
|
453 |
-
# in case there is downsample
|
454 |
-
if ctx.downsample:
|
455 |
-
t_list.append(ctx.saved_tensors[10])
|
456 |
-
|
457 |
-
grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
|
458 |
-
wgrad3_stream = torch.cuda.Stream()
|
459 |
-
wgrad3_stream.wait_stream(torch.cuda.current_stream())
|
460 |
-
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
|
461 |
-
wgrad2_stream = torch.cuda.Stream()
|
462 |
-
wgrad2_stream.wait_stream(torch.cuda.current_stream())
|
463 |
-
# do halo exchange of grad_out2 here
|
464 |
-
# compute halo cells for grad_out1
|
465 |
-
if ctx.spatial_group_size > 1:
|
466 |
-
if ctx.explicit_nhwc:
|
467 |
-
N,Hs,W,C = list(grad_out2.shape)
|
468 |
-
else:
|
469 |
-
N,C,Hs,W = list(grad_out2.shape)
|
470 |
-
relu1 = t_list[12]
|
471 |
-
ctx.stream1.wait_stream(torch.cuda.current_stream())
|
472 |
-
with torch.cuda.stream(ctx.stream1):
|
473 |
-
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
|
474 |
-
# copy halos to send buffer
|
475 |
-
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
|
476 |
-
# 1 -> halo recompute approach
|
477 |
-
# 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
|
478 |
-
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
|
479 |
-
ctx.stream2.wait_stream(ctx.stream1)
|
480 |
-
with torch.cuda.stream(ctx.stream2):
|
481 |
-
if ctx.explicit_nhwc:
|
482 |
-
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
|
483 |
-
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
|
484 |
-
btm_fat_halo[:,2:,:,:].copy_(btm_halo)
|
485 |
-
btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
|
486 |
-
btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
|
487 |
-
btm_fat_relu_halo[:,2:,:,:].zero_()
|
488 |
-
else:
|
489 |
-
btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
|
490 |
-
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
|
491 |
-
btm_fat_halo[:,:,2:,:].copy_(btm_halo)
|
492 |
-
btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
|
493 |
-
btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
|
494 |
-
btm_fat_relu_halo[:,:,2:,:].zero_()
|
495 |
-
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo)
|
496 |
-
if ctx.explicit_nhwc:
|
497 |
-
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
|
498 |
-
else:
|
499 |
-
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
|
500 |
-
if ctx.spatial_group_rank > 0:
|
501 |
-
with torch.cuda.stream(ctx.stream1):
|
502 |
-
if ctx.explicit_nhwc:
|
503 |
-
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
|
504 |
-
top_fat_halo[:,:1,:,:].copy_(top_halo)
|
505 |
-
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
|
506 |
-
top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
|
507 |
-
top_fat_relu_halo[:,:1,:,:].zero_()
|
508 |
-
top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
|
509 |
-
else:
|
510 |
-
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
|
511 |
-
top_fat_halo[:,:,:1,:].copy_(top_halo)
|
512 |
-
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
|
513 |
-
top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
|
514 |
-
top_fat_relu_halo[:,:,:1,:].zero_()
|
515 |
-
top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
|
516 |
-
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo)
|
517 |
-
if ctx.explicit_nhwc:
|
518 |
-
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
|
519 |
-
else:
|
520 |
-
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
|
521 |
-
if ctx.use_delay_kernel: inc.add_delay(10)
|
522 |
-
elif ctx.spatial_method != 3:
|
523 |
-
assert(False), "spatial_method must be 1, 2 or 3"
|
524 |
-
|
525 |
-
# compute grad_out1 for internal cells
|
526 |
-
if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
|
527 |
-
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
|
528 |
-
elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
|
529 |
-
grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom)
|
530 |
-
|
531 |
-
# apply halo cells to grad_out1
|
532 |
-
if ctx.spatial_group_size > 1:
|
533 |
-
w = t_list[2]
|
534 |
-
z = t_list[4]
|
535 |
-
relu1 = t_list[12]
|
536 |
-
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
|
537 |
-
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
|
538 |
-
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
|
539 |
-
torch.cuda.current_stream().wait_stream(ctx.stream2)
|
540 |
-
if ctx.explicit_nhwc:
|
541 |
-
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
|
542 |
-
else:
|
543 |
-
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
|
544 |
-
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
|
545 |
-
if ctx.spatial_group_rank > 0:
|
546 |
-
torch.cuda.current_stream().wait_stream(ctx.stream1)
|
547 |
-
if ctx.explicit_nhwc:
|
548 |
-
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
|
549 |
-
else:
|
550 |
-
grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
|
551 |
-
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
|
552 |
-
elif ctx.spatial_method == 3:
|
553 |
-
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
|
554 |
-
if ctx.explicit_nhwc:
|
555 |
-
btm_relu_halo = relu1[:,Hs-1:,:,:].clone()
|
556 |
-
btm_grad_out1 = grad_out1[:,Hs-1:,:,:]
|
557 |
-
else:
|
558 |
-
btm_relu_halo = relu1[:,:,Hs-1:,:].clone()
|
559 |
-
btm_grad_out1 = grad_out1[:,:,Hs-1:,:]
|
560 |
-
w1by3 = w[:,:1,:,:].clone()
|
561 |
-
ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish
|
562 |
-
ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
|
563 |
-
with torch.cuda.stream(ctx.stream2):
|
564 |
-
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
|
565 |
-
btm_grad_out1.copy_(btm_grad_out1_halo)
|
566 |
-
if ctx.spatial_group_rank > 0:
|
567 |
-
if ctx.explicit_nhwc:
|
568 |
-
top_relu_halo = relu1[:,:1,:,:].clone()
|
569 |
-
top_grad_out1 = grad_out1[:,:1,:,:]
|
570 |
-
else:
|
571 |
-
top_relu_halo = relu1[:,:,:1,:].clone()
|
572 |
-
top_grad_out1 = grad_out1[:,:,:1,:]
|
573 |
-
w1by3 = w[:,2:,:,:].clone()
|
574 |
-
ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
|
575 |
-
with torch.cuda.stream(ctx.stream1):
|
576 |
-
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone())
|
577 |
-
top_grad_out1.copy_(top_grad_out1_halo)
|
578 |
-
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
|
579 |
-
torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish
|
580 |
-
if ctx.spatial_group_rank > 0:
|
581 |
-
torch.cuda.current_stream().wait_stream(ctx.stream1)
|
582 |
-
|
583 |
-
wgrad1_stream = torch.cuda.Stream()
|
584 |
-
wgrad1_stream.wait_stream(torch.cuda.current_stream())
|
585 |
-
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1)
|
586 |
-
with torch.cuda.stream(wgrad3_stream):
|
587 |
-
fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
|
588 |
-
with torch.cuda.stream(wgrad2_stream):
|
589 |
-
if ctx.spatial_group_size > 1:
|
590 |
-
fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
|
591 |
-
else:
|
592 |
-
fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
|
593 |
-
with torch.cuda.stream(wgrad1_stream):
|
594 |
-
fast_bottleneck.backward_wgrad1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1)
|
595 |
-
torch.cuda.current_stream().wait_stream(wgrad3_stream)
|
596 |
-
torch.cuda.current_stream().wait_stream(wgrad2_stream)
|
597 |
-
torch.cuda.current_stream().wait_stream(wgrad1_stream)
|
598 |
-
|
599 |
-
return (None, None, None, None, None, None, None, None, None, None, None, None, *grads)
|
600 |
-
|
601 |
-
spatial_bottleneck_function = SpatialBottleneckFunction.apply
|
602 |
-
|
603 |
-
class SpatialBottleneck(torch.nn.Module):
|
604 |
-
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
605 |
-
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
606 |
-
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
607 |
-
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
608 |
-
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
609 |
-
# here we put it at 1x1
|
610 |
-
|
611 |
-
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
|
612 |
-
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
|
613 |
-
spatial_parallel_args=None):
|
614 |
-
super(SpatialBottleneck, self).__init__()
|
615 |
-
if groups != 1:
|
616 |
-
raise RuntimeError('Only support groups == 1')
|
617 |
-
if dilation != 1:
|
618 |
-
raise RuntimeError('Only support dilation == 1')
|
619 |
-
if norm_func == None:
|
620 |
-
norm_func = FrozenBatchNorm2d
|
621 |
-
else:
|
622 |
-
raise RuntimeError('Only support frozen BN now.')
|
623 |
-
|
624 |
-
if stride != 1 or in_channels != out_channels:
|
625 |
-
self.downsample = nn.Sequential(
|
626 |
-
conv1x1(in_channels, out_channels, stride),
|
627 |
-
norm_func(out_channels),
|
628 |
-
)
|
629 |
-
else:
|
630 |
-
self.downsample = None
|
631 |
-
|
632 |
-
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
633 |
-
self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
|
634 |
-
self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
|
635 |
-
self.conv3 = conv1x1(bottleneck_channels, out_channels)
|
636 |
-
self.relu = nn.ReLU(inplace=True)
|
637 |
-
self.stride = stride
|
638 |
-
|
639 |
-
self.bn1 = norm_func(bottleneck_channels)
|
640 |
-
self.bn2 = norm_func(bottleneck_channels)
|
641 |
-
self.bn3 = norm_func(out_channels)
|
642 |
-
self.w_scale = None
|
643 |
-
|
644 |
-
self.use_cudnn = use_cudnn
|
645 |
-
|
646 |
-
# setup conv weights
|
647 |
-
self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
|
648 |
-
if self.downsample is not None:
|
649 |
-
self.w_conv.append(self.downsample[0].weight)
|
650 |
-
|
651 |
-
# init weight in nchw format before possible transpose
|
652 |
-
for w in self.w_conv:
|
653 |
-
kaiming_uniform_(w, a=1)
|
654 |
-
|
655 |
-
self.thresholdTop, self.thresholdBottom = None, None
|
656 |
-
|
657 |
-
# TODO: prevent unsupported case usage
|
658 |
-
# support cases
|
659 |
-
# native cudnn
|
660 |
-
# normal yes no
|
661 |
-
# channel_last yes yes
|
662 |
-
# explicit_nhwc no yes
|
663 |
-
self.explicit_nhwc = explicit_nhwc
|
664 |
-
if self.explicit_nhwc:
|
665 |
-
for p in self.parameters():
|
666 |
-
with torch.no_grad():
|
667 |
-
p.data = p.data.permute(0,2,3,1).contiguous()
|
668 |
-
|
669 |
-
# spatial communicator
|
670 |
-
if spatial_parallel_args is None:
|
671 |
-
self.spatial_parallel_args = (1, 0, None, None, 0, False)
|
672 |
-
else:
|
673 |
-
self.spatial_parallel_args = spatial_parallel_args
|
674 |
-
return
|
675 |
-
|
676 |
-
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
|
677 |
-
# This method must be called before cuda graphing.
|
678 |
-
# The callable it returns can be called anytime.
|
679 |
-
# Calling this method will prevent these from being computed every forward call.
|
680 |
-
def get_scale_bias_callable(self):
|
681 |
-
self.w_scale, self.w_bias, args = [], [], []
|
682 |
-
batch_norms = [self.bn1, self.bn2, self.bn3]
|
683 |
-
if self.downsample is not None:
|
684 |
-
batch_norms.append(self.downsample[1])
|
685 |
-
for bn in batch_norms:
|
686 |
-
s = torch.empty_like(bn.weight)
|
687 |
-
b = torch.empty_like(s)
|
688 |
-
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
|
689 |
-
if self.explicit_nhwc:
|
690 |
-
self.w_scale.append( s.reshape(1, 1, 1, -1) )
|
691 |
-
self.w_bias.append( b.reshape(1, 1, 1, -1) )
|
692 |
-
else:
|
693 |
-
self.w_scale.append( s.reshape(1, -1, 1, 1) )
|
694 |
-
self.w_bias.append( b.reshape(1, -1, 1, 1) )
|
695 |
-
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
|
696 |
-
|
697 |
-
def forward(self, x):
|
698 |
-
if self.use_cudnn:
|
699 |
-
if self.thresholdTop is None:
|
700 |
-
spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args
|
701 |
-
if self.explicit_nhwc:
|
702 |
-
N,H,W,C = list(x.shape)
|
703 |
-
else:
|
704 |
-
N,C,H,W = list(x.shape)
|
705 |
-
self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
|
706 |
-
self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
|
707 |
-
|
708 |
-
if self.w_scale is None:
|
709 |
-
# calculate scale/bias from registered buffers
|
710 |
-
# TODO: make this better
|
711 |
-
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
|
712 |
-
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
|
713 |
-
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
|
714 |
-
w_scale = [s1, s2, s3]
|
715 |
-
w_bias = [b1, b2, b3]
|
716 |
-
if self.downsample is not None:
|
717 |
-
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
|
718 |
-
w_scale.append(s4)
|
719 |
-
w_bias.append(b4)
|
720 |
-
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
|
721 |
-
else:
|
722 |
-
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
|
723 |
-
return out
|
724 |
-
|
725 |
-
if self.explicit_nhwc:
|
726 |
-
raise RuntimeError('explicit nhwc with native ops is not supported.')
|
727 |
-
|
728 |
-
# fallback to native ops
|
729 |
-
identity = x
|
730 |
-
|
731 |
-
out = self.conv1(x)
|
732 |
-
out = self.bn1(out)
|
733 |
-
out = self.relu(out)
|
734 |
-
|
735 |
-
out = self.conv2(out)
|
736 |
-
out = self.bn2(out)
|
737 |
-
out = self.relu(out)
|
738 |
-
|
739 |
-
out = self.conv3(out)
|
740 |
-
out = self.bn3(out)
|
741 |
-
|
742 |
-
if self.downsample is not None:
|
743 |
-
identity = self.downsample(x)
|
744 |
-
|
745 |
-
out += identity
|
746 |
-
out = self.relu(out)
|
747 |
-
|
748 |
-
return out
|
749 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/bottleneck/halo_exchangers.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.distributed as dist
|
3 |
-
from torch import nn
|
4 |
-
import nccl_p2p_cuda as inc
|
5 |
-
import peer_memory_cuda as pm
|
6 |
-
|
7 |
-
# Communication free halo exchanger.
|
8 |
-
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
|
9 |
-
# NB! This is only useful for performance testing.
|
10 |
-
# NB! Do not use for actual production runs
|
11 |
-
class HaloExchanger(object):
|
12 |
-
def __init__(self, ranks, rank_in_group):
|
13 |
-
self.stream1 = torch.cuda.Stream()
|
14 |
-
self.stream2 = torch.cuda.Stream()
|
15 |
-
self.stream3 = torch.cuda.Stream()
|
16 |
-
self.group_size = len(ranks)
|
17 |
-
self.ranks = ranks
|
18 |
-
self.rank_in_group = rank_in_group
|
19 |
-
self.wrap_around_left_rank_in_group = (rank_in_group + self.group_size - 1) % self.group_size
|
20 |
-
self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size
|
21 |
-
self.left_rank = ranks[rank_in_group-1] if rank_in_group > 0 else -1
|
22 |
-
self.left_zero = True if rank_in_group == 0 else False
|
23 |
-
self.right_rank = ranks[rank_in_group+1] if rank_in_group < self.group_size - 1 else -1
|
24 |
-
self.right_zero = True if rank_in_group == self.group_size - 1 else False
|
25 |
-
|
26 |
-
class HaloExchangerNoComm(HaloExchanger):
|
27 |
-
def __init__(self, ranks, rank_in_group):
|
28 |
-
super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)
|
29 |
-
|
30 |
-
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
|
31 |
-
if left_input_halo is None:
|
32 |
-
return right_output_halo, left_output_halo
|
33 |
-
else:
|
34 |
-
left_input_halo.copy_(right_output_halo)
|
35 |
-
right_input_halo.copy_(left_output_halo)
|
36 |
-
|
37 |
-
class HaloExchangerAllGather(HaloExchanger):
|
38 |
-
def __init__(self, ranks, rank_in_group, comm):
|
39 |
-
super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)
|
40 |
-
# self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
|
41 |
-
self.comm = comm
|
42 |
-
|
43 |
-
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
|
44 |
-
N,Hh,W,C = list(left_output_halo.shape)
|
45 |
-
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
|
46 |
-
send_halos[:,:Hh,:,:].copy_(left_output_halo)
|
47 |
-
send_halos[:,Hh:,:,:].copy_(right_output_halo)
|
48 |
-
all_halos = torch.empty((N,2*Hh*self.group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
|
49 |
-
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.group_size)]
|
50 |
-
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
|
51 |
-
ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:,Hh:,:,:]
|
52 |
-
ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:,:Hh,:,:]
|
53 |
-
if left_input_halo is None:
|
54 |
-
if self.left_zero:
|
55 |
-
ag_left_input_halo.zero_()
|
56 |
-
if self.right_zero:
|
57 |
-
ag_right_input_halo.zero_()
|
58 |
-
return ag_left_input_halo, ag_right_input_halo
|
59 |
-
else:
|
60 |
-
if self.left_zero:
|
61 |
-
left_input_halo.zero_()
|
62 |
-
else:
|
63 |
-
left_input_halo.copy_(ag_left_input_halo)
|
64 |
-
if self.right_zero:
|
65 |
-
right_input_halo.zero_()
|
66 |
-
else:
|
67 |
-
right_input_halo.copy_(ag_right_input_halo)
|
68 |
-
|
69 |
-
class HaloExchangerSendRecv(HaloExchanger):
|
70 |
-
def __init__(self, ranks, rank_in_group):
|
71 |
-
super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)
|
72 |
-
nccl_id = inc.get_unique_nccl_id(1).cuda()
|
73 |
-
torch.distributed.broadcast(nccl_id, 0)
|
74 |
-
nccl_id = nccl_id.cpu()
|
75 |
-
print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id)))
|
76 |
-
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
|
77 |
-
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
|
78 |
-
# it cannot be accessed from another class.
|
79 |
-
# TODO: Figure out a way to avoid creating a second global communicator
|
80 |
-
assert(torch.distributed.get_rank() == self.ranks[self.rank_in_group]), "ranks[%d](%d) != torch.distributed.get_rank()(%d)" % (self.rank_in_group, self.ranks[self.rank_in_group], torch.distributed.get_rank())
|
81 |
-
self.handle = inc.init_nccl_comm(nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size())
|
82 |
-
|
83 |
-
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
|
84 |
-
if left_input_halo is None:
|
85 |
-
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_rank, self.right_rank , left_output_halo, right_output_halo)
|
86 |
-
return left_input_halo, right_input_halo
|
87 |
-
else:
|
88 |
-
inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo)
|
89 |
-
|
90 |
-
class HaloExchangerPeer(HaloExchanger):
|
91 |
-
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0):
|
92 |
-
super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
|
93 |
-
self.diagnostics = False
|
94 |
-
self.explicit_nhwc = explicit_nhwc
|
95 |
-
self.numSM = numSM
|
96 |
-
self.peer_pool = peer_pool
|
97 |
-
|
98 |
-
def _allocate_peer_tensor(self, halo):
|
99 |
-
|
100 |
-
# Compute size in bytes
|
101 |
-
# Note: Pad buffer so each CUDA block gets required buffer size
|
102 |
-
size = 4 * halo.numel() * halo.element_size()
|
103 |
-
size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers
|
104 |
-
size = (size + size_per_block - 1) // size_per_block * size_per_block
|
105 |
-
|
106 |
-
# Construct dtype peer buffer with desired size
|
107 |
-
shape = [1, 1, 1, size // halo.element_size()]
|
108 |
-
return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)
|
109 |
-
|
110 |
-
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
|
111 |
-
inplace = False if left_input_halo is None and right_input_halo is None else True
|
112 |
-
if not inplace:
|
113 |
-
left_input_halo = torch.empty_like(right_output_halo)
|
114 |
-
right_input_halo = torch.empty_like(left_output_halo)
|
115 |
-
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
|
116 |
-
left_tx = self._allocate_peer_tensor(left_input_halo)
|
117 |
-
right_tx = self._allocate_peer_tensor(right_input_halo)
|
118 |
-
pm.push_pull_halos_1d(
|
119 |
-
self.diagnostics, self.explicit_nhwc, self.numSM, self.rank_in_group,
|
120 |
-
self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
|
121 |
-
self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
|
122 |
-
)
|
123 |
-
if not inplace:
|
124 |
-
return left_input_halo, right_input_halo
|
125 |
-
|
126 |
-
# Class that combines input volume with halos from neighbors (1d).
|
127 |
-
class HaloPadder:
|
128 |
-
def __init__(self, halo_ex):
|
129 |
-
self.halo_ex = halo_ex
|
130 |
-
self.stream1 = torch.cuda.Stream()
|
131 |
-
self.stream2 = torch.cuda.Stream()
|
132 |
-
|
133 |
-
def __call__(self, y, half_halo, explicit_nhwc, H_split):
|
134 |
-
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
|
135 |
-
if explicit_nhwc:
|
136 |
-
N,H,W,C = list(y.shape)
|
137 |
-
if H_split:
|
138 |
-
padded_shape = [N,H+2*half_halo,W,C]
|
139 |
-
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
|
140 |
-
yleft = ypad[:,:half_halo,:,:]
|
141 |
-
ymid = ypad[:,half_halo:H+half_halo,:,:]
|
142 |
-
yright = ypad[:,H+half_halo:H+2*half_halo,:,:]
|
143 |
-
oleft = y[:,:half_halo,:,:]
|
144 |
-
oright = y[:,H-half_halo:,:,:]
|
145 |
-
else:
|
146 |
-
padded_shape = [N,H,W+2*half_halo,C]
|
147 |
-
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
|
148 |
-
yleft = ypad[:,:,:half_halo,:]
|
149 |
-
ymid = ypad[:,:,half_halo:W+half_halo,:]
|
150 |
-
yright = ypad[:,:,W+half_halo:W+2*half_halo,:]
|
151 |
-
oleft = y[:,:,:half_halo,:]
|
152 |
-
oright = y[:,:,W-half_halo:,:]
|
153 |
-
else:
|
154 |
-
N,C,H,W = list(y.shape)
|
155 |
-
if H_split:
|
156 |
-
padded_shape = [N,C,H+2*half_halo,W]
|
157 |
-
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
|
158 |
-
yleft = ypad[:,:,:half_halo,:]
|
159 |
-
ymid = ypad[:,:,half_halo:H+half_halo,:]
|
160 |
-
yright = ypad[:,:,H+half_halo:H+2*half_halo,:]
|
161 |
-
oleft = y[:,:,:half_halo,:]
|
162 |
-
oright = y[:,:,H-half_halo:,:]
|
163 |
-
else:
|
164 |
-
padded_shape = [N,C,H,W+2*half_halo]
|
165 |
-
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
|
166 |
-
yleft = ypad[:,:,:,:half_halo]
|
167 |
-
ymid = ypad[:,:,:,half_halo:W+half_halo]
|
168 |
-
yright = ypad[:,:,:,W+half_halo:W+2*half_halo]
|
169 |
-
oleft = y[:,:,:,:half_halo]
|
170 |
-
oright = y[:,:,:,W-half_halo:]
|
171 |
-
with torch.cuda.stream(self.stream1):
|
172 |
-
self.halo_ex(oleft, oright, yleft, yright)
|
173 |
-
with torch.cuda.stream(self.stream2):
|
174 |
-
ymid.copy_(y)
|
175 |
-
return ypad
|
176 |
-
|
177 |
-
def wait(self):
|
178 |
-
current_stream = torch.cuda.current_stream()
|
179 |
-
current_stream.wait_stream(self.stream1)
|
180 |
-
current_stream.wait_stream(self.stream2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/bottleneck/test.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from bottleneck import Bottleneck
|
3 |
-
torch.manual_seed(23337)
|
4 |
-
|
5 |
-
# use True to print layerwise sum for all outputs in reference code path
|
6 |
-
DEBUG = False#True
|
7 |
-
|
8 |
-
for stride, o_channel in [(1,32), (1,128), (2,32)]:
|
9 |
-
print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel)
|
10 |
-
a_ = torch.randn(17,32,28,28)
|
11 |
-
|
12 |
-
a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_()
|
13 |
-
model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last)
|
14 |
-
|
15 |
-
# test model
|
16 |
-
b = model(a)
|
17 |
-
b.mean().backward()
|
18 |
-
d_grad = a.grad.float()
|
19 |
-
a.grad = None
|
20 |
-
torch.cuda.synchronize()
|
21 |
-
|
22 |
-
if DEBUG:
|
23 |
-
print("[DEBUG] ref dx :", d_grad.sum().item())
|
24 |
-
# print wgrad. we don't need to reset since later cpp print before accumulation
|
25 |
-
for i, w in enumerate(model.w_conv):
|
26 |
-
print("[DEBUG] ref wgrad{} :".format(i+1), w.grad.sum().item())
|
27 |
-
|
28 |
-
wgrads = []
|
29 |
-
for w in model.w_conv:
|
30 |
-
wgrads.append(w.grad.float())
|
31 |
-
|
32 |
-
model.use_cudnn = True
|
33 |
-
model.zero_grad()
|
34 |
-
c = model(a)
|
35 |
-
c.mean().backward()
|
36 |
-
|
37 |
-
torch.cuda.synchronize()
|
38 |
-
print("comparing native and channels_last:")
|
39 |
-
print("max error fprop:", (b-c).abs().max().item(), "max elem:", b.abs().max().item())
|
40 |
-
print("max error dgrad:", (d_grad-a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item())
|
41 |
-
for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)):
|
42 |
-
print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item())
|
43 |
-
|
44 |
-
nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_()
|
45 |
-
nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half()
|
46 |
-
for p,q in zip(model.parameters(), nhwc_model.parameters()):
|
47 |
-
# model's storage is already in nhwc, we clone and assign to explicit nhwc model
|
48 |
-
q.data.copy_(p.data.permute(0,2,3,1).contiguous())
|
49 |
-
for p,q in zip(model.buffers(), nhwc_model.buffers()):
|
50 |
-
q.data.copy_(p.data)
|
51 |
-
|
52 |
-
d = nhwc_model(nhwc_a)
|
53 |
-
d.mean().backward()
|
54 |
-
torch.cuda.synchronize()
|
55 |
-
|
56 |
-
# reset reference to cudnn channels_last permute
|
57 |
-
#c_s = c.storage().tolist()
|
58 |
-
#d_s = d.storage().tolist()
|
59 |
-
#print(max([x-y for x,y in zip(c_s,d_s)]))
|
60 |
-
c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous()
|
61 |
-
d_grad = a.grad.float().permute(0,2,3,1).contiguous()
|
62 |
-
wgrads = []
|
63 |
-
for w in model.w_conv:
|
64 |
-
wgrads.append(w.grad.float().permute(0,2,3,1).contiguous())
|
65 |
-
|
66 |
-
torch.cuda.synchronize()
|
67 |
-
print("comparing nhwc and channels_last:")
|
68 |
-
print("max error fprop:", (d-c).abs().max().item(), "max elem:", c.abs().max().item())
|
69 |
-
print("max error dgrad:", (d_grad-nhwc_a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item())
|
70 |
-
for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)):
|
71 |
-
print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/clip_grad/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .clip_grad import clip_grad_norm_
|
|
|
|
apex/apex/contrib/clip_grad/clip_grad.py
DELETED
@@ -1,128 +0,0 @@
|
|
1 |
-
from typing import Union, Iterable
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
_kernel_import_succeeded = False
|
6 |
-
try:
|
7 |
-
import amp_C
|
8 |
-
from apex.multi_tensor_apply import multi_tensor_applier
|
9 |
-
_kernel_import_succeeded = True
|
10 |
-
except ImportError:
|
11 |
-
_kernel_import_succeeded = False
|
12 |
-
|
13 |
-
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
|
14 |
-
|
15 |
-
|
16 |
-
def clip_grad_norm_(
|
17 |
-
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
|
18 |
-
error_if_nonfinite: bool = False) -> torch.Tensor:
|
19 |
-
r"""Clips gradient norm of an iterable of parameters.
|
20 |
-
|
21 |
-
The norm is computed over all gradients together, as if they were
|
22 |
-
concatenated into a single vector. Gradients are modified in-place.
|
23 |
-
|
24 |
-
This is identical to torch.nn.utils.clip_grad_norm_, except it
|
25 |
-
uses a fused CUDA kernel when computing the 2-norm of GPU tensors
|
26 |
-
in float32 and float16.
|
27 |
-
|
28 |
-
Args:
|
29 |
-
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
30 |
-
single Tensor that will have gradients normalized
|
31 |
-
max_norm (float or int): max norm of the gradients
|
32 |
-
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
33 |
-
infinity norm.
|
34 |
-
error_if_nonfinite (bool): if True, an error is thrown if the total
|
35 |
-
norm of the gradients from :attr:`parameters` is ``nan``,
|
36 |
-
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
Total norm of the parameters (viewed as a single vector).
|
40 |
-
|
41 |
-
"""
|
42 |
-
if isinstance(parameters, torch.Tensor):
|
43 |
-
parameters = [parameters]
|
44 |
-
parameters = [p for p in parameters if p.grad is not None]
|
45 |
-
max_norm = float(max_norm)
|
46 |
-
norm_type = float(norm_type)
|
47 |
-
|
48 |
-
# Trivial case
|
49 |
-
if len(parameters) == 0:
|
50 |
-
return torch.tensor(0.)
|
51 |
-
|
52 |
-
# Fallback implementation
|
53 |
-
if not (_kernel_import_succeeded
|
54 |
-
and norm_type == 2.0
|
55 |
-
and any(p.is_cuda for p in parameters)):
|
56 |
-
return torch.nn.utils.clip_grad_norm_(
|
57 |
-
parameters,
|
58 |
-
max_norm,
|
59 |
-
norm_type=norm_type,
|
60 |
-
error_if_nonfinite = error_if_nonfinite,
|
61 |
-
)
|
62 |
-
|
63 |
-
# Find fp32 and fp16 gradients on GPU
|
64 |
-
device = next(p.device for p in parameters if p.is_cuda)
|
65 |
-
grads_fp32, grads_fp16, grads_misc = [], [], []
|
66 |
-
for p in parameters:
|
67 |
-
grad = p.grad.detach()
|
68 |
-
if p.dtype == torch.float32 and p.device == device:
|
69 |
-
grads_fp32.append(grad)
|
70 |
-
elif p.dtype == torch.float16 and p.device == device:
|
71 |
-
grads_fp16.append(grad)
|
72 |
-
else:
|
73 |
-
grads_misc.append(grad)
|
74 |
-
|
75 |
-
# Compute gradient L2 norms
|
76 |
-
norms = []
|
77 |
-
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)
|
78 |
-
if grads_fp32:
|
79 |
-
norms.append(
|
80 |
-
multi_tensor_applier(
|
81 |
-
amp_C.multi_tensor_l2norm,
|
82 |
-
dummy_overflow_buf,
|
83 |
-
[grads_fp32],
|
84 |
-
False,
|
85 |
-
)[0]
|
86 |
-
)
|
87 |
-
if grads_fp16:
|
88 |
-
norms.append(
|
89 |
-
multi_tensor_applier(
|
90 |
-
amp_C.multi_tensor_l2norm,
|
91 |
-
dummy_overflow_buf,
|
92 |
-
[grads_fp16],
|
93 |
-
False,
|
94 |
-
)[0],
|
95 |
-
)
|
96 |
-
for g in grads_misc:
|
97 |
-
norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))
|
98 |
-
total_norm = torch.linalg.norm(torch.cat(norms))
|
99 |
-
|
100 |
-
# Check for non-finite values
|
101 |
-
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
102 |
-
raise RuntimeError(
|
103 |
-
f'The total norm of order {norm_type} for gradients from '
|
104 |
-
'`parameters` is non-finite, so it cannot be clipped. To disable '
|
105 |
-
'this error and scale the gradients by the non-finite norm anyway, '
|
106 |
-
'set `error_if_nonfinite=False`')
|
107 |
-
|
108 |
-
# Scale gradients
|
109 |
-
clip_coef = max_norm / (total_norm + 1e-6)
|
110 |
-
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
111 |
-
if grads_fp32:
|
112 |
-
multi_tensor_applier(
|
113 |
-
amp_C.multi_tensor_scale,
|
114 |
-
dummy_overflow_buf,
|
115 |
-
[grads_fp32, grads_fp32],
|
116 |
-
clip_coef_clamped,
|
117 |
-
)
|
118 |
-
if grads_fp16:
|
119 |
-
multi_tensor_applier(
|
120 |
-
amp_C.multi_tensor_scale,
|
121 |
-
dummy_overflow_buf,
|
122 |
-
[grads_fp16, grads_fp16],
|
123 |
-
clip_coef_clamped,
|
124 |
-
)
|
125 |
-
for g in grads_misc:
|
126 |
-
g.mul_(clip_coef_clamped.to(g.device))
|
127 |
-
|
128 |
-
return total_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/conv_bias_relu/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU, ConvFrozenScaleBiasReLU
|
2 |
-
|
|
|
|
|
|
apex/apex/contrib/conv_bias_relu/conv_bias_relu.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
import pdb
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch.autograd import gradcheck
|
5 |
-
|
6 |
-
from apex import check_cudnn_version_and_warn
|
7 |
-
import fused_conv_bias_relu
|
8 |
-
|
9 |
-
check_cudnn_version_and_warn(__name__, 8400)
|
10 |
-
|
11 |
-
|
12 |
-
class ConvBiasReLU_(torch.autograd.Function):
|
13 |
-
@staticmethod
|
14 |
-
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
|
15 |
-
def forward(ctx, x, weight, bias, padding, stride):
|
16 |
-
outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride)
|
17 |
-
ctx.save_for_backward(x, weight, outputs[0])
|
18 |
-
ctx.padding = padding
|
19 |
-
ctx.stride = stride
|
20 |
-
|
21 |
-
return outputs[0]
|
22 |
-
|
23 |
-
@staticmethod
|
24 |
-
@torch.cuda.amp.custom_bwd
|
25 |
-
def backward(ctx, grad_output):
|
26 |
-
bwd_args = [*ctx.saved_tensors, grad_output]
|
27 |
-
padding = ctx.padding
|
28 |
-
stride = ctx.stride
|
29 |
-
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
|
30 |
-
|
31 |
-
return grads[0], grads[1], grads[2], None, None
|
32 |
-
|
33 |
-
|
34 |
-
class ConvBiasMaskReLU_(torch.autograd.Function):
|
35 |
-
@staticmethod
|
36 |
-
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
|
37 |
-
def forward(ctx, x, weight, bias, mask, padding, stride):
|
38 |
-
outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride)
|
39 |
-
ctx.save_for_backward(x, weight, outputs[0])
|
40 |
-
ctx.padding = padding
|
41 |
-
ctx.stride = stride
|
42 |
-
|
43 |
-
return outputs[0]
|
44 |
-
|
45 |
-
@staticmethod
|
46 |
-
@torch.cuda.amp.custom_bwd
|
47 |
-
def backward(ctx, grad_output):
|
48 |
-
bwd_args = [*ctx.saved_tensors, grad_output]
|
49 |
-
padding = ctx.padding
|
50 |
-
stride = ctx.stride
|
51 |
-
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
|
52 |
-
|
53 |
-
return grads[0], grads[1], grads[2], None, None, None
|
54 |
-
|
55 |
-
|
56 |
-
class ConvBias_(torch.autograd.Function):
|
57 |
-
@staticmethod
|
58 |
-
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
|
59 |
-
def forward(ctx, x, weight, bias, padding, stride):
|
60 |
-
outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride)
|
61 |
-
ctx.save_for_backward(x, weight)
|
62 |
-
ctx.padding = padding
|
63 |
-
ctx.stride = stride
|
64 |
-
|
65 |
-
return outputs[0]
|
66 |
-
|
67 |
-
@staticmethod
|
68 |
-
@torch.cuda.amp.custom_bwd
|
69 |
-
def backward(ctx, grad_output):
|
70 |
-
bwd_args = [*ctx.saved_tensors, grad_output]
|
71 |
-
padding = ctx.padding
|
72 |
-
stride = ctx.stride
|
73 |
-
grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride)
|
74 |
-
|
75 |
-
return grads[0], grads[1], grads[2], None, None
|
76 |
-
|
77 |
-
|
78 |
-
class ConvFrozenScaleBiasReLU_(torch.autograd.Function):
|
79 |
-
@staticmethod
|
80 |
-
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
|
81 |
-
def forward(ctx, x, weight, scale, bias, padding, stride):
|
82 |
-
output = fused_conv_bias_relu.forward_cscale_cbias_relu([x, weight, scale, bias], padding, stride)
|
83 |
-
ctx.save_for_backward(x, weight, scale, output)
|
84 |
-
ctx.padding = padding
|
85 |
-
ctx.stride = stride
|
86 |
-
|
87 |
-
return output
|
88 |
-
|
89 |
-
@staticmethod
|
90 |
-
@torch.cuda.amp.custom_bwd
|
91 |
-
def backward(ctx, grad_output):
|
92 |
-
bwd_args = [*ctx.saved_tensors, grad_output]
|
93 |
-
padding = ctx.padding
|
94 |
-
stride = ctx.stride
|
95 |
-
grads = fused_conv_bias_relu.backward_cscale_cbias_relu(bwd_args, padding, stride)
|
96 |
-
|
97 |
-
return grads[0], grads[1], None, None, None, None
|
98 |
-
|
99 |
-
|
100 |
-
ConvBiasReLU = ConvBiasReLU_.apply
|
101 |
-
ConvBiasMaskReLU = ConvBiasMaskReLU_.apply
|
102 |
-
ConvBias = ConvBias_.apply
|
103 |
-
ConvFrozenScaleBiasReLU = ConvFrozenScaleBiasReLU_.apply
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/csrc/bottleneck/bottleneck.cpp
DELETED
The diff for this file is too large to render.
See raw diff
|
|
apex/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp
DELETED
@@ -1,2153 +0,0 @@
|
|
1 |
-
#include <ATen/ATen.h>
|
2 |
-
#include <ATen/cudnn/Handle.h> // for getcudnnhandle
|
3 |
-
#include <torch/extension.h>
|
4 |
-
#include <torch/torch.h>
|
5 |
-
#include <vector>
|
6 |
-
#include <cudnn_frontend.h>
|
7 |
-
|
8 |
-
#include <iostream>
|
9 |
-
|
10 |
-
#ifdef DEBUG
|
11 |
-
#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false )
|
12 |
-
#else
|
13 |
-
#define DEBUG_MSG(str) do { } while ( false )
|
14 |
-
#endif
|
15 |
-
|
16 |
-
#ifdef DEBUG_CUDNN
|
17 |
-
#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false )
|
18 |
-
#else
|
19 |
-
#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false )
|
20 |
-
#endif
|
21 |
-
|
22 |
-
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
23 |
-
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be contiguous")
|
24 |
-
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
25 |
-
|
26 |
-
#define checkCudnnErr(...) \
|
27 |
-
do { \
|
28 |
-
int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
|
29 |
-
if (err) { \
|
30 |
-
return; \
|
31 |
-
} \
|
32 |
-
} while (0)
|
33 |
-
|
34 |
-
|
35 |
-
int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {
|
36 |
-
if (code) {
|
37 |
-
printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr);
|
38 |
-
return 1;
|
39 |
-
}
|
40 |
-
return 0;
|
41 |
-
}
|
42 |
-
|
43 |
-
void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true);
|
44 |
-
#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function
|
45 |
-
|
46 |
-
void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) {
|
47 |
-
if (code != cudaSuccess)
|
48 |
-
{
|
49 |
-
const char * errorMessage = cudaGetErrorString(code);
|
50 |
-
fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage);
|
51 |
-
if (abort){
|
52 |
-
cudaDeviceReset();
|
53 |
-
exit(code);
|
54 |
-
}
|
55 |
-
}
|
56 |
-
}
|
57 |
-
|
58 |
-
void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {
|
59 |
-
// For INT8x4 and INT8x32 we still compute standard strides here to input
|
60 |
-
// into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
|
61 |
-
if (filterFormat == CUDNN_TENSOR_NCHW) {
|
62 |
-
strideA[nbDims - 1] = 1;
|
63 |
-
for (int64_t d = nbDims - 2; d >= 0; d--) {
|
64 |
-
strideA[d] = strideA[d + 1] * dimA[d + 1];
|
65 |
-
}
|
66 |
-
} else {
|
67 |
-
// Here we assume that the format is CUDNN_TENSOR_NHWC
|
68 |
-
strideA[1] = 1;
|
69 |
-
strideA[nbDims - 1] = strideA[1] * dimA[1];
|
70 |
-
for (int64_t d = nbDims - 2; d >= 2; d--) {
|
71 |
-
strideA[d] = strideA[d + 1] * dimA[d + 1];
|
72 |
-
}
|
73 |
-
strideA[0] = strideA[2] * dimA[2];
|
74 |
-
}
|
75 |
-
}
|
76 |
-
|
77 |
-
|
78 |
-
int getFwdConvDilatedFilterDim(int filterDim, int dilation) {
|
79 |
-
return ((filterDim - 1) * dilation) + 1;
|
80 |
-
}
|
81 |
-
|
82 |
-
|
83 |
-
int getFwdConvPaddedImageDim(int tensorDim, int pad) {
|
84 |
-
return tensorDim + (2 * pad);
|
85 |
-
}
|
86 |
-
|
87 |
-
|
88 |
-
int getFwdConvOutputDim(int tensorDim,
|
89 |
-
int pad,
|
90 |
-
int filterDim,
|
91 |
-
int stride,
|
92 |
-
int dilation) {
|
93 |
-
int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;
|
94 |
-
return (p);
|
95 |
-
}
|
96 |
-
|
97 |
-
|
98 |
-
// create a cache for plan
|
99 |
-
std::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;
|
100 |
-
|
101 |
-
|
102 |
-
std::string getConvFusionString(int64_t* x_dim_padded,
|
103 |
-
int64_t* padA,
|
104 |
-
int64_t* convstrideA,
|
105 |
-
int64_t* dilationA,
|
106 |
-
int64_t* w_dim_padded,
|
107 |
-
cudnnDataType_t dataType,
|
108 |
-
std::string fusion_string) {
|
109 |
-
|
110 |
-
for(int i=0;i<4;i++) {
|
111 |
-
fusion_string += 'X';
|
112 |
-
fusion_string += std::to_string(x_dim_padded[i]);
|
113 |
-
}
|
114 |
-
for(int i=0;i<4;i++) {
|
115 |
-
fusion_string += 'W';
|
116 |
-
fusion_string += std::to_string(w_dim_padded[i]);
|
117 |
-
}
|
118 |
-
for(int i=0;i<2;i++) {
|
119 |
-
fusion_string += 'P';
|
120 |
-
fusion_string += std::to_string(padA[i]);
|
121 |
-
}
|
122 |
-
for(int i=0;i<2;i++) {
|
123 |
-
fusion_string += 'S';
|
124 |
-
fusion_string += std::to_string(convstrideA[i]);
|
125 |
-
}
|
126 |
-
for(int i=0;i<2;i++) {
|
127 |
-
fusion_string += 'D';
|
128 |
-
fusion_string += std::to_string(dilationA[i]);
|
129 |
-
}
|
130 |
-
fusion_string += 'T';
|
131 |
-
fusion_string += std::to_string(dataType);
|
132 |
-
return fusion_string;
|
133 |
-
}
|
134 |
-
|
135 |
-
|
136 |
-
cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_,
|
137 |
-
std::stringstream& log_buf,
|
138 |
-
cudnn_frontend::OperationGraph& opGraph,
|
139 |
-
std::string cache_string,
|
140 |
-
bool use_heuristic = true){
|
141 |
-
auto it = plan_cache.find(cache_string);
|
142 |
-
if (it != plan_cache.end()) {
|
143 |
-
DEBUG_CUDNN_MSG(log_buf, "Found plan in cache");
|
144 |
-
return it->second;
|
145 |
-
} else {
|
146 |
-
DEBUG_CUDNN_MSG(log_buf, "No plan in cache");
|
147 |
-
if (use_heuristic) {
|
148 |
-
// TODO: confirm which mode to use
|
149 |
-
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
150 |
-
.setOperationGraph(opGraph)
|
151 |
-
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
|
152 |
-
.build();
|
153 |
-
auto engine_config_count = heuristics.getEngineConfigCount();
|
154 |
-
auto& engine_configs = heuristics.getEngineConfig(engine_config_count);
|
155 |
-
for (int64_t count = 0; count < engine_config_count; count++) {
|
156 |
-
try {
|
157 |
-
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()
|
158 |
-
.setHandle(handle_)
|
159 |
-
.setEngineConfig(engine_configs[count], opGraph.getTag())
|
160 |
-
.build()));
|
161 |
-
break;
|
162 |
-
} catch (cudnn_frontend::cudnnException e) {
|
163 |
-
// Throw exception if all engines failed
|
164 |
-
if (count == (engine_config_count - 1)) {
|
165 |
-
throw e;
|
166 |
-
} else {
|
167 |
-
continue;
|
168 |
-
}
|
169 |
-
}
|
170 |
-
}
|
171 |
-
} else {
|
172 |
-
// How many engines support this operation graph ?
|
173 |
-
auto total_engines = opGraph.getEngineCount();
|
174 |
-
DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines.");
|
175 |
-
// We have to randomly pick one engine from [0, total_engines)
|
176 |
-
// Selecting "0" by default
|
177 |
-
auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();
|
178 |
-
DEBUG_CUDNN_MSG(log_buf, engine.describe());
|
179 |
-
auto& knobs = engine.getSupportedKnobs();
|
180 |
-
for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {
|
181 |
-
DEBUG_CUDNN_MSG(log_buf, it->describe());
|
182 |
-
}
|
183 |
-
if (knobs.begin() != knobs.end()) {
|
184 |
-
DEBUG_CUDNN_MSG(log_buf, "Updated knob choice");
|
185 |
-
knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);
|
186 |
-
DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());
|
187 |
-
}
|
188 |
-
|
189 |
-
// Createmplacee the requisite engine config
|
190 |
-
auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();
|
191 |
-
DEBUG_CUDNN_MSG(log_buf, engine_config.describe());
|
192 |
-
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));
|
193 |
-
}
|
194 |
-
|
195 |
-
return plan_cache.find(cache_string)->second;
|
196 |
-
}
|
197 |
-
}
|
198 |
-
|
199 |
-
|
200 |
-
void
|
201 |
-
run_conv_bias(int64_t* x_dim,
|
202 |
-
int64_t* w_dim,
|
203 |
-
int64_t* y_dim,
|
204 |
-
int64_t* conv_pad,
|
205 |
-
int64_t* convstride,
|
206 |
-
int64_t* dilation,
|
207 |
-
cudnnDataType_t dataType,
|
208 |
-
at::Half* devPtrX,
|
209 |
-
at::Half* devPtrW,
|
210 |
-
at::Half* devPtrB,
|
211 |
-
at::Half* devPtrY) {
|
212 |
-
|
213 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
214 |
-
std::stringstream log_buf;
|
215 |
-
|
216 |
-
try {
|
217 |
-
int convDim = 2;
|
218 |
-
float alpha = 1.0f;
|
219 |
-
float beta = 0.0f;
|
220 |
-
int64_t b_dim[] = {1, y_dim[1], 1, 1};
|
221 |
-
|
222 |
-
// Creates the necessary tensor descriptors
|
223 |
-
int64_t stride[4];
|
224 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
225 |
-
auto xTensor = cudnn_frontend::TensorBuilder()
|
226 |
-
.setDim(4, x_dim)
|
227 |
-
.setStrides(4, stride)
|
228 |
-
.setId('x')
|
229 |
-
.setAlignment(16)
|
230 |
-
.setDataType(dataType)
|
231 |
-
.build();
|
232 |
-
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
|
233 |
-
|
234 |
-
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
235 |
-
auto wTensor = cudnn_frontend::TensorBuilder()
|
236 |
-
.setDim(4, w_dim)
|
237 |
-
.setStrides(4, stride)
|
238 |
-
.setId('w')
|
239 |
-
.setAlignment(16)
|
240 |
-
.setDataType(dataType)
|
241 |
-
.build();
|
242 |
-
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
|
243 |
-
|
244 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
245 |
-
auto afterConvTensor = cudnn_frontend::TensorBuilder()
|
246 |
-
.setDim(4, y_dim)
|
247 |
-
.setStrides(4, stride)
|
248 |
-
.setId('c')
|
249 |
-
.setAlignment(16)
|
250 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
251 |
-
.setVirtual()
|
252 |
-
.build();
|
253 |
-
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
|
254 |
-
|
255 |
-
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
256 |
-
auto bTensor = cudnn_frontend::TensorBuilder()
|
257 |
-
.setDim(4, b_dim)
|
258 |
-
.setStrides(4, stride)
|
259 |
-
.setId('b')
|
260 |
-
.setAlignment(16)
|
261 |
-
.setDataType(dataType)
|
262 |
-
.build();
|
263 |
-
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
|
264 |
-
|
265 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
266 |
-
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
|
267 |
-
.setDim(4, y_dim)
|
268 |
-
.setStrides(4, stride)
|
269 |
-
.setId('y')
|
270 |
-
.setAlignment(16)
|
271 |
-
.setDataType(dataType)
|
272 |
-
.build();
|
273 |
-
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
|
274 |
-
|
275 |
-
// Define the bias operation
|
276 |
-
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
|
277 |
-
.setMode(CUDNN_POINTWISE_ADD)
|
278 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
279 |
-
.build();
|
280 |
-
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
|
281 |
-
|
282 |
-
// Define the convolution problem
|
283 |
-
auto convDesc = cudnn_frontend::ConvDescBuilder()
|
284 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
285 |
-
.setMathMode(CUDNN_CROSS_CORRELATION)
|
286 |
-
.setNDims(convDim)
|
287 |
-
.setStrides(convDim, convstride)
|
288 |
-
.setPrePadding(convDim, conv_pad)
|
289 |
-
.setPostPadding(convDim, conv_pad)
|
290 |
-
.setDilation(convDim, dilation)
|
291 |
-
.build();
|
292 |
-
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
|
293 |
-
|
294 |
-
|
295 |
-
// Create a convolution Node
|
296 |
-
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
|
297 |
-
.setxDesc(xTensor)
|
298 |
-
.setwDesc(wTensor)
|
299 |
-
.setyDesc(afterConvTensor)
|
300 |
-
.setcDesc(convDesc)
|
301 |
-
.setAlpha(alpha)
|
302 |
-
.setBeta(beta)
|
303 |
-
.build();
|
304 |
-
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
|
305 |
-
|
306 |
-
// Create a Bias Node.
|
307 |
-
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
308 |
-
.setxDesc(conv_op.getOutputTensor())
|
309 |
-
.setbDesc(bTensor)
|
310 |
-
.setyDesc(afterBiasTensor)
|
311 |
-
.setpwDesc(biasDesc)
|
312 |
-
.build();
|
313 |
-
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
|
314 |
-
|
315 |
-
// Create an Operation Graph. In this case it is convolution bias activation
|
316 |
-
std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &bias_op};
|
317 |
-
|
318 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
319 |
-
.setHandle(handle_)
|
320 |
-
.setOperationGraph(2, ops.data())
|
321 |
-
.build();
|
322 |
-
|
323 |
-
// Create string encoding for plan caching
|
324 |
-
auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag());
|
325 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
326 |
-
|
327 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
328 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
329 |
-
|
330 |
-
auto workspace_size = plan.getWorkspaceSize();
|
331 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
332 |
-
|
333 |
-
void* workspace_ptr = nullptr;
|
334 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
335 |
-
if (workspace_size > 0) {
|
336 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
337 |
-
}
|
338 |
-
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};
|
339 |
-
int64_t uids[] = {'x', 'w', 'b', 'y'};
|
340 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
341 |
-
.setWorkspacePointer(workspace_ptr)
|
342 |
-
.setDataPointers(4, data_ptrs)
|
343 |
-
.setUids(4, uids)
|
344 |
-
.build();
|
345 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
346 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
347 |
-
checkCudnnErr(status);
|
348 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
349 |
-
} catch (cudnn_frontend::cudnnException e) {
|
350 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
351 |
-
}
|
352 |
-
}
|
353 |
-
|
354 |
-
|
355 |
-
void
|
356 |
-
run_conv_bias_mask_relu(int64_t* x_dim,
|
357 |
-
int64_t* w_dim,
|
358 |
-
int64_t* y_dim,
|
359 |
-
int64_t* conv_pad,
|
360 |
-
int64_t* conv_stride,
|
361 |
-
int64_t* conv_dilation,
|
362 |
-
cudnnDataType_t dataType,
|
363 |
-
at::Half* devPtrX,
|
364 |
-
at::Half* devPtrW,
|
365 |
-
at::Half* devPtrB,
|
366 |
-
int8_t* devPtrM,
|
367 |
-
at::Half* devPtrY) {
|
368 |
-
|
369 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
370 |
-
std::stringstream log_buf;
|
371 |
-
|
372 |
-
try {
|
373 |
-
int conv_dim = 2;
|
374 |
-
float alpha = 1.0f;
|
375 |
-
float beta = 0.0f;
|
376 |
-
int64_t b_dim[] = {1, y_dim[1], 1, 1};
|
377 |
-
|
378 |
-
// Creates the necessary tensor descriptors
|
379 |
-
int64_t stride[4];
|
380 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
381 |
-
auto xTensor = cudnn_frontend::TensorBuilder()
|
382 |
-
.setDim(4, x_dim)
|
383 |
-
.setStrides(4, stride)
|
384 |
-
.setId('x')
|
385 |
-
.setAlignment(16)
|
386 |
-
.setDataType(dataType)
|
387 |
-
.build();
|
388 |
-
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
|
389 |
-
|
390 |
-
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
391 |
-
auto wTensor = cudnn_frontend::TensorBuilder()
|
392 |
-
.setDim(4, w_dim)
|
393 |
-
.setStrides(4, stride)
|
394 |
-
.setId('w')
|
395 |
-
.setAlignment(16)
|
396 |
-
.setDataType(dataType)
|
397 |
-
.build();
|
398 |
-
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
|
399 |
-
|
400 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
401 |
-
auto mTensor = cudnn_frontend::TensorBuilder()
|
402 |
-
.setDim(4, y_dim)
|
403 |
-
.setStrides(4, stride)
|
404 |
-
.setId('m')
|
405 |
-
.setAlignment(16)
|
406 |
-
.setDataType(CUDNN_DATA_INT8)
|
407 |
-
.build();
|
408 |
-
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
|
409 |
-
|
410 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
411 |
-
auto afterConvTensor = cudnn_frontend::TensorBuilder()
|
412 |
-
.setDim(4, y_dim)
|
413 |
-
.setStrides(4, stride)
|
414 |
-
.setId('c')
|
415 |
-
.setAlignment(16)
|
416 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
417 |
-
.setVirtual()
|
418 |
-
.build();
|
419 |
-
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
|
420 |
-
|
421 |
-
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
422 |
-
auto bTensor = cudnn_frontend::TensorBuilder()
|
423 |
-
.setDim(4, b_dim)
|
424 |
-
.setStrides(4, stride)
|
425 |
-
.setId('b')
|
426 |
-
.setAlignment(16)
|
427 |
-
.setDataType(dataType)
|
428 |
-
.build();
|
429 |
-
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
|
430 |
-
|
431 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
432 |
-
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
|
433 |
-
.setDim(4, y_dim)
|
434 |
-
.setStrides(4, stride)
|
435 |
-
.setId('B')
|
436 |
-
.setAlignment(16)
|
437 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
438 |
-
.setVirtual()
|
439 |
-
.build();
|
440 |
-
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
|
441 |
-
|
442 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
443 |
-
auto afterMaskTensor = cudnn_frontend::TensorBuilder()
|
444 |
-
.setDim(4, y_dim)
|
445 |
-
.setStrides(4, stride)
|
446 |
-
.setId('M')
|
447 |
-
.setAlignment(16)
|
448 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
449 |
-
.setVirtual()
|
450 |
-
.build();
|
451 |
-
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
|
452 |
-
|
453 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
454 |
-
auto afterReLUTensor = cudnn_frontend::TensorBuilder()
|
455 |
-
.setDim(4, y_dim)
|
456 |
-
.setStrides(4, stride)
|
457 |
-
.setId('y')
|
458 |
-
.setAlignment(16)
|
459 |
-
.setDataType(dataType)
|
460 |
-
.build();
|
461 |
-
DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
|
462 |
-
|
463 |
-
// Define the convolution problem
|
464 |
-
auto convDesc = cudnn_frontend::ConvDescBuilder()
|
465 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
466 |
-
.setMathMode(CUDNN_CROSS_CORRELATION)
|
467 |
-
.setNDims(conv_dim)
|
468 |
-
.setStrides(conv_dim, conv_stride)
|
469 |
-
.setPrePadding(conv_dim, conv_pad)
|
470 |
-
.setPostPadding(conv_dim, conv_pad)
|
471 |
-
.setDilation(conv_dim, conv_dilation)
|
472 |
-
.build();
|
473 |
-
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
|
474 |
-
|
475 |
-
// Define the bias operation
|
476 |
-
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
|
477 |
-
.setMode(CUDNN_POINTWISE_ADD)
|
478 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
479 |
-
.build();
|
480 |
-
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
|
481 |
-
|
482 |
-
// Define the mask operation
|
483 |
-
auto maskDesc = cudnn_frontend::PointWiseDescBuilder()
|
484 |
-
.setMode(CUDNN_POINTWISE_MUL)
|
485 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
486 |
-
.build();
|
487 |
-
|
488 |
-
// Define the activation operation
|
489 |
-
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
|
490 |
-
.setMode(CUDNN_POINTWISE_RELU_FWD)
|
491 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
492 |
-
.build();
|
493 |
-
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
|
494 |
-
|
495 |
-
// Create a convolution Node
|
496 |
-
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
|
497 |
-
.setxDesc(xTensor)
|
498 |
-
.setwDesc(wTensor)
|
499 |
-
.setyDesc(afterConvTensor)
|
500 |
-
.setcDesc(convDesc)
|
501 |
-
.setAlpha(alpha)
|
502 |
-
.setBeta(beta)
|
503 |
-
.build();
|
504 |
-
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
|
505 |
-
|
506 |
-
// Create a Bias Node
|
507 |
-
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
508 |
-
.setxDesc(conv_op.getOutputTensor())
|
509 |
-
.setbDesc(bTensor)
|
510 |
-
.setyDesc(afterBiasTensor)
|
511 |
-
.setpwDesc(biasDesc)
|
512 |
-
.build();
|
513 |
-
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
|
514 |
-
|
515 |
-
// create a Mask Node
|
516 |
-
auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
517 |
-
.setxDesc(bias_op.getOutputTensor())
|
518 |
-
.setbDesc(mTensor)
|
519 |
-
.setyDesc(afterMaskTensor)
|
520 |
-
.setpwDesc(maskDesc)
|
521 |
-
.build();
|
522 |
-
DEBUG_CUDNN_MSG(log_buf, mask_op.describe());
|
523 |
-
|
524 |
-
// Create an Activation Node
|
525 |
-
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
526 |
-
.setxDesc(mask_op.getOutputTensor())
|
527 |
-
.setyDesc(afterReLUTensor)
|
528 |
-
.setpwDesc(actDesc)
|
529 |
-
.build();
|
530 |
-
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
|
531 |
-
|
532 |
-
// Create an Operation Graph. In this case it is convolution bias activation
|
533 |
-
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &bias_op, &mask_op, &act_op};
|
534 |
-
|
535 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
536 |
-
.setHandle(handle_)
|
537 |
-
.setOperationGraph(4, ops.data())
|
538 |
-
.build();
|
539 |
-
|
540 |
-
// Create string encoding for plan caching
|
541 |
-
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
|
542 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
543 |
-
|
544 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
545 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
546 |
-
|
547 |
-
auto workspace_size = plan.getWorkspaceSize();
|
548 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
549 |
-
|
550 |
-
void* workspace_ptr = nullptr;
|
551 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
552 |
-
if (workspace_size > 0) {
|
553 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
554 |
-
}
|
555 |
-
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY};
|
556 |
-
int64_t uids[] = {'x', 'w', 'b', 'm', 'y'};
|
557 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
558 |
-
.setWorkspacePointer(workspace_ptr)
|
559 |
-
.setDataPointers(5, data_ptrs)
|
560 |
-
.setUids(5, uids)
|
561 |
-
.build();
|
562 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
563 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
564 |
-
checkCudnnErr(status);
|
565 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
566 |
-
} catch (cudnn_frontend::cudnnException e) {
|
567 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
568 |
-
}
|
569 |
-
}
|
570 |
-
|
571 |
-
|
572 |
-
void
|
573 |
-
run_conv_cscale_cbias_relu(int64_t* x_dim,
|
574 |
-
int64_t* w_dim,
|
575 |
-
int64_t* y_dim,
|
576 |
-
int64_t* conv_pad,
|
577 |
-
int64_t* conv_stride,
|
578 |
-
int64_t* conv_dilation,
|
579 |
-
cudnnDataType_t dataType,
|
580 |
-
at::Half* devPtrX,
|
581 |
-
at::Half* devPtrW,
|
582 |
-
at::Half* devPtrS,
|
583 |
-
at::Half* devPtrB,
|
584 |
-
at::Half* devPtrY) {
|
585 |
-
|
586 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
587 |
-
std::stringstream log_buf;
|
588 |
-
|
589 |
-
try {
|
590 |
-
int conv_dim = 2;
|
591 |
-
float alpha = 1.0f;
|
592 |
-
float beta = 0.0f;
|
593 |
-
int64_t s_dim[] = {1, y_dim[1], 1, 1};
|
594 |
-
int64_t b_dim[] = {1, y_dim[1], 1, 1};
|
595 |
-
|
596 |
-
// Creates the necessary tensor descriptors
|
597 |
-
int64_t stride[4];
|
598 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
599 |
-
auto xTensor = cudnn_frontend::TensorBuilder()
|
600 |
-
.setDim(4, x_dim)
|
601 |
-
.setStrides(4, stride)
|
602 |
-
.setId('x')
|
603 |
-
.setAlignment(16)
|
604 |
-
.setDataType(dataType)
|
605 |
-
.build();
|
606 |
-
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
|
607 |
-
|
608 |
-
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
609 |
-
auto wTensor = cudnn_frontend::TensorBuilder()
|
610 |
-
.setDim(4, w_dim)
|
611 |
-
.setStrides(4, stride)
|
612 |
-
.setId('w')
|
613 |
-
.setAlignment(16)
|
614 |
-
.setDataType(dataType)
|
615 |
-
.build();
|
616 |
-
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
|
617 |
-
|
618 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
619 |
-
auto afterConvTensor = cudnn_frontend::TensorBuilder()
|
620 |
-
.setDim(4, y_dim)
|
621 |
-
.setStrides(4, stride)
|
622 |
-
.setId('c')
|
623 |
-
.setAlignment(16)
|
624 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
625 |
-
.setVirtual()
|
626 |
-
.build();
|
627 |
-
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
|
628 |
-
|
629 |
-
generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
630 |
-
auto sTensor = cudnn_frontend::TensorBuilder()
|
631 |
-
.setDim(4, s_dim)
|
632 |
-
.setStrides(4, stride)
|
633 |
-
.setId('s')
|
634 |
-
.setAlignment(16)
|
635 |
-
.setDataType(dataType)
|
636 |
-
.build();
|
637 |
-
DEBUG_CUDNN_MSG(log_buf, sTensor.describe());
|
638 |
-
|
639 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
640 |
-
auto afterScaleTensor = cudnn_frontend::TensorBuilder()
|
641 |
-
.setDim(4, y_dim)
|
642 |
-
.setStrides(4, stride)
|
643 |
-
.setId('S')
|
644 |
-
.setAlignment(16)
|
645 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
646 |
-
.setVirtual()
|
647 |
-
.build();
|
648 |
-
DEBUG_CUDNN_MSG(log_buf, afterScaleTensor.describe());
|
649 |
-
|
650 |
-
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
651 |
-
auto bTensor = cudnn_frontend::TensorBuilder()
|
652 |
-
.setDim(4, b_dim)
|
653 |
-
.setStrides(4, stride)
|
654 |
-
.setId('b')
|
655 |
-
.setAlignment(16)
|
656 |
-
.setDataType(dataType)
|
657 |
-
.build();
|
658 |
-
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
|
659 |
-
|
660 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
661 |
-
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
|
662 |
-
.setDim(4, y_dim)
|
663 |
-
.setStrides(4, stride)
|
664 |
-
.setId('B')
|
665 |
-
.setAlignment(16)
|
666 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
667 |
-
.setVirtual()
|
668 |
-
.build();
|
669 |
-
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
|
670 |
-
|
671 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
672 |
-
auto afterReLUTensor = cudnn_frontend::TensorBuilder()
|
673 |
-
.setDim(4, y_dim)
|
674 |
-
.setStrides(4, stride)
|
675 |
-
.setId('y')
|
676 |
-
.setAlignment(16)
|
677 |
-
.setDataType(dataType)
|
678 |
-
.build();
|
679 |
-
DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
|
680 |
-
|
681 |
-
// Define the convolution problem
|
682 |
-
auto convDesc = cudnn_frontend::ConvDescBuilder()
|
683 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
684 |
-
.setMathMode(CUDNN_CROSS_CORRELATION)
|
685 |
-
.setNDims(conv_dim)
|
686 |
-
.setStrides(conv_dim, conv_stride)
|
687 |
-
.setPrePadding(conv_dim, conv_pad)
|
688 |
-
.setPostPadding(conv_dim, conv_pad)
|
689 |
-
.setDilation(conv_dim, conv_dilation)
|
690 |
-
.build();
|
691 |
-
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
|
692 |
-
|
693 |
-
// Define the scale operation
|
694 |
-
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
|
695 |
-
.setMode(CUDNN_POINTWISE_MUL)
|
696 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
697 |
-
.build();
|
698 |
-
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
|
699 |
-
|
700 |
-
// Define the bias operation
|
701 |
-
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
|
702 |
-
.setMode(CUDNN_POINTWISE_ADD)
|
703 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
704 |
-
.build();
|
705 |
-
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
|
706 |
-
|
707 |
-
// Define the activation operation
|
708 |
-
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
|
709 |
-
.setMode(CUDNN_POINTWISE_RELU_FWD)
|
710 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
711 |
-
.build();
|
712 |
-
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
|
713 |
-
|
714 |
-
// Create a convolution Node
|
715 |
-
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
|
716 |
-
.setxDesc(xTensor)
|
717 |
-
.setwDesc(wTensor)
|
718 |
-
.setyDesc(afterConvTensor)
|
719 |
-
.setcDesc(convDesc)
|
720 |
-
.setAlpha(alpha)
|
721 |
-
.setBeta(beta)
|
722 |
-
.build();
|
723 |
-
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
|
724 |
-
|
725 |
-
// Create a scale Node.
|
726 |
-
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
727 |
-
.setxDesc(conv_op.getOutputTensor())
|
728 |
-
.setbDesc(sTensor)
|
729 |
-
.setyDesc(afterScaleTensor)
|
730 |
-
.setpwDesc(scaleDesc)
|
731 |
-
.build();
|
732 |
-
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
|
733 |
-
|
734 |
-
// Create a Bias Node.
|
735 |
-
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
736 |
-
.setxDesc(scale_op.getOutputTensor())
|
737 |
-
.setbDesc(bTensor)
|
738 |
-
.setyDesc(afterBiasTensor)
|
739 |
-
.setpwDesc(biasDesc)
|
740 |
-
.build();
|
741 |
-
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
|
742 |
-
|
743 |
-
// Create an Activation Node.
|
744 |
-
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
745 |
-
.setxDesc(bias_op.getOutputTensor())
|
746 |
-
.setyDesc(afterReLUTensor)
|
747 |
-
.setpwDesc(actDesc)
|
748 |
-
.build();
|
749 |
-
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
|
750 |
-
|
751 |
-
// Create an Operation Graph. In this case it is convolution bias activation
|
752 |
-
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &scale_op, &bias_op, &act_op};
|
753 |
-
|
754 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
755 |
-
.setHandle(handle_)
|
756 |
-
.setOperationGraph(ops.size(), ops.data())
|
757 |
-
.build();
|
758 |
-
|
759 |
-
// Create string encoding for plan caching
|
760 |
-
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
|
761 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
762 |
-
|
763 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
764 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
765 |
-
|
766 |
-
auto workspace_size = plan.getWorkspaceSize();
|
767 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
768 |
-
|
769 |
-
void* workspace_ptr = nullptr;
|
770 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
771 |
-
if (workspace_size > 0) {
|
772 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
773 |
-
}
|
774 |
-
void* data_ptrs[] = {devPtrX, devPtrW, devPtrS, devPtrB, devPtrY};
|
775 |
-
int64_t uids[] = {'x', 'w', 's', 'b', 'y'};
|
776 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
777 |
-
.setWorkspacePointer(workspace_ptr)
|
778 |
-
.setDataPointers(5, data_ptrs)
|
779 |
-
.setUids(5, uids)
|
780 |
-
.build();
|
781 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
782 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
783 |
-
checkCudnnErr(status);
|
784 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
785 |
-
} catch (cudnn_frontend::cudnnException e) {
|
786 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
787 |
-
}
|
788 |
-
}
|
789 |
-
|
790 |
-
|
791 |
-
void
|
792 |
-
run_conv_bias_relu(int64_t* x_dim,
|
793 |
-
int64_t* w_dim,
|
794 |
-
int64_t* y_dim,
|
795 |
-
int64_t* conv_pad,
|
796 |
-
int64_t* conv_stride,
|
797 |
-
int64_t* conv_dilation,
|
798 |
-
cudnnDataType_t dataType,
|
799 |
-
at::Half* devPtrX,
|
800 |
-
at::Half* devPtrW,
|
801 |
-
at::Half* devPtrB,
|
802 |
-
at::Half* devPtrY) {
|
803 |
-
|
804 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
805 |
-
std::stringstream log_buf;
|
806 |
-
|
807 |
-
try {
|
808 |
-
int conv_dim = 2;
|
809 |
-
float alpha = 1.0f;
|
810 |
-
float beta = 0.0f;
|
811 |
-
int64_t b_dim[] = {1, y_dim[1], 1, 1};
|
812 |
-
|
813 |
-
// Creates the necessary tensor descriptors
|
814 |
-
int64_t stride[4];
|
815 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
816 |
-
auto xTensor = cudnn_frontend::TensorBuilder()
|
817 |
-
.setDim(4, x_dim)
|
818 |
-
.setStrides(4, stride)
|
819 |
-
.setId('x')
|
820 |
-
.setAlignment(16)
|
821 |
-
.setDataType(dataType)
|
822 |
-
.build();
|
823 |
-
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
|
824 |
-
|
825 |
-
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
826 |
-
auto wTensor = cudnn_frontend::TensorBuilder()
|
827 |
-
.setDim(4, w_dim)
|
828 |
-
.setStrides(4, stride)
|
829 |
-
.setId('w')
|
830 |
-
.setAlignment(16)
|
831 |
-
.setDataType(dataType)
|
832 |
-
.build();
|
833 |
-
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
|
834 |
-
|
835 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
836 |
-
auto afterConvTensor = cudnn_frontend::TensorBuilder()
|
837 |
-
.setDim(4, y_dim)
|
838 |
-
.setStrides(4, stride)
|
839 |
-
.setId('c')
|
840 |
-
.setAlignment(16)
|
841 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
842 |
-
.setVirtual()
|
843 |
-
.build();
|
844 |
-
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
|
845 |
-
|
846 |
-
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
847 |
-
auto bTensor = cudnn_frontend::TensorBuilder()
|
848 |
-
.setDim(4, b_dim)
|
849 |
-
.setStrides(4, stride)
|
850 |
-
.setId('b')
|
851 |
-
.setAlignment(16)
|
852 |
-
.setDataType(dataType)
|
853 |
-
.build();
|
854 |
-
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
|
855 |
-
|
856 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
857 |
-
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
|
858 |
-
.setDim(4, y_dim)
|
859 |
-
.setStrides(4, stride)
|
860 |
-
.setId('B')
|
861 |
-
.setAlignment(16)
|
862 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
863 |
-
.setVirtual()
|
864 |
-
.build();
|
865 |
-
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
|
866 |
-
|
867 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
868 |
-
auto afterReLUTensor = cudnn_frontend::TensorBuilder()
|
869 |
-
.setDim(4, y_dim)
|
870 |
-
.setStrides(4, stride)
|
871 |
-
.setId('y')
|
872 |
-
.setAlignment(16)
|
873 |
-
.setDataType(dataType)
|
874 |
-
.build();
|
875 |
-
DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
|
876 |
-
|
877 |
-
// Define the convolution problem
|
878 |
-
auto convDesc = cudnn_frontend::ConvDescBuilder()
|
879 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
880 |
-
.setMathMode(CUDNN_CROSS_CORRELATION)
|
881 |
-
.setNDims(conv_dim)
|
882 |
-
.setStrides(conv_dim, conv_stride)
|
883 |
-
.setPrePadding(conv_dim, conv_pad)
|
884 |
-
.setPostPadding(conv_dim, conv_pad)
|
885 |
-
.setDilation(conv_dim, conv_dilation)
|
886 |
-
.build();
|
887 |
-
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
|
888 |
-
|
889 |
-
// Define the bias operation
|
890 |
-
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
|
891 |
-
.setMode(CUDNN_POINTWISE_ADD)
|
892 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
893 |
-
.build();
|
894 |
-
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
|
895 |
-
|
896 |
-
// Define the activation operation
|
897 |
-
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
|
898 |
-
.setMode(CUDNN_POINTWISE_RELU_FWD)
|
899 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
900 |
-
.build();
|
901 |
-
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
|
902 |
-
|
903 |
-
// Create a convolution Node
|
904 |
-
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
|
905 |
-
.setxDesc(xTensor)
|
906 |
-
.setwDesc(wTensor)
|
907 |
-
.setyDesc(afterConvTensor)
|
908 |
-
.setcDesc(convDesc)
|
909 |
-
.setAlpha(alpha)
|
910 |
-
.setBeta(beta)
|
911 |
-
.build();
|
912 |
-
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
|
913 |
-
|
914 |
-
// Create a Bias Node.
|
915 |
-
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
916 |
-
.setxDesc(conv_op.getOutputTensor())
|
917 |
-
.setbDesc(bTensor)
|
918 |
-
.setyDesc(afterBiasTensor)
|
919 |
-
.setpwDesc(biasDesc)
|
920 |
-
.build();
|
921 |
-
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
|
922 |
-
|
923 |
-
// Create an Activation Node.
|
924 |
-
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
925 |
-
.setxDesc(bias_op.getOutputTensor())
|
926 |
-
.setyDesc(afterReLUTensor)
|
927 |
-
.setpwDesc(actDesc)
|
928 |
-
.build();
|
929 |
-
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
|
930 |
-
|
931 |
-
// Create an Operation Graph. In this case it is convolution bias activation
|
932 |
-
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &bias_op, &act_op};
|
933 |
-
|
934 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
935 |
-
.setHandle(handle_)
|
936 |
-
.setOperationGraph(3, ops.data())
|
937 |
-
.build();
|
938 |
-
|
939 |
-
// Create string encoding for plan caching
|
940 |
-
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
|
941 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
942 |
-
|
943 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
944 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
945 |
-
|
946 |
-
auto workspace_size = plan.getWorkspaceSize();
|
947 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
948 |
-
|
949 |
-
void* workspace_ptr = nullptr;
|
950 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
951 |
-
if (workspace_size > 0) {
|
952 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
953 |
-
}
|
954 |
-
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};
|
955 |
-
int64_t uids[] = {'x', 'w', 'b', 'y'};
|
956 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
957 |
-
.setWorkspacePointer(workspace_ptr)
|
958 |
-
.setDataPointers(4, data_ptrs)
|
959 |
-
.setUids(4, uids)
|
960 |
-
.build();
|
961 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
962 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
963 |
-
checkCudnnErr(status);
|
964 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
965 |
-
} catch (cudnn_frontend::cudnnException e) {
|
966 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
967 |
-
}
|
968 |
-
}
|
969 |
-
|
970 |
-
|
971 |
-
void
|
972 |
-
run_drelu_dscale(int64_t* dy_dim,
|
973 |
-
cudnnDataType_t dataType,
|
974 |
-
at::Half* devPtrDY,
|
975 |
-
at::Half* devPtrR,
|
976 |
-
at::Half* devPtrS,
|
977 |
-
at::Half* devPtrDX) {
|
978 |
-
|
979 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
980 |
-
std::stringstream log_buf;
|
981 |
-
|
982 |
-
try {
|
983 |
-
int convDim = 2;
|
984 |
-
float alpha = 1.0f;
|
985 |
-
float beta = 0.0f;
|
986 |
-
int64_t s_dim[] = {1, dy_dim[1], 1, 1};
|
987 |
-
|
988 |
-
// Creates the necessary tensor descriptors
|
989 |
-
int64_t stride[4];
|
990 |
-
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
991 |
-
auto dyTensor = cudnn_frontend::TensorBuilder()
|
992 |
-
.setDim(4, dy_dim)
|
993 |
-
.setStrides(4, stride)
|
994 |
-
.setId('y')
|
995 |
-
.setAlignment(16)
|
996 |
-
.setDataType(dataType)
|
997 |
-
.build();
|
998 |
-
DEBUG_CUDNN_MSG(log_buf, dyTensor.describe());
|
999 |
-
|
1000 |
-
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1001 |
-
auto rTensor = cudnn_frontend::TensorBuilder()
|
1002 |
-
.setDim(4, dy_dim)
|
1003 |
-
.setStrides(4, stride)
|
1004 |
-
.setId('r')
|
1005 |
-
.setAlignment(16)
|
1006 |
-
.setDataType(dataType)
|
1007 |
-
.build();
|
1008 |
-
DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
|
1009 |
-
|
1010 |
-
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1011 |
-
auto inActGradTensor = cudnn_frontend::TensorBuilder()
|
1012 |
-
.setDim(4, dy_dim)
|
1013 |
-
.setStrides(4, stride)
|
1014 |
-
.setId('R')
|
1015 |
-
.setAlignment(16)
|
1016 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
1017 |
-
.setVirtual()
|
1018 |
-
.build();
|
1019 |
-
DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe());
|
1020 |
-
|
1021 |
-
generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1022 |
-
auto scaleTensor = cudnn_frontend::TensorBuilder()
|
1023 |
-
.setDim(4, s_dim)
|
1024 |
-
.setStrides(4, stride)
|
1025 |
-
.setId('s')
|
1026 |
-
.setAlignment(16)
|
1027 |
-
.setDataType(dataType)
|
1028 |
-
.build();
|
1029 |
-
DEBUG_CUDNN_MSG(log_buf, scaleTensor.describe());
|
1030 |
-
|
1031 |
-
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1032 |
-
auto dxTensor = cudnn_frontend::TensorBuilder()
|
1033 |
-
.setDim(4, dy_dim)
|
1034 |
-
.setStrides(4, stride)
|
1035 |
-
.setId('x')
|
1036 |
-
.setAlignment(16)
|
1037 |
-
.setDataType(dataType)
|
1038 |
-
.build();
|
1039 |
-
DEBUG_CUDNN_MSG(log_buf, dxTensor.describe());
|
1040 |
-
|
1041 |
-
// Define the activation backward operation
|
1042 |
-
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
|
1043 |
-
.setMode(CUDNN_POINTWISE_RELU_BWD)
|
1044 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
1045 |
-
.build();
|
1046 |
-
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
|
1047 |
-
|
1048 |
-
// Define the bias backward operation
|
1049 |
-
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
|
1050 |
-
.setMode(CUDNN_POINTWISE_MUL)
|
1051 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
1052 |
-
.build();
|
1053 |
-
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
|
1054 |
-
|
1055 |
-
// Create an relu backward Node
|
1056 |
-
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
1057 |
-
.setdyDesc(dyTensor)
|
1058 |
-
.setxDesc(rTensor)
|
1059 |
-
.setdxDesc(inActGradTensor)
|
1060 |
-
.setpwDesc(actDesc)
|
1061 |
-
.build();
|
1062 |
-
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
|
1063 |
-
|
1064 |
-
// Create bias node
|
1065 |
-
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
1066 |
-
.setxDesc(inActGradTensor)
|
1067 |
-
.setbDesc(scaleTensor)
|
1068 |
-
.setyDesc(dxTensor)
|
1069 |
-
.setpwDesc(scaleDesc)
|
1070 |
-
.build();
|
1071 |
-
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
|
1072 |
-
|
1073 |
-
// Create an Operation Graph. In this case it is bias only
|
1074 |
-
std::array<cudnn_frontend::Operation const*, 2> ops = {&act_op, &scale_op};
|
1075 |
-
|
1076 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
1077 |
-
.setHandle(handle_)
|
1078 |
-
.setOperationGraph(ops.size(), ops.data())
|
1079 |
-
.build();
|
1080 |
-
|
1081 |
-
// Create string encoding for plan caching
|
1082 |
-
// creating unique dummy values
|
1083 |
-
int64_t pad_dummy[] = {40, 40};
|
1084 |
-
int64_t stride_dummy[] = {40, 40};
|
1085 |
-
int64_t dilation_dummy[] = {40, 40};
|
1086 |
-
auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, s_dim, dataType, opGraph.getTag());
|
1087 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
1088 |
-
|
1089 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
1090 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
1091 |
-
|
1092 |
-
auto workspace_size = plan.getWorkspaceSize();
|
1093 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
1094 |
-
|
1095 |
-
void* workspace_ptr = nullptr;
|
1096 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
1097 |
-
if (workspace_size > 0) {
|
1098 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
1099 |
-
}
|
1100 |
-
void* data_ptrs[] = {devPtrDY, devPtrR, devPtrS, devPtrDX};
|
1101 |
-
int64_t uids[] = {'y', 'r', 's', 'x'};
|
1102 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
1103 |
-
.setWorkspacePointer(workspace_ptr)
|
1104 |
-
.setDataPointers(4, data_ptrs)
|
1105 |
-
.setUids(4, uids)
|
1106 |
-
.build();
|
1107 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
1108 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
1109 |
-
checkCudnnErr(status);
|
1110 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
1111 |
-
} catch (cudnn_frontend::cudnnException e) {
|
1112 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
1113 |
-
}
|
1114 |
-
}
|
1115 |
-
|
1116 |
-
|
1117 |
-
void
|
1118 |
-
run_drelu_dbias(int64_t* dy_dim,
|
1119 |
-
cudnnDataType_t dataType,
|
1120 |
-
at::Half* devPtrDY,
|
1121 |
-
at::Half* devPtrR,
|
1122 |
-
at::Half* devPtrDR,
|
1123 |
-
float* devPtrDB) {
|
1124 |
-
|
1125 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
1126 |
-
std::stringstream log_buf;
|
1127 |
-
|
1128 |
-
try {
|
1129 |
-
int convDim = 2;
|
1130 |
-
float alpha = 1.0f;
|
1131 |
-
float beta = 0.0f;
|
1132 |
-
int64_t b_dim[] = {1, dy_dim[1], 1, 1};
|
1133 |
-
|
1134 |
-
// Creates the necessary tensor descriptors
|
1135 |
-
int64_t stride[4];
|
1136 |
-
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1137 |
-
auto dyTensor = cudnn_frontend::TensorBuilder()
|
1138 |
-
.setDim(4, dy_dim)
|
1139 |
-
.setStrides(4, stride)
|
1140 |
-
.setId('x')
|
1141 |
-
.setAlignment(16)
|
1142 |
-
.setDataType(dataType)
|
1143 |
-
.build();
|
1144 |
-
DEBUG_CUDNN_MSG(log_buf, dyTensor.describe());
|
1145 |
-
|
1146 |
-
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1147 |
-
auto rTensor = cudnn_frontend::TensorBuilder()
|
1148 |
-
.setDim(4, dy_dim)
|
1149 |
-
.setStrides(4, stride)
|
1150 |
-
.setId('r')
|
1151 |
-
.setAlignment(16)
|
1152 |
-
.setDataType(dataType)
|
1153 |
-
.build();
|
1154 |
-
DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
|
1155 |
-
|
1156 |
-
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1157 |
-
auto inActGradTensor = cudnn_frontend::TensorBuilder()
|
1158 |
-
.setDim(4, dy_dim)
|
1159 |
-
.setStrides(4, stride)
|
1160 |
-
.setId('R')
|
1161 |
-
.setAlignment(16)
|
1162 |
-
.setDataType(dataType)
|
1163 |
-
.build();
|
1164 |
-
DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe());
|
1165 |
-
|
1166 |
-
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1167 |
-
auto biasGradTensor = cudnn_frontend::TensorBuilder()
|
1168 |
-
.setDim(4, b_dim)
|
1169 |
-
.setStrides(4, stride)
|
1170 |
-
.setId('y')
|
1171 |
-
.setAlignment(16)
|
1172 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
1173 |
-
.build();
|
1174 |
-
DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe());
|
1175 |
-
|
1176 |
-
// Define the activation backward operation
|
1177 |
-
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
|
1178 |
-
.setMode(CUDNN_POINTWISE_RELU_BWD)
|
1179 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
1180 |
-
.build();
|
1181 |
-
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
|
1182 |
-
|
1183 |
-
// Define the bias backward operation
|
1184 |
-
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
|
1185 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
1186 |
-
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
|
1187 |
-
.build();
|
1188 |
-
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
|
1189 |
-
|
1190 |
-
// Create an relu backward Node
|
1191 |
-
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
1192 |
-
.setdyDesc(dyTensor)
|
1193 |
-
.setxDesc(rTensor)
|
1194 |
-
.setdxDesc(inActGradTensor)
|
1195 |
-
.setpwDesc(actDesc)
|
1196 |
-
.build();
|
1197 |
-
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
|
1198 |
-
|
1199 |
-
// Create bias node
|
1200 |
-
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
|
1201 |
-
.setxDesc(inActGradTensor)
|
1202 |
-
.setyDesc(biasGradTensor)
|
1203 |
-
.setreductionDesc(biasDesc)
|
1204 |
-
.build();
|
1205 |
-
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
|
1206 |
-
|
1207 |
-
// Create an Operation Graph. In this case it is bias only
|
1208 |
-
std::array<cudnn_frontend::Operation const*, 2> ops = {&act_op, &bias_op};
|
1209 |
-
|
1210 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
1211 |
-
.setHandle(handle_)
|
1212 |
-
.setOperationGraph(ops.size(), ops.data())
|
1213 |
-
.build();
|
1214 |
-
|
1215 |
-
// Create string encoding for plan caching
|
1216 |
-
// creating unique dummy values
|
1217 |
-
int64_t pad_dummy[] = {20, 20};
|
1218 |
-
int64_t stride_dummy[] = {20, 20};
|
1219 |
-
int64_t dilation_dummy[] = {20, 20};
|
1220 |
-
auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());
|
1221 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
1222 |
-
|
1223 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
1224 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
1225 |
-
|
1226 |
-
auto workspace_size = plan.getWorkspaceSize();
|
1227 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
1228 |
-
|
1229 |
-
void* workspace_ptr = nullptr;
|
1230 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
1231 |
-
if (workspace_size > 0) {
|
1232 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
1233 |
-
}
|
1234 |
-
void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB};
|
1235 |
-
int64_t uids[] = {'x', 'r', 'R', 'y'};
|
1236 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
1237 |
-
.setWorkspacePointer(workspace_ptr)
|
1238 |
-
.setDataPointers(4, data_ptrs)
|
1239 |
-
.setUids(4, uids)
|
1240 |
-
.build();
|
1241 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
1242 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
1243 |
-
checkCudnnErr(status);
|
1244 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
1245 |
-
} catch (cudnn_frontend::cudnnException e) {
|
1246 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
1247 |
-
}
|
1248 |
-
}
|
1249 |
-
|
1250 |
-
|
1251 |
-
void
|
1252 |
-
run_dconv_drelu_dbias(int64_t* x_dim,
|
1253 |
-
int64_t* w_dim,
|
1254 |
-
int64_t* y_dim,
|
1255 |
-
int64_t* pad,
|
1256 |
-
int64_t* convstride,
|
1257 |
-
int64_t* dilation,
|
1258 |
-
cudnnDataType_t dataType,
|
1259 |
-
at::Half* devPtrX,
|
1260 |
-
at::Half* devPtrW,
|
1261 |
-
at::Half* devPtrR,
|
1262 |
-
at::Half* devPtrRg,
|
1263 |
-
float* devPtrY) {
|
1264 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
1265 |
-
std::stringstream log_buf;
|
1266 |
-
try {
|
1267 |
-
int convDim = 2;
|
1268 |
-
float alpha = 1.0f;
|
1269 |
-
float beta = 0.0f;
|
1270 |
-
int64_t b_dim[] = {1, x_dim[1], 1, 1};
|
1271 |
-
|
1272 |
-
int64_t stride[4];
|
1273 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1274 |
-
auto outConvGradTensor = cudnn_frontend::TensorBuilder()
|
1275 |
-
.setDim(4, y_dim)
|
1276 |
-
.setStrides(4, stride)
|
1277 |
-
.setId('x')
|
1278 |
-
.setAlignment(16)
|
1279 |
-
.setDataType(dataType)
|
1280 |
-
.build();
|
1281 |
-
DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe());
|
1282 |
-
|
1283 |
-
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1284 |
-
auto wTensor = cudnn_frontend::TensorBuilder()
|
1285 |
-
.setDim(4, w_dim)
|
1286 |
-
.setStrides(4, stride)
|
1287 |
-
.setId('w')
|
1288 |
-
.setAlignment(16)
|
1289 |
-
.setDataType(dataType)
|
1290 |
-
.build();
|
1291 |
-
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
|
1292 |
-
|
1293 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1294 |
-
auto inConvGradTensor = cudnn_frontend::TensorBuilder()
|
1295 |
-
.setDim(4, x_dim)
|
1296 |
-
.setStrides(4, stride)
|
1297 |
-
.setId('A')
|
1298 |
-
.setAlignment(16)
|
1299 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
1300 |
-
.setVirtual()
|
1301 |
-
.build();
|
1302 |
-
DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe());
|
1303 |
-
|
1304 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1305 |
-
auto rTensor = cudnn_frontend::TensorBuilder()
|
1306 |
-
.setDim(4, x_dim)
|
1307 |
-
.setStrides(4, stride)
|
1308 |
-
.setId('r')
|
1309 |
-
.setAlignment(16)
|
1310 |
-
.setDataType(dataType)
|
1311 |
-
.build();
|
1312 |
-
DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
|
1313 |
-
|
1314 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1315 |
-
auto inReLUGradTensor = cudnn_frontend::TensorBuilder()
|
1316 |
-
.setDim(4, x_dim)
|
1317 |
-
.setStrides(4, stride)
|
1318 |
-
.setId('R')
|
1319 |
-
.setAlignment(16)
|
1320 |
-
.setDataType(dataType)
|
1321 |
-
.build();
|
1322 |
-
DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe());
|
1323 |
-
|
1324 |
-
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1325 |
-
auto inBiasGradTensor = cudnn_frontend::TensorBuilder()
|
1326 |
-
.setDim(4, b_dim)
|
1327 |
-
.setStrides(4, stride)
|
1328 |
-
.setId('y')
|
1329 |
-
.setAlignment(16)
|
1330 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
1331 |
-
.build();
|
1332 |
-
DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe());
|
1333 |
-
|
1334 |
-
// Define the convolution problem
|
1335 |
-
auto convDesc = cudnn_frontend::ConvDescBuilder()
|
1336 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
1337 |
-
.setMathMode(CUDNN_CROSS_CORRELATION)
|
1338 |
-
.setNDims(convDim)
|
1339 |
-
.setStrides(convDim, convstride)
|
1340 |
-
.setPrePadding(convDim, pad)
|
1341 |
-
.setPostPadding(convDim, pad)
|
1342 |
-
.setDilation(convDim, dilation)
|
1343 |
-
.build();
|
1344 |
-
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
|
1345 |
-
|
1346 |
-
// Define the activation backward operation
|
1347 |
-
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
|
1348 |
-
.setMode(CUDNN_POINTWISE_RELU_BWD)
|
1349 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
1350 |
-
.build();
|
1351 |
-
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
|
1352 |
-
|
1353 |
-
// Define the bias backward operation
|
1354 |
-
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
|
1355 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
1356 |
-
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
|
1357 |
-
.build();
|
1358 |
-
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
|
1359 |
-
|
1360 |
-
// Create a convolution Node
|
1361 |
-
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
|
1362 |
-
.setdyDesc(outConvGradTensor)
|
1363 |
-
.setwDesc(wTensor)
|
1364 |
-
.setdxDesc(inConvGradTensor)
|
1365 |
-
.setcDesc(convDesc)
|
1366 |
-
.setAlpha(alpha)
|
1367 |
-
.setBeta(beta)
|
1368 |
-
.build();
|
1369 |
-
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
|
1370 |
-
|
1371 |
-
// Create an relu backward Node
|
1372 |
-
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
1373 |
-
.setdyDesc(inConvGradTensor)
|
1374 |
-
.setxDesc(rTensor)
|
1375 |
-
.setdxDesc(inReLUGradTensor)
|
1376 |
-
.setpwDesc(actDesc)
|
1377 |
-
.build();
|
1378 |
-
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
|
1379 |
-
|
1380 |
-
// Create bias node
|
1381 |
-
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
|
1382 |
-
.setxDesc(inReLUGradTensor)
|
1383 |
-
.setyDesc(inBiasGradTensor)
|
1384 |
-
.setreductionDesc(biasDesc)
|
1385 |
-
.build();
|
1386 |
-
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
|
1387 |
-
|
1388 |
-
// Create an Operation Graph. In this case it is bias only
|
1389 |
-
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &bias_op};
|
1390 |
-
|
1391 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
1392 |
-
.setHandle(handle_)
|
1393 |
-
.setOperationGraph(ops.size(), ops.data())
|
1394 |
-
.build();
|
1395 |
-
|
1396 |
-
// Create string encoding for plan caching
|
1397 |
-
auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag());
|
1398 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
1399 |
-
|
1400 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
1401 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
1402 |
-
|
1403 |
-
auto workspace_size = plan.getWorkspaceSize();
|
1404 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
1405 |
-
|
1406 |
-
void* workspace_ptr = nullptr;
|
1407 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
1408 |
-
if (workspace_size > 0) {
|
1409 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
1410 |
-
}
|
1411 |
-
void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY};
|
1412 |
-
int64_t uids[] = {'x', 'w', 'r', 'R', 'y'};
|
1413 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
1414 |
-
.setWorkspacePointer(workspace_ptr)
|
1415 |
-
.setDataPointers(5, data_ptrs)
|
1416 |
-
.setUids(5, uids)
|
1417 |
-
.build();
|
1418 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
1419 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
1420 |
-
checkCudnnErr(status);
|
1421 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
1422 |
-
} catch (cudnn_frontend::cudnnException e) {
|
1423 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
1424 |
-
}
|
1425 |
-
|
1426 |
-
}
|
1427 |
-
|
1428 |
-
|
1429 |
-
void
|
1430 |
-
run_dconv(int64_t* x_dim,
|
1431 |
-
int64_t* w_dim,
|
1432 |
-
int64_t* y_dim,
|
1433 |
-
int64_t* conv_pad,
|
1434 |
-
int64_t* conv_stride,
|
1435 |
-
int64_t* conv_dilation,
|
1436 |
-
cudnnDataType_t dataType,
|
1437 |
-
at::Half* devPtrX,
|
1438 |
-
at::Half* devPtrW,
|
1439 |
-
at::Half* devPtrY,
|
1440 |
-
cudnnBackendDescriptorType_t mode) {
|
1441 |
-
|
1442 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
1443 |
-
std::stringstream log_buf;
|
1444 |
-
|
1445 |
-
try {
|
1446 |
-
int conv_dim = 2;
|
1447 |
-
float alpha = 1.0f;
|
1448 |
-
float beta = 0.0f;
|
1449 |
-
|
1450 |
-
// Define the convolution problem
|
1451 |
-
int64_t stride[4];
|
1452 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1453 |
-
auto xTensor = cudnn_frontend::TensorBuilder()
|
1454 |
-
.setDim(4, x_dim)
|
1455 |
-
.setStrides(4, stride)
|
1456 |
-
.setId('x')
|
1457 |
-
.setAlignment(16)
|
1458 |
-
.setDataType(dataType)
|
1459 |
-
.build();
|
1460 |
-
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
|
1461 |
-
|
1462 |
-
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1463 |
-
auto wTensor = cudnn_frontend::TensorBuilder()
|
1464 |
-
.setDim(4, w_dim)
|
1465 |
-
.setStrides(4, stride)
|
1466 |
-
.setId('w')
|
1467 |
-
.setAlignment(16)
|
1468 |
-
.setDataType(dataType)
|
1469 |
-
.build();
|
1470 |
-
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
|
1471 |
-
|
1472 |
-
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1473 |
-
auto yTensor = cudnn_frontend::TensorBuilder()
|
1474 |
-
.setDim(4, y_dim)
|
1475 |
-
.setStrides(4, stride)
|
1476 |
-
.setId('y')
|
1477 |
-
.setAlignment(16)
|
1478 |
-
.setDataType(dataType)
|
1479 |
-
.build();
|
1480 |
-
DEBUG_CUDNN_MSG(log_buf, yTensor.describe());
|
1481 |
-
|
1482 |
-
|
1483 |
-
// Define the convolution problem
|
1484 |
-
auto convDesc = cudnn_frontend::ConvDescBuilder()
|
1485 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
1486 |
-
.setMathMode(CUDNN_CROSS_CORRELATION)
|
1487 |
-
.setNDims(conv_dim)
|
1488 |
-
.setStrides(conv_dim, conv_stride)
|
1489 |
-
.setPrePadding(conv_dim, conv_pad)
|
1490 |
-
.setPostPadding(conv_dim, conv_pad)
|
1491 |
-
.setDilation(conv_dim, conv_dilation)
|
1492 |
-
.build();
|
1493 |
-
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
|
1494 |
-
|
1495 |
-
// Create a convolution node
|
1496 |
-
// mode should be one of following
|
1497 |
-
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
1498 |
-
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
1499 |
-
auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);
|
1500 |
-
if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
|
1501 |
-
conv_op_builder.setdxDesc(xTensor)
|
1502 |
-
.setwDesc(wTensor)
|
1503 |
-
.setdyDesc(yTensor)
|
1504 |
-
.setcDesc(convDesc);
|
1505 |
-
}
|
1506 |
-
else {
|
1507 |
-
conv_op_builder.setxDesc(xTensor)
|
1508 |
-
.setdwDesc(wTensor)
|
1509 |
-
.setdyDesc(yTensor)
|
1510 |
-
.setcDesc(convDesc);
|
1511 |
-
}
|
1512 |
-
auto conv_op = conv_op_builder
|
1513 |
-
.setAlpha(alpha)
|
1514 |
-
.setBeta(beta)
|
1515 |
-
.build();
|
1516 |
-
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
|
1517 |
-
|
1518 |
-
// Create an Operation Graph. In this case it is convolution add bias activation
|
1519 |
-
std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};
|
1520 |
-
|
1521 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
1522 |
-
.setHandle(handle_)
|
1523 |
-
.setOperationGraph(ops.size(), ops.data())
|
1524 |
-
.build();
|
1525 |
-
|
1526 |
-
// Create string encoding for plan caching
|
1527 |
-
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
|
1528 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
1529 |
-
|
1530 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
1531 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
1532 |
-
|
1533 |
-
auto workspace_size = plan.getWorkspaceSize();
|
1534 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
1535 |
-
|
1536 |
-
void* workspace_ptr = nullptr;
|
1537 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
1538 |
-
if (workspace_size > 0) {
|
1539 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
1540 |
-
}
|
1541 |
-
void* data_ptrs[] = {devPtrX, devPtrW, devPtrY};
|
1542 |
-
int64_t uids[] = {'x', 'w', 'y'};
|
1543 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
1544 |
-
.setWorkspacePointer(workspace_ptr)
|
1545 |
-
.setDataPointers(3, data_ptrs)
|
1546 |
-
.setUids(3, uids)
|
1547 |
-
.build();
|
1548 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
1549 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
1550 |
-
checkCudnnErr(status);
|
1551 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
1552 |
-
} catch (cudnn_frontend::cudnnException e) {
|
1553 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
1554 |
-
}
|
1555 |
-
}
|
1556 |
-
|
1557 |
-
|
1558 |
-
void
|
1559 |
-
run_dbias(int64_t* x_dim,
|
1560 |
-
cudnnDataType_t dataType,
|
1561 |
-
at::Half* devPtrX,
|
1562 |
-
float* devPtrY) {
|
1563 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
1564 |
-
std::stringstream log_buf;
|
1565 |
-
try {
|
1566 |
-
int convDim = 2;
|
1567 |
-
int64_t b_dim[] = {1, x_dim[1], 1, 1};
|
1568 |
-
|
1569 |
-
int64_t stride[4];
|
1570 |
-
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1571 |
-
auto xTensor = cudnn_frontend::TensorBuilder()
|
1572 |
-
.setDim(4, x_dim)
|
1573 |
-
.setStrides(4, stride)
|
1574 |
-
.setId('x')
|
1575 |
-
.setAlignment(16)
|
1576 |
-
.setDataType(dataType)
|
1577 |
-
.build();
|
1578 |
-
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
|
1579 |
-
|
1580 |
-
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
|
1581 |
-
auto yTensor = cudnn_frontend::TensorBuilder()
|
1582 |
-
.setDim(4, b_dim)
|
1583 |
-
.setStrides(4, stride)
|
1584 |
-
.setId('y')
|
1585 |
-
.setAlignment(16)
|
1586 |
-
.setDataType(CUDNN_DATA_FLOAT)
|
1587 |
-
.build();
|
1588 |
-
DEBUG_CUDNN_MSG(log_buf, yTensor.describe());
|
1589 |
-
|
1590 |
-
// Define the bias backward operation
|
1591 |
-
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
|
1592 |
-
.setMathPrecision(CUDNN_DATA_FLOAT)
|
1593 |
-
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
|
1594 |
-
.build();
|
1595 |
-
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
|
1596 |
-
|
1597 |
-
// Create bias node
|
1598 |
-
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
|
1599 |
-
.setxDesc(xTensor)
|
1600 |
-
.setyDesc(yTensor)
|
1601 |
-
.setreductionDesc(biasDesc)
|
1602 |
-
.build();
|
1603 |
-
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
|
1604 |
-
|
1605 |
-
// Create an Operation Graph. In this case it is bias only
|
1606 |
-
std::array<cudnn_frontend::Operation const*, 1> ops = {&bias_op};
|
1607 |
-
|
1608 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder()
|
1609 |
-
.setHandle(handle_)
|
1610 |
-
.setOperationGraph(ops.size(), ops.data())
|
1611 |
-
.build();
|
1612 |
-
|
1613 |
-
// Create string encoding for plan caching
|
1614 |
-
int64_t pad_dummy[] = {10, 10};
|
1615 |
-
int64_t stride_dummy[] = {10, 10};
|
1616 |
-
int64_t dilation_dummy[] = {10, 10};
|
1617 |
-
auto cache_string = getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());
|
1618 |
-
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
|
1619 |
-
|
1620 |
-
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
|
1621 |
-
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
|
1622 |
-
|
1623 |
-
auto workspace_size = plan.getWorkspaceSize();
|
1624 |
-
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
|
1625 |
-
|
1626 |
-
void* workspace_ptr = nullptr;
|
1627 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
1628 |
-
if (workspace_size > 0) {
|
1629 |
-
workspace_ptr = workspace_tensor.data_ptr<float>();
|
1630 |
-
}
|
1631 |
-
void* data_ptrs[] = {devPtrX, devPtrY};
|
1632 |
-
int64_t uids[] = {'x', 'y'};
|
1633 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
1634 |
-
.setWorkspacePointer(workspace_ptr)
|
1635 |
-
.setDataPointers(2, data_ptrs)
|
1636 |
-
.setUids(2, uids)
|
1637 |
-
.build();
|
1638 |
-
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
|
1639 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
1640 |
-
checkCudnnErr(status);
|
1641 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
1642 |
-
} catch (cudnn_frontend::cudnnException e) {
|
1643 |
-
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
|
1644 |
-
}
|
1645 |
-
|
1646 |
-
}
|
1647 |
-
|
1648 |
-
|
1649 |
-
std::vector<at::Tensor> conv_bias_mask_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
|
1650 |
-
std::cout << std::fixed;
|
1651 |
-
|
1652 |
-
// create output vector
|
1653 |
-
std::vector<at::Tensor> outputs;
|
1654 |
-
auto output_format = at::MemoryFormat::ChannelsLast;
|
1655 |
-
|
1656 |
-
// setup dimensions
|
1657 |
-
int64_t x_dim[] = {0, 0, 0, 0};
|
1658 |
-
int64_t w_dim[] = {0, 0, 0, 0};
|
1659 |
-
|
1660 |
-
// All dim calculation after this order of n,c,h,w
|
1661 |
-
int axis[] = {0, 1, 2, 3};
|
1662 |
-
for (int dim = 0; dim < 4; dim++) {
|
1663 |
-
x_dim[dim] = inputs[0].size(axis[dim]);
|
1664 |
-
w_dim[dim] = inputs[1].size(axis[dim]);
|
1665 |
-
}
|
1666 |
-
|
1667 |
-
// output dim in n,c,h,w used by backend
|
1668 |
-
int64_t y_dim[] = {0, 0, 0, 0};
|
1669 |
-
|
1670 |
-
// use these fixed values
|
1671 |
-
int64_t conv_pad[] = {padding, padding};
|
1672 |
-
int64_t conv_stride[] = {stride, stride};
|
1673 |
-
int64_t conv_dilation[] = {1, 1};
|
1674 |
-
|
1675 |
-
// compute output from pad/stride/dilation
|
1676 |
-
y_dim[0] = x_dim[0];
|
1677 |
-
y_dim[1] = w_dim[0];
|
1678 |
-
for (int dim = 0; dim < 2; dim++) {
|
1679 |
-
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
|
1680 |
-
}
|
1681 |
-
|
1682 |
-
// run
|
1683 |
-
at::Half* x = inputs[0].data_ptr<at::Half>();
|
1684 |
-
at::Half* w = inputs[1].data_ptr<at::Half>();
|
1685 |
-
at::Half* b = inputs[2].data_ptr<at::Half>();
|
1686 |
-
int8_t* m = inputs[3].data_ptr<int8_t>();
|
1687 |
-
auto out = at::empty(y_dim, inputs[0].type(), output_format);
|
1688 |
-
at::Half* y = out.data_ptr<at::Half>();
|
1689 |
-
|
1690 |
-
run_conv_bias_mask_relu(x_dim,
|
1691 |
-
w_dim,
|
1692 |
-
y_dim,
|
1693 |
-
conv_pad,
|
1694 |
-
conv_stride,
|
1695 |
-
conv_dilation,
|
1696 |
-
CUDNN_DATA_HALF,
|
1697 |
-
x,
|
1698 |
-
w,
|
1699 |
-
b,
|
1700 |
-
m,
|
1701 |
-
y);
|
1702 |
-
|
1703 |
-
DEBUG_MSG("[DEBUG] conv-bias-mask-relu : " << y.to(at::kFloat).sum().item<float>());
|
1704 |
-
|
1705 |
-
outputs.push_back(out);
|
1706 |
-
|
1707 |
-
return outputs;
|
1708 |
-
}
|
1709 |
-
|
1710 |
-
|
1711 |
-
at::Tensor conv_cscale_cbias_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
|
1712 |
-
std::cout << std::fixed;
|
1713 |
-
|
1714 |
-
// setup dimensions
|
1715 |
-
int64_t x_dim[] = {0, 0, 0, 0};
|
1716 |
-
int64_t w_dim[] = {0, 0, 0, 0};
|
1717 |
-
|
1718 |
-
// All dim calculation after this order of n,c,h,w
|
1719 |
-
int axis[] = {0, 1, 2, 3};
|
1720 |
-
for (int dim = 0; dim < 4; dim++) {
|
1721 |
-
x_dim[dim] = inputs[0].size(axis[dim]);
|
1722 |
-
w_dim[dim] = inputs[1].size(axis[dim]);
|
1723 |
-
}
|
1724 |
-
|
1725 |
-
// output dim in n,c,h,w used by backend
|
1726 |
-
int64_t y_dim[] = {0, 0, 0, 0};
|
1727 |
-
|
1728 |
-
// use these fixed values
|
1729 |
-
int64_t conv_pad[] = {padding, padding};
|
1730 |
-
int64_t conv_stride[] = {stride, stride};
|
1731 |
-
int64_t conv_dilation[] = {1, 1};
|
1732 |
-
|
1733 |
-
// compute output from pad/stride/dilation
|
1734 |
-
y_dim[0] = x_dim[0];
|
1735 |
-
y_dim[1] = w_dim[0];
|
1736 |
-
for (int dim = 0; dim < 2; dim++) {
|
1737 |
-
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
|
1738 |
-
}
|
1739 |
-
|
1740 |
-
// run
|
1741 |
-
at::Half* x = inputs[0].data_ptr<at::Half>();
|
1742 |
-
at::Half* w = inputs[1].data_ptr<at::Half>();
|
1743 |
-
at::Half* s = inputs[2].data_ptr<at::Half>();
|
1744 |
-
at::Half* b = inputs[3].data_ptr<at::Half>();
|
1745 |
-
auto out = at::empty(y_dim, inputs[0].type(), at::MemoryFormat::ChannelsLast);
|
1746 |
-
at::Half* y = out.data_ptr<at::Half>();
|
1747 |
-
|
1748 |
-
run_conv_cscale_cbias_relu(x_dim,
|
1749 |
-
w_dim,
|
1750 |
-
y_dim,
|
1751 |
-
conv_pad,
|
1752 |
-
conv_stride,
|
1753 |
-
conv_dilation,
|
1754 |
-
CUDNN_DATA_HALF,
|
1755 |
-
x,
|
1756 |
-
w,
|
1757 |
-
s,
|
1758 |
-
b,
|
1759 |
-
y);
|
1760 |
-
|
1761 |
-
DEBUG_MSG("[DEBUG] conv-cscale-cbias-relu : " << y.to(at::kFloat).sum().item<float>());
|
1762 |
-
|
1763 |
-
return out;
|
1764 |
-
}
|
1765 |
-
|
1766 |
-
|
1767 |
-
std::vector<at::Tensor> conv_cscale_cbias_relu_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
|
1768 |
-
bool requires_grad = inputs[0].requires_grad();
|
1769 |
-
|
1770 |
-
for (int i = 0; i <= 4; i++) {
|
1771 |
-
CHECK_INPUT(inputs[i]);
|
1772 |
-
}
|
1773 |
-
|
1774 |
-
std::cout << std::fixed;
|
1775 |
-
|
1776 |
-
// create output vector
|
1777 |
-
std::vector<at::Tensor> outputs;
|
1778 |
-
auto output_format = at::MemoryFormat::ChannelsLast;
|
1779 |
-
|
1780 |
-
// setup dimensions
|
1781 |
-
int64_t x_dim[] = {0, 0, 0, 0};
|
1782 |
-
int64_t w_dim[] = {0, 0, 0, 0};
|
1783 |
-
int64_t y_dim[] = {0, 0, 0, 0};
|
1784 |
-
|
1785 |
-
// All dim calculation after this order of n,c,h,w
|
1786 |
-
int axis[] = {0, 1, 2, 3};
|
1787 |
-
for (int dim = 0; dim < 4; dim++) {
|
1788 |
-
x_dim[dim] = inputs[0].size(axis[dim]);
|
1789 |
-
w_dim[dim] = inputs[1].size(axis[dim]);
|
1790 |
-
y_dim[dim] = inputs[3].size(axis[dim]);
|
1791 |
-
}
|
1792 |
-
|
1793 |
-
int64_t b_dim[] = {1, y_dim[1], 1, 1};
|
1794 |
-
|
1795 |
-
int64_t conv_pad[] = {padding, padding};
|
1796 |
-
int64_t conv_stride[] = {stride, stride};
|
1797 |
-
int64_t conv_dilation[] = {1, 1};
|
1798 |
-
|
1799 |
-
// run
|
1800 |
-
// drelu-dbias
|
1801 |
-
at::Half* dy = inputs[4].data_ptr<at::Half>();
|
1802 |
-
at::Half* r = inputs[3].data_ptr<at::Half>();
|
1803 |
-
auto s = inputs[2].data_ptr<at::Half>();
|
1804 |
-
auto dscale = at::empty_like(inputs[4]);
|
1805 |
-
at::Half* ds = dscale.data_ptr<at::Half>();
|
1806 |
-
|
1807 |
-
auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
|
1808 |
-
run_drelu_dscale(y_dim,
|
1809 |
-
CUDNN_DATA_HALF,
|
1810 |
-
dy,
|
1811 |
-
r,
|
1812 |
-
s,
|
1813 |
-
ds);
|
1814 |
-
|
1815 |
-
// conv wgrad
|
1816 |
-
at::Half* x = inputs[0].data_ptr<at::Half>();
|
1817 |
-
auto wgrad = at::empty_like(inputs[1]);
|
1818 |
-
at::Half* dw = wgrad.data_ptr<at::Half>();
|
1819 |
-
run_dconv(x_dim,
|
1820 |
-
w_dim,
|
1821 |
-
y_dim,
|
1822 |
-
conv_pad,
|
1823 |
-
conv_stride,
|
1824 |
-
conv_dilation,
|
1825 |
-
CUDNN_DATA_HALF,
|
1826 |
-
x,
|
1827 |
-
dw,
|
1828 |
-
ds,
|
1829 |
-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
|
1830 |
-
|
1831 |
-
// conv dgrad
|
1832 |
-
at::Half* w = inputs[1].data_ptr<at::Half>();
|
1833 |
-
auto dgrad = at::empty_like(inputs[0]);
|
1834 |
-
at::Half* dx = dgrad.data_ptr<at::Half>();
|
1835 |
-
run_dconv(x_dim,
|
1836 |
-
w_dim,
|
1837 |
-
y_dim,
|
1838 |
-
conv_pad,
|
1839 |
-
conv_stride,
|
1840 |
-
conv_dilation,
|
1841 |
-
CUDNN_DATA_HALF,
|
1842 |
-
dx,
|
1843 |
-
w,
|
1844 |
-
ds,
|
1845 |
-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
|
1846 |
-
|
1847 |
-
outputs.push_back(dgrad);
|
1848 |
-
outputs.push_back(wgrad);
|
1849 |
-
|
1850 |
-
return outputs;
|
1851 |
-
}
|
1852 |
-
|
1853 |
-
|
1854 |
-
std::vector<at::Tensor> conv_bias_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
|
1855 |
-
std::cout << std::fixed;
|
1856 |
-
|
1857 |
-
// create output vector
|
1858 |
-
std::vector<at::Tensor> outputs;
|
1859 |
-
auto output_format = at::MemoryFormat::ChannelsLast;
|
1860 |
-
|
1861 |
-
// setup dimensions
|
1862 |
-
int64_t x_dim[] = {0, 0, 0, 0};
|
1863 |
-
int64_t w_dim[] = {0, 0, 0, 0};
|
1864 |
-
|
1865 |
-
// All dim calculation after this order of n,c,h,w
|
1866 |
-
int axis[] = {0, 1, 2, 3};
|
1867 |
-
for (int dim = 0; dim < 4; dim++) {
|
1868 |
-
x_dim[dim] = inputs[0].size(axis[dim]);
|
1869 |
-
w_dim[dim] = inputs[1].size(axis[dim]);
|
1870 |
-
}
|
1871 |
-
|
1872 |
-
// output dim in n,c,h,w used by backend
|
1873 |
-
int64_t y_dim[] = {0, 0, 0, 0};
|
1874 |
-
|
1875 |
-
// use these fixed values
|
1876 |
-
int64_t conv_pad[] = {padding, padding};
|
1877 |
-
int64_t conv_stride[] = {stride, stride};
|
1878 |
-
int64_t conv_dilation[] = {1, 1};
|
1879 |
-
|
1880 |
-
// compute output from pad/stride/dilation
|
1881 |
-
y_dim[0] = x_dim[0];
|
1882 |
-
y_dim[1] = w_dim[0];
|
1883 |
-
for (int dim = 0; dim < 2; dim++) {
|
1884 |
-
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
|
1885 |
-
}
|
1886 |
-
|
1887 |
-
// run
|
1888 |
-
at::Half* x = inputs[0].data_ptr<at::Half>();
|
1889 |
-
at::Half* w = inputs[1].data_ptr<at::Half>();
|
1890 |
-
at::Half* b = inputs[2].data_ptr<at::Half>();
|
1891 |
-
auto out = at::empty(y_dim, inputs[0].type(), output_format);
|
1892 |
-
at::Half* y = out.data_ptr<at::Half>();
|
1893 |
-
|
1894 |
-
run_conv_bias_relu(x_dim,
|
1895 |
-
w_dim,
|
1896 |
-
y_dim,
|
1897 |
-
conv_pad,
|
1898 |
-
conv_stride,
|
1899 |
-
conv_dilation,
|
1900 |
-
CUDNN_DATA_HALF,
|
1901 |
-
x,
|
1902 |
-
w,
|
1903 |
-
b,
|
1904 |
-
y);
|
1905 |
-
|
1906 |
-
DEBUG_MSG("[DEBUG] conv-bias-relu : " << y.to(at::kFloat).sum().item<float>());
|
1907 |
-
|
1908 |
-
outputs.push_back(out);
|
1909 |
-
|
1910 |
-
return outputs;
|
1911 |
-
}
|
1912 |
-
|
1913 |
-
|
1914 |
-
std::vector<at::Tensor> conv_bias_relu_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
|
1915 |
-
bool requires_grad = inputs[0].requires_grad();
|
1916 |
-
|
1917 |
-
for (int i = 0; i <= 3; i++) {
|
1918 |
-
CHECK_INPUT(inputs[i]);
|
1919 |
-
}
|
1920 |
-
|
1921 |
-
std::cout << std::fixed;
|
1922 |
-
|
1923 |
-
// create output vector
|
1924 |
-
std::vector<at::Tensor> outputs;
|
1925 |
-
auto output_format = at::MemoryFormat::ChannelsLast;
|
1926 |
-
|
1927 |
-
// setup dimensions
|
1928 |
-
int64_t x_dim[] = {0, 0, 0, 0};
|
1929 |
-
int64_t w_dim[] = {0, 0, 0, 0};
|
1930 |
-
int64_t y_dim[] = {0, 0, 0, 0};
|
1931 |
-
|
1932 |
-
// All dim calculation after this order of n,c,h,w
|
1933 |
-
int axis[] = {0, 1, 2, 3};
|
1934 |
-
for (int dim = 0; dim < 4; dim++) {
|
1935 |
-
x_dim[dim] = inputs[0].size(axis[dim]);
|
1936 |
-
w_dim[dim] = inputs[1].size(axis[dim]);
|
1937 |
-
y_dim[dim] = inputs[3].size(axis[dim]);
|
1938 |
-
}
|
1939 |
-
|
1940 |
-
int64_t b_dim[] = {1, y_dim[1], 1, 1};
|
1941 |
-
|
1942 |
-
int64_t conv_pad[] = {padding, padding};
|
1943 |
-
int64_t conv_stride[] = {stride, stride};
|
1944 |
-
int64_t conv_dilation[] = {1, 1};
|
1945 |
-
|
1946 |
-
// run
|
1947 |
-
// drelu-dbias
|
1948 |
-
at::Half* dy = inputs[3].data_ptr<at::Half>();
|
1949 |
-
at::Half* r = inputs[2].data_ptr<at::Half>();
|
1950 |
-
auto drelu = at::empty_like(inputs[2]);
|
1951 |
-
at::Half* dr = drelu.data_ptr<at::Half>();
|
1952 |
-
auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
|
1953 |
-
auto bgrad = at::empty(b_dim, options, output_format);
|
1954 |
-
float* db = bgrad.data_ptr<float>();
|
1955 |
-
run_drelu_dbias(y_dim,
|
1956 |
-
CUDNN_DATA_HALF,
|
1957 |
-
dy,
|
1958 |
-
r,
|
1959 |
-
dr,
|
1960 |
-
db);
|
1961 |
-
|
1962 |
-
// conv wgrad
|
1963 |
-
at::Half* x = inputs[0].data_ptr<at::Half>();
|
1964 |
-
auto wgrad = at::empty_like(inputs[1]);
|
1965 |
-
at::Half* dw = wgrad.data_ptr<at::Half>();
|
1966 |
-
run_dconv(x_dim,
|
1967 |
-
w_dim,
|
1968 |
-
y_dim,
|
1969 |
-
conv_pad,
|
1970 |
-
conv_stride,
|
1971 |
-
conv_dilation,
|
1972 |
-
CUDNN_DATA_HALF,
|
1973 |
-
x,
|
1974 |
-
dw,
|
1975 |
-
dr,
|
1976 |
-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
|
1977 |
-
|
1978 |
-
// conv dgrad
|
1979 |
-
at::Half* w = inputs[1].data_ptr<at::Half>();
|
1980 |
-
auto dgrad = at::empty_like(inputs[0]);
|
1981 |
-
at::Half* dx = dgrad.data_ptr<at::Half>();
|
1982 |
-
run_dconv(x_dim,
|
1983 |
-
w_dim,
|
1984 |
-
y_dim,
|
1985 |
-
conv_pad,
|
1986 |
-
conv_stride,
|
1987 |
-
conv_dilation,
|
1988 |
-
CUDNN_DATA_HALF,
|
1989 |
-
dx,
|
1990 |
-
w,
|
1991 |
-
dr,
|
1992 |
-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
|
1993 |
-
|
1994 |
-
outputs.push_back(dgrad);
|
1995 |
-
outputs.push_back(wgrad);
|
1996 |
-
outputs.push_back(bgrad);
|
1997 |
-
|
1998 |
-
return outputs;
|
1999 |
-
|
2000 |
-
}
|
2001 |
-
|
2002 |
-
std::vector<at::Tensor> conv_bias_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
|
2003 |
-
std::cout << std::fixed;
|
2004 |
-
|
2005 |
-
// create output vector
|
2006 |
-
std::vector<at::Tensor> outputs;
|
2007 |
-
auto output_format = at::MemoryFormat::ChannelsLast;
|
2008 |
-
|
2009 |
-
// setup dimensions
|
2010 |
-
int64_t x_dim[] = {0, 0, 0, 0};
|
2011 |
-
int64_t w_dim[] = {0, 0, 0, 0};
|
2012 |
-
|
2013 |
-
// All dim calculation after this order of n,c,h,w
|
2014 |
-
int axis[] = {0, 1, 2, 3};
|
2015 |
-
for (int dim = 0; dim < 4; dim++) {
|
2016 |
-
x_dim[dim] = inputs[0].size(axis[dim]);
|
2017 |
-
w_dim[dim] = inputs[1].size(axis[dim]);
|
2018 |
-
}
|
2019 |
-
|
2020 |
-
// output dim in n,c,h,w used by backend
|
2021 |
-
int64_t y_dim[] = {0, 0, 0, 0};
|
2022 |
-
|
2023 |
-
// use these fixed values
|
2024 |
-
int64_t conv_pad[] = {padding, padding};
|
2025 |
-
int64_t conv_stride[] = {stride, stride};
|
2026 |
-
int64_t conv_dilation[] = {1, 1};
|
2027 |
-
|
2028 |
-
// compute output from pad/stride/dilation
|
2029 |
-
y_dim[0] = x_dim[0];
|
2030 |
-
y_dim[1] = w_dim[0];
|
2031 |
-
for (int dim = 0; dim < 2; dim++) {
|
2032 |
-
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
|
2033 |
-
}
|
2034 |
-
|
2035 |
-
// run
|
2036 |
-
at::Half* x = inputs[0].data_ptr<at::Half>();
|
2037 |
-
at::Half* w = inputs[1].data_ptr<at::Half>();
|
2038 |
-
at::Half* b = inputs[2].data_ptr<at::Half>();
|
2039 |
-
auto out = at::empty(y_dim, inputs[0].type(), output_format);
|
2040 |
-
at::Half* y = out.data_ptr<at::Half>();
|
2041 |
-
|
2042 |
-
run_conv_bias(x_dim,
|
2043 |
-
w_dim,
|
2044 |
-
y_dim,
|
2045 |
-
conv_pad,
|
2046 |
-
conv_stride,
|
2047 |
-
conv_dilation,
|
2048 |
-
CUDNN_DATA_HALF,
|
2049 |
-
x,
|
2050 |
-
w,
|
2051 |
-
b,
|
2052 |
-
y);
|
2053 |
-
|
2054 |
-
DEBUG_MSG("[DEBUG] conv-bias : " << y.to(at::kFloat).sum().item<float>());
|
2055 |
-
|
2056 |
-
outputs.push_back(out);
|
2057 |
-
|
2058 |
-
return outputs;
|
2059 |
-
}
|
2060 |
-
|
2061 |
-
|
2062 |
-
std::vector<at::Tensor> conv_bias_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
|
2063 |
-
bool requires_grad = inputs[0].requires_grad();
|
2064 |
-
|
2065 |
-
for (int i = 0; i <= 2; i++) {
|
2066 |
-
CHECK_INPUT(inputs[i]);
|
2067 |
-
}
|
2068 |
-
|
2069 |
-
std::cout << std::fixed;
|
2070 |
-
|
2071 |
-
// create output vector
|
2072 |
-
std::vector<at::Tensor> outputs;
|
2073 |
-
auto output_format = at::MemoryFormat::ChannelsLast;
|
2074 |
-
|
2075 |
-
// setup dimensions
|
2076 |
-
int64_t x_dim[] = {0, 0, 0, 0};
|
2077 |
-
int64_t w_dim[] = {0, 0, 0, 0};
|
2078 |
-
int64_t y_dim[] = {0, 0, 0, 0};
|
2079 |
-
|
2080 |
-
// All dim calculation after this order of n,c,h,w
|
2081 |
-
int axis[] = {0, 1, 2, 3};
|
2082 |
-
for (int dim = 0; dim < 4; dim++) {
|
2083 |
-
x_dim[dim] = inputs[0].size(axis[dim]);
|
2084 |
-
w_dim[dim] = inputs[1].size(axis[dim]);
|
2085 |
-
y_dim[dim] = inputs[2].size(axis[dim]);
|
2086 |
-
}
|
2087 |
-
|
2088 |
-
int64_t b_dim[] = {1, y_dim[1], 1, 1};
|
2089 |
-
|
2090 |
-
int64_t conv_pad[] = {padding, padding};
|
2091 |
-
int64_t conv_stride[] = {stride, stride};
|
2092 |
-
int64_t conv_dilation[] = {1, 1};
|
2093 |
-
|
2094 |
-
// run
|
2095 |
-
// dbias
|
2096 |
-
at::Half* dy = inputs[2].data_ptr<at::Half>();
|
2097 |
-
auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
|
2098 |
-
auto bgrad = at::empty(b_dim, options, output_format);
|
2099 |
-
float* db = bgrad.data_ptr<float>();
|
2100 |
-
run_dbias(y_dim,
|
2101 |
-
CUDNN_DATA_HALF,
|
2102 |
-
dy,
|
2103 |
-
db);
|
2104 |
-
|
2105 |
-
// conv wgrad
|
2106 |
-
at::Half* x = inputs[0].data_ptr<at::Half>();
|
2107 |
-
auto wgrad = at::empty_like(inputs[1]);
|
2108 |
-
at::Half* dw = wgrad.data_ptr<at::Half>();
|
2109 |
-
run_dconv(x_dim,
|
2110 |
-
w_dim,
|
2111 |
-
y_dim,
|
2112 |
-
conv_pad,
|
2113 |
-
conv_stride,
|
2114 |
-
conv_dilation,
|
2115 |
-
CUDNN_DATA_HALF,
|
2116 |
-
x,
|
2117 |
-
dw,
|
2118 |
-
dy,
|
2119 |
-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
|
2120 |
-
|
2121 |
-
// conv dgrad
|
2122 |
-
at::Half* w = inputs[1].data_ptr<at::Half>();
|
2123 |
-
auto dgrad = at::empty_like(inputs[0]);
|
2124 |
-
at::Half* dx = dgrad.data_ptr<at::Half>();
|
2125 |
-
run_dconv(x_dim,
|
2126 |
-
w_dim,
|
2127 |
-
y_dim,
|
2128 |
-
conv_pad,
|
2129 |
-
conv_stride,
|
2130 |
-
conv_dilation,
|
2131 |
-
CUDNN_DATA_HALF,
|
2132 |
-
dx,
|
2133 |
-
w,
|
2134 |
-
dy,
|
2135 |
-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
|
2136 |
-
|
2137 |
-
outputs.push_back(dgrad);
|
2138 |
-
outputs.push_back(wgrad);
|
2139 |
-
outputs.push_back(bgrad);
|
2140 |
-
|
2141 |
-
return outputs;
|
2142 |
-
}
|
2143 |
-
|
2144 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
2145 |
-
m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward");
|
2146 |
-
m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward");
|
2147 |
-
m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward");
|
2148 |
-
m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward");
|
2149 |
-
m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward");
|
2150 |
-
m.def("forward_cscale_cbias_relu", &conv_cscale_cbias_relu_forward, "Fused Conv-(const)Scale-(const)Bias-ReLU");
|
2151 |
-
m.def("backward_cscale_cbias_relu", &conv_cscale_cbias_relu_backward, "Fused Conv-(const)Scale-(const)Bias-ReLU backward");
|
2152 |
-
}
|
2153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp
DELETED
@@ -1,163 +0,0 @@
|
|
1 |
-
#include <ATen/ATen.h>
|
2 |
-
#include <torch/extension.h>
|
3 |
-
#include <torch/torch.h>
|
4 |
-
#include <vector>
|
5 |
-
|
6 |
-
#include <iostream>
|
7 |
-
|
8 |
-
#include "norm_sample.h"
|
9 |
-
|
10 |
-
// define this enum:
|
11 |
-
enum bn_type { BN_FWD, BN_BWD };
|
12 |
-
|
13 |
-
// this is a global variable
|
14 |
-
static std::map<std::vector<int64_t>, cudnn_frontend::ExecutionPlan> gbn_plan_cache;
|
15 |
-
|
16 |
-
at::Tensor gbn_forward(const at::Tensor& x,
|
17 |
-
const at::Tensor& scale,
|
18 |
-
const at::Tensor& bias,
|
19 |
-
const at::Tensor& running_mean,
|
20 |
-
const at::Tensor& running_var,
|
21 |
-
const at::Tensor& minibatch_mean,
|
22 |
-
const at::Tensor& minibatch_inv_var,
|
23 |
-
const float momentum,
|
24 |
-
const float epsilon,
|
25 |
-
const int64_t bn_group,
|
26 |
-
const int rank_id,
|
27 |
-
const std::vector<int64_t> &peer_buffers) {
|
28 |
-
|
29 |
-
int64_t N = x.size(0);
|
30 |
-
int64_t C = x.size(1);
|
31 |
-
int64_t H = x.size(2);
|
32 |
-
int64_t W = x.size(3);
|
33 |
-
|
34 |
-
int64_t tensorDims[] = {N, C, H, W};
|
35 |
-
int64_t peerDims[] = {bn_group, 4*C, 1, 1};
|
36 |
-
int64_t perChannelDims[] = {1, C, 1, 1};
|
37 |
-
int64_t epsilonDims[] = {1, 1, 1, 1};
|
38 |
-
|
39 |
-
// Allocate output tensor
|
40 |
-
at::Tensor y = at::empty_like(x);
|
41 |
-
|
42 |
-
std::vector<void*> void_peer_buffers;
|
43 |
-
for (int64_t addr : peer_buffers) {
|
44 |
-
void_peer_buffers.push_back((void*)addr);
|
45 |
-
}
|
46 |
-
|
47 |
-
// we need the peer size for the buffer reset
|
48 |
-
size_t peer_size = 1;
|
49 |
-
for (size_t i = 0; i < 4; ++i){
|
50 |
-
peer_size *= peerDims[i];
|
51 |
-
}
|
52 |
-
|
53 |
-
// sanity check
|
54 |
-
assert(bn_group == void_peer_buffers.size());
|
55 |
-
|
56 |
-
// check if plan already exists
|
57 |
-
std::vector<int64_t> fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
|
58 |
-
if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) {
|
59 |
-
auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
|
60 |
-
gbn_plan_cache.emplace(fv, std::move(plan));
|
61 |
-
}
|
62 |
-
|
63 |
-
// get plan and handle
|
64 |
-
auto plan = gbn_plan_cache.find(fv)->second;
|
65 |
-
|
66 |
-
// execute
|
67 |
-
execute_batch_norm_forward(plan,
|
68 |
-
x.data_ptr(),
|
69 |
-
y.data_ptr(),
|
70 |
-
scale.data_ptr(),
|
71 |
-
bias.data_ptr(),
|
72 |
-
running_mean.data_ptr(),
|
73 |
-
running_var.data_ptr(),
|
74 |
-
running_mean.data_ptr(),
|
75 |
-
running_var.data_ptr(),
|
76 |
-
minibatch_mean.data_ptr(),
|
77 |
-
minibatch_inv_var.data_ptr(),
|
78 |
-
void_peer_buffers,
|
79 |
-
static_cast<double>(epsilon),
|
80 |
-
static_cast<double>(momentum),
|
81 |
-
peer_size,
|
82 |
-
rank_id);
|
83 |
-
|
84 |
-
return y;
|
85 |
-
}
|
86 |
-
|
87 |
-
std::vector<at::Tensor> gbn_backward(
|
88 |
-
const at::Tensor& x,
|
89 |
-
const at::Tensor& dy,
|
90 |
-
const at::Tensor& scale,
|
91 |
-
const at::Tensor& minibatch_mean,
|
92 |
-
const at::Tensor& minibatch_inv_var,
|
93 |
-
const float epsilon,
|
94 |
-
const int64_t bn_group,
|
95 |
-
const int rank_id,
|
96 |
-
const std::vector<int64_t> &peer_buffers) {
|
97 |
-
|
98 |
-
int64_t N = x.size(0);
|
99 |
-
int64_t C = x.size(1);
|
100 |
-
int64_t H = x.size(2);
|
101 |
-
int64_t W = x.size(3);
|
102 |
-
|
103 |
-
int64_t tensorDims[] = {N, C, H, W};
|
104 |
-
int64_t peerDims[] = {bn_group, 4*C, 1, 1};
|
105 |
-
int64_t perChannelDims[] = {1, C, 1, 1};
|
106 |
-
int64_t epsilonDims[] = {1, 1, 1, 1};
|
107 |
-
|
108 |
-
// Allocate output tensor
|
109 |
-
// outputs
|
110 |
-
at::Tensor x_grad, scale_grad, bias_grad;
|
111 |
-
|
112 |
-
// Allocate outputs
|
113 |
-
x_grad = at::empty_like(x);
|
114 |
-
scale_grad = at::empty_like(scale);
|
115 |
-
bias_grad = at::empty_like(scale);
|
116 |
-
|
117 |
-
std::vector<void*> void_peer_buffers;
|
118 |
-
for (int64_t addr : peer_buffers) {
|
119 |
-
void_peer_buffers.push_back((void*)addr);
|
120 |
-
}
|
121 |
-
|
122 |
-
// we need the peer size for the buffer reset
|
123 |
-
size_t peer_size = 1;
|
124 |
-
for (size_t i = 0; i < 4; ++i){
|
125 |
-
peer_size *= peerDims[i];
|
126 |
-
}
|
127 |
-
|
128 |
-
assert(bn_group == void_peer_buffers.size());
|
129 |
-
|
130 |
-
std::vector<int64_t> fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
|
131 |
-
if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) {
|
132 |
-
auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
|
133 |
-
gbn_plan_cache.emplace(fv, std::move(plan));
|
134 |
-
}
|
135 |
-
|
136 |
-
// get plan and handle
|
137 |
-
auto plan = gbn_plan_cache.find(fv)->second;
|
138 |
-
|
139 |
-
// execute
|
140 |
-
execute_batch_norm_backward(plan,
|
141 |
-
x.data_ptr(),
|
142 |
-
dy.data_ptr(),
|
143 |
-
scale.data_ptr(),
|
144 |
-
minibatch_mean.data_ptr(),
|
145 |
-
minibatch_inv_var.data_ptr(),
|
146 |
-
void_peer_buffers,
|
147 |
-
x_grad.data_ptr(),
|
148 |
-
scale_grad.data_ptr(),
|
149 |
-
bias_grad.data_ptr(),
|
150 |
-
static_cast<double>(epsilon),
|
151 |
-
peer_size,
|
152 |
-
rank_id);
|
153 |
-
|
154 |
-
return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
|
155 |
-
}
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
161 |
-
m.def("forward", &gbn_forward, "Group batch norm forward");
|
162 |
-
m.def("backward", &gbn_backward, "Group batch backward");
|
163 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp
DELETED
@@ -1,479 +0,0 @@
|
|
1 |
-
/*
|
2 |
-
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
*
|
4 |
-
* Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
-
* copy of this software and associated documentation files (the "Software"),
|
6 |
-
* to deal in the Software without restriction, including without limitation
|
7 |
-
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
-
* and/or sell copies of the Software, and to permit persons to whom the
|
9 |
-
* Software is furnished to do so, subject to the following conditions:
|
10 |
-
*
|
11 |
-
* The above copyright notice and this permission notice shall be included in
|
12 |
-
* all copies or substantial portions of the Software.
|
13 |
-
*
|
14 |
-
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
-
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
-
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
-
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
-
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
-
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
-
* DEALINGS IN THE SOFTWARE.
|
21 |
-
*/
|
22 |
-
|
23 |
-
#include "norm_sample.h"
|
24 |
-
#include <cudnn_frontend.h>
|
25 |
-
#include "cudnn_backend.h"
|
26 |
-
#include <ATen/cudnn/Handle.h> // for getcudnnhandle
|
27 |
-
#include <torch/extension.h>
|
28 |
-
#include <torch/torch.h>
|
29 |
-
|
30 |
-
// some helpers
|
31 |
-
int64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line) {
|
32 |
-
if (code) {
|
33 |
-
printf("CUDA error at %s:%d, code=%d (%s) in '%s'", file, line, (int)code, cudaGetErrorString(code), expr);
|
34 |
-
return 1;
|
35 |
-
}
|
36 |
-
return 0;
|
37 |
-
}
|
38 |
-
|
39 |
-
int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {
|
40 |
-
if (code) {
|
41 |
-
printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr);
|
42 |
-
return 1;
|
43 |
-
}
|
44 |
-
return 0;
|
45 |
-
}
|
46 |
-
|
47 |
-
bool
|
48 |
-
AllowAll(cudnnBackendDescriptor_t engine_config) {
|
49 |
-
(void)engine_config;
|
50 |
-
return false;
|
51 |
-
}
|
52 |
-
|
53 |
-
void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat) {
|
54 |
-
// For INT8x4 and INT8x32 we still compute standard strides here to input
|
55 |
-
// into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
|
56 |
-
if (filterFormat == CUDNN_TENSOR_NCHW) {
|
57 |
-
strideA[nbDims - 1] = 1;
|
58 |
-
for (int64_t d = nbDims - 2; d >= 0; d--) {
|
59 |
-
strideA[d] = strideA[d + 1] * dimA[d + 1];
|
60 |
-
}
|
61 |
-
} else {
|
62 |
-
// Here we assume that the format is CUDNN_TENSOR_NHWC
|
63 |
-
strideA[1] = 1;
|
64 |
-
strideA[nbDims - 1] = strideA[1] * dimA[1];
|
65 |
-
for (int64_t d = nbDims - 2; d >= 2; d--) {
|
66 |
-
strideA[d] = strideA[d + 1] * dimA[d + 1];
|
67 |
-
}
|
68 |
-
strideA[0] = strideA[2] * dimA[2];
|
69 |
-
}
|
70 |
-
}
|
71 |
-
|
72 |
-
|
73 |
-
// runtime
|
74 |
-
cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t *tensorDims,
|
75 |
-
int64_t *perChannelSum,
|
76 |
-
int64_t *epsilon,
|
77 |
-
int64_t *peerDims,
|
78 |
-
cudnnDataType_t data_type) {
|
79 |
-
|
80 |
-
// get the cudnn handle
|
81 |
-
cudnnHandle_t handle = torch::native::getCudnnHandle();
|
82 |
-
|
83 |
-
// Creates the necessary tensor descriptors
|
84 |
-
int64_t tensor_stride[4];
|
85 |
-
int64_t stride[4];
|
86 |
-
int64_t peer_stride[4];
|
87 |
-
|
88 |
-
// NHWC format. GenerateStrides() takes care of this. Howeever, tensor dims should still be NCHW
|
89 |
-
generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC);
|
90 |
-
generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC);
|
91 |
-
|
92 |
-
auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type,
|
93 |
-
int64_t id) {
|
94 |
-
return cudnn_frontend::TensorBuilder()
|
95 |
-
.setDim(4, tensorDims)
|
96 |
-
.setStrides(4, tensor_stride)
|
97 |
-
.setId(id)
|
98 |
-
.setAlignment(16)
|
99 |
-
.setDataType(type)
|
100 |
-
.build();
|
101 |
-
};
|
102 |
-
|
103 |
-
auto peer_tensor_create = [&peer_stride, &tensorDims](cudnnDataType_t type,
|
104 |
-
int64_t id) {
|
105 |
-
return cudnn_frontend::TensorBuilder()
|
106 |
-
.setDim(4, tensorDims)
|
107 |
-
.setStrides(4, peer_stride)
|
108 |
-
.setId(id)
|
109 |
-
.setAlignment(16)
|
110 |
-
.setDataType(type)
|
111 |
-
.build();
|
112 |
-
};
|
113 |
-
|
114 |
-
|
115 |
-
generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC);
|
116 |
-
|
117 |
-
auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) {
|
118 |
-
return cudnn_frontend::TensorBuilder()
|
119 |
-
.setDim(4, perChannelSum)
|
120 |
-
.setStrides(4, stride)
|
121 |
-
.setId(id)
|
122 |
-
.setAlignment(16)
|
123 |
-
.setDataType(type)
|
124 |
-
.build();
|
125 |
-
};
|
126 |
-
|
127 |
-
auto xTensor = tensor_create(data_type, 100);
|
128 |
-
auto yTensor = tensor_create(data_type, 101);
|
129 |
-
auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102);
|
130 |
-
auto biasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103);
|
131 |
-
auto inMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104);
|
132 |
-
auto inVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 105);
|
133 |
-
auto outMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106);
|
134 |
-
auto outVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107);
|
135 |
-
auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 108);
|
136 |
-
auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 109);
|
137 |
-
|
138 |
-
|
139 |
-
int64_t epsilon_stride[4];
|
140 |
-
generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC);
|
141 |
-
auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) {
|
142 |
-
return cudnn_frontend::TensorBuilder()
|
143 |
-
.setDim(4, epsilon)
|
144 |
-
.setStrides(4, epsilon_stride)
|
145 |
-
.setId(id)
|
146 |
-
.setAlignment(16)
|
147 |
-
.setDataType(type)
|
148 |
-
.setByValue(true)
|
149 |
-
.build();
|
150 |
-
};
|
151 |
-
|
152 |
-
auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 110);
|
153 |
-
auto expDecayTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 111);
|
154 |
-
|
155 |
-
// Create the two peer stat tensors. Jump IDs in case we need to add more tensors with UIDs
|
156 |
-
std::vector<cudnn_frontend::Tensor_v8> peerStatTensors;
|
157 |
-
for (size_t i = 112; i < 112 + peerDims[0]; ++i) {
|
158 |
-
peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i));
|
159 |
-
}
|
160 |
-
|
161 |
-
#if (CUDNN_VERSION >= 8500)
|
162 |
-
// Batch normalization
|
163 |
-
cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM;
|
164 |
-
|
165 |
-
// Forward training
|
166 |
-
cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_TRAINING;
|
167 |
-
|
168 |
-
//Create a Finalize node
|
169 |
-
auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)
|
170 |
-
.setNormalizationMode(normalizationMode)
|
171 |
-
.setNormFwdPhase(phase)
|
172 |
-
.setxDesc(xTensor)
|
173 |
-
.setScaleAndBias(scaleTensor, biasTensor)
|
174 |
-
.setPrevRunningMeanAndVar(inMeanTensor, inVarTensor)
|
175 |
-
.setNextRunningMeanAndVar(outMeanTensor, outVarTensor)
|
176 |
-
.setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor)
|
177 |
-
.setEpsilonTensor(epsilonTensor)
|
178 |
-
.setExpDecayFactorTensor(expDecayTensor)
|
179 |
-
.setPeerStatTensor(peerStatTensors)
|
180 |
-
.setyDesc(yTensor)
|
181 |
-
.build();
|
182 |
-
|
183 |
-
std::array<cudnn_frontend::Operation const*, 1> ops = {&batch_norm_op};
|
184 |
-
#else
|
185 |
-
std::array<cudnn_frontend::Operation const*, 0> ops = {};
|
186 |
-
#endif
|
187 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build();
|
188 |
-
//std::cout << opGraph.describe() << std::endl;
|
189 |
-
|
190 |
-
cudnn_frontend::EngineConfigList filtered_configs;
|
191 |
-
auto statuses =
|
192 |
-
cudnn_frontend::get_heuristics_list<2>({"heuristics_instant"
|
193 |
-
, "heuristics_fallback"
|
194 |
-
}, opGraph,::AllowAll, filtered_configs, true);
|
195 |
-
|
196 |
-
//std::cout << "get_heuristics_list Statuses: ";
|
197 |
-
//for (auto i = 0u ; i < statuses.size(); i++) {
|
198 |
-
// std::cout << cudnn_frontend::to_string(statuses[i]) << " ";
|
199 |
-
//}
|
200 |
-
//std::cout << std::endl;
|
201 |
-
//std::cout << "Filter config list has " << filtered_configs.size() << " configurations " << std::endl;
|
202 |
-
|
203 |
-
// some verbose printing:
|
204 |
-
//std::cout << "Tensor shape: (" << tensorDims[0] << ", " << tensorDims[1] << ", " << tensorDims[2] << ", " << tensorDims[3] << ")" << std::endl;
|
205 |
-
|
206 |
-
auto plan_builder = [&filtered_configs, &opGraph, &handle]() {
|
207 |
-
for (auto i = 0u; i < filtered_configs.size(); i++) {
|
208 |
-
try {
|
209 |
-
auto plan = cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[i], opGraph.getTag()).build();
|
210 |
-
return plan;
|
211 |
-
} catch (cudnn_frontend::cudnnException &e) {
|
212 |
-
continue;
|
213 |
-
}
|
214 |
-
}
|
215 |
-
return cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[0], opGraph.getTag()).build();
|
216 |
-
};
|
217 |
-
|
218 |
-
assert(filtered_configs.size() > 0);
|
219 |
-
auto plan = plan_builder();
|
220 |
-
|
221 |
-
return plan;
|
222 |
-
|
223 |
-
}
|
224 |
-
|
225 |
-
void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan,
|
226 |
-
void *xDevPtr,
|
227 |
-
void *yDevPtr,
|
228 |
-
void *scaledevPtr,
|
229 |
-
void *biasdevPtr,
|
230 |
-
void *in_meandevPtr,
|
231 |
-
void *in_vardevPtr,
|
232 |
-
void *out_meandevPtr,
|
233 |
-
void *out_vardevPtr,
|
234 |
-
void *saved_meandevPtr,
|
235 |
-
void *saved_inv_vardevPtr,
|
236 |
-
const std::vector<void*> &peer_devPtrs,
|
237 |
-
double epsilon_val,
|
238 |
-
double exponential_decay_factor,
|
239 |
-
size_t peer_size,
|
240 |
-
int rank_id) {
|
241 |
-
|
242 |
-
// get handle
|
243 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
244 |
-
|
245 |
-
// get stream
|
246 |
-
cudaStream_t stream;
|
247 |
-
cudnnGetStream(handle_, &stream);
|
248 |
-
|
249 |
-
try {
|
250 |
-
// allocate workspace
|
251 |
-
auto workspace_size = plan.getWorkspaceSize();
|
252 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
253 |
-
void* workPtr = nullptr;
|
254 |
-
if (workspace_size > 0) {
|
255 |
-
workPtr = workspace_tensor.data_ptr<float>();
|
256 |
-
}
|
257 |
-
|
258 |
-
// first the data pointers
|
259 |
-
std::vector<void*> data_ptrs {xDevPtr, yDevPtr, scaledevPtr, biasdevPtr,
|
260 |
-
in_meandevPtr, in_vardevPtr, out_meandevPtr, out_vardevPtr,
|
261 |
-
saved_meandevPtr, saved_inv_vardevPtr,
|
262 |
-
&epsilon_val, &exponential_decay_factor};
|
263 |
-
data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end());
|
264 |
-
// then the uids
|
265 |
-
std::vector<int64_t> uids;
|
266 |
-
for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) {
|
267 |
-
uids.push_back(i);
|
268 |
-
}
|
269 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
270 |
-
.setWorkspacePointer(workPtr)
|
271 |
-
.setDataPointers(data_ptrs.size(), data_ptrs.data())
|
272 |
-
.setUids(uids.size(), uids.data())
|
273 |
-
.build();
|
274 |
-
//std::cout << "variantPack " << variantPack.describe() << std::endl;
|
275 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
276 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
277 |
-
|
278 |
-
// Reset local communication buffer
|
279 |
-
cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size*4, stream);
|
280 |
-
|
281 |
-
} catch (cudnn_frontend::cudnnException &e) {
|
282 |
-
struct cudaDeviceProp prop;
|
283 |
-
checkCudaErr(cudaGetDeviceProperties(&prop, 0));
|
284 |
-
if (prop.major == 8) {
|
285 |
-
std::cout << "[ERROR] Exception " << e.what() << std::endl;
|
286 |
-
assert(false);
|
287 |
-
}
|
288 |
-
}
|
289 |
-
}
|
290 |
-
|
291 |
-
cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims,
|
292 |
-
int64_t *perChannelSum,
|
293 |
-
int64_t *epsilon,
|
294 |
-
int64_t *peerDims,
|
295 |
-
cudnnDataType_t data_type) {
|
296 |
-
|
297 |
-
// get cudnn handle
|
298 |
-
cudnnHandle_t handle = torch::native::getCudnnHandle();
|
299 |
-
|
300 |
-
// Creates the necessary tensor descriptors
|
301 |
-
int64_t tensor_stride[4];
|
302 |
-
int64_t stride[4];
|
303 |
-
int64_t peer_stride[4];
|
304 |
-
|
305 |
-
// NHWC format. GenerateStrides() takes care of this. Howeever, tensor dims should still be NCHW
|
306 |
-
generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC);
|
307 |
-
generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC);
|
308 |
-
|
309 |
-
auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, int64_t id) {
|
310 |
-
return cudnn_frontend::TensorBuilder()
|
311 |
-
.setDim(4, tensorDims)
|
312 |
-
.setStrides(4, tensor_stride)
|
313 |
-
.setId(id)
|
314 |
-
.setAlignment(16)
|
315 |
-
.setDataType(type)
|
316 |
-
.build();
|
317 |
-
};
|
318 |
-
|
319 |
-
auto peer_tensor_create = [&peer_stride, &peerDims](cudnnDataType_t type, int64_t id) {
|
320 |
-
return cudnn_frontend::TensorBuilder()
|
321 |
-
.setDim(4, peerDims)
|
322 |
-
.setStrides(4, peer_stride)
|
323 |
-
.setId(id)
|
324 |
-
.setAlignment(16)
|
325 |
-
.setDataType(type)
|
326 |
-
.build();
|
327 |
-
};
|
328 |
-
|
329 |
-
generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC);
|
330 |
-
|
331 |
-
auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) {
|
332 |
-
return cudnn_frontend::TensorBuilder()
|
333 |
-
.setDim(4, perChannelSum)
|
334 |
-
.setStrides(4, stride)
|
335 |
-
.setId(id)
|
336 |
-
.setAlignment(16)
|
337 |
-
.setDataType(type)
|
338 |
-
.build();
|
339 |
-
};
|
340 |
-
|
341 |
-
auto xTensor = tensor_create(data_type, 100);
|
342 |
-
auto dyTensor = tensor_create(data_type, 101);
|
343 |
-
auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102);
|
344 |
-
auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103);
|
345 |
-
auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104);
|
346 |
-
auto dxTensor = tensor_create(data_type, 105);
|
347 |
-
auto dScaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106);
|
348 |
-
auto dBiasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107);
|
349 |
-
|
350 |
-
int64_t epsilon_stride[4];
|
351 |
-
generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC);
|
352 |
-
auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) {
|
353 |
-
return cudnn_frontend::TensorBuilder()
|
354 |
-
.setDim(4, epsilon)
|
355 |
-
.setStrides(4, epsilon_stride)
|
356 |
-
.setId(id)
|
357 |
-
.setAlignment(16)
|
358 |
-
.setDataType(type)
|
359 |
-
.setByValue(true)
|
360 |
-
.build();
|
361 |
-
};
|
362 |
-
|
363 |
-
auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 108);
|
364 |
-
|
365 |
-
std::vector<cudnn_frontend::Tensor_v8> peerStatTensors;
|
366 |
-
for (size_t i = 109; i < 109 + peerDims[0]; ++i) {
|
367 |
-
peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i));
|
368 |
-
}
|
369 |
-
|
370 |
-
#if (CUDNN_VERSION >= 8500)
|
371 |
-
// Batch normalization
|
372 |
-
cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM;
|
373 |
-
|
374 |
-
//Create a Finalize node
|
375 |
-
auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)
|
376 |
-
.setNormalizationMode(normalizationMode)
|
377 |
-
.setxDesc(xTensor)
|
378 |
-
.setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor)
|
379 |
-
.setdyDesc(dyTensor)
|
380 |
-
.setScale(scaleTensor)
|
381 |
-
.setEpsilonTensor(epsilonTensor)
|
382 |
-
.setDScaleAndDBias(dScaleTensor, dBiasTensor)
|
383 |
-
.setdxDesc(dxTensor)
|
384 |
-
.setPeerStatTensor(peerStatTensors)
|
385 |
-
.build();
|
386 |
-
|
387 |
-
std::array<cudnn_frontend::Operation const*, 1> ops = {&batch_norm_op};
|
388 |
-
#else
|
389 |
-
std::array<cudnn_frontend::Operation const*, 0> ops = {};
|
390 |
-
#endif
|
391 |
-
|
392 |
-
auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build();
|
393 |
-
//std::cout << opGraph.describe() << std::endl;
|
394 |
-
|
395 |
-
cudnn_frontend::EngineConfigList filtered_configs;
|
396 |
-
auto statuses =
|
397 |
-
cudnn_frontend::get_heuristics_list<2>({"heuristics_instant"
|
398 |
-
, "heuristics_fallback"
|
399 |
-
}, opGraph,::AllowAll, filtered_configs, true);
|
400 |
-
|
401 |
-
auto plan_builder = [&filtered_configs, &opGraph, &handle]() {
|
402 |
-
for (auto i = 0u; i < filtered_configs.size(); i++) {
|
403 |
-
try {
|
404 |
-
auto plan = cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[i], opGraph.getTag()).build();
|
405 |
-
return plan;
|
406 |
-
} catch (cudnn_frontend::cudnnException &e) {
|
407 |
-
continue;
|
408 |
-
}
|
409 |
-
}
|
410 |
-
return cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(filtered_configs[0], opGraph.getTag()).build();
|
411 |
-
};
|
412 |
-
|
413 |
-
assert(filtered_configs.size() > 0);
|
414 |
-
auto plan = plan_builder();
|
415 |
-
|
416 |
-
return plan;
|
417 |
-
}
|
418 |
-
|
419 |
-
void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan,
|
420 |
-
void *xDevPtr,
|
421 |
-
void *dyDevPtr,
|
422 |
-
void *scaledevPtr,
|
423 |
-
void *saved_meandevPtr,
|
424 |
-
void *saved_inv_vardevPtr,
|
425 |
-
const std::vector<void*> &peer_devPtrs,
|
426 |
-
void *dxDevPtr,
|
427 |
-
void *dscaledevPtr,
|
428 |
-
void *dbiasdevPtr,
|
429 |
-
double epsilon_val,
|
430 |
-
size_t peer_size,
|
431 |
-
int rank_id) {
|
432 |
-
|
433 |
-
// get handle
|
434 |
-
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
|
435 |
-
|
436 |
-
// get stream
|
437 |
-
cudaStream_t stream;
|
438 |
-
cudnnGetStream(handle_, &stream);
|
439 |
-
|
440 |
-
try {
|
441 |
-
// allocate workspace
|
442 |
-
auto workspace_size = plan.getWorkspaceSize();
|
443 |
-
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
|
444 |
-
void* workPtr = nullptr;
|
445 |
-
if (workspace_size > 0) {
|
446 |
-
workPtr = workspace_tensor.data_ptr<float>();
|
447 |
-
}
|
448 |
-
|
449 |
-
// create helper arrays
|
450 |
-
std::vector<void*> data_ptrs {xDevPtr, dyDevPtr, scaledevPtr,
|
451 |
-
saved_meandevPtr, saved_inv_vardevPtr,
|
452 |
-
dxDevPtr, dscaledevPtr, dbiasdevPtr, &epsilon_val};
|
453 |
-
data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end());
|
454 |
-
std::vector<int64_t> uids;
|
455 |
-
for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) {
|
456 |
-
uids.push_back(i);
|
457 |
-
}
|
458 |
-
|
459 |
-
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
460 |
-
.setWorkspacePointer(workPtr)
|
461 |
-
.setDataPointers(data_ptrs.size(), data_ptrs.data())
|
462 |
-
.setUids(uids.size(), uids.data())
|
463 |
-
.build();
|
464 |
-
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
|
465 |
-
|
466 |
-
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
|
467 |
-
|
468 |
-
// Reset local communication buffer
|
469 |
-
cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size*4, stream);
|
470 |
-
|
471 |
-
} catch (cudnn_frontend::cudnnException &e) {
|
472 |
-
struct cudaDeviceProp prop;
|
473 |
-
checkCudaErr(cudaGetDeviceProperties(&prop, 0));
|
474 |
-
if (prop.major == 8) {
|
475 |
-
std::cout << "[ERROR] Exception " << e.what() << std::endl;
|
476 |
-
assert(false);
|
477 |
-
}
|
478 |
-
}
|
479 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/csrc/cudnn_gbn/norm_sample.h
DELETED
@@ -1,153 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
/*
|
4 |
-
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
5 |
-
*
|
6 |
-
* Permission is hereby granted, free of charge, to any person obtaining a
|
7 |
-
* copy of this software and associated documentation files (the "Software"),
|
8 |
-
* to deal in the Software without restriction, including without limitation
|
9 |
-
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
10 |
-
* and/or sell copies of the Software, and to permit persons to whom the
|
11 |
-
* Software is furnished to do so, subject to the following conditions:
|
12 |
-
*
|
13 |
-
* The above copyright notice and this permission notice shall be included in
|
14 |
-
* all copies or substantial portions of the Software.
|
15 |
-
*
|
16 |
-
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
-
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
-
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
19 |
-
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
-
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
21 |
-
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
22 |
-
* DEALINGS IN THE SOFTWARE.
|
23 |
-
*/
|
24 |
-
|
25 |
-
#pragma once
|
26 |
-
|
27 |
-
#include <iostream>
|
28 |
-
#include <inttypes.h>
|
29 |
-
#include <stdlib.h>
|
30 |
-
#include <string.h>
|
31 |
-
#include <ctype.h>
|
32 |
-
#include <assert.h>
|
33 |
-
#include <tuple>
|
34 |
-
#include <functional>
|
35 |
-
|
36 |
-
#include <cudnn.h>
|
37 |
-
#include <cudnn_frontend.h>
|
38 |
-
|
39 |
-
/* some helpers
|
40 |
-
*/
|
41 |
-
void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat);
|
42 |
-
|
43 |
-
int64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line);
|
44 |
-
int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line);
|
45 |
-
|
46 |
-
#define checkCudaErr(...) \
|
47 |
-
do { \
|
48 |
-
int64_t err = checkCudaError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
|
49 |
-
assert(err == 0); \
|
50 |
-
} while (0)
|
51 |
-
|
52 |
-
#define checkCudnnErr(...) \
|
53 |
-
do { \
|
54 |
-
int64_t err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
|
55 |
-
assert(err == 0); \
|
56 |
-
} while (0)
|
57 |
-
|
58 |
-
/**
|
59 |
-
* @brief Run a Group BN forward sample with 2 peer stat tensors.
|
60 |
-
*
|
61 |
-
* @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of memory format
|
62 |
-
* @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor
|
63 |
-
* @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN
|
64 |
-
* @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in GBN
|
65 |
-
|
66 |
-
*
|
67 |
-
*/
|
68 |
-
cudnn_frontend::ExecutionPlan run_batch_norm_forward(
|
69 |
-
int64_t *tensorDims,
|
70 |
-
int64_t *perChannelSum,
|
71 |
-
int64_t *epsilon,
|
72 |
-
int64_t *peerDims,
|
73 |
-
cudnnDataType_t in_out_data_type);
|
74 |
-
/**
|
75 |
-
* @param xDevPtr input tensor device pointer
|
76 |
-
* @param yDevPtr output tensor device pointer
|
77 |
-
* @param scaledevPtr input scale device pointer for BN scaling
|
78 |
-
* @param biasdevPtr input scale device pointer for BN bias
|
79 |
-
* @param in_meandevPtr Input mean device pointer
|
80 |
-
* @param in_vardevPtr Input variance device pointer
|
81 |
-
* @param out_meandevPtr output mean device pointer
|
82 |
-
* @param out_vardevPtr output variance device pointer
|
83 |
-
* @param saved_meandevPtr saved mean device pointer for BN backward
|
84 |
-
* @param saved_inv_vardevPtr saved inverse variance device pointer for BN backward
|
85 |
-
* @param peer_devPtr1 peer stat tensor 1 device pointer
|
86 |
-
* @param peer_devPtr2 peer stat tensor 2 device pointer
|
87 |
-
* @param epsilon_val episilon value as a double
|
88 |
-
* @param exponential_decay_factor exponential_decay_factor as a value
|
89 |
-
*
|
90 |
-
**/
|
91 |
-
void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan,
|
92 |
-
void *xDevPtr,
|
93 |
-
void *yDevPtr,
|
94 |
-
void *scaledevPtr,
|
95 |
-
void *biasdevPtr,
|
96 |
-
void *in_meandevPtr,
|
97 |
-
void *in_vardevPtr,
|
98 |
-
void *out_meandevPtr,
|
99 |
-
void *out_vardevPtr,
|
100 |
-
void *saved_meandevPtr,
|
101 |
-
void *saved_inv_vardevPtr,
|
102 |
-
const std::vector<void*> &peer_devPtrs,
|
103 |
-
double epsilon_val,
|
104 |
-
double exponential_decay_factor,
|
105 |
-
size_t peer_size,
|
106 |
-
int rank_id);
|
107 |
-
|
108 |
-
/**
|
109 |
-
* @brief Run a Group BN backward sample with 2 peer stat tensors.
|
110 |
-
*
|
111 |
-
* @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of memory format
|
112 |
-
* @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor
|
113 |
-
* @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN
|
114 |
-
* @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in GBN
|
115 |
-
*
|
116 |
-
*/
|
117 |
-
cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t *tensorDims,
|
118 |
-
int64_t *perChannelSum,
|
119 |
-
int64_t *epsilon,
|
120 |
-
int64_t *peerDims,
|
121 |
-
cudnnDataType_t data_type);
|
122 |
-
|
123 |
-
/**
|
124 |
-
* @brief Run a Group BN backward sample with 2 peer stat tensors.
|
125 |
-
*
|
126 |
-
* @param xDevPtr input tensor device pointer
|
127 |
-
* @param yDevPtr output tensor device pointer
|
128 |
-
* @param scaledevPtr input scale device pointer for BN scaling
|
129 |
-
* @param biasdevPtr input scale device pointer for BN bias
|
130 |
-
* @param in_meandevPtr Input mean device pointer
|
131 |
-
* @param in_vardevPtr Input variance device pointer
|
132 |
-
* @param out_meandevPtr output mean device pointer
|
133 |
-
* @param out_vardevPtr output variance device pointer
|
134 |
-
* @param saved_meandevPtr saved mean device pointer for BN backward
|
135 |
-
* @param saved_inv_vardevPtr saved inverse variance device pointer for BN backward
|
136 |
-
* @param peer_devPtr1 peer stat tensor 1 device pointer
|
137 |
-
* @param peer_devPtr2 peer stat tensor 2 device pointer
|
138 |
-
* @param epsilon_val episilon value as a double
|
139 |
-
*
|
140 |
-
*/
|
141 |
-
void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan,
|
142 |
-
void *xDevPtr,
|
143 |
-
void *dyDevPtr,
|
144 |
-
void *scaledevPtr,
|
145 |
-
void *saved_meandevPtr,
|
146 |
-
void *saved_inv_vardevPtr,
|
147 |
-
const std::vector<void*> &peer_devPtrs,
|
148 |
-
void *dxDevPtr,
|
149 |
-
void *dscaledevPtr,
|
150 |
-
void *dbiasdevPtr,
|
151 |
-
double epsilon_val,
|
152 |
-
size_t peer_size,
|
153 |
-
int rank_id);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/csrc/fmha/fmha_api.cpp
DELETED
@@ -1,365 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
*
|
4 |
-
* Redistribution and use in source and binary forms, with or without
|
5 |
-
* modification, are permitted provided that the following conditions are met:
|
6 |
-
* * Redistributions of source code must retain the above copyright
|
7 |
-
* notice, this list of conditions and the following disclaimer.
|
8 |
-
* * Redistributions in binary form must reproduce the above copyright
|
9 |
-
* notice, this list of conditions and the following disclaimer in the
|
10 |
-
* documentation and/or other materials provided with the distribution.
|
11 |
-
* * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
* names of its contributors may be used to endorse or promote products
|
13 |
-
* derived from this software without specific prior written permission.
|
14 |
-
*
|
15 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
-
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
-
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
-
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
*
|
26 |
-
******************************************************************************/
|
27 |
-
|
28 |
-
#include <torch/extension.h>
|
29 |
-
#include <ATen/cuda/CUDAContext.h>
|
30 |
-
|
31 |
-
#include "fmha.h"
|
32 |
-
|
33 |
-
extern at::Tensor & mha_fill(at::Tensor &self, const at::Tensor &start_index);
|
34 |
-
void set_params(Fused_multihead_attention_fprop_params ¶ms,
|
35 |
-
// sizes
|
36 |
-
const size_t b,
|
37 |
-
const size_t s,
|
38 |
-
const size_t h,
|
39 |
-
const size_t d,
|
40 |
-
// device pointers
|
41 |
-
void *qkv_packed_d,
|
42 |
-
void *cu_seqlens_d,
|
43 |
-
void *o_packed_d,
|
44 |
-
void *s_d,
|
45 |
-
float p_dropout) {
|
46 |
-
|
47 |
-
Data_type acc_type = DATA_TYPE_FP32;
|
48 |
-
Data_type data_type = DATA_TYPE_FP16;
|
49 |
-
|
50 |
-
// Reset the parameters
|
51 |
-
memset(¶ms, 0, sizeof(params));
|
52 |
-
|
53 |
-
// Set the pointers and strides.
|
54 |
-
params.qkv_ptr = qkv_packed_d;
|
55 |
-
params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);
|
56 |
-
params.o_ptr = o_packed_d;
|
57 |
-
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
|
58 |
-
|
59 |
-
params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
|
60 |
-
|
61 |
-
// S = softmax(P)
|
62 |
-
params.s_ptr = s_d;
|
63 |
-
params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);
|
64 |
-
|
65 |
-
// Set the dimensions.
|
66 |
-
params.b = b;
|
67 |
-
params.h = h;
|
68 |
-
params.s = s;
|
69 |
-
params.d = d;
|
70 |
-
|
71 |
-
// Set the different scale values.
|
72 |
-
const float scale_bmm1 = 1.f / sqrtf(d);
|
73 |
-
constexpr float scale_softmax = 1.f;
|
74 |
-
constexpr float scale_bmm2 = 1.f;
|
75 |
-
|
76 |
-
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
|
77 |
-
set_alpha(params.scale_softmax, scale_softmax, acc_type);
|
78 |
-
set_alpha(params.scale_bmm2, scale_bmm2, data_type);
|
79 |
-
|
80 |
-
// Set this to probability of keeping an element to simplify things.
|
81 |
-
params.p_dropout = 1.f - p_dropout;
|
82 |
-
params.rp_dropout = 1.f / params.p_dropout;
|
83 |
-
TORCH_CHECK(p_dropout < 1.f);
|
84 |
-
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
|
85 |
-
}
|
86 |
-
|
87 |
-
std::vector<at::Tensor>
|
88 |
-
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
|
89 |
-
const at::Tensor &cu_seqlens, // b+1
|
90 |
-
const float p_dropout,
|
91 |
-
const int max_seq_len,
|
92 |
-
const bool is_training,
|
93 |
-
const bool is_nl,
|
94 |
-
const bool zero_tensors,
|
95 |
-
c10::optional<at::Generator> gen_) {
|
96 |
-
|
97 |
-
using namespace torch::indexing;
|
98 |
-
auto dprops = at::cuda::getCurrentDeviceProperties();
|
99 |
-
TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) ||
|
100 |
-
(dprops->major == 9 && dprops->minor == 0));
|
101 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
102 |
-
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);
|
103 |
-
|
104 |
-
int seq_len = 512;
|
105 |
-
auto launch = &run_fmha_fp16_512_64_sm80;
|
106 |
-
if( max_seq_len <= 128 ) {
|
107 |
-
seq_len = 128;
|
108 |
-
launch = &run_fmha_fp16_128_64_sm80;
|
109 |
-
} else if( max_seq_len <= 256 ) {
|
110 |
-
seq_len = 256;
|
111 |
-
launch = &run_fmha_fp16_256_64_sm80;
|
112 |
-
} else if( max_seq_len <= 384 ) {
|
113 |
-
seq_len = 384;
|
114 |
-
launch = &run_fmha_fp16_384_64_sm80;
|
115 |
-
} else if( max_seq_len <= 512 ) {
|
116 |
-
seq_len = 512;
|
117 |
-
launch = &run_fmha_fp16_512_64_sm80;
|
118 |
-
} else {
|
119 |
-
TORCH_CHECK(false);
|
120 |
-
}
|
121 |
-
|
122 |
-
TORCH_CHECK(qkv.is_cuda())
|
123 |
-
TORCH_CHECK(cu_seqlens.is_cuda())
|
124 |
-
|
125 |
-
TORCH_CHECK(qkv.is_contiguous())
|
126 |
-
TORCH_CHECK(cu_seqlens.is_contiguous())
|
127 |
-
|
128 |
-
TORCH_CHECK(cu_seqlens.dim() == 1);
|
129 |
-
TORCH_CHECK(qkv.dim() == 4);
|
130 |
-
|
131 |
-
const auto sizes = qkv.sizes();
|
132 |
-
|
133 |
-
TORCH_CHECK(sizes[THREE_DIM] == 3);
|
134 |
-
|
135 |
-
const int batch_size = cu_seqlens.numel() - 1;
|
136 |
-
const int total = sizes[TOTAL_DIM];
|
137 |
-
const int num_heads = sizes[H_DIM];
|
138 |
-
const int head_size = sizes[D_DIM];
|
139 |
-
TORCH_CHECK(batch_size > 0);
|
140 |
-
TORCH_CHECK(head_size == 64);
|
141 |
-
auto opts = qkv.options();
|
142 |
-
|
143 |
-
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
|
144 |
-
|
145 |
-
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
|
146 |
-
|
147 |
-
if( zero_tensors ) {
|
148 |
-
mha_fill(ctx, cu_seqlens.index({Slice(-1,None)}));
|
149 |
-
}
|
150 |
-
|
151 |
-
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
152 |
-
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
153 |
-
|
154 |
-
|
155 |
-
set_params(launch_params.params,
|
156 |
-
batch_size,
|
157 |
-
seq_len,
|
158 |
-
num_heads,
|
159 |
-
head_size,
|
160 |
-
qkv.data_ptr(),
|
161 |
-
cu_seqlens.data_ptr(),
|
162 |
-
ctx.data_ptr(),
|
163 |
-
s.data_ptr(),
|
164 |
-
p_dropout);
|
165 |
-
|
166 |
-
launch(launch_params, /*configure=*/ true);
|
167 |
-
// number of times random will be generated per thread, to offset philox counter in thc random
|
168 |
-
// state
|
169 |
-
int64_t counter_offset = launch_params.elts_per_thread;
|
170 |
-
at::PhiloxCudaState rng_engine_inputs;
|
171 |
-
|
172 |
-
if( is_training ) {
|
173 |
-
// See Note [Acquire lock when using random generators]
|
174 |
-
std::lock_guard<std::mutex> lock(gen->mutex_);
|
175 |
-
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
|
176 |
-
}
|
177 |
-
|
178 |
-
launch(launch_params, /*configure=*/ false);
|
179 |
-
|
180 |
-
return { ctx, s };
|
181 |
-
}
|
182 |
-
|
183 |
-
|
184 |
-
std::vector<at::Tensor>
|
185 |
-
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
|
186 |
-
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
|
187 |
-
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
|
188 |
-
const at::Tensor &cu_seqlens, // b+1
|
189 |
-
const float p_dropout, // probability to drop
|
190 |
-
const int max_seq_len, // max sequence length to choose the kernel
|
191 |
-
const bool zero_tensors
|
192 |
-
) {
|
193 |
-
using namespace torch::indexing;
|
194 |
-
auto dprops = at::cuda::getCurrentDeviceProperties();
|
195 |
-
TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) ||
|
196 |
-
(dprops->major == 9 && dprops->minor == 0));
|
197 |
-
int seq_len = 512;
|
198 |
-
auto launch = &run_fmha_dgrad_fp16_512_64_sm80;
|
199 |
-
if( max_seq_len <= 128 ) {
|
200 |
-
seq_len = 128;
|
201 |
-
launch = &run_fmha_dgrad_fp16_128_64_sm80;
|
202 |
-
} else if( max_seq_len <= 256 ) {
|
203 |
-
seq_len = 256;
|
204 |
-
launch = &run_fmha_dgrad_fp16_256_64_sm80;
|
205 |
-
} else if( max_seq_len <= 384 ) {
|
206 |
-
seq_len = 384;
|
207 |
-
launch = &run_fmha_dgrad_fp16_384_64_sm80;
|
208 |
-
} else if( max_seq_len <= 512 ) {
|
209 |
-
seq_len = 512;
|
210 |
-
launch = &run_fmha_dgrad_fp16_512_64_sm80;
|
211 |
-
} else {
|
212 |
-
TORCH_CHECK(false);
|
213 |
-
}
|
214 |
-
|
215 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
216 |
-
|
217 |
-
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
|
218 |
-
TORCH_CHECK(dout.dtype() == torch::kFloat16);
|
219 |
-
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
|
220 |
-
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
|
221 |
-
|
222 |
-
TORCH_CHECK(qkv.is_cuda());
|
223 |
-
TORCH_CHECK(cu_seqlens.is_cuda());
|
224 |
-
|
225 |
-
TORCH_CHECK(qkv.is_contiguous());
|
226 |
-
TORCH_CHECK(cu_seqlens.is_contiguous());
|
227 |
-
|
228 |
-
TORCH_CHECK(cu_seqlens.dim() == 1);
|
229 |
-
TORCH_CHECK(qkv.dim() == 4);
|
230 |
-
|
231 |
-
const auto sizes = qkv.sizes();
|
232 |
-
|
233 |
-
TORCH_CHECK(sizes[THREE_DIM] == 3);
|
234 |
-
|
235 |
-
const int batch_size = cu_seqlens.numel() - 1;
|
236 |
-
const int num_heads = sizes[H_DIM];
|
237 |
-
const int head_size = sizes[D_DIM];
|
238 |
-
TORCH_CHECK(batch_size > 0);
|
239 |
-
TORCH_CHECK(head_size == 64);
|
240 |
-
|
241 |
-
auto dqkv = torch::empty_like(qkv);
|
242 |
-
|
243 |
-
if( zero_tensors ) {
|
244 |
-
mha_fill(dqkv, cu_seqlens.index({Slice(-1,None)}));
|
245 |
-
}
|
246 |
-
|
247 |
-
Fused_multihead_attention_fprop_params params;
|
248 |
-
|
249 |
-
set_params(params,
|
250 |
-
batch_size,
|
251 |
-
seq_len,
|
252 |
-
num_heads,
|
253 |
-
head_size,
|
254 |
-
qkv.data_ptr(),
|
255 |
-
cu_seqlens.data_ptr(),
|
256 |
-
dout.data_ptr(), // we set o_ptr to dout
|
257 |
-
softmax.data_ptr(), // softmax gets overwritten by dP!
|
258 |
-
p_dropout);
|
259 |
-
|
260 |
-
// we're re-using these scales
|
261 |
-
Data_type acc_type = DATA_TYPE_FP32;
|
262 |
-
set_alpha(params.scale_bmm1, 1.f, acc_type);
|
263 |
-
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
|
264 |
-
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
|
265 |
-
params.dqkv_ptr = dqkv.data_ptr();
|
266 |
-
|
267 |
-
launch(params, stream);
|
268 |
-
return { dqkv, softmax };
|
269 |
-
}
|
270 |
-
|
271 |
-
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
|
272 |
-
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
|
273 |
-
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
|
274 |
-
const at::Tensor &cu_seqlens, // b+1
|
275 |
-
const float p_dropout, // probability to drop
|
276 |
-
const int max_seq_len, // max sequence length to choose the kernel
|
277 |
-
const bool zero_tensors
|
278 |
-
) {
|
279 |
-
|
280 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
281 |
-
|
282 |
-
TORCH_CHECK(qkv.is_cuda())
|
283 |
-
TORCH_CHECK(cu_seqlens.is_cuda())
|
284 |
-
|
285 |
-
TORCH_CHECK(qkv.is_contiguous())
|
286 |
-
TORCH_CHECK(cu_seqlens.is_contiguous())
|
287 |
-
|
288 |
-
TORCH_CHECK(cu_seqlens.dim() == 1);
|
289 |
-
|
290 |
-
TORCH_CHECK(qkv.dim() == 4);
|
291 |
-
|
292 |
-
const auto sizes = qkv.sizes();
|
293 |
-
|
294 |
-
TORCH_CHECK(sizes[THREE_DIM] == 3);
|
295 |
-
|
296 |
-
const int batch_size = cu_seqlens.numel() - 1;
|
297 |
-
|
298 |
-
const int total = sizes[TOTAL_DIM];
|
299 |
-
const int num_heads = sizes[H_DIM];
|
300 |
-
const int head_size = sizes[D_DIM];
|
301 |
-
TORCH_CHECK(batch_size > 0);
|
302 |
-
TORCH_CHECK(head_size == 64);
|
303 |
-
|
304 |
-
int seq_len = 512;
|
305 |
-
auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;
|
306 |
-
|
307 |
-
auto opts = qkv.options();
|
308 |
-
|
309 |
-
auto dqkv = torch::empty_like(qkv);
|
310 |
-
|
311 |
-
if( zero_tensors ) {
|
312 |
-
dqkv.zero_();
|
313 |
-
}
|
314 |
-
|
315 |
-
int num_chunks = 2;
|
316 |
-
if( batch_size == 1 ) {
|
317 |
-
num_chunks = 4;
|
318 |
-
}else if( batch_size == 2 ) {
|
319 |
-
num_chunks = 3;
|
320 |
-
}
|
321 |
-
auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);
|
322 |
-
|
323 |
-
Fused_multihead_attention_fprop_params params;
|
324 |
-
|
325 |
-
set_params(params,
|
326 |
-
batch_size,
|
327 |
-
seq_len,
|
328 |
-
num_heads,
|
329 |
-
head_size,
|
330 |
-
qkv.data_ptr(),
|
331 |
-
cu_seqlens.data_ptr(),
|
332 |
-
dout.data_ptr(), // o_ptr = dout
|
333 |
-
softmax.data_ptr(), // softmax gets overwritten by dP!
|
334 |
-
p_dropout);
|
335 |
-
|
336 |
-
params.dkv_ptr = dkv.data_ptr();
|
337 |
-
|
338 |
-
Data_type acc_type = DATA_TYPE_FP32;
|
339 |
-
set_alpha(params.scale_bmm1, 1.f, acc_type);
|
340 |
-
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
|
341 |
-
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
|
342 |
-
params.dqkv_ptr = dqkv.data_ptr();
|
343 |
-
|
344 |
-
launch(params, num_chunks, stream);
|
345 |
-
|
346 |
-
//SPLIT-K reduction of num_chunks dK, dV parts
|
347 |
-
|
348 |
-
// The equivalent of the following Pytorch code:
|
349 |
-
// using namespace torch::indexing;
|
350 |
-
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
|
351 |
-
// torch::sum_out(view_out, dkv, 1);
|
352 |
-
|
353 |
-
const int hidden_size = num_heads * head_size;
|
354 |
-
fmha_run_noloop_reduce(
|
355 |
-
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
|
356 |
-
|
357 |
-
return { dqkv, softmax, dkv };
|
358 |
-
}
|
359 |
-
|
360 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
361 |
-
m.doc() = "Fused Multi-head Self-attention for BERT";
|
362 |
-
m.def("fwd", &mha_fwd, "Forward pass");
|
363 |
-
m.def("bwd", &mha_bwd, "Backward pass");
|
364 |
-
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
|
365 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/csrc/fmha/src/fmha.h
DELETED
@@ -1,163 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
*
|
4 |
-
* Redistribution and use in source and binary forms, with or without
|
5 |
-
* modification, are permitted provided that the following conditions are met:
|
6 |
-
* * Redistributions of source code must retain the above copyright
|
7 |
-
* notice, this list of conditions and the following disclaimer.
|
8 |
-
* * Redistributions in binary form must reproduce the above copyright
|
9 |
-
* notice, this list of conditions and the following disclaimer in the
|
10 |
-
* documentation and/or other materials provided with the distribution.
|
11 |
-
* * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
* names of its contributors may be used to endorse or promote products
|
13 |
-
* derived from this software without specific prior written permission.
|
14 |
-
*
|
15 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
-
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
-
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
-
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
*
|
26 |
-
******************************************************************************/
|
27 |
-
|
28 |
-
#pragma once
|
29 |
-
|
30 |
-
#include <cuda.h>
|
31 |
-
#include <vector>
|
32 |
-
|
33 |
-
#ifdef OLD_GENERATOR_PATH
|
34 |
-
#include <ATen/CUDAGeneratorImpl.h>
|
35 |
-
#else
|
36 |
-
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
37 |
-
#endif
|
38 |
-
|
39 |
-
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
40 |
-
|
41 |
-
#include <fmha_utils.h>
|
42 |
-
|
43 |
-
|
44 |
-
constexpr int TOTAL_DIM = 0;
|
45 |
-
constexpr int THREE_DIM = 1;
|
46 |
-
constexpr int H_DIM = 2;
|
47 |
-
constexpr int D_DIM = 3;
|
48 |
-
|
49 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
50 |
-
|
51 |
-
struct Qkv_params {
|
52 |
-
// The QKV matrices.
|
53 |
-
void * __restrict__ qkv_ptr;
|
54 |
-
|
55 |
-
// The stride between rows of the Q, K and V matrices.
|
56 |
-
size_t qkv_stride_in_bytes;
|
57 |
-
|
58 |
-
// The number of heads.
|
59 |
-
int h;
|
60 |
-
};
|
61 |
-
|
62 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
63 |
-
|
64 |
-
struct Fused_multihead_attention_fprop_params : public Qkv_params {
|
65 |
-
|
66 |
-
// The dQKV matrices.
|
67 |
-
void * __restrict__ dqkv_ptr;
|
68 |
-
|
69 |
-
// Temporary for dKV.
|
70 |
-
void * __restrict__ dkv_ptr;
|
71 |
-
|
72 |
-
// The O matrix (output).
|
73 |
-
void * __restrict__ o_ptr;
|
74 |
-
|
75 |
-
// The stride between rows of O.
|
76 |
-
int64_t o_stride_in_bytes;
|
77 |
-
|
78 |
-
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
|
79 |
-
void * __restrict__ s_ptr;
|
80 |
-
// The stride between rows of the S matrix.
|
81 |
-
int64_t s_stride_in_bytes;
|
82 |
-
|
83 |
-
// The dimensions.
|
84 |
-
int b, s, d;
|
85 |
-
|
86 |
-
// The scaling factors for the kernel.
|
87 |
-
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
|
88 |
-
|
89 |
-
// array of length b+1 holding starting offset of each sequence.
|
90 |
-
int * __restrict__ cu_seqlens;
|
91 |
-
|
92 |
-
// The dropout probability (probability of keeping an activation).
|
93 |
-
float p_dropout;
|
94 |
-
|
95 |
-
// Scale factor of 1 / (1 - p_dropout).
|
96 |
-
float rp_dropout;
|
97 |
-
|
98 |
-
// Scale factor of 1 / (1 - p_dropout), in half2.
|
99 |
-
uint32_t scale_dropout;
|
100 |
-
|
101 |
-
// Random state.
|
102 |
-
at::PhiloxCudaState philox_args;
|
103 |
-
};
|
104 |
-
|
105 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
106 |
-
|
107 |
-
template<typename Kernel_params>
|
108 |
-
struct Launch_params{
|
109 |
-
Launch_params(cudaDeviceProp * props_,
|
110 |
-
cudaStream_t stream_,
|
111 |
-
bool is_training_,
|
112 |
-
bool is_nl_)
|
113 |
-
: elts_per_thread(0)
|
114 |
-
, props(props_)
|
115 |
-
, stream(stream_)
|
116 |
-
, is_training(is_training_)
|
117 |
-
, is_nl(is_nl_) {
|
118 |
-
}
|
119 |
-
|
120 |
-
size_t elts_per_thread;
|
121 |
-
|
122 |
-
cudaDeviceProp * props;
|
123 |
-
|
124 |
-
cudaStream_t stream;
|
125 |
-
|
126 |
-
bool is_training;
|
127 |
-
|
128 |
-
Kernel_params params;
|
129 |
-
int num_full_heads;
|
130 |
-
int num_main_groups;
|
131 |
-
int heads_last_wave;
|
132 |
-
int main_steps;
|
133 |
-
int rest_steps;
|
134 |
-
bool is_nl;
|
135 |
-
|
136 |
-
};
|
137 |
-
|
138 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
139 |
-
|
140 |
-
void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
|
141 |
-
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
|
142 |
-
void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
|
143 |
-
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
|
144 |
-
|
145 |
-
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream);
|
146 |
-
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream);
|
147 |
-
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream);
|
148 |
-
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream);
|
149 |
-
|
150 |
-
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const bool is_training, const int num_chunks, cudaStream_t stream);
|
151 |
-
|
152 |
-
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const int num_chunks, cudaStream_t stream);
|
153 |
-
|
154 |
-
void fmha_run_noloop_reduce(void *out,
|
155 |
-
const void *in,
|
156 |
-
const int *cu_seqlens,
|
157 |
-
const int hidden_size,
|
158 |
-
const int batch_size,
|
159 |
-
const int total,
|
160 |
-
const int num_chunks,
|
161 |
-
cudaStream_t stream);
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/csrc/fmha/src/fmha/gemm.h
DELETED
@@ -1,314 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
*
|
4 |
-
* Redistribution and use in source and binary forms, with or without
|
5 |
-
* modification, are permitted provided that the following conditions are met:
|
6 |
-
* * Redistributions of source code must retain the above copyright
|
7 |
-
* notice, this list of conditions and the following disclaimer.
|
8 |
-
* * Redistributions in binary form must reproduce the above copyright
|
9 |
-
* notice, this list of conditions and the following disclaimer in the
|
10 |
-
* documentation and/or other materials provided with the distribution.
|
11 |
-
* * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
* names of its contributors may be used to endorse or promote products
|
13 |
-
* derived from this software without specific prior written permission.
|
14 |
-
*
|
15 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
-
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
-
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
-
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
*
|
26 |
-
******************************************************************************/
|
27 |
-
|
28 |
-
#pragma once
|
29 |
-
|
30 |
-
#include <fmha/utils.h>
|
31 |
-
|
32 |
-
#define FMHA_DIV_UP(m, n) (((m) + (n)-1) / (n))
|
33 |
-
|
34 |
-
namespace fmha {
|
35 |
-
|
36 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
37 |
-
|
38 |
-
template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
|
39 |
-
struct Fragment_base_ {
|
40 |
-
|
41 |
-
// The data type.
|
42 |
-
using Data_type = Data_type_;
|
43 |
-
// default input type
|
44 |
-
using Input_type_ = Data_type_;
|
45 |
-
// Does it store the array of elements.
|
46 |
-
enum { HAS_ELTS = BITS_PER_ELT_ >= 8 };
|
47 |
-
// The number of elements.
|
48 |
-
enum { NUM_ELTS = NUM_ELTS_ };
|
49 |
-
// The size of element in bits.
|
50 |
-
enum { BITS_PER_ELT = BITS_PER_ELT_ };
|
51 |
-
// The size of byte of a single register.
|
52 |
-
enum { BYTES_PER_REG = 4 };
|
53 |
-
// The size in bits.
|
54 |
-
enum { BITS_PER_REG = BYTES_PER_REG * 8 };
|
55 |
-
// The number of registers needed to store the fragment.
|
56 |
-
enum { NUM_REGS = Div_up<NUM_ELTS * BITS_PER_ELT, BITS_PER_REG>::VALUE };
|
57 |
-
// The size in bytes (as returned by sizeof(Fragment_base<>).
|
58 |
-
enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG };
|
59 |
-
// The alignment.
|
60 |
-
enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min<NUM_REGS * BYTES_PER_REG, 16>::VALUE };
|
61 |
-
};
|
62 |
-
|
63 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
64 |
-
|
65 |
-
template<
|
66 |
-
// The type of the elements.
|
67 |
-
typename Data_type_,
|
68 |
-
// The number of elements.
|
69 |
-
int NUM_ELTS_,
|
70 |
-
// The alignment if you want to force a value -- use 0 otherwise.
|
71 |
-
int ALIGNMENT_ = 0,
|
72 |
-
// The base class.
|
73 |
-
typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
|
74 |
-
>
|
75 |
-
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
|
76 |
-
|
77 |
-
// The size of a load/store.
|
78 |
-
enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) };
|
79 |
-
|
80 |
-
// Clear the fragment. Using PTX in that code seems to produce better SASS...
|
81 |
-
inline __device__ void clear() {
|
82 |
-
#pragma unroll
|
83 |
-
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
84 |
-
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
|
85 |
-
}
|
86 |
-
}
|
87 |
-
|
88 |
-
// Immutable access to a register.
|
89 |
-
inline __device__ const uint32_t& reg(int ii) const {
|
90 |
-
return this->regs_[ii];
|
91 |
-
}
|
92 |
-
|
93 |
-
// Mutable access to a register.
|
94 |
-
inline __device__ uint32_t& reg(int ii) {
|
95 |
-
return this->regs_[ii];
|
96 |
-
}
|
97 |
-
|
98 |
-
uint32_t regs_[Base_::NUM_REGS];
|
99 |
-
|
100 |
-
// Immutable access to the elements.
|
101 |
-
inline __device__ const Data_type_& elt(int ii) const {
|
102 |
-
return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
|
103 |
-
}
|
104 |
-
|
105 |
-
// Mutable access to the elements.
|
106 |
-
inline __device__ Data_type_& elt(int ii) {
|
107 |
-
return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
|
108 |
-
}
|
109 |
-
|
110 |
-
// Immutable access to the elements with a cast.
|
111 |
-
template< typename Cast_type >
|
112 |
-
inline __device__ const Cast_type& elt_as(int ii) const {
|
113 |
-
return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
|
114 |
-
}
|
115 |
-
|
116 |
-
// Mutable access to the elements.
|
117 |
-
template< typename Cast_type >
|
118 |
-
inline __device__ Cast_type& elt_as(int ii) {
|
119 |
-
return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
|
120 |
-
}
|
121 |
-
|
122 |
-
// Add another fragment.
|
123 |
-
inline __device__ void add(const Fragment &other) {
|
124 |
-
#pragma unroll
|
125 |
-
for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
|
126 |
-
this->elt(ii) += other.elt(ii);
|
127 |
-
}
|
128 |
-
}
|
129 |
-
};
|
130 |
-
|
131 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
132 |
-
|
133 |
-
template< typename Layout >
|
134 |
-
struct Fragment_a : public Fragment<uint16_t, 8> {
|
135 |
-
};
|
136 |
-
|
137 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
138 |
-
|
139 |
-
template< typename Layout >
|
140 |
-
struct Fragment_b : public Fragment<uint16_t, 8> {
|
141 |
-
};
|
142 |
-
|
143 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
144 |
-
|
145 |
-
struct Fragment_accumulator : public Fragment<float, 8> {
|
146 |
-
|
147 |
-
// The base class.
|
148 |
-
using Base = Fragment<float, 8>;
|
149 |
-
|
150 |
-
// Add two fragments.
|
151 |
-
template< typename Other_fragment_ >
|
152 |
-
inline __device__ void add(const Other_fragment_ &other) {
|
153 |
-
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
|
154 |
-
this->elt(ii) = this->elt(ii) + other.elt(ii);
|
155 |
-
}
|
156 |
-
}
|
157 |
-
|
158 |
-
// Do the HMMA.
|
159 |
-
template< typename Layout_a, typename Layout_b >
|
160 |
-
inline __device__ void mma(const Fragment_a<Layout_a> &a,
|
161 |
-
const Fragment_b<Layout_b> &b) {
|
162 |
-
asm volatile( \
|
163 |
-
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
|
164 |
-
" {%0, %1, %2, %3}, \n" \
|
165 |
-
" {%4, %5, %6, %7}, \n" \
|
166 |
-
" {%8, %9}, \n" \
|
167 |
-
" {%0, %1, %2, %3}; \n" \
|
168 |
-
: "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3))
|
169 |
-
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
|
170 |
-
, "r"(b.reg(0)), "r"(b.reg(1)));
|
171 |
-
asm volatile( \
|
172 |
-
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
|
173 |
-
" {%0, %1, %2, %3}, \n" \
|
174 |
-
" {%4, %5, %6, %7}, \n" \
|
175 |
-
" {%8, %9}, \n" \
|
176 |
-
" {%0, %1, %2, %3}; \n" \
|
177 |
-
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
|
178 |
-
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
|
179 |
-
, "r"(b.reg(2)), "r"(b.reg(3)));
|
180 |
-
}
|
181 |
-
|
182 |
-
};
|
183 |
-
|
184 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
185 |
-
|
186 |
-
template< typename Fragment, int M, int N >
|
187 |
-
inline __device__ void clear(Fragment (&frag)[M][N]) {
|
188 |
-
#pragma unroll
|
189 |
-
for( int mi = 0; mi < M; ++mi ) {
|
190 |
-
#pragma unroll
|
191 |
-
for( int ni = 0; ni < N; ++ni ) {
|
192 |
-
frag[mi][ni].clear();
|
193 |
-
}
|
194 |
-
}
|
195 |
-
}
|
196 |
-
|
197 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
198 |
-
|
199 |
-
template< typename Accumulator_type, int WARPS_K >
|
200 |
-
struct Clear_accumulator {
|
201 |
-
};
|
202 |
-
|
203 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
204 |
-
|
205 |
-
template< int WARPS_K >
|
206 |
-
struct Clear_accumulator<float, WARPS_K> {
|
207 |
-
template< typename Acc, int M, int N >
|
208 |
-
static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
|
209 |
-
fmha::clear(acc);
|
210 |
-
}
|
211 |
-
};
|
212 |
-
|
213 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
214 |
-
|
215 |
-
template<typename Acc, typename A, typename B, int M, int N>
|
216 |
-
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
|
217 |
-
|
218 |
-
#pragma unroll
|
219 |
-
for( int mi = 0; mi < M; ++mi ) {
|
220 |
-
#pragma unroll
|
221 |
-
for( int ni = 0; ni < N; ++ni ) {
|
222 |
-
acc[mi][ni].mma(a[mi], b[ni]);
|
223 |
-
}
|
224 |
-
}
|
225 |
-
}
|
226 |
-
|
227 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
228 |
-
|
229 |
-
template<
|
230 |
-
// The number of rows in the CTA tile.
|
231 |
-
int M_,
|
232 |
-
// The number of cols in the CTA tile.
|
233 |
-
int N_,
|
234 |
-
// The number of elements in the the K dimension of the GEMM loop.
|
235 |
-
int K_,
|
236 |
-
// The number of rows of warps.
|
237 |
-
int WARPS_M_,
|
238 |
-
// The number of cols of warps.
|
239 |
-
int WARPS_N_,
|
240 |
-
// The number of warps in the K dimension of the GEMM loop.
|
241 |
-
int WARPS_K_>
|
242 |
-
struct Cta_tile_ {
|
243 |
-
|
244 |
-
enum { M = M_, N = N_, K = K_ };
|
245 |
-
// The number of warps.
|
246 |
-
enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ };
|
247 |
-
// The number of warps per CTA.
|
248 |
-
enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K };
|
249 |
-
// The number of threads per warp.
|
250 |
-
enum { THREADS_PER_WARP = 32 };
|
251 |
-
// The number of threads per CTA.
|
252 |
-
enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP };
|
253 |
-
};
|
254 |
-
|
255 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
256 |
-
|
257 |
-
template<typename Cta_tile>
|
258 |
-
struct Hmma_tile {
|
259 |
-
// The number of elements computed with a single warp-MMA.
|
260 |
-
enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 };
|
261 |
-
|
262 |
-
// The number of elements computed with a single CTA-MMA.
|
263 |
-
enum {
|
264 |
-
M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
|
265 |
-
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
|
266 |
-
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K
|
267 |
-
};
|
268 |
-
|
269 |
-
// The number of MMAs needed to compute the GEMM.
|
270 |
-
enum {
|
271 |
-
MMAS_M = Div_up<Cta_tile::M, M_PER_MMA_PER_CTA>::VALUE,
|
272 |
-
MMAS_N = Div_up<Cta_tile::N, N_PER_MMA_PER_CTA>::VALUE,
|
273 |
-
MMAS_K = Div_up<Cta_tile::K, K_PER_MMA_PER_CTA>::VALUE,
|
274 |
-
};
|
275 |
-
|
276 |
-
// The number of elements computed per warp.
|
277 |
-
enum {
|
278 |
-
M_PER_WARP = MMAS_M * M_PER_MMA,
|
279 |
-
N_PER_WARP = MMAS_N * N_PER_MMA,
|
280 |
-
K_PER_WARP = MMAS_K * K_PER_MMA,
|
281 |
-
};
|
282 |
-
|
283 |
-
};
|
284 |
-
|
285 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
286 |
-
|
287 |
-
using A_type = uint16_t;
|
288 |
-
using B_type = uint16_t;
|
289 |
-
using C_type = uint16_t;
|
290 |
-
using Accumulator_type = float;
|
291 |
-
using Epilogue_type = float;
|
292 |
-
|
293 |
-
constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
|
294 |
-
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
|
295 |
-
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;
|
296 |
-
|
297 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
298 |
-
|
299 |
-
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
|
300 |
-
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
|
301 |
-
|
302 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
303 |
-
|
304 |
-
template<typename Cta_tile_>
|
305 |
-
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
|
306 |
-
Cta_tile_::N,
|
307 |
-
Next_power_of_two<Cta_tile_::K>::VALUE,
|
308 |
-
Cta_tile_::WARPS_M,
|
309 |
-
Cta_tile_::WARPS_N,
|
310 |
-
Cta_tile_::WARPS_K>;
|
311 |
-
|
312 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
313 |
-
|
314 |
-
} // namespace fmha
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apex/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
DELETED
@@ -1,456 +0,0 @@
|
|
1 |
-
/******************************************************************************
|
2 |
-
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
*
|
4 |
-
* Redistribution and use in source and binary forms, with or without
|
5 |
-
* modification, are permitted provided that the following conditions are met:
|
6 |
-
* * Redistributions of source code must retain the above copyright
|
7 |
-
* notice, this list of conditions and the following disclaimer.
|
8 |
-
* * Redistributions in binary form must reproduce the above copyright
|
9 |
-
* notice, this list of conditions and the following disclaimer in the
|
10 |
-
* documentation and/or other materials provided with the distribution.
|
11 |
-
* * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
* names of its contributors may be used to endorse or promote products
|
13 |
-
* derived from this software without specific prior written permission.
|
14 |
-
*
|
15 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
-
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
-
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
-
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
*
|
26 |
-
******************************************************************************/
|
27 |
-
|
28 |
-
#pragma once
|
29 |
-
|
30 |
-
namespace fmha {
|
31 |
-
|
32 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
33 |
-
|
34 |
-
template<
|
35 |
-
// The dimensions of the tile computed by the CTA.
|
36 |
-
typename Cta_tile,
|
37 |
-
// The number of bits per element.
|
38 |
-
int BITS_PER_ELEMENT,
|
39 |
-
// The number of rows of Q, K or V loaded by this tile.
|
40 |
-
int ROWS,
|
41 |
-
// The number of columns.
|
42 |
-
int COLS,
|
43 |
-
// The number of matrics.
|
44 |
-
int NUM_MATS = 3
|
45 |
-
>
|
46 |
-
struct Gmem_tile_qkv {
|
47 |
-
|
48 |
-
// The size of each LDG.
|
49 |
-
enum { BYTES_PER_LDG = 16 };
|
50 |
-
// The size of a row in bytes.
|
51 |
-
enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 };
|
52 |
-
|
53 |
-
// The number of threads to load a "row" of the matrix.
|
54 |
-
enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG };
|
55 |
-
|
56 |
-
// The number of "rows" loaded per LDG.
|
57 |
-
enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
|
58 |
-
// The number of LDGs needed to load a chunk of the Q matrix.
|
59 |
-
enum { LDGS = fmha::Div_up<ROWS, ROWS_PER_LDG>::VALUE };
|
60 |
-
|
61 |
-
// Ctor.
|
62 |
-
template< typename Params, typename BInfo >
|
63 |
-
inline __device__ Gmem_tile_qkv(const Params ¶ms, const int qkv_offset, const BInfo &binfo, const int tidx)
|
64 |
-
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
|
65 |
-
, actual_seqlen(binfo.actual_seqlen)
|
66 |
-
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {
|
67 |
-
|
68 |
-
// Compute the position in the sequence (within the CTA for the moment).
|
69 |
-
int row = tidx / THREADS_PER_ROW;
|
70 |
-
// Compute the position of the thread in the row.
|
71 |
-
int col = tidx % THREADS_PER_ROW;
|
72 |
-
|
73 |
-
// Store the row as we need it to disable the loads.
|
74 |
-
row_ = row;
|
75 |
-
|
76 |
-
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
|
77 |
-
int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
|
78 |
-
// Add the block index.
|
79 |
-
row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
|
80 |
-
|
81 |
-
// Assemble the final pointer.
|
82 |
-
qkv_ptr_ += row_offset + col * BYTES_PER_LDG;
|
83 |
-
}
|
84 |
-
|
85 |
-
// Store data to shared memory.
|
86 |
-
template< typename Smem_tile >
|
87 |
-
inline __device__ void commit(Smem_tile &smem_tile) {
|
88 |
-
smem_tile.store(fetch_);
|
89 |
-
}
|
90 |
-
|
91 |
-
// Load data from memory.
|
92 |
-
template< typename Smem_tile >
|
93 |
-
inline __device__ void load(Smem_tile &smem_tile) {
|
94 |
-
const void *ptrs[LDGS];
|
95 |
-
uint32_t preds[LDGS];
|
96 |
-
#pragma unroll
|
97 |
-
for( int ii = 0; ii < LDGS; ++ii ) {
|
98 |
-
ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
|
99 |
-
preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
|
100 |
-
fetch_[ii] = make_uint4(0, 0, 0, 0);
|
101 |
-
}
|
102 |
-
|
103 |
-
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
|
104 |
-
Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
|
105 |
-
#pragma unroll
|
106 |
-
for( int ii = 0; ii < LDGS; ++ii ) {
|
107 |
-
fct.load(ii, preds[ii]);
|
108 |
-
}
|
109 |
-
}
|
110 |
-
|
111 |
-
// Store data to memory.
|
112 |
-
inline __device__ void store(const uint4 (&data)[LDGS]) {
|
113 |
-
#pragma unroll
|
114 |
-
for( int ii = 0; ii < LDGS; ++ii ) {
|
115 |
-
char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
|
116 |
-
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
|
117 |
-
fmha::stg(ptr, data[ii]);
|
118 |
-
}
|
119 |
-
}
|
120 |
-
}
|
121 |
-
|
122 |
-
// Move the pointer to the next location.
|
123 |
-
inline __device__ void move() {
|
124 |
-
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;
|
125 |
-
actual_seqlen -= ROWS;
|
126 |
-
}
|
127 |
-
|
128 |
-
inline __device__ void move(int steps) {
|
129 |
-
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
|
130 |
-
actual_seqlen -= ROWS * steps;
|
131 |
-
}
|
132 |
-
|
133 |
-
// The stride between rows for the QKV matrice.
|
134 |
-
int64_t params_qkv_stride_in_bytes_;
|
135 |
-
// The pointer.
|
136 |
-
char *qkv_ptr_;
|
137 |
-
// The fetch registers.
|
138 |
-
uint4 fetch_[LDGS];
|
139 |
-
// Keep track of the row the thread is processing as we move the tile.
|
140 |
-
int row_;
|
141 |
-
// The length of the sequence loaded by that memory tile.
|
142 |
-
int actual_seqlen;
|
143 |
-
};
|
144 |
-
|
145 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
146 |
-
|
147 |
-
template< typename Cta_tile >
|
148 |
-
struct Gmem_tile_o {
|
149 |
-
|
150 |
-
// The mma tile.
|
151 |
-
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
152 |
-
|
153 |
-
// The size of each element.
|
154 |
-
enum { BYTES_PER_ELEMENT = 2 };
|
155 |
-
// The size of a row in bytes.
|
156 |
-
enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT };
|
157 |
-
|
158 |
-
// The number of threads to store a "row" of the matrix.
|
159 |
-
enum { THREADS_PER_ROW = 16 };
|
160 |
-
// The size of each STG.
|
161 |
-
enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW };
|
162 |
-
|
163 |
-
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
|
164 |
-
enum { ROWS = Cta_tile::M };
|
165 |
-
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
|
166 |
-
enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };
|
167 |
-
// The number of outter loop for the stores.
|
168 |
-
enum { LOOPS = ROWS / ROWS_PER_LOOP };
|
169 |
-
|
170 |
-
// The number of "rows" stored per STG.
|
171 |
-
enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
|
172 |
-
// Do we have to guard against partial writes/reads.
|
173 |
-
enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 };
|
174 |
-
// The number of STGs needed to store a chunk of the Q matrix.
|
175 |
-
enum { STGS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_STG>::VALUE };
|
176 |
-
// The number of STGs needed to store a chunk of the Q matrix in total.
|
177 |
-
enum { STGS = STGS_PER_LOOP * LOOPS };
|
178 |
-
|
179 |
-
// Ctor.
|
180 |
-
template<typename Params, typename BInfo>
|
181 |
-
inline __device__ Gmem_tile_o(const Params ¶ms, const BInfo &binfo, int tidx)
|
182 |
-
: params_o_stride_in_bytes_(params.o_stride_in_bytes)
|
183 |
-
, actual_seqlen_(binfo.actual_seqlen)
|
184 |
-
, o_ptr_(reinterpret_cast<char *>(params.o_ptr)) {
|
185 |
-
|
186 |
-
// Compute the position in the sequence (within the CTA for the moment).
|
187 |
-
int row = tidx / THREADS_PER_ROW;
|
188 |
-
// Compute the position of the thread in the row.
|
189 |
-
int col = tidx % THREADS_PER_ROW;
|
190 |
-
|
191 |
-
// Store the row as we need it to disable loads.
|
192 |
-
row_ = row;
|
193 |
-
|
194 |
-
// The row offset in the batched GEMM.
|
195 |
-
int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
|
196 |
-
// Assemble the final pointer.
|
197 |
-
o_ptr_ += row_offset + col * BYTES_PER_STG;
|
198 |
-
|
199 |
-
// Is that thread active on the last STG?
|
200 |
-
if( HAS_INCOMPLETE_STG ) {
|
201 |
-
is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
|
202 |
-
}
|
203 |
-
}
|
204 |
-
|
205 |
-
// Store data to global memory.
|
206 |
-
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
|
207 |
-
|
208 |
-
#pragma unroll
|
209 |
-
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
|
210 |
-
int jj = mi * STGS_PER_LOOP + ii;
|
211 |
-
if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {
|
212 |
-
break;
|
213 |
-
}
|
214 |
-
|
215 |
-
float x = reinterpret_cast<const float &>(src[ii].x);
|
216 |
-
float y = reinterpret_cast<const float &>(src[ii].y);
|
217 |
-
float z = reinterpret_cast<const float &>(src[ii].z);
|
218 |
-
float w = reinterpret_cast<const float &>(src[ii].w);
|
219 |
-
uint2 out = float4_to_half4(x, y, z, w);
|
220 |
-
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
221 |
-
fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out);
|
222 |
-
}
|
223 |
-
}
|
224 |
-
}
|
225 |
-
|
226 |
-
// Move the pointer to the next location.
|
227 |
-
inline __device__ void move() {
|
228 |
-
row_ += ROWS;
|
229 |
-
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;
|
230 |
-
}
|
231 |
-
|
232 |
-
inline __device__ void move(const int steps) {
|
233 |
-
row_ += ROWS * steps;
|
234 |
-
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps;
|
235 |
-
}
|
236 |
-
|
237 |
-
// The stride between rows for the QKV matrice.
|
238 |
-
int64_t params_o_stride_in_bytes_;
|
239 |
-
// The pointer.
|
240 |
-
char *o_ptr_;
|
241 |
-
// Is the thread active for the last STG?
|
242 |
-
int is_active_for_last_stg_;
|
243 |
-
// Keep track of the row to disable loads.
|
244 |
-
int row_;
|
245 |
-
// The length of the sequence loaded by that memory tile.
|
246 |
-
int actual_seqlen_;
|
247 |
-
};
|
248 |
-
|
249 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
250 |
-
|
251 |
-
template< typename Cta_tile, int BYTES_PER_ELEMENT >
|
252 |
-
struct Gmem_tile_mma_sd {
|
253 |
-
|
254 |
-
// The mma tile.
|
255 |
-
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
256 |
-
|
257 |
-
// Each STG stores 8 elements.
|
258 |
-
enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 };
|
259 |
-
// The number of MMAs in the M dimension.
|
260 |
-
enum { MMAS_M = Mma_tile::MMAS_M };
|
261 |
-
// The number of MMAs in the N dimension.
|
262 |
-
enum { MMAS_N = Mma_tile::MMAS_N };
|
263 |
-
// The number of rows computed per MMA per thread block.
|
264 |
-
enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA };
|
265 |
-
// The number of cols computed per MMA per thread block.
|
266 |
-
enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA };
|
267 |
-
// The number of threads per block.
|
268 |
-
enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA };
|
269 |
-
// The size of each row in bytes. I.e. how many bytes are stored per STG.
|
270 |
-
enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG };
|
271 |
-
// The fixed sequence length.
|
272 |
-
enum { SEQLEN = Cta_tile::N };
|
273 |
-
// The distance between two blocks (in bytes).
|
274 |
-
enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT };
|
275 |
-
// The distance between elements stored per loop (in bytes).
|
276 |
-
enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW };
|
277 |
-
|
278 |
-
// The type of elements stored per STG.
|
279 |
-
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
|
280 |
-
|
281 |
-
// Ctor.
|
282 |
-
template<typename Params>
|
283 |
-
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx)
|
284 |
-
: ptr_(static_cast<char *>(ptr)) {
|
285 |
-
|
286 |
-
// The block index.
|
287 |
-
size_t bidx = bidb * params.h + bidh;
|
288 |
-
|
289 |
-
// Set store location for each thread at the beginning of the loop
|
290 |
-
ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG;
|
291 |
-
}
|
292 |
-
|
293 |
-
// Store to global memory.
|
294 |
-
inline __device__ void store(const Type &data, const int mi, const int ni) {
|
295 |
-
size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
296 |
-
fmha::stg(ptr_ + offset, data);
|
297 |
-
}
|
298 |
-
|
299 |
-
// Load from global memory.
|
300 |
-
inline __device__ void load(Type &data, const int mi, const int ni) {
|
301 |
-
size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
302 |
-
fmha::ldg(data, ptr_ + offset);
|
303 |
-
}
|
304 |
-
|
305 |
-
// Move to the next tile.
|
306 |
-
inline __device__ void move() {
|
307 |
-
ptr_ += LOOP_STRIDE_BYTES;
|
308 |
-
}
|
309 |
-
inline __device__ void move(const int steps) {
|
310 |
-
ptr_ += LOOP_STRIDE_BYTES * steps;
|
311 |
-
}
|
312 |
-
|
313 |
-
// The pointer in global memory.
|
314 |
-
char *ptr_;
|
315 |
-
};
|
316 |
-
|
317 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
318 |
-
|
319 |
-
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
|
320 |
-
struct Gmem_tile_mma_s : public Base {
|
321 |
-
|
322 |
-
// The number of mmas in the vertical dimension.
|
323 |
-
enum { M = Base::MMAS_M };
|
324 |
-
// The number of mmas in the horizontal dimension.
|
325 |
-
enum { N = Base::MMAS_N };
|
326 |
-
// The type of the vectors stored by each STG.
|
327 |
-
using Type = typename Base::Type;
|
328 |
-
|
329 |
-
// Ctor.
|
330 |
-
template< typename Params, typename Block_info >
|
331 |
-
inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx)
|
332 |
-
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
|
333 |
-
}
|
334 |
-
|
335 |
-
// Store to global memory.
|
336 |
-
template<typename Mask>
|
337 |
-
inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) {
|
338 |
-
#pragma unroll
|
339 |
-
for( int mi = 0; mi < M; mi++ ) {
|
340 |
-
#pragma unroll
|
341 |
-
for( int ni = 0; ni < N; ni++ ) {
|
342 |
-
|
343 |
-
float tmp00 = softmax[2 * mi + 0][4 * ni + 0];
|
344 |
-
float tmp01 = softmax[2 * mi + 0][4 * ni + 1];
|
345 |
-
float tmp02 = softmax[2 * mi + 0][4 * ni + 2];
|
346 |
-
float tmp03 = softmax[2 * mi + 0][4 * ni + 3];
|
347 |
-
|
348 |
-
float tmp10 = softmax[2 * mi + 1][4 * ni + 0];
|
349 |
-
float tmp11 = softmax[2 * mi + 1][4 * ni + 1];
|
350 |
-
float tmp12 = softmax[2 * mi + 1][4 * ni + 2];
|
351 |
-
float tmp13 = softmax[2 * mi + 1][4 * ni + 3];
|
352 |
-
|
353 |
-
uint4 dst;
|
354 |
-
dst.x = fmha::float2_to_half2(tmp00, tmp01);
|
355 |
-
dst.y = fmha::float2_to_half2(tmp02, tmp03);
|
356 |
-
dst.z = fmha::float2_to_half2(tmp10, tmp11);
|
357 |
-
dst.w = fmha::float2_to_half2(tmp12, tmp13);
|
358 |
-
if( mask.is_valid(mi, ni, 0, 0) ) {
|
359 |
-
Base::store(dst, mi, ni);
|
360 |
-
}
|
361 |
-
}
|
362 |
-
}
|
363 |
-
}
|
364 |
-
|
365 |
-
// Store to global memory.
|
366 |
-
template<typename Mask, typename Fragment>
|
367 |
-
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
|
368 |
-
#pragma unroll
|
369 |
-
for( int mi = 0; mi < M; mi++ ) {
|
370 |
-
#pragma unroll
|
371 |
-
for( int ni = 0; ni < N; ni++ ) {
|
372 |
-
uint4 dst;
|
373 |
-
dst.x = frag[ni][mi].reg(0);
|
374 |
-
dst.y = frag[ni][mi].reg(2);
|
375 |
-
dst.z = frag[ni][mi].reg(1);
|
376 |
-
dst.w = frag[ni][mi].reg(3);
|
377 |
-
if( mask.any_valid(mi, ni) ) {
|
378 |
-
Base::store(dst, mi, ni);
|
379 |
-
}
|
380 |
-
}
|
381 |
-
}
|
382 |
-
}
|
383 |
-
|
384 |
-
// Load from global memory.
|
385 |
-
template<typename Mask>
|
386 |
-
inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) {
|
387 |
-
#pragma unroll
|
388 |
-
for( int mi = 0; mi < M; mi++ ) {
|
389 |
-
#pragma unroll
|
390 |
-
for( int ni = 0; ni < N; ni++ ) {
|
391 |
-
regs[mi][ni] = make_uint4(0, 0, 0, 0);
|
392 |
-
if( mask.any_valid(mi, ni) ) {
|
393 |
-
Base::load(regs[mi][ni], mi, ni);
|
394 |
-
}
|
395 |
-
}
|
396 |
-
}
|
397 |
-
}
|
398 |
-
};
|
399 |
-
|
400 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
401 |
-
|
402 |
-
template<
|
403 |
-
// The dimensions of the tile computed by the CTA.
|
404 |
-
typename Cta_tile,
|
405 |
-
// The base class.
|
406 |
-
typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K>
|
407 |
-
>
|
408 |
-
struct Gmem_tile_dout : public Base {
|
409 |
-
|
410 |
-
// Ctor.
|
411 |
-
template<typename Params, typename BInfo>
|
412 |
-
inline __device__ Gmem_tile_dout(const Params ¶ms, const BInfo &binfo, int tidx)
|
413 |
-
: Base(params, 0, binfo, tidx) {
|
414 |
-
|
415 |
-
this->qkv_ptr_ = reinterpret_cast<char *>(params.o_ptr);
|
416 |
-
this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move
|
417 |
-
|
418 |
-
// Compute the position of the thread in the row.
|
419 |
-
int col = tidx % Base::THREADS_PER_ROW;
|
420 |
-
|
421 |
-
// The row offset in the batched GEMM. For each seq element, we store O in that order.
|
422 |
-
int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
|
423 |
-
|
424 |
-
// Assemble the final pointer.
|
425 |
-
this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;
|
426 |
-
}
|
427 |
-
};
|
428 |
-
|
429 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
430 |
-
|
431 |
-
template< typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >
|
432 |
-
struct Gmem_tile_dq : public Base {
|
433 |
-
|
434 |
-
// Ctor.
|
435 |
-
template<typename Params, typename BInfo>
|
436 |
-
inline __device__ Gmem_tile_dq(const Params ¶ms, const BInfo &binfo, int tidx)
|
437 |
-
: Base(params, binfo, tidx) {
|
438 |
-
this->o_ptr_ = reinterpret_cast<char *>(params.dqkv_ptr);
|
439 |
-
this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes; // needed for move
|
440 |
-
|
441 |
-
// Compute the position of the thread in the row.
|
442 |
-
int col = tidx % Base::THREADS_PER_ROW;
|
443 |
-
|
444 |
-
// The row offset in the batched GEMM. For each seq element, we store O in that order.
|
445 |
-
int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +
|
446 |
-
(binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
|
447 |
-
|
448 |
-
// Assemble the final pointer.
|
449 |
-
this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG;
|
450 |
-
}
|
451 |
-
};
|
452 |
-
|
453 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
454 |
-
|
455 |
-
} // namespace fmha
|
456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|