theonlyengine commited on
Commit
3f9c425
·
verified ·
1 Parent(s): b218831

Upload 421 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ flashattention_logo.png filter=lfs diff=lfs merge=lfs -text
AUTHORS ADDED
@@ -0,0 +1 @@
 
 
1
+ Tri Dao, trid@cs.stanford.edu
Dockerfile ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile
2
+ # ARG COMPAT=0
3
+ ARG PERSONAL=0
4
+ # FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0
5
+ FROM nvcr.io/nvidia/pytorch:22.12-py3 as base
6
+
7
+ ENV HOST docker
8
+ ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
9
+ # https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes
10
+ ENV TZ America/Los_Angeles
11
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
12
+
13
+ # git for installing dependencies
14
+ # tzdata to set time zone
15
+ # wget and unzip to download data
16
+ # [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment.
17
+ # [2021-12-07] TD: openmpi-bin for MPI (multi-node training)
18
+ RUN apt-get update && apt-get install -y --no-install-recommends \
19
+ build-essential \
20
+ cmake \
21
+ curl \
22
+ ca-certificates \
23
+ sudo \
24
+ less \
25
+ htop \
26
+ git \
27
+ tzdata \
28
+ wget \
29
+ tmux \
30
+ zip \
31
+ unzip \
32
+ zsh stow subversion fasd \
33
+ && rm -rf /var/lib/apt/lists/*
34
+ # openmpi-bin \
35
+
36
+ # Allow running runmpi as root
37
+ # ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
38
+
39
+ # # Create a non-root user and switch to it
40
+ # RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
41
+ # && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
42
+ # USER user
43
+
44
+ # All users can use /home/user as their home directory
45
+ ENV HOME=/home/user
46
+ RUN mkdir -p /home/user && chmod 777 /home/user
47
+ WORKDIR /home/user
48
+
49
+ # Set up personal environment
50
+ # FROM base-${COMPAT} as env-0
51
+ FROM base as env-0
52
+ FROM env-0 as env-1
53
+ # Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image
54
+ # https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile
55
+ ONBUILD COPY dotfiles ./dotfiles
56
+ ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami)
57
+ # nvcr pytorch image sets SHELL=/bin/bash
58
+ ONBUILD ENV SHELL=/bin/zsh
59
+
60
+ FROM env-${PERSONAL} as packages
61
+
62
+ # Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
63
+ ENV PIP_NO_CACHE_DIR=1
64
+
65
+ # # apex and pytorch-fast-transformers take a while to compile so we install them first
66
+ # TD [2022-04-28] apex is already installed. In case we need a newer commit:
67
+ # RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
68
+
69
+ # xgboost conflicts with deepspeed
70
+ RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7
71
+
72
+ # General packages that we don't care about the version
73
+ # zstandard to extract the_pile dataset
74
+ # psutil to get the number of cpu physical cores
75
+ # twine to upload package to PyPI
76
+ RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \
77
+ && python -m spacy download en_core_web_sm
78
+ # hydra
79
+ RUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
80
+ # Core packages
81
+ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 triton==2.0.0.dev20221202 wandb==0.13.7 timm==0.6.12 torchmetrics==0.10.3
82
+ # torchmetrics 0.11.0 broke hydra's instantiate
83
+
84
+ # For MLPerf
85
+ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
86
+
87
+ # Install FlashAttention
88
+ RUN pip install flash-attn==2.6.3
89
+
90
+ # Install CUDA extensions for fused dense
91
+ RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ * Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
MANIFEST.in ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ recursive-include csrc *.cu
2
+ recursive-include csrc *.h
3
+ recursive-include csrc *.cuh
4
+ recursive-include csrc *.cpp
5
+ recursive-include csrc *.hpp
6
+
7
+ recursive-include flash_attn *.cu
8
+ recursive-include flash_attn *.h
9
+ recursive-include flash_attn *.cuh
10
+ recursive-include flash_attn *.cpp
11
+ recursive-include flash_attn *.hpp
Makefile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ clean_dist:
3
+ rm -rf dist/*
4
+
5
+ create_dist: clean_dist
6
+ python setup.py sdist
7
+
8
+ upload_package: create_dist
9
+ twine upload dist/*
README.md ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optimized Transformer implementation
2
+ This repo contains examples of how FlashAttention can be integrated into a model
3
+ (e.g., GPT, ViT) and trained end-to-end. We also provide optimized
4
+ implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss,
5
+ rotary embedding). Overall this speeds up training by 3-5x compared to the
6
+ baseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100,
7
+ equivalent to 60.6\% model FLOPs utilization (we don't need any activation
8
+ checkpointing). All without changing the model architecture (i.e., no
9
+ approximation).
10
+
11
+ Goals:
12
+ - Performance: we optimize for model speed and memory, especially on 1-node
13
+ (e.g., with 8 A100s).
14
+ - Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm),
15
+ and the model code illustrates how these components can be put together.
16
+ The training code also aims to be model- & task-agnostic.
17
+
18
+ Non-goals (and other resources):
19
+ - Support as many models as possible: Huggingface's
20
+ [transformers](https://github.com/huggingface/transformers) and
21
+ [timm](https://github.com/rwightman/pytorch-image-models/) are great for this.
22
+ - Large-scale distributed training: our codebase has been used for multi-GPU and multi-node
23
+ training for models up to 2.7B parameters. However, if you're looking for large-scale distributed
24
+ training techniques (e.g., pipeline parallelism, tensor parallelism),
25
+ check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and
26
+ [DeepSpeed](https://github.com/microsoft/deepspeed).
27
+ - Inference: we currently focus on training (this might change in the future).
28
+ If you want fast inference, take a look at
29
+ [FasterTransformer](https://github.com/NVIDIA/FasterTransformer).
30
+ - Production: this codebase was written during several research projects to validate ideas
31
+ on speeding up ML models.
32
+
33
+ ## Model Components
34
+
35
+ The GPT model is implemented
36
+ [here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
37
+ And here's an example to construct the GPT3-1.3B model with rotary embedding:
38
+ ```python
39
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
40
+ from flash_attn.models.gpt import GPTLMHeadModel
41
+
42
+ seqlen = 2048
43
+ hidden_dim = 2048
44
+ nheads = 16
45
+ n_layer = 24
46
+ rotary_emb_fraction = 0.5
47
+ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
48
+ n_layer=n_layer, n_head=nheads,
49
+ scale_attn_by_inverse_layer_idx=True,
50
+ rotary_emb_fraction=rotary_emb_fraction,
51
+ use_flash_attn=True, fused_mlp=True,
52
+ fused_bias_fc=True, fused_dropout_add_ln=True,
53
+ pad_vocab_size_multiple=8)
54
+ model = GPTLMHeadModel(config)
55
+ ```
56
+
57
+ We provide the following optimized components:
58
+
59
+ 1. FlashAttention: fast and memory-efficient exact attention. This makes
60
+ attention much faster and saves a lot of activation memory. As a result we don't need
61
+ to use any activation checkpointing.
62
+ ```sh
63
+ pip install flash-attn
64
+ ```
65
+
66
+ 2. Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
67
+ (forward and backward), adapted from Apex's
68
+ [FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We
69
+ make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before
70
+ this doesn't have the best matmul + bias + gelu performance for bfloat16.
71
+ ```sh
72
+ cd ../csrc/fused_dense_lib && pip install .
73
+ ```
74
+ 3. Optimized cross-entropy loss, adapted from Apex's
75
+ [Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.
76
+ ```sh
77
+ cd ../csrc/xentropy && pip install .
78
+ ```
79
+ 4. Fused rotary embedding:
80
+ ```sh
81
+ cd ../csrc/rotary && pip install .
82
+ ```
83
+ 5. Fused dropout + residual + LayerNorm, adapted from Apex's
84
+ [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
85
+ This supports dimensions divisible by 8, up to 6144.
86
+ ```sh
87
+ cd ../csrc/layer_norm && pip install .
88
+ ```
89
+
90
+ ## Training
91
+
92
+ We also provide here training scripts to train GPT2 on Openwebtext and GPT3 on
93
+ The Pile as examples. Feel free to use the model in your own training setup as
94
+ well.
95
+
96
+ We use [Hydra](https://hydra.cc/) for configuration,
97
+ [Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and
98
+ [Wandb](https://wandb.ai/) for logging.
99
+
100
+ We use the template from `https://github.com/ashleve/lightning-hydra-template`.
101
+ Please read the instructions there to understand the repo structure.
102
+
103
+ ### Requirements
104
+
105
+ Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
106
+ hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
107
+ We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
108
+
109
+ We provide a Dockerfile that lists all the required packages.
110
+
111
+ ### Dataset preparation
112
+
113
+ Running the training command would automatically download the datasets
114
+ (Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the
115
+ tokens, then save this cache to disk. Alternatively, you can also prepare the
116
+ datasets as a separate step.
117
+
118
+ The cached datasets are saved to `${DATA_DIR}/openwebtext` and
119
+ `${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to
120
+ `./data/{openwebtext,the_pile}`.
121
+
122
+ - Openwebtext:
123
+ ```sh
124
+ export PYTHONPATH=$PWD:$PYTHONPATH
125
+ pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext"
126
+ ```
127
+ This takes around 1h on a 64-core CPU. The processed dataset has size 17GB.
128
+
129
+ - The Pile:
130
+ ```sh
131
+ export PYTHONPATH=$PWD:$PYTHONPATH
132
+ pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile"
133
+ ```
134
+ This takes around 20h on a 64-core CPU. The processed dataset has size 699GB.
135
+
136
+ ### GPT2 training on Openwebtext
137
+ To train GPT2 on Openwebtext with 8 GPUs:
138
+ ```sh
139
+ python run.py experiment=owt/gpt2s-flash trainer.devices=8 # 125M
140
+ python run.py experiment=owt/gpt2m-flash trainer.devices=8 # 355M
141
+ python run.py experiment=owt/gpt2l-flash trainer.devices=8 # 760M
142
+ python run.py experiment=owt/gpt2xl-flash trainer.devices=8 # 1.6B
143
+ ```
144
+ The default parameters are set for 8 x A100 80GB.
145
+
146
+ To train with bf16 instead of fp16, add `trainer.precision=bf16`.
147
+
148
+ ### GPT3 training on The Pile
149
+ To train GPT3 on The Pile with 8 GPUs:
150
+ ```sh
151
+ python run.py experiment=pile/gpt3s-flash trainer.devices=8 # 125M
152
+ python run.py experiment=pile/gpt3m-flash trainer.devices=8 # 355M
153
+ python run.py experiment=pile/gpt3l-flash trainer.devices=8 # 760M
154
+ python run.py experiment=pile/gpt3xl-flash trainer.devices=8 # 1.3B
155
+ python run.py experiment=pile/gpt3-2.7B-flash-hdim128 trainer.devices=8 # 2.7B
156
+ ```
157
+ The default parameters are set for 8 x A100 80GB. We train with bf16 by default.
158
+
159
+ To train with rotary embedding, run the experiments `pile/gpt3{s,m,l,xl}-flash-rotary`.
160
+
161
+ ### Training options
162
+
163
+ **Gradient accumulation**: to adjust device batch size to fit into GPU memory
164
+ (the global batch size stays the same, and gradient accumulation is calculated
165
+ automatically), set `datamodule.batch_size=blah`.
166
+
167
+ **Multi-node**: to train on multiple nodes, add `trainer.num_nodes=blah`.
168
+
169
+ **Speed benchmarking**: to print out iteration time, add `+callbacks.speed_monitor.verbose=True`.
170
+
171
+ **Resumable training**: set a name to the run, and then set `resume=True` when
172
+ you resume. Training will restart at exactly the same batch.
173
+ ```sh
174
+ python run.py experiment=pile/gpt3s-flash trainer.devices=8 name=pile-gpt3s-flash resume=True
175
+ ```
176
+
177
+ ## Training speed
178
+
179
+ We measure the wallclock training speed on one node with 8 x A100 80GB SXM4 80GB (400W) with NVLink.
180
+
181
+ FLOPs are calculated using the formula from the [Megatron-LM
182
+ paper](https://arxiv.org/abs/2104.04473) (Section 5.1), except we scale by 3/4
183
+ to get the model FLOPs (instead of hardware FLOPs with activation
184
+ checkpointing).
185
+
186
+
187
+ ### GPT2 (sequence length 1024)
188
+
189
+ ![GPT2 speedup](../assets/gpt2_training_efficiency.jpg)
190
+
191
+ The implementation in this repo (FlashAttention) is 3-4x faster than the
192
+ baseline implementation from Huggingface.
193
+
194
+ ### GPT3 (sequence length 2048)
195
+
196
+ ![GPT3 speedup](../assets/gpt3_training_efficiency.jpg)
197
+
198
+ The implementation in this repo (FlashAttention) is 3-5x faster than the
199
+ baseline implementation from Huggingface.
200
+
201
+ For the GPT3-2.7B model, we set head dimension to 128 (instead of 80) for better efficiency.
202
+
203
+ We include here more details on the training speed with FlashAttention on 8 x
204
+ A100 80GB.
205
+
206
+ | Model | Batch size (tokens) | Through put (tokens/sec) | Hours / 1B tokens |
207
+ | --------- | ------------------- | ------------------------ | ----------------- |
208
+ | GPT3-125M | 0.5M | 1310k | 0.21 |
209
+ | GPT3-355M | 0.5M | 503k | 0.55 |
210
+ | GPT3-760M | 0.5M | 245k | 1.13 |
211
+ | GPT3-1.3B | 1M | 169k | 1.64 |
212
+ | GPT3-2.7B | 1M | 85k | 3.27 |
213
+
214
+ As an example, this means that one can train a GPT3-1.3B model on 26B tokens
215
+ (compute-optimal according to Chinchilla scaling) in about 43 hours on 8 x A100.
216
+
217
+ ## Training quality
218
+
219
+ We include here the loss curve for GPT2 on Openwebtext, trained for 200B tokens.
220
+ For GPT2, the runs with FlashAttention yield the same loss curve as the runs
221
+ with the baseline implementation from Huggingface for 125M and 355M models. For
222
+ larger models the baseline implementation just takes too long.
223
+
224
+ ![GPT2 training curve](../assets/gpt2_training_curve.jpg)
225
+
226
+ We include here the loss curve for GPT3 on The Pile, trained for 400B tokens.
227
+ The 125M, 355M, 760M models have batch size 512k tokens so this translates to
228
+ 800k training steps, while the 1.3B and 2.7B models have batch size 1M tokens,
229
+ which translates to 400k training steps.
230
+
231
+ ![GPT3 training curve](../assets/gpt3_training_curve.jpg)
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "3.0.0.b1"
acc.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # @package eval.metrics
2
+ acc:
3
+ _target_: src.metrics.accuracy.AccuracyMine
acc_ignore_index.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package eval.metrics
2
+ acc:
3
+ _target_: torchmetrics.Accuracy
4
+ ignore_index: -100
acctop5.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package eval.metrics
2
+ acctop5:
3
+ _target_: src.metrics.accuracy.AccuracyMine
4
+ top_k: 5
activations.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ # 1/sqrt(2*pi)-> 0.3989423
9
+ # 1/sqrt(2) -> 0.70710678
10
+ # sqrt(2/pi) -> 0.79788456
11
+
12
+ # this function is tanh approximation of gelu
13
+ # actual gelu is:
14
+ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
15
+ @torch.jit.script
16
+ def bias_gelu(y, bias):
17
+ x = bias + y
18
+ return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
19
+
20
+
21
+ # gradient of tanh approximation of gelu
22
+ # gradient of actual gelu is:
23
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
24
+ @torch.jit.script
25
+ def bias_gelu_back(g, y, bias):
26
+ """Assume that y has shape (B, D) and bias has shape (D)"""
27
+ x = bias + y
28
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
29
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
30
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
31
+ 1 + tanh_out
32
+ )
33
+ grad_y = ff * g
34
+ return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
35
+
36
+
37
+ class GeLUFunction(torch.autograd.Function):
38
+ @staticmethod
39
+ # bias is an optional argument
40
+ def forward(ctx, input, bias):
41
+ ctx.save_for_backward(input, bias)
42
+ return bias_gelu(input, bias)
43
+
44
+ @staticmethod
45
+ def backward(ctx, grad_output):
46
+ input, bias = ctx.saved_tensors
47
+ tmp = bias_gelu_back(grad_output, input, bias)
48
+ return tmp, tmp
49
+
50
+
51
+ bias_gelu_impl = GeLUFunction.apply
52
+
53
+ # this function is tanh approximation of gelu
54
+ # actual gelu is:
55
+ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
56
+ @torch.jit.script
57
+ def gelu_fwd(x):
58
+ return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
59
+
60
+
61
+ # gradient of tanh approximation of gelu
62
+ # gradient of actual gelu is:
63
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
64
+ @torch.jit.script
65
+ def gelu_bwd(g, x):
66
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
67
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
68
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
69
+ 1 + tanh_out
70
+ )
71
+ return (ff * g).to(dtype=x.dtype)
72
+
73
+
74
+ class FastGeLUFunction(torch.autograd.Function):
75
+ @staticmethod
76
+ # bias is an optional argument
77
+ def forward(ctx, input):
78
+ ctx.save_for_backward(input)
79
+ return gelu_fwd(input)
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output):
83
+ (input,) = ctx.saved_tensors
84
+ tmp = gelu_bwd(grad_output, input)
85
+ return tmp
86
+
87
+
88
+ fast_gelu_impl = FastGeLUFunction.apply
89
+
90
+
91
+ @torch.jit.script
92
+ def relu_bwd(g, x):
93
+ return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
94
+
95
+
96
+ @torch.jit.script
97
+ def sqrelu_fwd(x):
98
+ r = F.relu(x)
99
+ return (r * r).to(dtype=x.dtype)
100
+
101
+
102
+ @torch.jit.script
103
+ def sqrelu_bwd(g, x):
104
+ return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
105
+
106
+
107
+ swiglu_fwd_codestring = """
108
+ template <typename T> T swiglu_fwd(T x, T y) {
109
+ return float(x) * float(y) / (1.0f + ::exp(-float(x)));
110
+ }
111
+ """
112
+ swiglu_bwd_codestring = """
113
+ template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
114
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
115
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
116
+ dy = float(x) * x_sigmoid * float(g);
117
+ }
118
+ """
119
+ swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
120
+ swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
121
+
122
+
123
+ class SwiGLUFunction(torch.autograd.Function):
124
+
125
+ @staticmethod
126
+ def forward(ctx, x, y):
127
+ ctx.save_for_backward(x, y)
128
+ return swiglu_fwd(x, y)
129
+
130
+ @staticmethod
131
+ def backward(ctx, dout):
132
+ x, y = ctx.saved_tensors
133
+ return swiglu_bwd(x, y, dout)
134
+
135
+ swiglu = SwiGLUFunction.apply
adam.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # @package train.optimizer
2
+ _target_: torch.optim.Adam
adamw-apex-distributed.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # @package train.optimizer
2
+ _target_: apex.contrib.optimizers.distributed_fused_adam.DistributedFusedAdam
3
+ adam_w_mode: True
adamw-apex-zero.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # @package train.optimizer
2
+ _target_: torch.distributed.optim.ZeroRedundancyOptimizer
3
+ _recursive_: True
4
+ optimizer_class:
5
+ _target_: apex.optimizers.FusedAdam
6
+ _partial_: True
7
+ adam_w_mode: True
adamw-apex.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # @package train.optimizer
2
+ _target_: apex.optimizers.FusedAdam
3
+ adam_w_mode: True
adamw-zero.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # @package train.optimizer
2
+ _target_: torch.distributed.optim.ZeroRedundancyOptimizer
3
+ _recursive_: True
4
+ optimizer_class:
5
+ _target_: torch.optim.__getattribute__
6
+ _args_:
7
+ - "AdamW"
adamw.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # @package train.optimizer
2
+ _target_: torch.optim.AdamW
alibi.h ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cmath>
2
+
3
+ #include <cute/tensor.hpp>
4
+
5
+ #include <cutlass/cutlass.h>
6
+ #include <cutlass/array.h>
7
+
8
+ #include "utils.h"
9
+
10
+ namespace flash {
11
+
12
+ using namespace cute;
13
+
14
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
15
+
16
+ template <bool Is_causal>
17
+ struct Alibi {
18
+
19
+ const float alibi_slope;
20
+ const int max_seqlen_k, max_seqlen_q;
21
+
22
+ __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
23
+ : alibi_slope(alibi_slope)
24
+ , max_seqlen_k(max_seqlen_k)
25
+ , max_seqlen_q(max_seqlen_q) {
26
+ };
27
+
28
+
29
+ template <typename Engine, typename Layout>
30
+ __forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
31
+ const int col_idx_offset_,
32
+ const int row_idx_offset,
33
+ const int warp_row_stride) {
34
+ // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
35
+ static_assert(Layout::rank == 2, "Only support 2D Tensor");
36
+ const int lane_id = threadIdx.x % 32;
37
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
38
+ if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
39
+ #pragma unroll
40
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
41
+ const int col_idx_base = col_idx_offset + nj * 8;
42
+ #pragma unroll
43
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
44
+ const int col_idx = col_idx_base + j;
45
+ #pragma unroll
46
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
47
+ tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
48
+ }
49
+ }
50
+ }
51
+ } else { // Bias depends on both row_idx and col_idx
52
+ #pragma unroll
53
+ for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
54
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
55
+ #pragma unroll
56
+ for (int i = 0; i < size<0, 0>(tensor); ++i) {
57
+ const int row_idx = row_idx_base + i * 8;
58
+ #pragma unroll
59
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
60
+ const int col_idx_base = col_idx_offset + nj * 8;
61
+ #pragma unroll
62
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
63
+ const int col_idx = col_idx_base + j;
64
+ tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
65
+ }
66
+ }
67
+ }
68
+ }
69
+ }
70
+ }
71
+
72
+ };
73
+
74
+ } // namespace flash
all_params.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: pytorch_lightning.Trainer
2
+
3
+ # default values for all trainer parameters
4
+ checkpoint_callback: True
5
+ default_root_dir: null
6
+ gradient_clip_val: 0.0
7
+ process_position: 0
8
+ num_nodes: 1
9
+ num_processes: 1
10
+ gpus: null
11
+ auto_select_gpus: False
12
+ tpu_cores: null
13
+ log_gpu_memory: null
14
+ overfit_batches: 0.0
15
+ track_grad_norm: -1
16
+ check_val_every_n_epoch: 1
17
+ fast_dev_run: False
18
+ accumulate_grad_batches: 1
19
+ max_epochs: 1
20
+ min_epochs: 1
21
+ max_steps: null
22
+ min_steps: null
23
+ limit_train_batches: 1.0
24
+ limit_val_batches: 1.0
25
+ limit_test_batches: 1.0
26
+ val_check_interval: 1.0
27
+ flush_logs_every_n_steps: 100
28
+ log_every_n_steps: 50
29
+ accelerator: null
30
+ sync_batchnorm: False
31
+ precision: 32
32
+ weights_summary: "top"
33
+ weights_save_path: null
34
+ num_sanity_val_steps: 2
35
+ truncated_bptt_steps: null
36
+ resume_from_checkpoint: null
37
+ profiler: null
38
+ benchmark: False
39
+ deterministic: False
40
+ reload_dataloaders_every_epoch: False
41
+ auto_lr_find: False
42
+ replace_sampler_ddp: True
43
+ terminate_on_nan: False
44
+ auto_scale_batch_size: False
45
+ prepare_data_per_node: True
46
+ plugins: null
47
+ amp_backend: "native"
48
+ amp_level: "O2"
49
+ move_metrics_to_cpu: False
baichuan.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, GGGGGGXY, Tri Dao.
2
+
3
+ import math
4
+ import json
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from collections import OrderedDict
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from einops import rearrange
14
+ from transformers import GPT2Config, AutoConfig, PretrainedConfig
15
+
16
+
17
+ def remap_state_dict_hf_baichuan(state_dict, config):
18
+ def key_mapping_layers(key):
19
+ return re.sub(r"^model.", "transformer.", key)
20
+
21
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
22
+
23
+ # Word embedding
24
+ def key_mapping_emb(key):
25
+ return re.sub(
26
+ r"^transformer.embed_tokens.",
27
+ "transformer.embeddings.word_embeddings.",
28
+ key,
29
+ )
30
+
31
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
32
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
33
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
34
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
35
+ vocab_size = (
36
+ math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
37
+ * pad_vocab_size_multiple
38
+ )
39
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
40
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
41
+ )
42
+ if getattr(config, "tie_word_embeddings"):
43
+ state_dict["lm_head.weight"] = state_dict[
44
+ "transformer.embeddings.word_embeddings.weight"
45
+ ]
46
+ else:
47
+ output_embeddings = state_dict.pop("lm_head.weight")
48
+ # Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings
49
+ # differently.
50
+ vocab_size = (
51
+ math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
52
+ * pad_vocab_size_multiple
53
+ )
54
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
55
+ state_dict["lm_head.weight"] = F.pad(
56
+ output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
57
+ )
58
+
59
+ # LayerNorm
60
+ def key_mapping_ln(key):
61
+ key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
62
+ key = re.sub(
63
+ r"^transformer.layers.(\d+).input_layernorm.",
64
+ r"transformer.layers.\1.norm1.",
65
+ key,
66
+ )
67
+ key = re.sub(
68
+ r"^transformer.layers.(\d+).post_attention_layernorm.",
69
+ r"transformer.layers.\1.norm2.",
70
+ key,
71
+ )
72
+ return key
73
+
74
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
75
+
76
+ # MLP
77
+ for l in range(config.n_layer):
78
+ w1 = state_dict.pop(f"transformer.layers.{l}.mlp.gate_proj.weight")
79
+ w3 = state_dict.pop(f"transformer.layers.{l}.mlp.up_proj.weight")
80
+ # Our ordering is different
81
+ state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat(
82
+ [w3, w1], dim=0
83
+ )
84
+
85
+ def key_mapping_mlp(key):
86
+ return re.sub(
87
+ r"^transformer.layers.(\d+).mlp.down_proj.",
88
+ r"transformer.layers.\1.mlp.fc2.",
89
+ key,
90
+ )
91
+
92
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
93
+
94
+ # Attention
95
+ def key_mapping_attn(key):
96
+ key = re.sub(
97
+ r"^transformer.layers.(\d+).self_attn.W_pack.",
98
+ r"transformer.layers.\1.mixer.Wqkv.",
99
+ key,
100
+ )
101
+ key = re.sub(
102
+ r"^transformer.layers.(\d+).self_attn.o_proj.",
103
+ r"transformer.layers.\1.mixer.out_proj.",
104
+ key,
105
+ )
106
+ return key
107
+
108
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
109
+ for l in range(config.n_layer):
110
+ # pop rotary_emb.inv_freq from state dict
111
+ state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None)
112
+ return state_dict
113
+
114
+
115
+ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
116
+ # HACK: the config doesn't have say whether it's rotary or alibi.
117
+ # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
118
+ # HACK: the config doesn't have say whether it uses norm head.
119
+ # So we have to infer from the vocab size
120
+ # (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
121
+ use_rotary = baichuan_config.hidden_size < 5000
122
+ return GPT2Config(
123
+ vocab_size=baichuan_config.vocab_size,
124
+ n_positions=0, # No absolute position embedding
125
+ n_embd=baichuan_config.hidden_size,
126
+ n_layer=baichuan_config.num_hidden_layers,
127
+ n_head=baichuan_config.num_attention_heads,
128
+ n_inner=baichuan_config.intermediate_size,
129
+ activation_function="swiglu", # Hardcode since HF calls it 'silu'
130
+ # baichuan doesn't have dropout, idk if it's because they only release the inference code
131
+ resid_pdrop=0.0,
132
+ embd_pdrop=0.0,
133
+ attn_pdrop=0.0,
134
+ layer_norm_epsilon=baichuan_config.rms_norm_eps,
135
+ initializer_range=baichuan_config.initializer_range,
136
+ bos_token_id=baichuan_config.bos_token_id,
137
+ eos_token_id=baichuan_config.eos_token_id,
138
+ # These are new arguments not in the original GPT2Config
139
+ pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything
140
+ rms_norm=True,
141
+ rotary_emb_fraction=1.0 if use_rotary else 0.0,
142
+ rotary_emb_interleaved=False,
143
+ use_alibi=not use_rotary,
144
+ use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
145
+ tie_word_embeddings=False,
146
+ norm_head=baichuan_config.vocab_size > 70000,
147
+ qkv_proj_bias=False,
148
+ out_proj_bias=False,
149
+ mlp_fc1_bias=False,
150
+ mlp_fc2_bias=False,
151
+ )
base.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /trainer: default # choose trainer from 'configs/trainer/'
4
+ - override /model: null
5
+ - override /datamodule: openwebtext
6
+ # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time
7
+ # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms.
8
+ # For GPT2-medium time per global goes from 997ms to 972ms.
9
+ - override /optimizer: adamw-apex
10
+ - override /scheduler: linear-warmup
11
+ - override /callbacks: [default, norm-monitor]
12
+ - override /metrics: [perplexity, num-tokens]
13
+ - override /logger: wandb
14
+
15
+ # all parameters below will be merged with parameters from default configurations set above
16
+ # this allows you to overwrite only specified parameters
17
+
18
+ task:
19
+ _target_: src.tasks.seq.SequenceLMModel
20
+
21
+ seed: 1111
22
+
23
+ trainer:
24
+ accelerator: gpu
25
+ devices: 8
26
+ num_nodes: 1
27
+ accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
28
+ max_steps: 400000
29
+ val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
30
+ check_val_every_n_epoch: null # We don't care about epoch boundary
31
+ precision: 16
32
+ gradient_clip_val: 1.0
33
+ strategy: null
34
+
35
+ datamodule:
36
+ batch_size: 16 # Per GPU
37
+ batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k
38
+ max_length: 1024
39
+ fault_tolerant: True
40
+ ddp: ${eval:"${trainer.devices} > 1"}
41
+
42
+ train:
43
+ gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
44
+ global_batch_size: 512
45
+ optimizer:
46
+ lr: 6e-4
47
+ weight_decay: 0.1
48
+ optimizer_param_grouping:
49
+ bias_weight_decay: False
50
+ normalization_weight_decay: False
51
+ scheduler:
52
+ num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}}
53
+ num_training_steps: ${trainer.max_steps}
54
+ loss_fn:
55
+ # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
56
+ # It's also more numerically stable if we're using DeepSpeed 16 bits.
57
+ _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
58
+ inplace_backward: True # to save memory
59
+
60
+ eval:
61
+ log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step
62
+
63
+ callbacks:
64
+ model_checkpoint:
65
+ monitor: val/loss
66
+ mode: min
67
+ save_top_k: 3
68
+ save_last: True
69
+ every_n_train_steps: 1000
70
+ dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
71
+ filename: step_{step}
72
+ auto_insert_metric_name: False
73
+ model_checkpoint_progress:
74
+ _target_: src.callbacks.model_checkpoint.ModelCheckpointMine
75
+ fault_tolerant: True
76
+ every_n_train_steps: 50000
77
+ save_last: False
78
+ save_top_k: -1 # Save all the checkpoints
79
+ dirpath: ${..model_checkpoint.dirpath}
80
+ filename: progress_step_{step}
81
+ auto_insert_metric_name: False
82
+ early_stopping: null
benchmark.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ """ Useful functions for writing test code. """
3
+
4
+ import torch
5
+ import torch.utils.benchmark as benchmark
6
+
7
+
8
+ def benchmark_forward(
9
+ fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
10
+ ):
11
+ """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
12
+ if verbose:
13
+ print(desc, "- Forward pass")
14
+
15
+ def amp_wrapper(*inputs, **kwinputs):
16
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
17
+ fn(*inputs, **kwinputs)
18
+
19
+ t = benchmark.Timer(
20
+ stmt="fn_amp(*inputs, **kwinputs)",
21
+ globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
22
+ num_threads=torch.get_num_threads(),
23
+ )
24
+ m = t.timeit(repeats)
25
+ if verbose:
26
+ print(m)
27
+ return t, m
28
+
29
+
30
+ def benchmark_backward(
31
+ fn,
32
+ *inputs,
33
+ grad=None,
34
+ repeats=10,
35
+ desc="",
36
+ verbose=True,
37
+ amp=False,
38
+ amp_dtype=torch.float16,
39
+ **kwinputs,
40
+ ):
41
+ """Use Pytorch Benchmark on the backward pass of an arbitrary function."""
42
+ if verbose:
43
+ print(desc, "- Backward pass")
44
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
45
+ y = fn(*inputs, **kwinputs)
46
+ if type(y) is tuple:
47
+ y = y[0]
48
+ if grad is None:
49
+ grad = torch.randn_like(y)
50
+ else:
51
+ if grad.shape != y.shape:
52
+ raise RuntimeError("Grad shape does not match output shape")
53
+
54
+ def f(*inputs, y, grad):
55
+ # Set .grad to None to avoid extra operation of gradient accumulation
56
+ for x in inputs:
57
+ if isinstance(x, torch.Tensor):
58
+ x.grad = None
59
+ y.backward(grad, retain_graph=True)
60
+
61
+ t = benchmark.Timer(
62
+ stmt="f(*inputs, y=y, grad=grad)",
63
+ globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
64
+ num_threads=torch.get_num_threads(),
65
+ )
66
+ m = t.timeit(repeats)
67
+ if verbose:
68
+ print(m)
69
+ return t, m
70
+
71
+
72
+ def benchmark_combined(
73
+ fn,
74
+ *inputs,
75
+ grad=None,
76
+ repeats=10,
77
+ desc="",
78
+ verbose=True,
79
+ amp=False,
80
+ amp_dtype=torch.float16,
81
+ **kwinputs,
82
+ ):
83
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
84
+ if verbose:
85
+ print(desc, "- Forward + Backward pass")
86
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
87
+ y = fn(*inputs, **kwinputs)
88
+ if type(y) is tuple:
89
+ y = y[0]
90
+ if grad is None:
91
+ grad = torch.randn_like(y)
92
+ else:
93
+ if grad.shape != y.shape:
94
+ raise RuntimeError("Grad shape does not match output shape")
95
+
96
+ def f(grad, *inputs, **kwinputs):
97
+ for x in inputs:
98
+ if isinstance(x, torch.Tensor):
99
+ x.grad = None
100
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
101
+ y = fn(*inputs, **kwinputs)
102
+ if type(y) is tuple:
103
+ y = y[0]
104
+ y.backward(grad, retain_graph=True)
105
+
106
+ t = benchmark.Timer(
107
+ stmt="f(grad, *inputs, **kwinputs)",
108
+ globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
109
+ num_threads=torch.get_num_threads(),
110
+ )
111
+ m = t.timeit(repeats)
112
+ if verbose:
113
+ print(m)
114
+ return t, m
115
+
116
+
117
+ def benchmark_fwd_bwd(
118
+ fn,
119
+ *inputs,
120
+ grad=None,
121
+ repeats=10,
122
+ desc="",
123
+ verbose=True,
124
+ amp=False,
125
+ amp_dtype=torch.float16,
126
+ **kwinputs,
127
+ ):
128
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
129
+ return (
130
+ benchmark_forward(
131
+ fn,
132
+ *inputs,
133
+ repeats=repeats,
134
+ desc=desc,
135
+ verbose=verbose,
136
+ amp=amp,
137
+ amp_dtype=amp_dtype,
138
+ **kwinputs,
139
+ ),
140
+ benchmark_backward(
141
+ fn,
142
+ *inputs,
143
+ grad=grad,
144
+ repeats=repeats,
145
+ desc=desc,
146
+ verbose=verbose,
147
+ amp=amp,
148
+ amp_dtype=amp_dtype,
149
+ **kwinputs,
150
+ ),
151
+ )
152
+
153
+
154
+ def benchmark_all(
155
+ fn,
156
+ *inputs,
157
+ grad=None,
158
+ repeats=10,
159
+ desc="",
160
+ verbose=True,
161
+ amp=False,
162
+ amp_dtype=torch.float16,
163
+ **kwinputs,
164
+ ):
165
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
166
+ return (
167
+ benchmark_forward(
168
+ fn,
169
+ *inputs,
170
+ repeats=repeats,
171
+ desc=desc,
172
+ verbose=verbose,
173
+ amp=amp,
174
+ amp_dtype=amp_dtype,
175
+ **kwinputs,
176
+ ),
177
+ benchmark_backward(
178
+ fn,
179
+ *inputs,
180
+ grad=grad,
181
+ repeats=repeats,
182
+ desc=desc,
183
+ verbose=verbose,
184
+ amp=amp,
185
+ amp_dtype=amp_dtype,
186
+ **kwinputs,
187
+ ),
188
+ benchmark_combined(
189
+ fn,
190
+ *inputs,
191
+ grad=grad,
192
+ repeats=repeats,
193
+ desc=desc,
194
+ verbose=verbose,
195
+ amp=amp,
196
+ amp_dtype=amp_dtype,
197
+ **kwinputs,
198
+ ),
199
+ )
200
+
201
+
202
+ def pytorch_profiler(
203
+ fn,
204
+ *inputs,
205
+ trace_filename=None,
206
+ backward=False,
207
+ amp=False,
208
+ amp_dtype=torch.float16,
209
+ cpu=False,
210
+ verbose=True,
211
+ **kwinputs,
212
+ ):
213
+ """Wrap benchmark functions in Pytorch profiler to see CUDA information."""
214
+ if backward:
215
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
216
+ out = fn(*inputs, **kwinputs)
217
+ if type(out) is tuple:
218
+ out = out[0]
219
+ g = torch.randn_like(out)
220
+ for _ in range(30): # Warm up
221
+ if backward:
222
+ for x in inputs:
223
+ if isinstance(x, torch.Tensor):
224
+ x.grad = None
225
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
226
+ out = fn(*inputs, **kwinputs)
227
+ if type(out) is tuple:
228
+ out = out[0]
229
+ # Backward should be done outside autocast
230
+ if backward:
231
+ out.backward(g, retain_graph=True)
232
+ activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
233
+ torch.profiler.ProfilerActivity.CUDA
234
+ ]
235
+ with torch.profiler.profile(
236
+ activities=activities,
237
+ record_shapes=True,
238
+ # profile_memory=True,
239
+ with_stack=True,
240
+ ) as prof:
241
+ if backward:
242
+ for x in inputs:
243
+ if isinstance(x, torch.Tensor):
244
+ x.grad = None
245
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
246
+ out = fn(*inputs, **kwinputs)
247
+ if type(out) is tuple:
248
+ out = out[0]
249
+ if backward:
250
+ out.backward(g, retain_graph=True)
251
+ if verbose:
252
+ # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
253
+ print(prof.key_averages().table(row_limit=50))
254
+ if trace_filename is not None:
255
+ prof.export_chrome_trace(trace_filename)
256
+
257
+
258
+ def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
259
+ torch.cuda.empty_cache()
260
+ torch.cuda.reset_peak_memory_stats()
261
+ torch.cuda.synchronize()
262
+ fn(*inputs, **kwinputs)
263
+ torch.cuda.synchronize()
264
+ mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
265
+ if verbose:
266
+ print(f"{desc} max memory: {mem}GB")
267
+ torch.cuda.empty_cache()
268
+ return mem
benchmark_alibi.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Sanghun Cho, Tri Dao.
2
+
3
+ import pickle
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+ from flash_attn.layers.rotary import apply_rotary_emb
11
+
12
+ from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
13
+ from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
14
+
15
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
16
+
17
+ try:
18
+ import xformers.ops as xops
19
+ except ImportError:
20
+ xops = None
21
+
22
+
23
+ def generate_cos_sin(seqlen, rotary_dim, device, dtype):
24
+ assert rotary_dim % 2 == 0
25
+ angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
26
+ cos = torch.cos(angle).to(dtype=dtype)
27
+ sin = torch.sin(angle).to(dtype=dtype)
28
+ return cos, sin
29
+
30
+
31
+ def flash_rotary(q, k, v, cos, sin, causal=False):
32
+ # corrected by @tridao comments
33
+ q = apply_rotary_emb(
34
+ q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
35
+ )
36
+ k = apply_rotary_emb(
37
+ k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
38
+ )
39
+
40
+ return flash_attn_func(q, k, v, causal=causal)
41
+
42
+
43
+ def attn_bias_from_alibi_slopes(
44
+ slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
45
+ ):
46
+ batch, nheads = slopes.shape
47
+ device = slopes.device
48
+ slopes = rearrange(slopes, "b h -> b h 1 1")
49
+ if causal:
50
+ return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
51
+ else:
52
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
53
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
54
+ sk = (
55
+ seqlen_k
56
+ if key_padding_mask is None
57
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
58
+ )
59
+ sq = (
60
+ seqlen_q
61
+ if query_padding_mask is None
62
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
63
+ )
64
+ relative_pos = torch.abs(row_idx + sk - sq - col_idx)
65
+ return -slopes * relative_pos.to(dtype=slopes.dtype)
66
+
67
+
68
+ def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
69
+ assert mode in ["fwd", "bwd", "fwd_bwd"]
70
+ f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
71
+ return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
72
+
73
+
74
+ def efficiency(flop, time):
75
+ return (flop / time / 10**12) if not math.isnan(time) else 0.0
76
+
77
+
78
+ def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
79
+ """
80
+ Arguments:
81
+ q, k, v: (batch_size, seqlen, nheads, head_dim)
82
+ dropout_p: float
83
+ attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
84
+ Output:
85
+ output: (batch_size, seqlen, nheads, head_dim)
86
+ """
87
+ batch_size, seqlen, nheads, d = q.shape
88
+ q = rearrange(q, 'b t h d -> (b h) t d')
89
+ k = rearrange(k, 'b s h d -> (b h) d s')
90
+ softmax_scale = 1.0 / math.sqrt(d)
91
+ # Preallocate attn_weights for `baddbmm`
92
+ if attn_bias is not None:
93
+ scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
94
+ else:
95
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
96
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
97
+ '(b h) t s -> b h t s', h=nheads)
98
+ if causal:
99
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
100
+ # So we have to construct the mask in float
101
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
102
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
103
+ scores = scores + causal_mask.to(dtype=scores.dtype)
104
+ attention = torch.softmax(scores, dim=-1)
105
+ attention_drop = F.dropout(attention, dropout_p)
106
+ output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
107
+ return output.to(dtype=q.dtype)
108
+
109
+
110
+ def time_fwd_bwd(func, *args, **kwargs):
111
+ time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
112
+ return time_f[1].mean, time_b[1].mean
113
+
114
+
115
+ repeats = 30
116
+ device = 'cuda'
117
+ dtype = torch.float16
118
+
119
+ bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
120
+ causal_vals = [False, True]
121
+ headdim_vals = [64, 128]
122
+ dim = 2048
123
+ dropout_p = 0.0
124
+
125
+ methods = (["fa2_alibi", "torch"]
126
+ + (["xformers"] if xops is not None else [])
127
+ + ["sdpa"]
128
+ + ["fa2_baseline"]
129
+ + ["fa2_rotary"])
130
+
131
+ time_f = {}
132
+ time_b = {}
133
+ time_f_b = {}
134
+ speed_f = {}
135
+ speed_b = {}
136
+ speed_f_b = {}
137
+ for causal in causal_vals:
138
+ for headdim in headdim_vals:
139
+ for batch_size, seqlen in bs_seqlen_vals:
140
+ config = (causal, headdim, batch_size, seqlen)
141
+ nheads = dim // headdim
142
+ q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
143
+ requires_grad=True) for _ in range(3)]
144
+ # alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
145
+ alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
146
+ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
147
+ attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
148
+ f, b = time_fwd_bwd(
149
+ flash_attn_func,
150
+ q, k, v,
151
+ dropout_p,
152
+ causal=causal,
153
+ # alibi_slopes=alibi_slopes,
154
+ alibi_slopes=None,
155
+ repeats=repeats,
156
+ verbose=False
157
+ )
158
+ time_f[config, "fa2_baseline"] = f
159
+ time_b[config, "fa2_baseline"] = b
160
+
161
+ q = q.detach().requires_grad_(True)
162
+ k = k.detach().requires_grad_(True)
163
+ v = v.detach().requires_grad_(True)
164
+ f, b = time_fwd_bwd(
165
+ flash_attn_func,
166
+ q, k, v,
167
+ dropout_p,
168
+ causal=causal,
169
+ alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
170
+ # alibi_slopes=None,
171
+ repeats=repeats,
172
+ verbose=False
173
+ )
174
+ time_f[config, "fa2_alibi"] = f
175
+ time_b[config, "fa2_alibi"] = b
176
+
177
+ try:
178
+ q = q.detach().requires_grad_(True)
179
+ k = k.detach().requires_grad_(True)
180
+ v = v.detach().requires_grad_(True)
181
+ f, b = time_fwd_bwd(
182
+ attention_pytorch,
183
+ q, k, v,
184
+ dropout_p,
185
+ causal=causal,
186
+ attn_bias=attn_bias,
187
+ repeats=repeats,
188
+ verbose=False
189
+ )
190
+ except: # Skip if OOM
191
+ f, b = float('nan'), float('nan')
192
+ time_f[config, "torch"] = f
193
+ time_b[config, "torch"] = b
194
+
195
+ # F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
196
+ with torch.backends.cuda.sdp_kernel(enable_flash=False):
197
+ q_pt = q.detach().requires_grad_(True).transpose(1, 2)
198
+ k_pt = k.detach().requires_grad_(True).transpose(1, 2)
199
+ v_pt = v.detach().requires_grad_(True).transpose(1, 2)
200
+ f, b = time_fwd_bwd(
201
+ F.scaled_dot_product_attention,
202
+ q_pt, k_pt, v_pt,
203
+ attn_mask=attn_bias,
204
+ dropout_p=dropout_p,
205
+ is_causal=causal,
206
+ repeats=repeats,
207
+ verbose=False
208
+ )
209
+ time_f[config, "sdpa"] = f
210
+ time_b[config, "sdpa"] = b
211
+
212
+ if xops is not None:
213
+ q = q.detach().requires_grad_(True)
214
+ k = k.detach().requires_grad_(True)
215
+ v = v.detach().requires_grad_(True)
216
+ if causal:
217
+ attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
218
+ # NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
219
+ # `flshattB@v2.3.6` is not supported because:
220
+ # attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
221
+ # `cutlassB` is not supported because:
222
+ # attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
223
+ attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
224
+ else:
225
+ attn_bias_xops = attn_bias.to(dtype=q.dtype)
226
+ f, b = time_fwd_bwd(
227
+ xops.memory_efficient_attention,
228
+ q, k, v,
229
+ attn_bias_xops,
230
+ dropout_p,
231
+ repeats=repeats,
232
+ verbose=False
233
+ )
234
+ time_f[config, "xformers"] = f
235
+ time_b[config, "xformers"] = b
236
+
237
+ q = q.detach().requires_grad_(True)
238
+ k = k.detach().requires_grad_(True)
239
+ v = v.detach().requires_grad_(True)
240
+ cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
241
+ f, b = time_fwd_bwd(
242
+ flash_rotary,
243
+ q, k, v,
244
+ cos, sin,
245
+ causal,
246
+ repeats=repeats,
247
+ verbose=False
248
+ )
249
+ time_f[config, "fa2_rotary"] = f
250
+ time_b[config, "fa2_rotary"] = b
251
+
252
+ print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
253
+ csv_output = ""
254
+ csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
255
+ for method in methods:
256
+ time_f_b[config, method] = time_f[config, method] + time_b[config, method]
257
+ speed_f[config, method] = efficiency(
258
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
259
+ time_f[config, method]
260
+ )
261
+ speed_b[config, method] = efficiency(
262
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
263
+ time_b[config, method]
264
+ )
265
+ speed_f_b[config, method] = efficiency(
266
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
267
+ time_f_b[config, method]
268
+ )
269
+ print(
270
+ f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
271
+ f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
272
+ f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
273
+ )
274
+ csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
275
+ print(csv_output)
benchmark_attn.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import time
8
+
9
+ try:
10
+ import cudnn
11
+ except ImportError:
12
+ cudnn = None
13
+
14
+
15
+ from einops import rearrange, repeat
16
+
17
+ # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
18
+ from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
19
+ from flash_attn.flash_attn_interface import flash_attn_func
20
+ from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
21
+
22
+ # Need to install triton nightly:
23
+ # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
24
+
25
+ try:
26
+ from triton_fused_attention import attention as triton_attention
27
+ except ImportError:
28
+ triton_attention = None
29
+
30
+ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
31
+ assert mode in ["fwd", "bwd", "fwd_bwd"]
32
+ f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
33
+ return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
34
+
35
+
36
+ def convert_to_cudnn_type(torch_type):
37
+ if torch_type == torch.float16:
38
+ return cudnn.data_type.HALF
39
+ elif torch_type == torch.bfloat16:
40
+ return cudnn.data_type.BFLOAT16
41
+ elif torch_type == torch.float32:
42
+ return cudnn.data_type.FLOAT
43
+ elif torch_type == torch.int32:
44
+ return cudnn.data_type.INT32
45
+ elif torch_type == torch.int64:
46
+ return cudnn.data_type.INT64
47
+ else:
48
+ raise ValueError("Unsupported tensor data type.")
49
+
50
+
51
+ def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
52
+ b, nheads, seqlen_q, headdim = q.shape
53
+ _, nheads_kv, seqlen_k, _ = k.shape
54
+ assert v.shape == (b, nheads_kv, seqlen_k, headdim)
55
+ assert cudnn is not None, 'CUDNN is not available'
56
+ q_gpu, k_gpu, v_gpu = q, k, v
57
+ o_gpu, stats_gpu = o, stats
58
+ graph_forward = cudnn.pygraph(
59
+ io_data_type=convert_to_cudnn_type(q.dtype),
60
+ intermediate_data_type=cudnn.data_type.FLOAT,
61
+ compute_data_type=cudnn.data_type.FLOAT,
62
+ )
63
+ q_forward = graph_forward.tensor_like(q_gpu.detach())
64
+ k_forward = graph_forward.tensor_like(k_gpu.detach())
65
+ v_forward = graph_forward.tensor_like(v_gpu.detach())
66
+
67
+ seqlens_reshaped = seqlens if varlen else None
68
+ seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
69
+ seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
70
+
71
+ o_forward, stats_forward = graph_forward.sdpa(
72
+ name="sdpa",
73
+ q=q_forward,
74
+ k=k_forward,
75
+ v=v_forward,
76
+ is_inference=False,
77
+ attn_scale=1.0 / math.sqrt(headdim),
78
+ use_causal_mask=causal,
79
+ use_padding_mask=varlen,
80
+ seq_len_q=seq_len_q,
81
+ seq_len_kv=seq_len_kv,
82
+ )
83
+
84
+ o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
85
+ stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)
86
+
87
+ graph_forward.validate()
88
+ graph_forward.build_operation_graph()
89
+ graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
90
+ graph_forward.check_support()
91
+ graph_forward.build_plans()
92
+
93
+ variant_pack_forward = {
94
+ q_forward: q_gpu,
95
+ k_forward: k_gpu,
96
+ v_forward: v_gpu,
97
+ o_forward: o_gpu,
98
+ stats_forward: stats_gpu,
99
+ seq_len_q: seqlens_reshaped,
100
+ seq_len_kv: seqlens_reshaped,
101
+ }
102
+
103
+ dQ_gpu = torch.empty_like(q_gpu)
104
+ dK_gpu = torch.empty_like(k_gpu)
105
+ dV_gpu = torch.empty_like(v_gpu)
106
+ dO_gpu = grad
107
+
108
+ graph_backward = cudnn.pygraph(
109
+ io_data_type=cudnn.data_type.HALF,
110
+ intermediate_data_type=cudnn.data_type.FLOAT,
111
+ compute_data_type=cudnn.data_type.FLOAT,
112
+ )
113
+
114
+ q_backward = graph_backward.tensor_like(q_gpu.detach())
115
+ k_backward = graph_backward.tensor_like(k_gpu.detach())
116
+ v_backward = graph_backward.tensor_like(v_gpu.detach())
117
+ o_backward = graph_backward.tensor_like(o_gpu.detach())
118
+ dO_backward = graph_backward.tensor_like(dO_gpu.detach())
119
+ stats_backward = graph_backward.tensor_like(stats_gpu.detach())
120
+ seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
121
+ seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
122
+
123
+ dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
124
+ name="sdpa_backward",
125
+ q=q_backward,
126
+ k=k_backward,
127
+ v=v_backward,
128
+ o=o_backward,
129
+ dO=dO_backward,
130
+ stats=stats_backward,
131
+ attn_scale=1.0 / math.sqrt(headdim),
132
+ use_causal_mask=causal,
133
+ use_padding_mask=varlen,
134
+ seq_len_q=seq_len_q,
135
+ seq_len_kv=seq_len_kv,
136
+ )
137
+
138
+ dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
139
+ dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
140
+ dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
141
+
142
+ graph_backward.validate()
143
+ graph_backward.build_operation_graph()
144
+ graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
145
+ graph_backward.check_support()
146
+ graph_backward.build_plans()
147
+
148
+ variant_pack_backward = {
149
+ q_backward: q_gpu,
150
+ k_backward: k_gpu,
151
+ v_backward: v_gpu,
152
+ o_backward: o_gpu,
153
+ dO_backward: dO_gpu,
154
+ stats_backward: stats_gpu,
155
+ dQ_backward: dQ_gpu,
156
+ dK_backward: dK_gpu,
157
+ dV_backward: dV_gpu,
158
+ seq_len_q: seqlens_reshaped,
159
+ seq_len_kv: seqlens_reshaped,
160
+ }
161
+
162
+ workspace = torch.empty(
163
+ max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()),
164
+ device="cuda", dtype=torch.uint8
165
+ )
166
+
167
+ def run_fwd(*args, **kwargs):
168
+ graph_forward.execute(variant_pack_forward, workspace)
169
+ return o_gpu, stats_gpu
170
+
171
+ def run_bwd(*args, **kwargs):
172
+ graph_backward.execute(variant_pack_backward, workspace)
173
+ return dQ_gpu, dK_gpu, dV_gpu
174
+
175
+ return run_fwd, run_bwd
176
+
177
+
178
+ torch.manual_seed(0)
179
+ repeats = 100
180
+ dropout_p = 0.0
181
+ causal = False
182
+ dtype = torch.float16
183
+ device = 'cuda'
184
+ verbose = False
185
+ batch_size = 2
186
+ # seqlen = 2048
187
+ seqlen = 8192
188
+ # seqlen = 4096
189
+ # seqlen = 2047
190
+ dim = 2048
191
+ # headdim = 128
192
+ # headdim = 64
193
+ headdim = 256
194
+
195
+ for mode in ['fwd', 'bwd']:
196
+ # for mode in ['bwd']:
197
+ for headdim in [64, 128, 256]:
198
+ # for headdim in [128]:
199
+ for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
200
+ # for seqlen in [8192]:
201
+ nheads = dim // headdim
202
+ # nheads = 24
203
+ # headdim = 64
204
+ # batch_size = 64
205
+ # seqlen = 512
206
+ # nheads = 8
207
+ # headdim = 128
208
+ # nheads = 16
209
+ # headdim = 128
210
+ nheads_kv = nheads
211
+ # nheads_kv = 1
212
+
213
+ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
214
+ requires_grad=True)
215
+ q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
216
+ k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
217
+ v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
218
+ q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
219
+ k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
220
+ v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
221
+ grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
222
+ grad_t = grad.transpose(1, 2).contiguous()
223
+ o_t = torch.empty_like(q.transpose(1, 2))
224
+ stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)
225
+
226
+ bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
227
+
228
+ for causal in [False, True]:
229
+ # for causal in [True]:
230
+ print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
231
+ # For var-seq-len
232
+ lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
233
+ seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
234
+ cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
235
+ if headdim <= 128 and cudnn is not None:
236
+ cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
237
+ cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
238
+ f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
239
+ ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
240
+ _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
241
+ if mode == 'bwd':
242
+ ref_dv, v.grad = v.grad.clone(), None
243
+ ref_dk, k.grad = k.grad.clone(), None
244
+ ref_dq, q.grad = q.grad.clone(), None
245
+ # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
246
+ if headdim <= 128:
247
+ if triton_attention is not None and nheads_kv == nheads:
248
+ if mode == 'fwd':
249
+ time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
250
+ _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
251
+ # TODO: fix Triton numeric errors.
252
+ # if mode == 'bwd':
253
+ # dv, v_t.grad = v_t.grad.clone(), None
254
+ # dk, k_t.grad = k_t.grad.clone(), None
255
+ # dq, q_t.grad = q_t.grad.clone(), None
256
+ # torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
257
+ # torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
258
+ # torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
259
+ if cudnn is not None:
260
+ time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
261
+ if mode == 'fwd':
262
+ _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
263
+ _, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
264
+ cudnn_sdpa_fwd()
265
+ torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
266
+ cudnn_sdpa_fwd_varlen()
267
+ torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
268
+ else:
269
+ cudnn_sdpa_fwd()
270
+ _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
271
+ _, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
272
+ dq, dk, dv = cudnn_sdpa_bwd()
273
+ torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
274
+ torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
275
+ torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
276
+ dq, dk, dv = cudnn_sdpa_bwd_varlen()
277
+ torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
278
+ torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
279
+ torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
280
+ # pytorch_profiler(cudnn_sdpa, backward=False)
281
+
282
+ if headdim <= 128 or mode == 'fwd':
283
+ time.sleep(1)
284
+ _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
285
+ q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
286
+ k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
287
+ v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
288
+ time.sleep(1)
289
+ if mode == 'bwd':
290
+ dv, v.grad = v.grad.clone(), None
291
+ dk, k.grad = k.grad.clone(), None
292
+ dq, q.grad = q.grad.clone(), None
293
+ torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
294
+ torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
295
+ torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
296
+
297
+ bench_var_fn = bench_fn
298
+ if mode == 'bwd':
299
+ grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
300
+ bench_var_fn = partial(benchmark_backward, grad=grad_var)
301
+ _, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
302
+
303
+ # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
304
+ print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
305
+ if headdim <= 128:
306
+ if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
307
+ print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
308
+ if cudnn is not None:
309
+ print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
310
+ print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
311
+ if headdim <= 128 or mode == 'fwd':
312
+ print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
313
+ print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
314
+
benchmark_causal.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
10
+ from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
12
+ # # from flash_attn.triton.fused_attention import attention as attention
13
+ # from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
14
+ # from flash_attn.flash_attn_triton_og import attention as attention_og
15
+
16
+ # from triton.ops.flash_attention import attention as attention_triton
17
+
18
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
19
+
20
+ try:
21
+ from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
22
+ except ImportError:
23
+ scaled_upper_triang_masked_softmax = None
24
+
25
+
26
+ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
27
+ """
28
+ Arguments:
29
+ qkv: (batch_size, seqlen, 3, nheads, head_dim)
30
+ dropout_p: float
31
+ Output:
32
+ output: (batch_size, seqlen, nheads, head_dim)
33
+ """
34
+ batch_size, seqlen, _, nheads, d = qkv.shape
35
+ q, k, v = qkv.unbind(dim=2)
36
+ q = rearrange(q, 'b t h d -> (b h) t d')
37
+ k = rearrange(k, 'b s h d -> (b h) d s')
38
+ softmax_scale = 1.0 / math.sqrt(d)
39
+ # Preallocate attn_weights for `baddbmm`
40
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
41
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
42
+ '(b h) t s -> b h t s', h=nheads)
43
+ if causal:
44
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
45
+ # So we have to construct the mask in float
46
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
47
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
48
+ scores = scores + causal_mask.to(dtype=scores.dtype)
49
+ attention = torch.softmax(scores, dim=-1)
50
+ attention_drop = F.dropout(attention, dropout_p)
51
+ output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
52
+ return output.to(dtype=qkv.dtype)
53
+
54
+
55
+ def attention_megatron(qkv):
56
+ """
57
+ Arguments:
58
+ qkv: (batch_size, seqlen, 3, nheads, head_dim)
59
+ Output:
60
+ output: (batch_size, seqlen, nheads, head_dim)
61
+ """
62
+ batch_size, seqlen, _, nheads, d = qkv.shape
63
+ q, k, v = qkv.unbind(dim=2)
64
+ q = rearrange(q, 'b t h d -> (b h) t d')
65
+ k = rearrange(k, 'b s h d -> (b h) d s')
66
+ softmax_scale = 1.0 / math.sqrt(d)
67
+ # Preallocate attn_weights for `baddbmm`
68
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
69
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
70
+ '(b h) t s -> b h t s', h=nheads)
71
+ attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0)
72
+ output = torch.einsum('bhts,bshd->bthd', attention, v)
73
+ return output.to(dtype=qkv.dtype)
74
+
75
+
76
+ torch.manual_seed(0)
77
+ repeats = 30
78
+ batch_size = 8
79
+ seqlen = 2048
80
+ nheads = 12
81
+ headdim = 128
82
+ # nheads = 24
83
+ # headdim = 64
84
+ # batch_size = 64
85
+ # seqlen = 512
86
+ # nheads = 8
87
+ # headdim = 128
88
+ dropout_p = 0.0
89
+ causal = True
90
+ dtype = torch.float16
91
+ device = 'cuda'
92
+
93
+ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
94
+ requires_grad=True)
95
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
96
+ device=qkv.device)
97
+
98
+ qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
99
+ # benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
100
+ # cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
101
+ # pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
102
+ # cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
103
+ benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
104
+ pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)
105
+
106
+ # for dropout_p in [0.1, 0.0]:
107
+ # for causal in [False, True]:
108
+ # print(f"### {dropout_p = }, {causal = } ###")
109
+ # pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
110
+
111
+
112
+ # nheads_k = 2
113
+ # q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
114
+ # kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
115
+ # requires_grad=True)
116
+ # if fav2_kvpacked_func is not None:
117
+ # benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
118
+ # pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)
119
+
120
+ # dropout_p = 0.0
121
+ # causal = False
122
+ # benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
123
+ # repeats=repeats, desc='PyTorch Attention')
124
+
125
+ # benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
126
+ # pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
127
+
128
+ # q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
129
+ # requires_grad=True) for _ in range(3)]
130
+ # benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
131
+ # # pytorch_profiler(attention, q, k, v, 1.0, backward=True)
132
+
133
+ # if scaled_upper_triang_masked_softmax is not None:
134
+ # benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
135
+
136
+ # from src.ops.fftconv import fftconv_func
137
+
138
+ # dim = nheads * headdim
139
+ # u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)
140
+ # k = torch.randn(dim, seqlen, device=device, requires_grad=True)
141
+ # D = torch.randn(dim, device=device, requires_grad=True)
142
+ # benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')
143
+ # pytorch_profiler(fftconv_func, u, k, D, backward=True)
144
+ # pytorch_profiler(torch.fft.rfft, u.float())
145
+
146
+ flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
147
+ ideal_a100_time = flops / 312 / 1e9
148
+ print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
149
+ exit(0)
150
+
151
+
152
+ def time_fwd_bwd(func, *args, **kwargs):
153
+ time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
154
+ return time_f[1].mean, time_b[1].mean
155
+
156
+ bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
157
+ causal_vals = [False, True]
158
+ headdim_vals = [64, 128]
159
+ dim = 2048
160
+ dropout_p = 0.0
161
+
162
+ time_f = {}
163
+ time_b = {}
164
+ for causal in causal_vals:
165
+ for headdim in headdim_vals:
166
+ for batch_size, seqlen in bs_seqlen_vals:
167
+ nheads = dim // headdim
168
+ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
169
+ requires_grad=True)
170
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
171
+ device=qkv.device)
172
+ qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
173
+ f, b = time_fwd_bwd(
174
+ flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
175
+ causal=causal, repeats=repeats, verbose=False
176
+ )
177
+ time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
178
+ time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b
179
+
180
+ qkv = qkv.detach().requires_grad_(True)
181
+ f, b = time_fwd_bwd(
182
+ fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
183
+ )
184
+ time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
185
+ time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b
186
+
187
+ # q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
188
+ # requires_grad=True) for _ in range(3)]
189
+ # # Try both values of sequence_parallel and pick the faster one
190
+ # f, b = time_fwd_bwd(
191
+ # attention_triton, q, k, v, causal, headdim**(-0.5),
192
+ # False, repeats=repeats, verbose=False
193
+ # )
194
+ # _, b0 = time_fwd_bwd(
195
+ # attention_triton, q, k, v, causal, headdim**(-0.5),
196
+ # True, repeats=repeats, verbose=False
197
+ # )
198
+ # time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
199
+ # time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)
200
+
201
+ if seqlen <= 8 * 1024:
202
+ qkv = qkv.detach().requires_grad_(True)
203
+ f, b = time_fwd_bwd(
204
+ attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
205
+ )
206
+ else:
207
+ f, b = float('nan'), float('nan')
208
+ time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
209
+ time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b
210
+
211
+ # q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
212
+ # requires_grad=True) for _ in range(3)]
213
+ # import xformers.ops as xops
214
+ # f, b = time_fwd_bwd(
215
+ # xops.memory_efficient_attention, q, k, v,
216
+ # attn_bias=xops.LowerTriangularMask() if causal else None,
217
+ # op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
218
+ # )
219
+ # time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
220
+ # time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b
221
+
222
+
223
+ import pickle
224
+ with open('flash2_attn_time_h100.plk', 'wb') as fp:
225
+ pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
benchmark_flash_attention.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install the newest triton version with
2
+ # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
3
+ import pickle
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
12
+ from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
13
+
14
+ from flash_attn import flash_attn_qkvpacked_func
15
+
16
+ try:
17
+ from triton.ops.flash_attention import attention as attention_triton
18
+ except ImportError:
19
+ attention_triton = None
20
+
21
+ try:
22
+ import xformers.ops as xops
23
+ except ImportError:
24
+ xops = None
25
+
26
+
27
+ def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
28
+ assert mode in ["fwd", "bwd", "fwd_bwd"]
29
+ f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
30
+ return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
31
+
32
+ def efficiency(flop, time):
33
+ return (flop / time / 10**12) if not math.isnan(time) else 0.0
34
+
35
+
36
+ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
37
+ """
38
+ Arguments:
39
+ qkv: (batch_size, seqlen, 3, nheads, head_dim)
40
+ dropout_p: float
41
+ Output:
42
+ output: (batch_size, seqlen, nheads, head_dim)
43
+ """
44
+ batch_size, seqlen, _, nheads, d = qkv.shape
45
+ q, k, v = qkv.unbind(dim=2)
46
+ q = rearrange(q, 'b t h d -> (b h) t d')
47
+ k = rearrange(k, 'b s h d -> (b h) d s')
48
+ softmax_scale = 1.0 / math.sqrt(d)
49
+ # Preallocate attn_weights for `baddbmm`
50
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
51
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
52
+ '(b h) t s -> b h t s', h=nheads)
53
+ if causal:
54
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
55
+ # So we have to construct the mask in float
56
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
57
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
58
+ scores = scores + causal_mask.to(dtype=scores.dtype)
59
+ attention = torch.softmax(scores, dim=-1)
60
+ attention_drop = F.dropout(attention, dropout_p)
61
+ output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
62
+ return output.to(dtype=qkv.dtype)
63
+
64
+
65
+ def time_fwd_bwd(func, *args, **kwargs):
66
+ time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
67
+ return time_f[1].mean, time_b[1].mean
68
+
69
+
70
+ repeats = 30
71
+ device = 'cuda'
72
+ dtype = torch.float16
73
+
74
+ bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
75
+ causal_vals = [False, True]
76
+ headdim_vals = [64, 128]
77
+ dim = 2048
78
+ dropout_p = 0.0
79
+
80
+ methods = (["Flash2", "Pytorch"]
81
+ + (["Triton"] if attention_triton is not None else [])
82
+ + (["xformers.c"] if xops is not None else [])
83
+ + (["xformers.f"] if xops is not None else []))
84
+
85
+ time_f = {}
86
+ time_b = {}
87
+ time_f_b = {}
88
+ speed_f = {}
89
+ speed_b = {}
90
+ speed_f_b = {}
91
+ for causal in causal_vals:
92
+ for headdim in headdim_vals:
93
+ for batch_size, seqlen in bs_seqlen_vals:
94
+ config = (causal, headdim, batch_size, seqlen)
95
+ nheads = dim // headdim
96
+ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
97
+ requires_grad=True)
98
+ f, b = time_fwd_bwd(
99
+ flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
100
+ )
101
+ time_f[config, "Flash2"] = f
102
+ time_b[config, "Flash2"] = b
103
+
104
+ try:
105
+ qkv = qkv.detach().requires_grad_(True)
106
+ f, b = time_fwd_bwd(
107
+ attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
108
+ )
109
+ except: # Skip if OOM
110
+ f, b = float('nan'), float('nan')
111
+ time_f[config, "Pytorch"] = f
112
+ time_b[config, "Pytorch"] = b
113
+
114
+ if attention_triton is not None:
115
+ q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
116
+ requires_grad=True) for _ in range(3)]
117
+ # Try both values of sequence_parallel and pick the faster one
118
+ try:
119
+ f, b = time_fwd_bwd(
120
+ attention_triton, q, k, v, causal, headdim**(-0.5),
121
+ False, repeats=repeats, verbose=False
122
+ )
123
+ except:
124
+ f, b = float('nan'), float('inf')
125
+ try:
126
+ _, b0 = time_fwd_bwd(
127
+ attention_triton, q, k, v, causal, headdim**(-0.5),
128
+ True, repeats=repeats, verbose=False
129
+ )
130
+ except:
131
+ b0 = float('inf')
132
+ time_f[config, "Triton"] = f
133
+ time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')
134
+
135
+ if xops is not None:
136
+ q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
137
+ requires_grad=True) for _ in range(3)]
138
+ f, b = time_fwd_bwd(
139
+ xops.memory_efficient_attention, q, k, v,
140
+ attn_bias=xops.LowerTriangularMask() if causal else None,
141
+ op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
142
+ )
143
+ time_f[config, "xformers.c"] = f
144
+ time_b[config, "xformers.c"] = b
145
+
146
+ if xops is not None:
147
+ q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
148
+ requires_grad=True) for _ in range(3)]
149
+ f, b = time_fwd_bwd(
150
+ xops.memory_efficient_attention, q, k, v,
151
+ attn_bias=xops.LowerTriangularMask() if causal else None,
152
+ op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
153
+ )
154
+ time_f[config, "xformers.f"] = f
155
+ time_b[config, "xformers.f"] = b
156
+
157
+ print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
158
+ for method in methods:
159
+ time_f_b[config, method] = time_f[config, method] + time_b[config, method]
160
+ speed_f[config, method] = efficiency(
161
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
162
+ time_f[config, method]
163
+ )
164
+ speed_b[config, method] = efficiency(
165
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
166
+ time_b[config, method]
167
+ )
168
+ speed_f_b[config, method] = efficiency(
169
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
170
+ time_f_b[config, method]
171
+ )
172
+ print(
173
+ f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
174
+ f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
175
+ f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
176
+ )
177
+
178
+
179
+ # with open('flash2_attn_time.plk', 'wb') as fp:
180
+ # pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
benchmark_flash_attention_fp8.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install the newest triton version with
2
+ # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
3
+ import pickle
4
+ import math
5
+ import time
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, repeat
11
+
12
+ from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
13
+ from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
14
+
15
+ from flash_attn import flash_attn_qkvpacked_func
16
+ from flash_attn_interface import flash_attn_func
17
+
18
+ try:
19
+ from triton_fused_attention import attention as attention_triton
20
+ except ImportError:
21
+ attention_triton = None
22
+
23
+ try:
24
+ import xformers.ops as xops
25
+ except ImportError:
26
+ xops = None
27
+
28
+ try:
29
+ import cudnn
30
+ except ImportError:
31
+ cudnn = None
32
+
33
+
34
+ def convert_to_cudnn_type(torch_type):
35
+ if torch_type == torch.float16:
36
+ return cudnn.data_type.HALF
37
+ elif torch_type == torch.bfloat16:
38
+ return cudnn.data_type.BFLOAT16
39
+ elif torch_type == torch.float32:
40
+ return cudnn.data_type.FLOAT
41
+ elif torch_type == torch.int32:
42
+ return cudnn.data_type.INT32
43
+ elif torch_type == torch.int64:
44
+ return cudnn.data_type.INT64
45
+ elif torch_type == torch.float8_e4m3fn:
46
+ return cudnn.data_type.FP8_E4M3
47
+ elif torch_type == torch.float8_e4m3fn:
48
+ return cudnn.data_type.FP8_E5M2
49
+ else:
50
+ raise ValueError("Unsupported tensor data type.")
51
+
52
+ def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):
53
+ b, _, _, nheads, headdim = qkv.shape
54
+ assert cudnn is not None, 'CUDNN is not available'
55
+ o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device)
56
+ o_gpu_transposed = torch.as_strided(
57
+ o_gpu,
58
+ [b, nheads, seqlen_q, headdim],
59
+ [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
60
+ )
61
+ stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device)
62
+ amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
63
+ amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
64
+ graph = cudnn.pygraph(
65
+ io_data_type=convert_to_cudnn_type(qkv.dtype),
66
+ intermediate_data_type=cudnn.data_type.FLOAT,
67
+ compute_data_type=cudnn.data_type.FLOAT,
68
+ )
69
+ new_q = torch.as_strided(
70
+ qkv,
71
+ [b, nheads, seqlen_q, headdim],
72
+ [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
73
+ storage_offset=0,
74
+ )
75
+ q = graph.tensor(
76
+ name = "Q",
77
+ dim = list(new_q.shape),
78
+ stride = list(new_q.stride()),
79
+ data_type=convert_to_cudnn_type(qkv.dtype)
80
+ )
81
+ new_k = torch.as_strided(
82
+ qkv,
83
+ [b, nheads, seqlen_k, headdim],
84
+ [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
85
+ storage_offset=nheads * headdim,
86
+ )
87
+ k = graph.tensor(
88
+ name = "K",
89
+ dim = list(new_k.shape),
90
+ stride = list(new_k.stride()),
91
+ data_type=convert_to_cudnn_type(qkv.dtype)
92
+ )
93
+ new_v = torch.as_strided(
94
+ qkv,
95
+ [b, nheads, seqlen_k, headdim],
96
+ [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
97
+ storage_offset=nheads * headdim * 2,
98
+ )
99
+ v = graph.tensor(
100
+ name = "V",
101
+ dim = list(new_v.shape),
102
+ stride = list(new_v.stride()),
103
+ data_type=convert_to_cudnn_type(qkv.dtype)
104
+ )
105
+
106
+ def get_default_scale_tensor():
107
+ return graph.tensor(
108
+ dim = [1, 1, 1, 1],
109
+ stride = [1, 1, 1, 1],
110
+ data_type=cudnn.data_type.FLOAT
111
+ )
112
+
113
+ default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
114
+ descale_q = get_default_scale_tensor()
115
+ descale_k = get_default_scale_tensor()
116
+ descale_v = get_default_scale_tensor()
117
+ descale_s = get_default_scale_tensor()
118
+ scale_s = get_default_scale_tensor()
119
+ scale_o = get_default_scale_tensor()
120
+
121
+ o, _, amax_s, amax_o = graph.sdpa_fp8(
122
+ q=q,
123
+ k=k,
124
+ v=v,
125
+ descale_q=descale_q,
126
+ descale_k=descale_k,
127
+ descale_v=descale_v,
128
+ descale_s=descale_s,
129
+ scale_s=scale_s,
130
+ scale_o=scale_o,
131
+ is_inference=True,
132
+ attn_scale=1.0 / math.sqrt(headdim),
133
+ use_causal_mask=causal,
134
+ name="sdpa",
135
+ )
136
+
137
+ o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())
138
+
139
+ amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
140
+ amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
141
+ # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
142
+
143
+ graph.validate()
144
+ graph.build_operation_graph()
145
+ graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
146
+ graph.check_support()
147
+ graph.build_plans()
148
+
149
+ variant_pack = {
150
+ q: new_q,
151
+ k: new_k,
152
+ v: new_v,
153
+ descale_q: default_scale_gpu,
154
+ descale_k: default_scale_gpu,
155
+ descale_v: default_scale_gpu,
156
+ descale_s: default_scale_gpu,
157
+ scale_s: default_scale_gpu,
158
+ scale_o: default_scale_gpu,
159
+ o: o_gpu_transposed,
160
+ amax_s: amax_s_gpu,
161
+ amax_o: amax_o_gpu,
162
+ }
163
+
164
+ workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
165
+
166
+ def run(*args, **kwargs):
167
+ graph.execute(variant_pack, workspace)
168
+ return o_gpu, amax_o_gpu
169
+
170
+ return run
171
+
172
+
173
+ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
174
+ """
175
+ Arguments:
176
+ qkv: (batch_size, seqlen, 3, nheads, head_dim)
177
+ dropout_p: float
178
+ Output:
179
+ output: (batch_size, seqlen, nheads, head_dim)
180
+ """
181
+ batch_size, seqlen, _, nheads, d = qkv.shape
182
+ q, k, v = qkv.unbind(dim=2)
183
+ q = rearrange(q, 'b t h d -> (b h) t d')
184
+ k = rearrange(k, 'b s h d -> (b h) d s')
185
+ softmax_scale = 1.0 / math.sqrt(d)
186
+ # Preallocate attn_weights for `baddbmm`
187
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
188
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
189
+ '(b h) t s -> b h t s', h=nheads)
190
+ if causal:
191
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
192
+ # So we have to construct the mask in float
193
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
194
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
195
+ scores = scores + causal_mask.to(dtype=scores.dtype)
196
+ attention = torch.softmax(scores, dim=-1)
197
+ attention_drop = F.dropout(attention, dropout_p)
198
+ output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
199
+ return output.to(dtype=qkv.dtype)
200
+
201
+ def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
202
+ assert mode in ["fwd", "bwd", "fwd_bwd"]
203
+ f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
204
+ return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
205
+
206
+ def efficiency(flop, time):
207
+ return (flop / time / 10**12) if not math.isnan(time) else 0.0
208
+
209
+ def time_fwd(func, *args, **kwargs):
210
+ time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
211
+ time_f = benchmark_forward(func, *args, **kwargs)
212
+ return time_f[1].mean
213
+
214
+
215
+ torch.manual_seed(0)
216
+
217
+ repeats = 30
218
+ device = 'cuda'
219
+ # dtype = torch.float16
220
+ dtype = torch.float8_e4m3fn
221
+
222
+ bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
223
+ # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
224
+ # bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2), (4, 4224), (2, 8448), (1, 8448 * 2)]
225
+ # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
226
+ causal_vals = [False, True]
227
+ headdim_vals = [128]
228
+ dim = 2048
229
+ # dim = 256
230
+ dropout_p = 0.0
231
+
232
+ methods = (["Pytorch", "Flash3", "cuDNN"]
233
+ # + (["Triton"] if attention_triton is not None else [])
234
+ # + (["xformers.c"] if xops is not None else [])
235
+ # + (["xformers.f"] if xops is not None else [])
236
+ )
237
+
238
+ time_f = {}
239
+ time_b = {}
240
+ time_f_b = {}
241
+ speed_f = {}
242
+ speed_b = {}
243
+ speed_f_b = {}
244
+ for causal in causal_vals:
245
+ for headdim in headdim_vals:
246
+ for batch_size, seqlen in bs_seqlen_vals:
247
+ torch.cuda.empty_cache()
248
+ config = (causal, headdim, batch_size, seqlen)
249
+ nheads = dim // headdim
250
+ q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=False) for _ in range(3)]
251
+
252
+ qkv = torch.stack([q, k, v], dim=2)
253
+ qkv = qkv.to(torch.float16)
254
+ f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
255
+ time_f[config, "Pytorch"] = f
256
+ res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
257
+
258
+ if attention_triton is not None:
259
+ q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
260
+ k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
261
+ v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
262
+ scale = 1 / math.sqrt(headdim)
263
+ f = time_fwd(
264
+ attention_triton, q_transposed, k_transposed, v_transposed,
265
+ causal, scale, repeats=5, verbose=False, desc='Triton'
266
+ )
267
+ f = time_fwd(
268
+ attention_triton, q_transposed, k_transposed, v_transposed,
269
+ causal, scale, repeats=repeats, verbose=False, desc='Triton'
270
+ )
271
+ time_f[config, "Triton"] = f
272
+ res = attention_triton(
273
+ q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
274
+ causal, scale
275
+ ).half().transpose(1, 2)
276
+ torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
277
+
278
+ # out = torch.empty_like(q)
279
+ q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
280
+ f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
281
+
282
+ # res = flash_attn_func(q, k, v, causal=causal)
283
+ # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)
284
+
285
+ time_f[config, "Flash3"] = f
286
+
287
+ if cudnn is not None:
288
+ qkv_fp8 = qkv.to(dtype)
289
+ time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
290
+ f = time_fwd(
291
+ cudnn_spda_setup(
292
+ qkv_fp8, seqlen, seqlen,
293
+ causal=causal
294
+ ),
295
+ repeats=repeats, verbose=False
296
+ )
297
+ time_f[config, "cuDNN"] = f
298
+ # res, amax_o = cudnn_spda_setup(
299
+ # qkv_fp8, seqlen, seqlen,
300
+ # causal=causal
301
+ # )()
302
+ # res = res.half()
303
+ # TODO: CUDNN has numerics issues when
304
+ # num_heads=16, dim=128, seq_len=1024, batch_size=2
305
+ # or larger sizes.
306
+ # res_cpu = res.cpu().reshape(-1)
307
+ # res_baseline_cpu = res_baseline.cpu().reshape(-1)
308
+ # print(amax_o)
309
+ # print(res)
310
+ # print(res_baseline)
311
+ # for i in range(len(res_cpu)):
312
+ # item = res_cpu[i]
313
+ # item_baseline = res_baseline_cpu[i]
314
+ # if abs(item - item_baseline) > 0.5:
315
+ # print(i)
316
+ # print(item)
317
+ # print(item_baseline)
318
+ # torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)
319
+
320
+ print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
321
+ for method in methods:
322
+ speed_f[config, method] = efficiency(
323
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
324
+ time_f[config, method]
325
+ )
326
+ #print (time_f[config,method])
327
+ print(
328
+ f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, "
329
+ )
330
+
331
+
332
+ # with open('flash3_attn_time.plk', 'wb') as fp:
333
+ # pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
benchmark_gemm.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torch.utils.benchmark as benchmark
4
+
5
+ from triton.testing import do_bench
6
+
7
+
8
+ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
9
+ """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
10
+ if verbose:
11
+ print(desc, '- Forward pass')
12
+ t = benchmark.Timer(
13
+ stmt='fn(*inputs, **kwinputs)',
14
+ globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
15
+ num_threads=torch.get_num_threads(),
16
+ )
17
+ m = t.timeit(repeats)
18
+ if verbose:
19
+ print(m)
20
+ return t, m
21
+
22
+
23
+ torch.manual_seed(0)
24
+ repeats = 30
25
+ dtype = torch.float16
26
+ device = 'cuda'
27
+ verbose = False
28
+ m, n = 8192, 8192
29
+
30
+ tflops_matmul = {}
31
+ tflops_matmul1 = {}
32
+ for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
33
+ a = torch.randn(m, k, device=device, dtype=dtype)
34
+ b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
35
+ nFLOPS_matmul = 2 * m * n * k
36
+ time.sleep(2) # to reduce power throttling
37
+ timing = benchmark_forward(torch.matmul, a, b, desc='cuBLAS', verbose=verbose, repeats=repeats)[1]
38
+ tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12
39
+ print(f'[torch.utils.benchmark] cuBLAS, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')
40
+ time.sleep(2) # to reduce power throttling
41
+ ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
42
+ tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9
43
+ print(f'[triton.test.do_bench] cuBLAS, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')
bert.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
+
8
+ import logging
9
+ import re
10
+ from collections import OrderedDict
11
+ from collections.abc import Sequence
12
+ from functools import partial
13
+ from typing import Any, Mapping
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import rearrange
19
+ from transformers import BertConfig, PretrainedConfig
20
+ from transformers.models.bert.modeling_bert import (
21
+ BaseModelOutputWithPoolingAndCrossAttentions,
22
+ BertForPreTrainingOutput,
23
+ )
24
+
25
+ from flash_attn.bert_padding import (
26
+ index_first_axis,
27
+ index_first_axis_residual,
28
+ pad_input,
29
+ unpad_input,
30
+ )
31
+ from flash_attn.modules.block import Block
32
+ from flash_attn.modules.embedding import BertEmbeddings
33
+ from flash_attn.modules.mha import MHA
34
+ from flash_attn.modules.mlp import FusedMLP, Mlp
35
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
36
+
37
+ try:
38
+ from flash_attn.ops.fused_dense import FusedDense
39
+ except ImportError:
40
+ FusedDense = None
41
+
42
+ try:
43
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
44
+ except ImportError:
45
+ layer_norm_fn = None
46
+
47
+
48
+ try:
49
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
50
+ except ImportError:
51
+ CrossEntropyLoss = None
52
+
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
58
+ use_flash_attn = getattr(config, "use_flash_attn", False)
59
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
60
+ rotary_kwargs = {}
61
+ if config.position_embedding_type == "rotary":
62
+ rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
63
+ rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
64
+ rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
65
+ rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
66
+ mixer_cls = partial(
67
+ MHA,
68
+ num_heads=config.num_attention_heads,
69
+ cross_attn=cross_attn,
70
+ dropout=config.attention_probs_dropout_prob,
71
+ causal=False,
72
+ fused_bias_fc=fused_bias_fc,
73
+ use_flash_attn=use_flash_attn,
74
+ return_residual=return_residual,
75
+ **rotary_kwargs,
76
+ )
77
+ return mixer_cls
78
+
79
+
80
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
81
+ inner_dim = config.intermediate_size
82
+ fused_mlp = getattr(config, "fused_mlp", False)
83
+ if fused_mlp:
84
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
85
+ "fused_mlp only " "supports approximate gelu"
86
+ )
87
+ if not fused_mlp:
88
+ approximate = (
89
+ "tanh"
90
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
91
+ else "none"
92
+ )
93
+ mlp_cls = partial(
94
+ Mlp,
95
+ hidden_features=inner_dim,
96
+ activation=partial(F.gelu, approximate=approximate),
97
+ return_residual=return_residual,
98
+ )
99
+ else:
100
+ if FusedMLP is None:
101
+ raise ImportError("fused_dense is not installed")
102
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
103
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
104
+ if isinstance(mlp_checkpoint_lvl, Sequence):
105
+ assert layer_idx is not None
106
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
107
+ mlp_cls = partial(
108
+ FusedMLP,
109
+ hidden_features=inner_dim,
110
+ checkpoint_lvl=mlp_checkpoint_lvl,
111
+ return_residual=return_residual,
112
+ )
113
+ return mlp_cls
114
+
115
+
116
+ def create_block(config, layer_idx=None):
117
+ last_layer_subset = getattr(config, "last_layer_subset", False)
118
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
119
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
120
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
121
+ # one layer) so we just choose not to return residual in this case.
122
+ return_residual = not cross_attn
123
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
124
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
125
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
126
+ block = Block(
127
+ config.hidden_size,
128
+ mixer_cls,
129
+ mlp_cls,
130
+ norm_cls=norm_cls,
131
+ prenorm=False,
132
+ resid_dropout1=config.hidden_dropout_prob,
133
+ resid_dropout2=config.hidden_dropout_prob,
134
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
135
+ return_residual=return_residual,
136
+ )
137
+ return block
138
+
139
+
140
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
141
+ def _init_weights(module, initializer_range=0.02):
142
+ if isinstance(module, nn.Linear):
143
+ nn.init.normal_(module.weight, std=initializer_range)
144
+ if module.bias is not None:
145
+ nn.init.zeros_(module.bias)
146
+ elif isinstance(module, nn.Embedding):
147
+ nn.init.normal_(module.weight, std=initializer_range)
148
+ if module.padding_idx is not None:
149
+ nn.init.zeros_(module.weight[module.padding_idx])
150
+
151
+
152
+ class BertEncoder(nn.Module):
153
+ def __init__(self, config: BertConfig):
154
+ super().__init__()
155
+ self.use_flash_attn = getattr(config, "use_flash_attn", False)
156
+ self.layers = nn.ModuleList(
157
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
158
+ )
159
+
160
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
161
+ """If subset_mask is not None, we only want output for the subset of the sequence.
162
+ This means that we only compute the last layer output for these tokens.
163
+ subset_mask: (batch, seqlen), dtype=torch.bool
164
+ """
165
+ if key_padding_mask is None or not self.use_flash_attn:
166
+ mixer_kwargs = (
167
+ {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
168
+ )
169
+ for layer in self.layers:
170
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
171
+ if subset_mask is not None:
172
+ hidden_states = hidden_states[subset_mask]
173
+ else:
174
+ batch, seqlen = hidden_states.shape[:2]
175
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
176
+ hidden_states, key_padding_mask
177
+ )
178
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
179
+ if subset_mask is None:
180
+ for layer in self.layers:
181
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
182
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
183
+ else:
184
+ for layer in self.layers[:-1]:
185
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
186
+ if key_padding_mask is not None:
187
+ subset_idx = torch.nonzero(
188
+ subset_mask[key_padding_mask], as_tuple=False
189
+ ).flatten()
190
+ subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
191
+ subset_cu_seqlens = F.pad(
192
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
193
+ )
194
+ else:
195
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
196
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
197
+ subset_cu_seqlens = F.pad(
198
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
199
+ )
200
+ hidden_states_subset, hidden_states = index_first_axis_residual(
201
+ hidden_states, subset_idx
202
+ )
203
+ # It's ok to set max_seqlen_q to be much larger
204
+ mixer_kwargs = {
205
+ "x_kv": hidden_states,
206
+ "cu_seqlens": subset_cu_seqlens,
207
+ "max_seqlen": max_seqlen_in_batch,
208
+ "cu_seqlens_k": cu_seqlens,
209
+ "max_seqlen_k": max_seqlen_in_batch,
210
+ }
211
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
212
+ return hidden_states
213
+
214
+
215
+ class BertPooler(nn.Module):
216
+ def __init__(self, config):
217
+ super().__init__()
218
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
219
+ if fused_bias_fc and FusedDense is None:
220
+ raise ImportError("fused_dense is not installed")
221
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
222
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
223
+ self.activation = nn.Tanh()
224
+
225
+ def forward(self, hidden_states, pool=True):
226
+ # We "pool" the model by simply taking the hidden state corresponding
227
+ # to the first token.
228
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
229
+ pooled_output = self.dense(first_token_tensor)
230
+ pooled_output = self.activation(pooled_output)
231
+ return pooled_output
232
+
233
+
234
+ class BertPredictionHeadTransform(nn.Module):
235
+ def __init__(self, config):
236
+ super().__init__()
237
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
238
+ if fused_bias_fc and FusedDense is None:
239
+ raise ImportError("fused_dense is not installed")
240
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
241
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
242
+ raise ImportError("Triton is not installed")
243
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
244
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
245
+ approximate = (
246
+ "tanh"
247
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
248
+ else "none"
249
+ )
250
+ self.transform_act_fn = nn.GELU(approximate=approximate)
251
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
252
+
253
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
254
+ hidden_states = self.dense(hidden_states)
255
+ hidden_states = self.transform_act_fn(hidden_states)
256
+ if not self.fused_dropout_add_ln:
257
+ hidden_states = self.layer_norm(hidden_states)
258
+ else:
259
+ hidden_states = layer_norm_fn(
260
+ hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
261
+ )
262
+ return hidden_states
263
+
264
+
265
+ class BertLMPredictionHead(nn.Module):
266
+ def __init__(self, config):
267
+ super().__init__()
268
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
269
+ if fused_bias_fc and FusedDense is None:
270
+ raise ImportError("fused_dense is not installed")
271
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
272
+
273
+ self.transform = BertPredictionHeadTransform(config)
274
+
275
+ # The output weights are the same as the input embeddings, but there is
276
+ # an output-only bias for each token.
277
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
278
+
279
+ def forward(self, hidden_states):
280
+ hidden_states = self.transform(hidden_states)
281
+ hidden_states = self.decoder(hidden_states)
282
+ return hidden_states
283
+
284
+
285
+ class BertPreTrainingHeads(nn.Module):
286
+ def __init__(self, config):
287
+ super().__init__()
288
+ self.predictions = BertLMPredictionHead(config)
289
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
290
+
291
+ def forward(self, sequence_output, pooled_output):
292
+ prediction_scores = self.predictions(sequence_output)
293
+ seq_relationship_score = self.seq_relationship(pooled_output)
294
+ return prediction_scores, seq_relationship_score
295
+
296
+
297
+ class BertPreTrainedModel(nn.Module):
298
+ """An abstract class to handle weights initialization and
299
+ a simple interface for dowloading and loading pretrained models.
300
+ """
301
+
302
+ def __init__(self, config, *inputs, **kwargs):
303
+ super().__init__()
304
+ if not isinstance(config, BertConfig):
305
+ raise ValueError(
306
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
307
+ "To create a model from a Google pretrained model use "
308
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
309
+ self.__class__.__name__, self.__class__.__name__
310
+ )
311
+ )
312
+ self.config = config
313
+
314
+ @classmethod
315
+ def from_pretrained(cls, model_name, config, *inputs, **kwargs):
316
+ """
317
+ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
318
+ Download and cache the pre-trained model file if needed.
319
+
320
+ Params:
321
+ pretrained_model_name_or_path: either:
322
+ - a path or url to a pretrained model archive containing:
323
+ . `bert_config.json` a configuration file for the model
324
+ . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
325
+ - a path or url to a pretrained model archive containing:
326
+ . `bert_config.json` a configuration file for the model
327
+ . `model.chkpt` a TensorFlow checkpoint
328
+ *inputs, **kwargs: additional input for the specific Bert class
329
+ (ex: num_labels for BertForSequenceClassification)
330
+ """
331
+ # Instantiate model.
332
+ model = cls(config, *inputs, **kwargs)
333
+ load_return = model.load_state_dict(
334
+ remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
335
+ )
336
+ logger.info(load_return)
337
+ return model
338
+
339
+
340
+ class BertModel(BertPreTrainedModel):
341
+ def __init__(self, config: BertConfig, add_pooling_layer=True):
342
+ super().__init__(config)
343
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
344
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
345
+ config.vocab_size += self.pad_vocab_size_multiple - (
346
+ config.vocab_size % self.pad_vocab_size_multiple
347
+ )
348
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
349
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
350
+ raise ImportError("Triton is not installed")
351
+ assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
352
+
353
+ self.embeddings = BertEmbeddings(
354
+ config.hidden_size,
355
+ config.vocab_size,
356
+ config.max_position_embeddings,
357
+ config.type_vocab_size,
358
+ padding_idx=config.pad_token_id,
359
+ )
360
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
361
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
362
+ self.encoder = BertEncoder(config)
363
+ self.pooler = BertPooler(config) if add_pooling_layer else None
364
+
365
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
366
+
367
+ def forward(
368
+ self,
369
+ input_ids,
370
+ position_ids=None,
371
+ token_type_ids=None,
372
+ attention_mask=None,
373
+ masked_tokens_mask=None,
374
+ ):
375
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
376
+ we only want the output for the masked tokens. This means that we only compute the last
377
+ layer output for these tokens.
378
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
379
+ """
380
+ hidden_states = self.embeddings(
381
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
382
+ )
383
+ # TD [2022-12:18]: Don't need to force residual in fp32
384
+ # BERT puts embedding LayerNorm before embedding dropout.
385
+ if not self.fused_dropout_add_ln:
386
+ hidden_states = self.emb_ln(hidden_states)
387
+ else:
388
+ hidden_states = layer_norm_fn(
389
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
390
+ )
391
+ hidden_states = self.emb_drop(hidden_states)
392
+
393
+ if masked_tokens_mask is not None:
394
+ batch_size, seqlen = input_ids.shape[:2]
395
+ # We also need the first column for the CLS token
396
+ first_col_mask = torch.zeros(
397
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
398
+ )
399
+ first_col_mask[:, 0] = True
400
+ subset_mask = masked_tokens_mask | first_col_mask
401
+ else:
402
+ subset_mask = None
403
+
404
+ sequence_output = self.encoder(
405
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
406
+ )
407
+
408
+ if masked_tokens_mask is None:
409
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
410
+ else:
411
+ # TD [2022-03-01]: the indexing here is very tricky.
412
+ if attention_mask is not None:
413
+ subset_idx = subset_mask[attention_mask]
414
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
415
+ sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
416
+ else:
417
+ pool_input = sequence_output[first_col_mask[subset_mask]]
418
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
419
+ pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
420
+
421
+ return BaseModelOutputWithPoolingAndCrossAttentions(
422
+ last_hidden_state=sequence_output,
423
+ pooler_output=pooled_output,
424
+ )
425
+
426
+
427
+ class BertForPreTraining(BertPreTrainedModel):
428
+ def __init__(self, config: BertConfig):
429
+ super().__init__(config)
430
+ # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
431
+ # (around 15%) to the classifier heads.
432
+ self.dense_seq_output = getattr(config, "dense_seq_output", False)
433
+ # If last_layer_subset, we only need the compute the last layer for a subset of tokens
434
+ # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
435
+ self.last_layer_subset = getattr(config, "last_layer_subset", False)
436
+ if self.last_layer_subset:
437
+ assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
438
+ use_xentropy = getattr(config, "use_xentropy", False)
439
+ if use_xentropy and CrossEntropyLoss is None:
440
+ raise ImportError("xentropy_cuda is not installed")
441
+ loss_cls = (
442
+ nn.CrossEntropyLoss
443
+ if not use_xentropy
444
+ else partial(CrossEntropyLoss, inplace_backward=True)
445
+ )
446
+
447
+ self.bert = BertModel(config)
448
+ self.cls = BertPreTrainingHeads(config)
449
+ self.mlm_loss = loss_cls(ignore_index=0)
450
+ self.nsp_loss = loss_cls(ignore_index=-1)
451
+
452
+ # Initialize weights and apply final processing
453
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
454
+ self.tie_weights()
455
+
456
+ def tie_weights(self):
457
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
458
+
459
+ def forward(
460
+ self,
461
+ input_ids,
462
+ position_ids=None,
463
+ token_type_ids=None,
464
+ attention_mask=None,
465
+ labels=None,
466
+ next_sentence_label=None,
467
+ ):
468
+ """
469
+ If labels are provided, they must be 0 for masked out tokens (as specified in the attention
470
+ mask).
471
+ Outputs:
472
+ if `labels` and `next_sentence_label` are not `None`:
473
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
474
+ sentence classification loss.
475
+ if `labels` or `next_sentence_label` is `None`:
476
+ Outputs a tuple comprising
477
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
478
+ - the next sentence classification logits of shape [batch_size, 2].
479
+
480
+ """
481
+ masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
482
+ outputs = self.bert(
483
+ input_ids,
484
+ position_ids=position_ids,
485
+ token_type_ids=token_type_ids,
486
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
487
+ masked_tokens_mask=masked_tokens_mask,
488
+ )
489
+ sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
490
+ if self.dense_seq_output and labels is not None:
491
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
492
+ if not self.last_layer_subset:
493
+ sequence_output = index_first_axis(
494
+ rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
495
+ )
496
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
497
+
498
+ total_loss = None
499
+ if labels is not None and next_sentence_label is not None:
500
+ if (
501
+ self.dense_seq_output and labels is not None
502
+ ): # prediction_scores are already flattened
503
+ masked_lm_loss = self.mlm_loss(
504
+ prediction_scores, labels.flatten()[masked_token_idx]
505
+ )
506
+ else:
507
+ masked_lm_loss = self.mlm_loss(
508
+ rearrange(prediction_scores, "... v -> (...) v"),
509
+ rearrange(labels, "... -> (...)"),
510
+ )
511
+ next_sentence_loss = self.nsp_loss(
512
+ rearrange(seq_relationship_score, "... t -> (...) t"),
513
+ rearrange(next_sentence_label, "... -> (...)"),
514
+ )
515
+ total_loss = masked_lm_loss.float() + next_sentence_loss.float()
516
+
517
+ return BertForPreTrainingOutput(
518
+ loss=total_loss,
519
+ prediction_logits=prediction_scores,
520
+ seq_relationship_logits=seq_relationship_score,
521
+ )
522
+
523
+
524
+ def remap_state_dict(state_dict, config: PretrainedConfig):
525
+ """
526
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
527
+ """
528
+
529
+ # LayerNorm
530
+ def key_mapping_ln_gamma_beta(key):
531
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
532
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
533
+ return key
534
+
535
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
536
+
537
+ # Layers
538
+ def key_mapping_layers(key):
539
+ return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
540
+
541
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
542
+
543
+ # LayerNorm
544
+ def key_mapping_ln(key):
545
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
546
+ key = re.sub(
547
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
548
+ r"bert.encoder.layers.\1.norm1.\2",
549
+ key,
550
+ )
551
+ key = re.sub(
552
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
553
+ r"bert.encoder.layers.\1.norm2.\2",
554
+ key,
555
+ )
556
+ key = re.sub(
557
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
558
+ r"cls.predictions.transform.layer_norm.\1",
559
+ key,
560
+ )
561
+ return key
562
+
563
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
564
+
565
+ # MLP
566
+ def key_mapping_mlp(key):
567
+ key = re.sub(
568
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
569
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
570
+ key,
571
+ )
572
+ key = re.sub(
573
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
574
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
575
+ key,
576
+ )
577
+ return key
578
+
579
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
580
+
581
+ # Attention
582
+ last_layer_subset = getattr(config, "last_layer_subset", False)
583
+ for d in range(config.num_hidden_layers):
584
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
585
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
586
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
587
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
588
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
589
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
590
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
591
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
592
+ [Wq, Wk, Wv], dim=0
593
+ )
594
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
595
+ else:
596
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
597
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
598
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
599
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
600
+
601
+ def key_mapping_attn(key):
602
+ return re.sub(
603
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
604
+ r"bert.encoder.layers.\1.mixer.out_proj.\2",
605
+ key,
606
+ )
607
+
608
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
609
+
610
+ def key_mapping_decoder_bias(key):
611
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
612
+
613
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
614
+
615
+ # Word embedding
616
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
617
+ if pad_vocab_size_multiple > 1:
618
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
619
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
620
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
621
+ )
622
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
623
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
624
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
625
+ )
626
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
627
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
628
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
629
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
630
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
631
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
632
+ )
633
+
634
+ return state_dict
635
+
636
+
637
+ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
638
+ """
639
+ Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
640
+
641
+ This function is meant to be the inverse of remap_state_dict.
642
+ """
643
+ # Word embedding
644
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
645
+ if pad_vocab_size_multiple > 1:
646
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
647
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
648
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
649
+ # unpad embeddings
650
+ state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
651
+ : config.orig_vocab_size, :
652
+ ]
653
+ state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
654
+ state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
655
+
656
+ for d in range(config.num_hidden_layers):
657
+ last_layer_subset = getattr(config, "last_layer_subset", False)
658
+ if not last_layer_subset or d != (config.num_hidden_layers - 1):
659
+ Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
660
+ Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
661
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
662
+ : Wqkv_weights.shape[0] // 3, :
663
+ ]
664
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
665
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
666
+ ]
667
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
668
+ 2 * Wqkv_weights.shape[0] // 3 :, :
669
+ ]
670
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
671
+ : Wqkv_biases.shape[0] // 3
672
+ ]
673
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
674
+ Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
675
+ ]
676
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
677
+ 2 * Wqkv_biases.shape[0] // 3 :
678
+ ]
679
+ else:
680
+ Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
681
+ Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
682
+ Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
683
+ Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
684
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
685
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
686
+ : Wkv_weights.shape[0] // 2, :
687
+ ]
688
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
689
+ Wkv_weights.shape[0] // 2 :, :
690
+ ]
691
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
692
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
693
+ : Wkv_biases.shape[0] // 2
694
+ ]
695
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
696
+ Wkv_biases.shape[0] // 2 :
697
+ ]
698
+
699
+ def inv_key_mapping_ln(key):
700
+ key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
701
+ key = re.sub(
702
+ r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
703
+ r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
704
+ key,
705
+ )
706
+ key = re.sub(
707
+ r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
708
+ r"bert.encoder.layers.\1.output.LayerNorm.\2",
709
+ key,
710
+ )
711
+ key = re.sub(
712
+ r"cls.predictions.transform.layer_norm.(weight|bias)",
713
+ r"cls.predictions.transform.LayerNorm.\1",
714
+ key,
715
+ )
716
+ return key
717
+
718
+ def inv_key_mapping_ln_gamma_beta(key):
719
+ key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
720
+ key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
721
+ return key
722
+
723
+ def inv_key_mapping_layers(key):
724
+ return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
725
+
726
+ def inv_key_mapping_mlp(key):
727
+ key = re.sub(
728
+ r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
729
+ r"bert.encoder.layer.\1.intermediate.dense.\2",
730
+ key,
731
+ )
732
+ key = re.sub(
733
+ r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
734
+ r"bert.encoder.layer.\1.output.dense.\2",
735
+ key,
736
+ )
737
+ return key
738
+
739
+ def inv_key_mapping_attn(key):
740
+ return re.sub(
741
+ r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
742
+ r"bert.encoder.layer.\1.attention.output.dense.\2",
743
+ key,
744
+ )
745
+
746
+ def inv_key_mapping_decoder_bias(key):
747
+ return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
748
+
749
+ state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
750
+ state_dict = OrderedDict(
751
+ (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
752
+ )
753
+ state_dict = OrderedDict(
754
+ (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
755
+ )
756
+ state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
757
+ state_dict = OrderedDict(
758
+ (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
759
+ )
760
+ state_dict = OrderedDict(
761
+ (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
762
+ )
763
+
764
+ return state_dict
bert_padding.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class IndexFirstAxis(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, input, indices):
11
+ ctx.save_for_backward(indices)
12
+ assert input.ndim >= 2
13
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
14
+ second_dim = other_shape.numel()
15
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
16
+ # return input[indices]
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
19
+ ).reshape(-1, *other_shape)
20
+
21
+ @staticmethod
22
+ def backward(ctx, grad_output):
23
+ (indices,) = ctx.saved_tensors
24
+ assert grad_output.ndim >= 2
25
+ other_shape = grad_output.shape[1:]
26
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
27
+ grad_input = torch.zeros(
28
+ [ctx.first_axis_dim, grad_output.shape[1]],
29
+ device=grad_output.device,
30
+ dtype=grad_output.dtype,
31
+ )
32
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
33
+ # grad_input[indices] = grad_output
34
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
35
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
36
+
37
+
38
+ index_first_axis = IndexFirstAxis.apply
39
+
40
+
41
+ class IndexPutFirstAxis(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(ctx, values, indices, first_axis_dim):
44
+ ctx.save_for_backward(indices)
45
+ assert indices.ndim == 1
46
+ assert values.ndim >= 2
47
+ output = torch.zeros(
48
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
49
+ )
50
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ output[indices] = values
52
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
53
+ return output
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_output):
57
+ (indices,) = ctx.saved_tensors
58
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
59
+ grad_values = grad_output[indices]
60
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
61
+ return grad_values, None, None
62
+
63
+
64
+ index_put_first_axis = IndexPutFirstAxis.apply
65
+
66
+
67
+ class IndexFirstAxisResidual(torch.autograd.Function):
68
+ @staticmethod
69
+ def forward(ctx, input, indices):
70
+ ctx.save_for_backward(indices)
71
+ assert input.ndim >= 2
72
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
73
+ second_dim = other_shape.numel()
74
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
75
+ output = input[indices]
76
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
77
+ # memory format to channel_first. In other words, input might not be contiguous.
78
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
79
+ return output, input.detach()
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output, grad_residual):
83
+ (indices,) = ctx.saved_tensors
84
+ assert grad_output.ndim >= 2
85
+ other_shape = grad_output.shape[1:]
86
+ assert grad_residual.shape[1:] == other_shape
87
+ grad_input = grad_residual
88
+ # grad_input[indices] += grad_output
89
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
90
+ indices = indices.expand_as(grad_output)
91
+ grad_input.scatter_add_(0, indices, grad_output)
92
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
93
+
94
+
95
+ index_first_axis_residual = IndexFirstAxisResidual.apply
96
+
97
+
98
+ def unpad_input(hidden_states, attention_mask):
99
+ """
100
+ Arguments:
101
+ hidden_states: (batch, seqlen, ...)
102
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
103
+ Return:
104
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
105
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
106
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
107
+ max_seqlen_in_batch: int
108
+ """
109
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
110
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
111
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
112
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
113
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
114
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
115
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
116
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
117
+ # so we write custom forward and backward to make it a bit faster.
118
+ return (
119
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
120
+ indices,
121
+ cu_seqlens,
122
+ max_seqlen_in_batch,
123
+ )
124
+
125
+
126
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
127
+ """
128
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
129
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
130
+
131
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
132
+ ```
133
+ [
134
+ [2, 3, 0, 0, 0, 0],
135
+ [3, 2, 0, 0, 0, 0],
136
+ [6, 0, 0, 0, 0, 0]
137
+ ]
138
+ ```
139
+ , which refers to the 3D-attention mask:
140
+ ```
141
+ [
142
+ [
143
+ [1, 0, 0, 0, 0, 0],
144
+ [1, 1, 0, 0, 0, 0],
145
+ [0, 0, 1, 0, 0, 0],
146
+ [0, 0, 1, 1, 0, 0],
147
+ [0, 0, 1, 1, 1, 0],
148
+ [0, 0, 0, 0, 0, 1]
149
+ ],
150
+ [
151
+ [1, 0, 0, 0, 0, 0],
152
+ [1, 1, 0, 0, 0, 0],
153
+ [1, 1, 1, 0, 0, 0],
154
+ [0, 0, 0, 1, 0, 0],
155
+ [0, 0, 0, 1, 1, 0],
156
+ [0, 0, 0, 0, 0, 1]
157
+ ],
158
+ [
159
+ [1, 0, 0, 0, 0, 0],
160
+ [1, 1, 0, 0, 0, 0],
161
+ [1, 1, 1, 0, 0, 0],
162
+ [1, 1, 1, 1, 0, 0],
163
+ [1, 1, 1, 1, 1, 0],
164
+ [1, 1, 1, 1, 1, 1]
165
+ ]
166
+ ]
167
+ ```.
168
+
169
+ Arguments:
170
+ hidden_states: (batch, seqlen, ...)
171
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
172
+ Return:
173
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
174
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
175
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
176
+ max_seqlen_in_batch: int
177
+ """
178
+ length = attention_mask_in_length.sum(dim=-1)
179
+ seqlen = attention_mask_in_length.size(-1)
180
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
181
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
182
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
183
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
184
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
185
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
186
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
187
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
188
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
189
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
190
+ # so we write custom forward and backward to make it a bit faster.
191
+ return (
192
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
193
+ indices,
194
+ cu_seqlens,
195
+ max_seqlen_in_batch,
196
+ )
197
+
198
+
199
+ def pad_input(hidden_states, indices, batch, seqlen):
200
+ """
201
+ Arguments:
202
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
203
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
204
+ batch: int, batch size for the padded sequence.
205
+ seqlen: int, maximum sequence length for the padded sequence.
206
+ Return:
207
+ hidden_states: (batch, seqlen, ...)
208
+ """
209
+ dim = hidden_states.shape[-1]
210
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
211
+ # output[indices] = hidden_states
212
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
213
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
bigcode.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
8
+
9
+
10
+ def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
11
+ """
12
+ Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
13
+ """
14
+
15
+ # Word embedding and position embedding
16
+ def key_mapping_pos_emb(key):
17
+ return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
18
+
19
+ state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
20
+ word_embeddings = state_dict.pop("transformer.wte.weight")
21
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
22
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
23
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
24
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
25
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
26
+ )
27
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
28
+
29
+ # LayerNorm
30
+ def key_mapping_ln(key):
31
+ key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
32
+ key = re.sub(
33
+ r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
34
+ r"transformer.layers.\1.norm\2.\3",
35
+ key,
36
+ )
37
+ return key
38
+
39
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
40
+
41
+ def key_mapping_mlp(key):
42
+ key = re.sub(
43
+ r"^transformer.h.(\d+).mlp.c_fc.weight",
44
+ r"transformer.layers.\1.mlp.fc1.weight",
45
+ key,
46
+ )
47
+ key = re.sub(
48
+ r"^transformer.h.(\d+).mlp.c_proj.weight",
49
+ r"transformer.layers.\1.mlp.fc2.weight",
50
+ key,
51
+ )
52
+ key = re.sub(
53
+ r"^transformer.h.(\d+).mlp.c_fc.bias",
54
+ r"transformer.layers.\1.mlp.fc1.bias",
55
+ key,
56
+ )
57
+ key = re.sub(
58
+ r"^transformer.h.(\d+).mlp.c_proj.bias",
59
+ r"transformer.layers.\1.mlp.fc2.bias",
60
+ key,
61
+ )
62
+ return key
63
+
64
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
65
+
66
+ # TODO: add support for multi-head attention
67
+ assert config.multi_query, "Only multi-query attention is supported"
68
+
69
+ # Attention
70
+ for d in range(config.num_hidden_layers):
71
+ embed_dim = config.n_embd
72
+ head_dim = embed_dim // config.n_head
73
+
74
+ c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
75
+ # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
76
+ # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
77
+ # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
78
+ # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
79
+ q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
80
+ # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
81
+ k = torch.tile(k, (config.n_head, 1))
82
+ v = torch.tile(v, (config.n_head, 1))
83
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)
84
+
85
+ # same deal with the bias
86
+ c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
87
+ # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
88
+ q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
89
+ # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
90
+ k = torch.tile(k, (config.n_head,))
91
+ v = torch.tile(v, (config.n_head,))
92
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)
93
+
94
+ def key_mapping_attn(key):
95
+ key = re.sub(
96
+ r"^transformer.h.(\d+).attn.c_proj.weight",
97
+ r"transformer.layers.\1.mixer.out_proj.weight",
98
+ key,
99
+ )
100
+ key = re.sub(
101
+ r"^transformer.h.(\d+).attn.c_proj.bias",
102
+ r"transformer.layers.\1.mixer.out_proj.bias",
103
+ key,
104
+ )
105
+ return key
106
+
107
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
108
+
109
+ return state_dict
110
+
111
+
112
+ def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
113
+ """
114
+ Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.
115
+
116
+ This function is meant to be the inverse of remap_state_dict_hf_bigcode.
117
+ """
118
+
119
+ # Word embedding and position embeddings
120
+ def inv_key_mapping_pos_emb(key):
121
+ return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)
122
+
123
+ state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
124
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
125
+
126
+ word_embeddings = word_embeddings[:, : config.vocab_size]
127
+ state_dict["transformer.wte.weight"] = word_embeddings
128
+ state_dict["lm_head.weight"] = word_embeddings
129
+
130
+ # LayerNorm
131
+ def inv_key_mapping_ln(key):
132
+ key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
133
+ key = re.sub(
134
+ r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
135
+ r"transformer.h.\1.ln_\2.\3",
136
+ key,
137
+ )
138
+ return key
139
+
140
+ state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())
141
+
142
+ # MLPs
143
+ def inv_key_mapping_mlp(key):
144
+ key = re.sub(
145
+ r"^transformer.layers.(\d+).mlp.fc1.weight",
146
+ r"transformer.h.\1.mlp.c_fc.weight",
147
+ key,
148
+ )
149
+ key = re.sub(
150
+ r"^transformer.layers.(\d+).mlp.fc2.weight",
151
+ r"transformer.h.\1.mlp.c_proj.weight",
152
+ key,
153
+ )
154
+ key = re.sub(
155
+ r"^transformer.layers.(\d+).mlp.fc1.bias",
156
+ r"transformer.h.\1.mlp.c_fc.bias",
157
+ key,
158
+ )
159
+ key = re.sub(
160
+ r"^transformer.layers.(\d+).mlp.fc2.bias",
161
+ r"transformer.h.\1.mlp.c_proj.bias",
162
+ key,
163
+ )
164
+ return key
165
+
166
+ state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())
167
+
168
+ # Attention
169
+ for d in range(config.num_hidden_layers):
170
+ embed_dim = config.n_embd
171
+ head_dim = embed_dim // config.n_head
172
+
173
+ Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
174
+ q, k, v = torch.split(
175
+ Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
176
+ )
177
+ c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
178
+ state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight
179
+
180
+ # Same deal with the bias
181
+ Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
182
+ q, k, v = torch.split(
183
+ Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
184
+ )
185
+ c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
186
+ state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias
187
+
188
+ def inv_key_mapping_attn(key):
189
+ key = re.sub(
190
+ r"^transformer.layers.(\d+).mixer.out_proj.weight",
191
+ r"transformer.h.\1.attn.c_proj.weight",
192
+ key,
193
+ )
194
+ key = re.sub(
195
+ r"^transformer.layers.(\d+).mixer.out_proj.bias",
196
+ r"transformer.h.\1.attn.c_proj.bias",
197
+ key,
198
+ )
199
+ return key
200
+
201
+ state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())
202
+
203
+ return state_dict
204
+
205
+
206
+ def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
207
+ return GPT2Config(
208
+ activation_function=bigcode_config.activation_function,
209
+ attn_pdrop=bigcode_config.attn_pdrop,
210
+ bos_token_id=bigcode_config.bos_token_id,
211
+ embd_pdrop=bigcode_config.embd_pdrop,
212
+ eos_token_id=bigcode_config.eos_token_id,
213
+ initializer_range=bigcode_config.initializer_range,
214
+ layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
215
+ max_batch_size=bigcode_config.max_batch_size,
216
+ max_sequence_length=bigcode_config.max_sequence_length,
217
+ model_type=bigcode_config.model_type,
218
+ multi_query=bigcode_config.multi_query,
219
+ n_embd=bigcode_config.n_embd,
220
+ n_head=bigcode_config.n_head,
221
+ n_inner=bigcode_config.n_inner,
222
+ n_layer=bigcode_config.n_layer,
223
+ n_positions=bigcode_config.n_positions,
224
+ resid_pdrop=bigcode_config.resid_pdrop,
225
+ scale_attn_weights=bigcode_config.scale_attn_weights,
226
+ summary_activation=bigcode_config.summary_activation,
227
+ summary_first_dropout=bigcode_config.summary_first_dropout,
228
+ summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
229
+ summary_type=bigcode_config.summary_type,
230
+ summary_use_proj=bigcode_config.summary_use_proj,
231
+ use_cache=bigcode_config.use_cache,
232
+ vocab_size=bigcode_config.vocab_size,
233
+ )
block.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ from functools import partial
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+ from torchvision.ops import StochasticDepth
11
+
12
+ from flash_attn.modules.mha import MHA
13
+ from flash_attn.modules.mlp import Mlp
14
+
15
+ try:
16
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
17
+ except ImportError:
18
+ layer_norm_fn, RMSNorm = None, None
19
+
20
+
21
+ class Block(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim,
25
+ mixer_cls=None,
26
+ mlp_cls=None,
27
+ norm_cls=nn.LayerNorm,
28
+ dropout_cls=nn.Dropout,
29
+ prenorm=True,
30
+ resid_dropout1=0.0,
31
+ resid_dropout2=0.0,
32
+ drop_path1=0.0,
33
+ drop_path2=0.0,
34
+ fused_dropout_add_ln=False,
35
+ return_residual=False,
36
+ residual_in_fp32=False,
37
+ sequence_parallel=False,
38
+ mark_shared_params=False,
39
+ ):
40
+ """
41
+ For prenorm=True, this Block has a slightly different structure compared to a regular
42
+ prenorm Transformer block.
43
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
44
+ [Ref: https://arxiv.org/abs/2002.04745]
45
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
46
+ the hidden_states (output of the MLP) and the residual.
47
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
48
+ The residual needs to be provided (except for the very first block).
49
+
50
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
51
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
52
+
53
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
54
+ This is for performance reason: for post-norm architecture, returning the input allows us
55
+ to fuse the backward of nn.Linear with the residual connection.
56
+ """
57
+ super().__init__()
58
+ self.prenorm = prenorm
59
+ self.fused_dropout_add_ln = fused_dropout_add_ln
60
+ self.return_residual = return_residual
61
+ self.residual_in_fp32 = residual_in_fp32
62
+ if self.residual_in_fp32:
63
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
64
+ if mixer_cls is None:
65
+ mixer_cls = partial(MHA, num_heads=dim // 64)
66
+ if mlp_cls is None:
67
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
68
+ self.mixer = mixer_cls(dim)
69
+ self.dropout1 = dropout_cls(resid_dropout1)
70
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
71
+ self.norm1 = norm_cls(dim)
72
+ self.mlp = mlp_cls(dim)
73
+ if not isinstance(self.mlp, nn.Identity):
74
+ self.dropout2 = dropout_cls(resid_dropout2)
75
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
76
+ self.norm2 = norm_cls(dim)
77
+
78
+ if self.fused_dropout_add_ln:
79
+ assert layer_norm_fn is not None, "Triton is not installed"
80
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
81
+ self.dropout1, nn.Dropout
82
+ )
83
+
84
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
85
+ # then the input to each worker in the tensor parallel group will be different.
86
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
87
+ # For now this is not an issue because we always use sequence_parallel=True during training
88
+ # and only use sequence_parallel=False during inference.
89
+
90
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
91
+ if sequence_parallel:
92
+ for p in self.norm1.parameters():
93
+ p._sequence_parallel = True
94
+ if hasattr(self, "norm2"):
95
+ for p in self.norm2.parameters():
96
+ p._sequence_parallel = True
97
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
98
+ if mark_shared_params:
99
+ for p in self.norm1.parameters():
100
+ p._shared_params = True
101
+ if hasattr(self, "norm2"):
102
+ for p in self.norm2.parameters():
103
+ p._shared_params = True
104
+
105
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
106
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
107
+
108
+ def forward(
109
+ self,
110
+ hidden_states: Tensor,
111
+ residual: Optional[Tensor] = None,
112
+ mixer_subset=None,
113
+ mixer_kwargs=None,
114
+ ):
115
+ r"""Pass the input through the encoder layer.
116
+
117
+ Args:
118
+ hidden_states: the sequence to the encoder layer (required).
119
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
120
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
121
+ before applying the query projection. Useful for e.g., ViT where we only care
122
+ about the CLS token in the last layer.
123
+ """
124
+ if self.prenorm:
125
+ if not self.fused_dropout_add_ln:
126
+ dropped = self.drop_path1(self.dropout1(hidden_states))
127
+ residual = (dropped + residual) if residual is not None else dropped
128
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
129
+ if self.residual_in_fp32:
130
+ residual = residual.to(torch.float32)
131
+ else:
132
+ if self.drop_path1.p == 0 or not self.training:
133
+ rowscale1 = None
134
+ else:
135
+ rowscale1 = self.drop_path1(
136
+ torch.ones(
137
+ hidden_states.shape[:-1],
138
+ device=hidden_states.device,
139
+ dtype=hidden_states.dtype,
140
+ )
141
+ )
142
+ hidden_states, residual = layer_norm_fn(
143
+ hidden_states,
144
+ self.norm1.weight,
145
+ self.norm1.bias,
146
+ residual=residual,
147
+ eps=self.norm1.eps,
148
+ dropout_p=self.dropout1.p if self.training else 0.0,
149
+ rowscale=rowscale1,
150
+ prenorm=True,
151
+ residual_in_fp32=self.residual_in_fp32,
152
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
153
+ )
154
+ if mixer_kwargs is None:
155
+ mixer_kwargs = {}
156
+ if mixer_subset is not None:
157
+ mixer_kwargs["mixer_subset"] = mixer_subset
158
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
159
+ if mixer_subset is not None:
160
+ residual = residual[:, mixer_subset]
161
+ if not isinstance(self.mlp, nn.Identity):
162
+ if not self.fused_dropout_add_ln:
163
+ dropped = self.drop_path2(self.dropout2(hidden_states))
164
+ residual = (dropped + residual) if residual is not None else dropped
165
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
166
+ if self.residual_in_fp32:
167
+ residual = residual.to(torch.float32)
168
+ else:
169
+ if self.drop_path2.p == 0 or not self.training:
170
+ rowscale2 = None
171
+ else:
172
+ rowscale2 = self.drop_path2(
173
+ torch.ones(
174
+ hidden_states.shape[:-1],
175
+ device=hidden_states.device,
176
+ dtype=hidden_states.dtype,
177
+ )
178
+ )
179
+ hidden_states, residual = layer_norm_fn(
180
+ hidden_states,
181
+ self.norm2.weight,
182
+ self.norm2.bias,
183
+ residual=residual,
184
+ eps=self.norm2.eps,
185
+ dropout_p=self.dropout2.p if self.training else 0.0,
186
+ rowscale=rowscale2,
187
+ prenorm=True,
188
+ residual_in_fp32=self.residual_in_fp32,
189
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
190
+ )
191
+ hidden_states = self.mlp(hidden_states)
192
+ return hidden_states, residual
193
+ else:
194
+ assert residual is None
195
+ mixer_out = self.mixer(
196
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
197
+ )
198
+ if self.return_residual: # mixer out is actually a pair here
199
+ mixer_out, hidden_states = mixer_out
200
+ if not self.fused_dropout_add_ln:
201
+ hidden_states = self.norm1(
202
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
203
+ dtype=self.norm1.weight.dtype
204
+ )
205
+ )
206
+ else:
207
+ if self.drop_path1.p == 0 or not self.training:
208
+ rowscale1 = None
209
+ else:
210
+ rowscale1 = self.drop_path1(
211
+ torch.ones(
212
+ mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
213
+ )
214
+ )
215
+ hidden_states = layer_norm_fn(
216
+ mixer_out,
217
+ self.norm1.weight,
218
+ self.norm1.bias,
219
+ residual=hidden_states,
220
+ eps=self.norm1.eps,
221
+ dropout_p=self.dropout1.p if self.training else 0.0,
222
+ rowscale=rowscale1,
223
+ prenorm=False,
224
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
225
+ )
226
+ if not isinstance(self.mlp, nn.Identity):
227
+ mlp_out = self.mlp(hidden_states)
228
+ if self.return_residual: # mlp out is actually a pair here
229
+ mlp_out, hidden_states = mlp_out
230
+ if not self.fused_dropout_add_ln:
231
+ hidden_states = self.norm2(
232
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
233
+ dtype=self.norm2.weight.dtype
234
+ )
235
+ )
236
+ else:
237
+ if self.drop_path2.p == 0 or not self.training:
238
+ rowscale2 = None
239
+ else:
240
+ rowscale2 = self.drop_path2(
241
+ torch.ones(
242
+ mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
243
+ )
244
+ )
245
+ hidden_states = layer_norm_fn(
246
+ mlp_out,
247
+ self.norm2.weight,
248
+ self.norm2.bias,
249
+ residual=hidden_states,
250
+ eps=self.norm2.eps,
251
+ dropout_p=self.dropout2.p if self.training else 0.0,
252
+ rowscale=rowscale2,
253
+ prenorm=False,
254
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
255
+ )
256
+ return hidden_states
257
+
258
+
259
+ class ParallelBlock(nn.Module):
260
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
261
+ and PaLM.
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ dim,
267
+ mixer_cls=None,
268
+ mlp_cls=None,
269
+ norm_cls=nn.LayerNorm,
270
+ dropout_cls=nn.Dropout,
271
+ resid_dropout1=0.0,
272
+ resid_dropout2=0.0,
273
+ tied_norm=False,
274
+ fused_dropout_add_ln=False,
275
+ residual_in_fp32=False,
276
+ sequence_parallel=False,
277
+ mark_shared_params=False,
278
+ ):
279
+ """
280
+ This Block has a slightly different structure compared to a regular
281
+ prenorm Transformer block.
282
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
283
+ [Ref: https://arxiv.org/abs/2002.04745]
284
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
285
+ the hidden_states (output1 of the MHA / MLP) and the residual.
286
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
287
+ The residual needs to be provided (except for the very first block).
288
+ """
289
+ super().__init__()
290
+ self.tied_norm = tied_norm
291
+ self.fused_dropout_add_ln = fused_dropout_add_ln
292
+ self.residual_in_fp32 = residual_in_fp32
293
+ if mixer_cls is None:
294
+ mixer_cls = partial(MHA, num_heads=dim // 64)
295
+ if mlp_cls is None:
296
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
297
+ self.mixer = mixer_cls(dim)
298
+ self.dropout1 = dropout_cls(resid_dropout1)
299
+ self.norm1 = norm_cls(dim)
300
+ self.mlp = mlp_cls(dim)
301
+ self.dropout2 = dropout_cls(resid_dropout2)
302
+ if not self.tied_norm:
303
+ self.norm2 = norm_cls(dim)
304
+
305
+ if self.fused_dropout_add_ln:
306
+ assert layer_norm_fn is not None, "Triton is not installed"
307
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
308
+ self.dropout1, nn.Dropout
309
+ )
310
+
311
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
312
+ # then the input to each worker in the tensor parallel group will be different.
313
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
314
+ # For now this is not an issue because we always use sequence_parallel=True during training
315
+ # and only use sequence_parallel=False during inference.
316
+
317
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
318
+ if sequence_parallel:
319
+ for p in self.norm1.parameters():
320
+ p._sequence_parallel = True
321
+ if hasattr(self, "norm2"):
322
+ for p in self.norm2.parameters():
323
+ p._sequence_parallel = True
324
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
325
+ if mark_shared_params:
326
+ for p in self.norm1.parameters():
327
+ p._shared_params = True
328
+ if hasattr(self, "norm2"):
329
+ for p in self.norm2.parameters():
330
+ p._shared_params = True
331
+
332
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
333
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
334
+
335
+ def forward(
336
+ self,
337
+ hidden_states1: Tensor,
338
+ hidden_states2: Optional[Tensor] = None,
339
+ residual: Optional[Tensor] = None,
340
+ mixer_kwargs=None,
341
+ ):
342
+ r"""Pass the input through the encoder layer.
343
+
344
+ Args:
345
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
346
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
347
+ residual.
348
+ """
349
+ # TODO: Ideally we should only do the allgather / allreduce once for
350
+ # the Linear to MLP & Attention
351
+ if not self.fused_dropout_add_ln:
352
+ dropped1 = self.dropout1(hidden_states1)
353
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
354
+ if hidden_states2 is not None:
355
+ dropped2 = self.dropout2(hidden_states2)
356
+ residual = (
357
+ (residual + dropped1 + dropped2)
358
+ if residual is not None
359
+ else dropped1 + dropped2
360
+ )
361
+ else:
362
+ residual = (residual + dropped1) if residual is not None else dropped1
363
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
364
+ hidden_states2 = (
365
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
366
+ if not self.tied_norm
367
+ else hidden_states1
368
+ )
369
+ if self.residual_in_fp32:
370
+ residual = residual.to(torch.float32)
371
+ else:
372
+ weight2, bias2 = (
373
+ (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
374
+ )
375
+ hidden_states1, *rest, residual = layer_norm_fn(
376
+ hidden_states1,
377
+ self.norm1.weight,
378
+ self.norm1.bias,
379
+ residual=residual,
380
+ x1=hidden_states2,
381
+ weight1=weight2,
382
+ bias1=bias2,
383
+ eps=self.norm1.eps,
384
+ dropout_p=self.dropout1.p if self.training else 0.0,
385
+ prenorm=True,
386
+ residual_in_fp32=self.residual_in_fp32,
387
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
388
+ )
389
+ if self.tied_norm:
390
+ hidden_states2 = hidden_states1
391
+ else:
392
+ hidden_states2, = rest
393
+ if mixer_kwargs is None:
394
+ mixer_kwargs = {}
395
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
396
+ hidden_states2 = self.mlp(hidden_states2)
397
+ return hidden_states1, hidden_states2, residual
block_info.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ namespace flash {
8
+
9
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
10
+
11
+ template<bool Varlen=true>
12
+ struct BlockInfo {
13
+
14
+ template<typename Params>
15
+ __device__ BlockInfo(const Params &params, const int bidb)
16
+ : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
17
+ , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
18
+ , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
19
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
20
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
21
+ , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
22
+ , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
23
+ {
24
+ }
25
+
26
+ template <typename index_t>
27
+ __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
28
+ return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
29
+ }
30
+
31
+ template <typename index_t>
32
+ __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
33
+ return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
34
+ }
35
+
36
+ const int sum_s_q;
37
+ const int sum_s_k;
38
+ const int actual_seqlen_q;
39
+ // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
40
+ const int seqlen_k_cache;
41
+ const int actual_seqlen_k;
42
+ };
43
+
44
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ } // namespace flash
btlm.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ import json
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from collections import OrderedDict
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from einops import rearrange
14
+ from transformers import GPT2Config, AutoConfig, PretrainedConfig
15
+
16
+
17
+ def remap_state_dict_hf_btlm(state_dict, config):
18
+ # Word embedding and position embedding
19
+ def key_mapping_pos_emb(key):
20
+ return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
21
+
22
+ if "transformer.wpe.weight" in state_dict:
23
+ state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
24
+ word_embeddings = state_dict.pop("transformer.wte.weight")
25
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
26
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
27
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
28
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
29
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
30
+ )
31
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
32
+
33
+ # LayerNorm
34
+ def key_mapping_ln(key):
35
+ key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
36
+ key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
37
+ return key
38
+
39
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
40
+
41
+ # MLP
42
+ for d in range(config.num_hidden_layers):
43
+ W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight")
44
+ W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight")
45
+ state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0)
46
+ b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias")
47
+ b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias")
48
+ state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0)
49
+ W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight")
50
+ state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
51
+
52
+ def key_mapping_mlp(key):
53
+ key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
54
+ return key
55
+
56
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
57
+
58
+ # Attention
59
+ for d in range(config.num_hidden_layers):
60
+ Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
61
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
62
+ Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight")
63
+ state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
64
+ state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes
65
+
66
+ def key_mapping_attn(key):
67
+ key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
68
+ key = re.sub(
69
+ r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
70
+ )
71
+ return key
72
+
73
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
74
+
75
+ return state_dict
76
+
77
+
78
+ def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:
79
+ return GPT2Config(
80
+ vocab_size=btlm_config.vocab_size,
81
+ n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions,
82
+ n_embd=btlm_config.hidden_size,
83
+ n_layer=btlm_config.num_hidden_layers,
84
+ n_head=btlm_config.num_attention_heads,
85
+ n_inner=btlm_config.n_inner,
86
+ activation_function=btlm_config.activation_function,
87
+ resid_pdrop=btlm_config.resid_pdrop,
88
+ embd_pdrop=btlm_config.embd_pdrop,
89
+ attn_pdrop=btlm_config.attn_pdrop,
90
+ layer_norm_epsilon=btlm_config.layer_norm_epsilon,
91
+ initializer_range=btlm_config.initializer_range,
92
+ bos_token_id=btlm_config.bos_token_id,
93
+ eos_token_id=btlm_config.eos_token_id,
94
+ # These are new arguments not in the original GPT2Config
95
+ use_alibi=btlm_config.position_embedding_type == "alibi",
96
+ use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn
97
+ mup_width_scale=btlm_config.mup_width_scale,
98
+ mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,
99
+ mup_output_multiplier=btlm_config.mup_output_alpha,
100
+ mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,
101
+ mlp_multiple_of=1,
102
+ )
causality-monitor.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ causality-monitor:
2
+ _target_: src.callbacks.causality_monitor.CausalityMonitor
comet.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # https://www.comet.ml
2
+
3
+ comet:
4
+ _target_: pytorch_lightning.loggers.comet.CometLogger
5
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6
+ project_name: "template-tests"
7
+ experiment_name: ${name}
config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default training configuration
4
+ defaults:
5
+ - _self_
6
+ - trainer: default
7
+ - optimizer: adamw
8
+ - scheduler: null
9
+ - task: sequence-model
10
+ - model: null
11
+ - datamodule: null
12
+ - callbacks: default # set this to null if you don't want to use callbacks
13
+ - metrics: null
14
+ - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)
15
+
16
+ - mode: default
17
+
18
+ - experiment: null
19
+ - hparams_search: null
20
+
21
+ # enable color logging
22
+ - override hydra/hydra_logging: colorlog
23
+ - override hydra/job_logging: colorlog
24
+
25
+ # path to original working directory
26
+ # hydra hijacks working directory by changing it to the current log directory,
27
+ # so it's useful to have this path as a special variable
28
+ # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
29
+ work_dir: ${hydra:runtime.cwd}
30
+
31
+ # path to folder with data
32
+ data_dir: ${work_dir}/data/
33
+
34
+ # pretty print config at the start of the run using Rich library
35
+ print_config: True
36
+
37
+ # disable python warnings if they annoy you
38
+ ignore_warnings: True
39
+
40
+ # check performance on test set, using the best model achieved during training
41
+ # lightning chooses best model based on metric specified in checkpoint callback
42
+ test_after_training: True
43
+
44
+ resume: False
45
+
46
+ # seed for random number generators in pytorch, numpy and python.random
47
+ seed: null
48
+
49
+ # name of the run, accessed by loggers
50
+ name: null
cosine-warmup-timm.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # @package train.scheduler
2
+ _target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler
cosine-warmup.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # @package train.scheduler
2
+ _target_: transformers.get_cosine_schedule_with_warmup
cross_entropy.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Tuple, Optional, Union
4
+
5
+ import torch
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
11
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
12
+ # version of PyTorch. The following 2 lines are for backward compatibility with
13
+ # older PyTorch.
14
+ if "all_gather_into_tensor" not in dir(torch.distributed):
15
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
16
+
17
+
18
+ @triton.heuristics(
19
+ {
20
+ "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
21
+ }
22
+ )
23
+ @triton.jit
24
+ def cross_entropy_fwd_kernel(
25
+ loss_ptr, # data ptrs
26
+ lse_ptr,
27
+ z_loss_ptr,
28
+ logits_ptr,
29
+ labels_ptr,
30
+ smoothing,
31
+ logit_scale,
32
+ lse_square_scale,
33
+ ignore_index,
34
+ total_classes,
35
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
36
+ n_cols, # shapes
37
+ n_rows,
38
+ logits_row_stride, # strides
39
+ BLOCK_SIZE: tl.constexpr,
40
+ HAS_SMOOTHING: tl.constexpr,
41
+ # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
42
+ SPLIT: tl.constexpr,
43
+ ):
44
+ row_idx = tl.program_id(0)
45
+ col_block_idx = tl.program_id(1)
46
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
47
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
48
+ label_idx = tl.load(labels_ptr + row_idx)
49
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
50
+ tl.float32
51
+ ) * logit_scale
52
+ max_logits = tl.max(logits, 0)
53
+ if HAS_SMOOTHING:
54
+ sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
55
+ lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
56
+ tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
57
+ if label_idx == ignore_index:
58
+ loss = 0.0
59
+ z_loss = 0.0
60
+ else:
61
+ label_idx -= class_start_idx
62
+ if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
63
+ n_cols, (col_block_idx + 1) * BLOCK_SIZE
64
+ ):
65
+ logits_label = tl.load(logits_ptr + label_idx) * logit_scale
66
+ if HAS_SMOOTHING:
67
+ loss = (
68
+ (lse if not SPLIT else 0.0)
69
+ - smoothing * sum_logits / total_classes
70
+ - (1 - smoothing) * logits_label
71
+ )
72
+ else:
73
+ loss = (lse if not SPLIT else 0.0) - logits_label
74
+ else:
75
+ # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
76
+ if HAS_SMOOTHING:
77
+ loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
78
+ else:
79
+ loss = 0.0
80
+ if not SPLIT:
81
+ z_loss = lse_square_scale * lse * lse
82
+ loss += z_loss
83
+ else:
84
+ z_loss = 0.0
85
+ tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
86
+ if not SPLIT:
87
+ tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
88
+
89
+
90
+ @triton.heuristics(
91
+ {
92
+ "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
93
+ }
94
+ )
95
+ @triton.jit
96
+ def cross_entropy_bwd_kernel(
97
+ dlogits_ptr, # data ptrs
98
+ dloss_ptr,
99
+ logits_ptr,
100
+ lse_ptr,
101
+ labels_ptr,
102
+ smoothing,
103
+ logit_scale,
104
+ lse_square_scale,
105
+ ignore_index,
106
+ total_classes,
107
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
108
+ n_cols, # shapes
109
+ logits_row_stride, # strides
110
+ dlogits_row_stride,
111
+ dloss_row_stride,
112
+ BLOCK_SIZE: tl.constexpr,
113
+ HAS_SMOOTHING: tl.constexpr,
114
+ ):
115
+ row_idx = tl.program_id(0)
116
+ col_block_idx = tl.program_id(1)
117
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
118
+ dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
119
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
120
+ label_idx = tl.load(labels_ptr + row_idx)
121
+ if label_idx != ignore_index:
122
+ dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
123
+ else:
124
+ dloss = 0.0
125
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
126
+ tl.float32
127
+ ) * logit_scale
128
+ lse = tl.load(lse_ptr + row_idx)
129
+ probs = tl.exp(logits - lse)
130
+ probs += 2.0 * lse_square_scale * lse * probs
131
+ label_idx -= class_start_idx
132
+ if HAS_SMOOTHING:
133
+ smooth_positive = 1.0 - smoothing
134
+ smooth_negative = smoothing / total_classes
135
+ probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
136
+ else:
137
+ probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
138
+ tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
139
+
140
+
141
+ class CrossEntropyLoss(torch.autograd.Function):
142
+
143
+ @staticmethod
144
+ def forward(
145
+ ctx,
146
+ logits,
147
+ labels,
148
+ smoothing=0.0,
149
+ logit_scale=1.0,
150
+ lse_square_scale=0.0,
151
+ ignore_index=-100,
152
+ inplace_backward=False,
153
+ process_group=None,
154
+ ):
155
+ n_rows, n_cols = logits.shape
156
+ assert labels.shape == (n_rows,)
157
+ world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
158
+ total_classes = world_size * n_cols
159
+ rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
160
+ class_start_idx = rank * n_cols
161
+
162
+ if logits.stride(-1) != 1:
163
+ logits = logits.contiguous()
164
+ # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
165
+ MAX_BLOCK_SIZE = 64 * 1024
166
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
167
+ num_warps = (
168
+ 4
169
+ if BLOCK_SIZE < 2048
170
+ else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
171
+ )
172
+ # We may split the lse computation across multiple blocks, then do a reduction
173
+ # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
174
+ # where having just one thread block processing more than 64k elements is slow.
175
+ split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
176
+ n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
177
+ loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
178
+ losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
179
+ lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
180
+ z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
181
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
182
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
183
+ with torch.cuda.device(logits.device.index):
184
+ cross_entropy_fwd_kernel[(n_rows, n_splits)](
185
+ losses, # data ptrs
186
+ lse,
187
+ z_losses,
188
+ logits,
189
+ labels,
190
+ smoothing,
191
+ logit_scale,
192
+ lse_square_scale,
193
+ ignore_index,
194
+ total_classes,
195
+ class_start_idx,
196
+ n_cols, # shapes
197
+ n_rows,
198
+ logits.stride(0), # strides
199
+ BLOCK_SIZE=BLOCK_SIZE, # constants
200
+ num_warps=num_warps,
201
+ SPLIT=split,
202
+ )
203
+
204
+ if split:
205
+ # If there's no smoothing, if labels are in the vocab of this partition, losses contains
206
+ # - predicted logit, and 0 otherwise.
207
+ # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
208
+ # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
209
+ # For labels not in the vocab of this partition, losses contains
210
+ # -0.1 * sum logit / total_classes.
211
+ if n_splits > 1:
212
+ lse = torch.logsumexp(lse, dim=0)
213
+ losses = losses.sum(dim=0)
214
+ if world_size > 1:
215
+ lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
216
+ torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
217
+ handle_losses = torch.distributed.all_reduce(
218
+ losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
219
+ )
220
+ lse = torch.logsumexp(lse_allgather, dim=0)
221
+ handle_losses.wait()
222
+ # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
223
+ # we just have to add the (global) lse.
224
+ # If there's smoothing=0.1, the total losses are
225
+ # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
226
+ # Again, we just have to add the (global) lse.
227
+ losses += lse
228
+ if lse_square_scale != 0.0:
229
+ z_losses = lse_square_scale * lse.square()
230
+ z_losses.masked_fill_(labels == ignore_index, 0.0)
231
+ losses += z_losses
232
+ else:
233
+ z_losses = torch.zeros_like(losses)
234
+ losses.masked_fill_(labels == ignore_index, 0.0)
235
+
236
+ ctx.save_for_backward(logits, lse, labels)
237
+ ctx.mark_non_differentiable(z_losses)
238
+ ctx.smoothing = smoothing
239
+ ctx.logit_scale = logit_scale
240
+ ctx.lse_square_scale = lse_square_scale
241
+ ctx.ignore_index = ignore_index
242
+ ctx.total_classes = total_classes
243
+ ctx.class_start_idx = class_start_idx
244
+ ctx.inplace_backward = inplace_backward
245
+
246
+ return losses, z_losses
247
+
248
+ @staticmethod
249
+ def backward(ctx, grad_losses, grad_z_losses):
250
+ del grad_z_losses # z_losses are only for logging.
251
+
252
+ logits, lse, labels = ctx.saved_tensors
253
+ dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
254
+ n_rows, n_cols = logits.shape
255
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
256
+ num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
257
+ grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
258
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
259
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
260
+ with torch.cuda.device(logits.device.index):
261
+ cross_entropy_bwd_kernel[grid](
262
+ dlogits, # data ptrs
263
+ grad_losses,
264
+ logits,
265
+ lse,
266
+ labels,
267
+ ctx.smoothing,
268
+ ctx.logit_scale,
269
+ ctx.lse_square_scale,
270
+ ctx.ignore_index,
271
+ ctx.total_classes,
272
+ ctx.class_start_idx,
273
+ n_cols, # shapes
274
+ logits.stride(0), # strides
275
+ dlogits.stride(0),
276
+ grad_losses.stride(0),
277
+ BLOCK_SIZE=BLOCK_SIZE, # constants
278
+ num_warps=num_warps,
279
+ )
280
+ return dlogits, None, None, None, None, None, None, None, None
281
+
282
+ def cross_entropy_loss(
283
+ logits: torch.Tensor,
284
+ labels: torch.Tensor,
285
+ label_smoothing: float = 0.0,
286
+ logit_scale: float = 1.0,
287
+ lse_square_scale: float = 0.0,
288
+ ignore_index=-100,
289
+ inplace_backward: bool = False,
290
+ process_group=None,
291
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
292
+ """
293
+ Arguments:
294
+ logits: (batch, vocab_size)
295
+ labels: (batch,)
296
+ label_smoothing: float
297
+ logit_scale: float. Multiply logits by this scale before calculating the loss.
298
+ lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
299
+ This is also referred to as "z-loss".
300
+ ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
301
+ inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
302
+ This saves memory.
303
+ process_group: if not None, we're doing Tensor Parallel: each process is responsible for
304
+ one part of the vocab. The loss will be aggregated across processes.
305
+ Returns:
306
+ losses: (batch,), float
307
+ z_losses: (batch,), float
308
+ """
309
+ return CrossEntropyLoss.apply(
310
+ logits,
311
+ labels,
312
+ label_smoothing,
313
+ logit_scale,
314
+ lse_square_scale,
315
+ ignore_index,
316
+ inplace_backward,
317
+ process_group,
318
+ )
csv.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5
+ save_dir: "."
6
+ name: "csv/"
7
+ version: ${name}
8
+ prefix: ""
cuda_bf16_fallbacks.cuh ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
3
+ /*
4
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ #include "cuda_bf16_wrapper.h"
22
+ #include <cuda_fp16.h>
23
+
24
+ namespace fastertransformer {
25
+
26
+ #ifdef ENABLE_BF16
27
+ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
28
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
29
+ float2 f_val;
30
+ f_val.x = __low2float(val);
31
+ f_val.y = __high2float(val);
32
+ return f_val;
33
+ #else
34
+ return __bfloat1622float2(val);
35
+ #endif
36
+ }
37
+
38
+ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
39
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
40
+ float2 f_val;
41
+ f_val.x = max(min(__low2float(val), 127.f), -128.f);
42
+ f_val.y = max(min(__high2float(val), 127.f), -128.f);
43
+ union { int8_t int8[2]; int16_t int16; };
44
+ int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
45
+ int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
46
+ return int16;
47
+ #else
48
+ val = __hmin2(val, make_bfloat162(127., 127.));
49
+ val = __hmax2(val, make_bfloat162(-128., -128.));
50
+ union { int8_t int8[2]; int16_t int16; };
51
+ int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
52
+ int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
53
+ return int16;
54
+ #endif
55
+ }
56
+
57
+ inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
58
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
59
+ return __floats2bfloat162_rn(val.x, val.y);
60
+ #else
61
+ return __float22bfloat162_rn(val);
62
+ #endif
63
+ }
64
+
65
+ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
66
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
67
+ __nv_bfloat162 val2;
68
+ val2.x = val;
69
+ val2.y = val;
70
+ return val2;
71
+ #else
72
+ return __bfloat162bfloat162(val);
73
+ #endif
74
+ }
75
+
76
+ inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
77
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
78
+ float fxl, fxh, fyl, fyh;
79
+ fxl = __low2float(x);
80
+ fxh = __high2float(x);
81
+ fyl = __low2float(y);
82
+ fyh = __high2float(y);
83
+ return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
84
+ #else
85
+ return __hadd2(x, y);
86
+ #endif
87
+ }
88
+
89
+ inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
90
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
91
+ return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
92
+ #else
93
+ return __hadd(x, y);
94
+ #endif
95
+ }
96
+
97
+ inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
98
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
99
+ float fxl, fxh, fyl, fyh;
100
+ fxl = __low2float(x);
101
+ fxh = __high2float(x);
102
+ fyl = __low2float(y);
103
+ fyh = __high2float(y);
104
+ return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
105
+ #else
106
+ return __hsub2(x, y);
107
+ #endif
108
+ }
109
+
110
+ inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
111
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
112
+ return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
113
+ #else
114
+ return __hsub(x, y);
115
+ #endif
116
+ }
117
+
118
+ inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
119
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
120
+ float fxl, fxh, fyl, fyh;
121
+ fxl = __low2float(x);
122
+ fxh = __high2float(x);
123
+ fyl = __low2float(y);
124
+ fyh = __high2float(y);
125
+ return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
126
+ #else
127
+ return __hmul2(x, y);
128
+ #endif
129
+ }
130
+
131
+ inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
132
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
133
+ return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
134
+ #else
135
+ return __hmul(x, y);
136
+ #endif
137
+ }
138
+
139
+ inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
140
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
141
+ float fxl, fxh, fyl, fyh, fzl, fzh;
142
+ fxl = __low2float(x);
143
+ fxh = __high2float(x);
144
+ fyl = __low2float(y);
145
+ fyh = __high2float(y);
146
+ fzl = __low2float(z);
147
+ fzh = __high2float(z);
148
+ return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
149
+ #else
150
+ return __hfma2(x, y, z);
151
+ #endif
152
+ }
153
+
154
+ inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
155
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
156
+ return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
157
+ #else
158
+ return __hfma(x, y, z);
159
+ #endif
160
+ }
161
+
162
+ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
163
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
164
+ float fxl, fxh;
165
+ fxl = __low2float(x);
166
+ fxh = __high2float(x);;
167
+ return __floats2bfloat162_rn(expf(fxl), expf(fxh));
168
+ #else
169
+ return h2exp(x);
170
+ #endif
171
+ }
172
+
173
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
174
+ inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
175
+ inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
176
+
177
+ inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
178
+ {
179
+ __nv_bfloat162 t; t.x = x; t.y = y; return t;
180
+ }
181
+
182
+ #endif
183
+
184
+ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
185
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
186
+ return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
187
+ #else
188
+ return a + b + c;
189
+ #endif
190
+ }
191
+
192
+ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
193
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
194
+ return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
195
+ #else
196
+ return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
197
+ #endif
198
+ }
199
+
200
+ inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
201
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
202
+ float fal, fah, fbl, fbh, fcl, fch;
203
+ fal = __low2float(a);
204
+ fah = __high2float(a);
205
+ fbl = __low2float(b);
206
+ fbh = __high2float(b);
207
+ fcl = __low2float(c);
208
+ fch = __high2float(c);
209
+ return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
210
+ #else
211
+ return a + b + c;
212
+ #endif
213
+ }
214
+
215
+ inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
216
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
217
+ return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
218
+ #else
219
+ return a * b * c;
220
+ #endif
221
+ }
222
+
223
+ inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
224
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
225
+ float fal, fah, fbl, fbh, fcl, fch;
226
+ fal = __low2float(a);
227
+ fah = __high2float(a);
228
+ fbl = __low2float(b);
229
+ fbh = __high2float(b);
230
+ fcl = __low2float(c);
231
+ fch = __high2float(c);
232
+ return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
233
+ #else
234
+ return a * b * c;
235
+ #endif
236
+ }
237
+
238
+ inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
239
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
240
+ float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
241
+ fal = __low2float(a);
242
+ fah = __high2float(a);
243
+ fbl = __low2float(b);
244
+ fbh = __high2float(b);
245
+ fcl = __low2float(c);
246
+ fch = __high2float(c);
247
+ fdl = __low2float(d);
248
+ fdh = __high2float(d);
249
+ return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
250
+ #else
251
+ return a * b * c + d;
252
+ #endif
253
+ }
254
+
255
+ #endif // ENABLE_BF16
256
+
257
+ } // namespace fastertransformer
cuda_bf16_wrapper.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
3
+ /*
4
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ #ifdef ENABLE_BF16
22
+ #include <cuda_bf16.h>
23
+ #endif
ddp.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+
4
+ accelerator: gpu
5
+ devices: 4
6
+ strategy: ddp
debug.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # run in debug mode with:
4
+ # `python run.py mode=debug`
5
+
6
+ defaults:
7
+ - override /trainer: debug.yaml
8
+
9
+ debug_mode: True
10
+
11
+ hydra:
12
+ # sets level of all command line loggers to 'DEBUG'
13
+ verbose: True
14
+
15
+ # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
16
+ # sets level of only chosen command line loggers to 'DEBUG'
17
+ # verbose: [src.train, src.utils.utils]
18
+
19
+ # sets output paths for all file logs to 'logs/debug/'
20
+ run:
21
+ dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S}
22
+ sweep:
23
+ dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S}
24
+ subdir: ${hydra.job.num}
25
+
26
+ # disable rich config printing, since it will be already printed by hydra when `verbose: True`
27
+ print_config: False
decoder_masked_multihead_attention.cu ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
3
+ /*
4
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #include "decoder_masked_multihead_attention.h"
20
+ #include "decoder_masked_multihead_attention_utils.h"
21
+ #include "cuda_bf16_wrapper.h"
22
+ #include <assert.h>
23
+ #include <float.h>
24
+ #include <type_traits>
25
+
26
+ #include "decoder_masked_multihead_attention_template.hpp"
27
+
28
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
29
+
30
+ #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
31
+ size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
32
+ auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
33
+ THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
34
+ cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
35
+ dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \
36
+ kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
37
+
38
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
39
+
40
+ // !!! Specialize the launcher for Cross attention
41
+ template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
42
+ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
43
+ {
44
+ constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
45
+ constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
46
+ int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
47
+ // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
48
+ if (tlength < 32) {
49
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
50
+ }
51
+ else if (tlength < 2048) {
52
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
53
+ }
54
+ else {
55
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
56
+ }
57
+ }
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
60
+
61
+ #undef MMHA_LAUNCH_KERNEL
62
+
63
+ template<typename T, typename KERNEL_PARAMS_TYPE>
64
+ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
65
+ {
66
+ switch (params.hidden_size_per_head) {
67
+ case 32:
68
+ mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
69
+ break;
70
+ case 48:
71
+ mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
72
+ break;
73
+ case 64:
74
+ mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
75
+ break;
76
+ case 80:
77
+ mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
78
+ break;
79
+ case 96:
80
+ mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
81
+ break;
82
+ case 128:
83
+ mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
84
+ break;
85
+ case 160:
86
+ mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
87
+ break;
88
+ case 192:
89
+ mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
90
+ break;
91
+ case 224:
92
+ mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
93
+ break;
94
+ case 256:
95
+ mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
96
+ break;
97
+ default:
98
+ assert(false);
99
+ }
100
+ }
101
+
102
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
103
+
104
+ void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
105
+ {
106
+ multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
107
+ }
108
+
109
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
110
+
111
+ void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
112
+ {
113
+ multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
114
+ }
115
+
116
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
117
+
118
+ #ifdef ENABLE_BF16
119
+ void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
120
+ const cudaStream_t& stream)
121
+ {
122
+ multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
123
+ }
124
+ #endif
125
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
126
+
127
+ void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
128
+ {
129
+ multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
130
+ }
131
+
132
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
133
+
134
+ void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
135
+ {
136
+ multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
137
+ }
138
+
139
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
140
+
141
+ #ifdef ENABLE_BF16
142
+ void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
143
+ const cudaStream_t& stream)
144
+ {
145
+ multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
146
+ }
147
+ #endif
148
+
149
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
decoder_masked_multihead_attention.h ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
3
+ /*
4
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ #include "cuda_bf16_wrapper.h"
22
+ #include <cuda_fp16.h>
23
+ #include <cuda_runtime_api.h>
24
+ #include <stdint.h>
25
+ #include <stdio.h>
26
+ #include <stdlib.h>
27
+
28
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
29
+
30
+ #define CHECK_CUDA(call) \
31
+ do { \
32
+ cudaError_t status_ = call; \
33
+ if (status_ != cudaSuccess) { \
34
+ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
35
+ exit(1); \
36
+ } \
37
+ } while (0)
38
+
39
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ // The structure of parameters for the masked multihead attention kernel.
42
+ //
43
+ // We use the following terminology to describe the different dimensions.
44
+ //
45
+ // B: Batch size (number of sequences),
46
+ // L: Sequence length,
47
+ // D: Hidden dimension,
48
+ // H: Number of heads,
49
+ // Dh: Hidden dimension per head - Dh = D / H.
50
+
51
+ template<typename T>
52
+ struct Multihead_attention_params_base {
53
+
54
+ // The output buffer. Dimensions B x D.
55
+ T* out = nullptr;
56
+
57
+ // The input Qs and the associated bias. Dimensions B x D and D, resp.
58
+ const T *q = nullptr, *q_bias = nullptr;
59
+ // The input Ks and the associated bias. Dimensions B x D and D, resp.
60
+ const T *k = nullptr, *k_bias = nullptr;
61
+ // The input Vs and the associated bias. Dimensions B x D and D, resp.
62
+ const T *v = nullptr, *v_bias = nullptr;
63
+
64
+ // The cache for the Ks. The size must be at least B x L x D.
65
+ T* k_cache = nullptr;
66
+ // The cache for the Vs. The size must be at least B x L x D.
67
+ T* v_cache = nullptr;
68
+ // The indirections to use for cache when beam sampling.
69
+ const int* cache_indir = nullptr;
70
+
71
+ // Stride to handle the case when KQV is a single buffer
72
+ int stride_q = 0;
73
+ int stride_k = 0;
74
+ int stride_v = 0;
75
+
76
+ // The batch size.
77
+ int batch_size = 0;
78
+ // The beam width
79
+ int beam_width = 0;
80
+ // The sequence length.
81
+ int memory_max_len = 0;
82
+ // The number of heads (H).
83
+ int num_heads = 0;
84
+ int num_heads_kv = 0;
85
+ int num_heads_q_kv_ratio = 0;
86
+ // The hidden dimension per head (Dh).
87
+ int hidden_size_per_head = 0;
88
+ // The per-head latent space reserved for rotary embeddings.
89
+ int rotary_embedding_dim = 0;
90
+ bool neox_rotary_style = false;
91
+ float rotary_base = 0.0f;
92
+ // The maximum length of input sentences.
93
+ int max_input_length = 0;
94
+ // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
95
+ int timestep = 0;
96
+ // The current timestep of each sentences (support different timestep for different sentences)
97
+
98
+ // The 1.f / sqrt(Dh). Computed on the host.
99
+ float inv_sqrt_dh = 0.0f;
100
+
101
+ // Used when we have some input context like gpt
102
+ const int* total_padding_tokens = nullptr;
103
+
104
+ const bool* masked_tokens = nullptr;
105
+ const int* prefix_prompt_lengths = nullptr;
106
+ int max_prefix_prompt_length = 0;
107
+
108
+ const T* relative_attention_bias = nullptr;
109
+ int relative_attention_bias_stride = 0;
110
+ // The slope per head of linear position bias to attention score (H).
111
+ const T* linear_bias_slopes = nullptr;
112
+
113
+ const T* ia3_key_weights = nullptr;
114
+ const T* ia3_value_weights = nullptr;
115
+ const int* ia3_tasks = nullptr;
116
+
117
+ const float* qkv_scale_out = nullptr;
118
+ const float* attention_out_scale = nullptr;
119
+ int int8_mode = 0;
120
+
121
+ const T *rotary_cos = nullptr;
122
+ const T *rotary_sin = nullptr;
123
+
124
+ const int *nnz_head_idx = nullptr;
125
+ int nnz_heads = 0;
126
+ };
127
+
128
+ template<typename T, bool CROSS_ATTENTION>
129
+ struct Multihead_attention_params: public Multihead_attention_params_base<T> {
130
+ // output cross attentions
131
+ float* cross_attention_out = nullptr;
132
+ int max_decoder_seq_len = 0;
133
+ bool is_return_cross_attentions = false;
134
+
135
+ // allows to exist attention eary
136
+ bool* finished = nullptr;
137
+
138
+ // required in case of cross attention
139
+ // will need it here till if constexpr in c++17
140
+ int* memory_length_per_sample = nullptr;
141
+
142
+ // required in case of masked attention with different length
143
+ const int* length_per_sample = nullptr;
144
+ };
145
+
146
+ template<typename T>
147
+ struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
148
+ // output cross attentions
149
+ float* cross_attention_out = nullptr;
150
+ int max_decoder_seq_len = 0;
151
+ bool is_return_cross_attentions = false;
152
+
153
+ // allows to exist attention eary
154
+ bool* finished = nullptr;
155
+
156
+ // required in case of cross attention
157
+ int* memory_length_per_sample = nullptr;
158
+
159
+ // required in case of masked attention with different length
160
+ const int* length_per_sample = nullptr;
161
+ };
162
+
163
+ template<class T>
164
+ using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
165
+
166
+ template<class T>
167
+ using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
168
+
169
+ template<typename T>
170
+ struct outputCrossAttentionParam {
171
+ // max decoder output length
172
+ int max_decoder_seq_len = 0;
173
+ T* cross_attention_out = nullptr;
174
+ bool is_return_cross_attentions = false;
175
+ };
176
+
177
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
178
+
179
+ void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
180
+ void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
181
+ #ifdef ENABLE_BF16
182
+ void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
183
+ const cudaStream_t& stream);
184
+ #endif
185
+ void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
186
+ void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
187
+ #ifdef ENABLE_BF16
188
+ void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
189
+ const cudaStream_t& stream);
190
+ #endif
191
+
192
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
decoder_masked_multihead_attention_template.hpp ADDED
@@ -0,0 +1,1619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
+ /*
4
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+ #pragma once
19
+
20
+ #include "decoder_masked_multihead_attention.h"
21
+ #include "decoder_masked_multihead_attention_utils.h"
22
+ #include "cuda_bf16_wrapper.h"
23
+ #include "cuda_bf16_fallbacks.cuh"
24
+ #include <assert.h>
25
+ #include <float.h>
26
+ #include <type_traits>
27
+
28
+ // #define MMHA_USE_HMMA_FOR_REDUCTION
29
+
30
+ // Below are knobs to extend FP32 accumulation for higher FP16 accuracy
31
+
32
+ // Does not seem to affect the accuracy that much
33
+ #define MMHA_USE_FP32_ACUM_FOR_FMA
34
+
35
+ // Seems to slightly improve the accuracy
36
+ #define MMHA_USE_FP32_ACUM_FOR_OUT
37
+
38
+ #if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
39
+ // Does not seem to improve the accuracy
40
+ //#define MMHA_USE_FP32_ACUM_FOR_LOGITS
41
+ #endif
42
+
43
+ namespace mmha {
44
+
45
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ //
48
+ // We use the following terminology to describe the different dimensions.
49
+ //
50
+ // B: Batch size (number of sequences),
51
+ // L: Sequence length,
52
+ // D: Hidden dimension,
53
+ // H: Number of heads,
54
+ // Dh: Hidden dimension per head - Dh = D / H.
55
+ //
56
+ // The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
57
+ // 64, 128 and 256 threads per block.
58
+ //
59
+ // Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
60
+ // compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
61
+ // cache buffer helps with memory accesses and contains keys with bias.
62
+ //
63
+ // The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
64
+ // x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
65
+ // values for x are chosen to create chunks of 16 bytes.
66
+ //
67
+ // The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
68
+ // depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
69
+ // the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
70
+ // HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
71
+ //
72
+ // After that loop, a parallel softmax is computed across the different Q * K^T values stored in
73
+ // shared memory.
74
+ //
75
+ // The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
76
+ // timesteps are computed by loop iteration. As with the keys, the values are read from a cache
77
+ // except for the current timestep. The layout of the cache buffer for the values is much simpler
78
+ // as it is [B, H, L, Dh].
79
+ //
80
+
81
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
82
+
83
+ template<typename T, int Dh>
84
+ struct Qk_vec_ {
85
+ };
86
+
87
+ template<>
88
+ struct Qk_vec_<float, 32> {
89
+ using Type = float;
90
+ };
91
+ template<>
92
+ struct Qk_vec_<float, 64> {
93
+ using Type = float2;
94
+ };
95
+ template<>
96
+ struct Qk_vec_<float, 128> {
97
+ using Type = float4;
98
+ };
99
+ template<>
100
+ struct Qk_vec_<float, 256> {
101
+ using Type = float4;
102
+ };
103
+ template<>
104
+ struct Qk_vec_<uint16_t, 32> {
105
+ using Type = uint32_t;
106
+ };
107
+ template<>
108
+ struct Qk_vec_<uint16_t, 64> {
109
+ using Type = uint32_t;
110
+ };
111
+ template<>
112
+ struct Qk_vec_<uint16_t, 128> {
113
+ using Type = uint2;
114
+ };
115
+ template<>
116
+ struct Qk_vec_<uint16_t, 256> {
117
+ using Type = uint4;
118
+ };
119
+ #ifdef ENABLE_BF16
120
+ template<>
121
+ struct Qk_vec_<__nv_bfloat16, 32> {
122
+ using Type = __nv_bfloat162;
123
+ };
124
+ template<>
125
+ struct Qk_vec_<__nv_bfloat16, 64> {
126
+ using Type = __nv_bfloat162;
127
+ };
128
+ template<>
129
+ struct Qk_vec_<__nv_bfloat16, 128> {
130
+ using Type = bf16_4_t;
131
+ };
132
+ template<>
133
+ struct Qk_vec_<__nv_bfloat16, 256> {
134
+ using Type = bf16_8_t;
135
+ };
136
+ #endif // ENABLE_BF16
137
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
138
+
139
+ template<typename T, int THREADS_PER_KEY>
140
+ struct K_vec_ {
141
+ };
142
+
143
+ template<>
144
+ struct K_vec_<float, 4> {
145
+ using Type = float;
146
+ };
147
+ template<>
148
+ struct K_vec_<float, 2> {
149
+ using Type = float2;
150
+ };
151
+ template<>
152
+ struct K_vec_<float, 1> {
153
+ using Type = float4;
154
+ };
155
+ template<>
156
+ struct K_vec_<uint16_t, 4> {
157
+ using Type = uint32_t;
158
+ };
159
+ template<>
160
+ struct K_vec_<uint16_t, 2> {
161
+ using Type = uint2;
162
+ };
163
+ template<>
164
+ struct K_vec_<uint16_t, 1> {
165
+ using Type = uint4;
166
+ };
167
+ #ifdef ENABLE_BF16
168
+ template<>
169
+ struct K_vec_<__nv_bfloat16, 4> {
170
+ using Type = __nv_bfloat162;
171
+ };
172
+ template<>
173
+ struct K_vec_<__nv_bfloat16, 2> {
174
+ using Type = bf16_4_t;
175
+ };
176
+ template<>
177
+ struct K_vec_<__nv_bfloat16, 1> {
178
+ using Type = bf16_8_t;
179
+ };
180
+ #endif // ENABLE_BF16
181
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
182
+
183
+ template<typename T, int V_VEC_SIZE>
184
+ struct V_vec_ {
185
+ };
186
+
187
+ template<>
188
+ struct V_vec_<float, 1> {
189
+ using Type = float;
190
+ };
191
+ template<>
192
+ struct V_vec_<float, 2> {
193
+ using Type = float2;
194
+ };
195
+ template<>
196
+ struct V_vec_<float, 4> {
197
+ using Type = float4;
198
+ };
199
+ template<>
200
+ struct V_vec_<uint16_t, 2> {
201
+ using Type = uint32_t;
202
+ };
203
+ template<>
204
+ struct V_vec_<uint16_t, 4> {
205
+ using Type = uint2;
206
+ };
207
+ template<>
208
+ struct V_vec_<uint16_t, 8> {
209
+ using Type = uint4;
210
+ };
211
+ #ifdef ENABLE_BF16
212
+ template<>
213
+ struct V_vec_<__nv_bfloat16, 2> {
214
+ using Type = __nv_bfloat162;
215
+ };
216
+ template<>
217
+ struct V_vec_<__nv_bfloat16, 4> {
218
+ using Type = bf16_4_t;
219
+ };
220
+ template<>
221
+ struct V_vec_<__nv_bfloat16, 8> {
222
+ using Type = bf16_8_t;
223
+ };
224
+ #endif // ENABLE_BF16
225
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
226
+
227
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
228
+ template<typename T>
229
+ struct Qk_vec_acum_fp32_ {
230
+ };
231
+
232
+ template<>
233
+ struct Qk_vec_acum_fp32_<float> {
234
+ using Type = float;
235
+ };
236
+ template<>
237
+ struct Qk_vec_acum_fp32_<float2> {
238
+ using Type = float2;
239
+ };
240
+ template<>
241
+ struct Qk_vec_acum_fp32_<float4> {
242
+ using Type = float4;
243
+ };
244
+ // template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
245
+ template<>
246
+ struct Qk_vec_acum_fp32_<uint32_t> {
247
+ using Type = float2;
248
+ };
249
+ template<>
250
+ struct Qk_vec_acum_fp32_<uint2> {
251
+ using Type = Float4_;
252
+ };
253
+ template<>
254
+ struct Qk_vec_acum_fp32_<uint4> {
255
+ using Type = Float8_;
256
+ };
257
+ template<>
258
+ struct Qk_vec_acum_fp32_<__nv_bfloat16> {
259
+ using Type = float;
260
+ };
261
+ template<>
262
+ struct Qk_vec_acum_fp32_<__nv_bfloat162> {
263
+ using Type = float2;
264
+ };
265
+ template<>
266
+ struct Qk_vec_acum_fp32_<bf16_4_t> {
267
+ using Type = Float4_;
268
+ };
269
+ template<>
270
+ struct Qk_vec_acum_fp32_<bf16_8_t> {
271
+ using Type = Float8_;
272
+ };
273
+
274
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
275
+
276
+ template<typename T>
277
+ struct K_vec_acum_fp32_ {
278
+ };
279
+
280
+ template<>
281
+ struct K_vec_acum_fp32_<float> {
282
+ using Type = float;
283
+ };
284
+ template<>
285
+ struct K_vec_acum_fp32_<float2> {
286
+ using Type = float2;
287
+ };
288
+ template<>
289
+ struct K_vec_acum_fp32_<float4> {
290
+ using Type = float4;
291
+ };
292
+ template<>
293
+ struct K_vec_acum_fp32_<uint32_t> {
294
+ using Type = float2;
295
+ };
296
+ template<>
297
+ struct K_vec_acum_fp32_<uint2> {
298
+ using Type = Float4_;
299
+ };
300
+ template<>
301
+ struct K_vec_acum_fp32_<uint4> {
302
+ using Type = Float8_;
303
+ };
304
+ template<>
305
+ struct K_vec_acum_fp32_<__nv_bfloat16> {
306
+ using Type = float;
307
+ };
308
+ template<>
309
+ struct K_vec_acum_fp32_<__nv_bfloat162> {
310
+ using Type = float2;
311
+ };
312
+ template<>
313
+ struct K_vec_acum_fp32_<bf16_4_t> {
314
+ using Type = Float4_;
315
+ };
316
+ template<>
317
+ struct K_vec_acum_fp32_<bf16_8_t> {
318
+ using Type = Float8_;
319
+ };
320
+ #endif
321
+
322
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
323
+
324
+ #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
325
+ template<typename T>
326
+ struct V_vec_acum_fp32_ {
327
+ };
328
+
329
+ template<>
330
+ struct V_vec_acum_fp32_<float> {
331
+ using Type = float;
332
+ };
333
+ template<>
334
+ struct V_vec_acum_fp32_<float2> {
335
+ using Type = float2;
336
+ };
337
+ template<>
338
+ struct V_vec_acum_fp32_<float4> {
339
+ using Type = float4;
340
+ };
341
+ template<>
342
+ struct V_vec_acum_fp32_<uint32_t> {
343
+ using Type = float2;
344
+ };
345
+ template<>
346
+ struct V_vec_acum_fp32_<uint2> {
347
+ using Type = Float4_;
348
+ };
349
+ template<>
350
+ struct V_vec_acum_fp32_<uint4> {
351
+ using Type = Float8_;
352
+ };
353
+ #ifdef ENABLE_BF16
354
+ template<>
355
+ struct V_vec_acum_fp32_<__nv_bfloat162> {
356
+ using Type = float2;
357
+ };
358
+ template<>
359
+ struct V_vec_acum_fp32_<bf16_4_t> {
360
+ using Type = Float4_;
361
+ };
362
+ template<>
363
+ struct V_vec_acum_fp32_<bf16_8_t> {
364
+ using Type = Float8_;
365
+ };
366
+ #endif // ENABLE_BF16
367
+ #endif
368
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
369
+
370
+ template<int THREADS_PER_KEY, typename K_vec, int N>
371
+ inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
372
+ {
373
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
374
+ using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
375
+ #else
376
+ using K_vec_acum = K_vec;
377
+ #endif
378
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
379
+ K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
380
+ #pragma unroll
381
+ for (int ii = 1; ii < N; ++ii) {
382
+ qk_vec = fma(q[ii], k[ii], qk_vec);
383
+ }
384
+
385
+ // Finalize the reduction across lanes.
386
+ float qk = sum(qk_vec);
387
+ #pragma unroll
388
+ for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
389
+ qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
390
+ }
391
+ return qk;
392
+ }
393
+
394
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
395
+
396
+ template<typename T, int THREADS_PER_KEY>
397
+ struct Qk_dot {
398
+ template<typename K_vec, int N>
399
+ static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
400
+ {
401
+ return qk_dot_<THREADS_PER_KEY>(q, k);
402
+ }
403
+ };
404
+
405
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
406
+
407
+ inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
408
+ {
409
+ float4 c;
410
+ float zero = 0.f;
411
+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
412
+ " {%0, %1, %2, %3}, \n"
413
+ " {%4, %5}, \n"
414
+ " {%6}, \n"
415
+ " {%7, %7, %7, %7}; \n"
416
+
417
+ : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
418
+ : "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
419
+ return c;
420
+ }
421
+
422
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
423
+
424
+ template<int N>
425
+ inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
426
+ {
427
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
428
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
429
+ using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
430
+ #else
431
+ using K_vec_acum = uint32_t;
432
+ #endif
433
+ K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
434
+ #pragma unroll
435
+ for (int ii = 1; ii < N; ++ii) {
436
+ qk_vec = fma(q[ii], k[ii], qk_vec);
437
+ }
438
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
439
+ uint32_t qk_vec_ = float2_to_half2(qk_vec);
440
+ return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
441
+ #else
442
+ return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
443
+ #endif
444
+ #else
445
+ return 0.f;
446
+ #endif
447
+ }
448
+
449
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
450
+
451
+ template<>
452
+ struct Qk_dot<uint16_t, 4> {
453
+ template<int N>
454
+ static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
455
+ {
456
+ #if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
457
+ return qk_hmma_dot_(q, k);
458
+ #else
459
+ return qk_dot_<4>(q, k);
460
+ #endif // defined MMHA_USE_HMMA_FOR_REDUCTION
461
+ }
462
+ };
463
+
464
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
465
+
466
+ template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
467
+ inline __device__ float block_sum(float* red_smem, float sum)
468
+ {
469
+
470
+ // Decompose the thread index into warp / lane.
471
+ int warp = threadIdx.x / WARP_SIZE;
472
+ int lane = threadIdx.x % WARP_SIZE;
473
+
474
+ // Compute the sum per warp.
475
+ #pragma unroll
476
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
477
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
478
+ }
479
+
480
+ // Warp leaders store the data to shared memory.
481
+ if (lane == 0) {
482
+ red_smem[warp] = sum;
483
+ }
484
+
485
+ // Make sure the data is in shared memory.
486
+ __syncthreads();
487
+
488
+ // The warps compute the final sums.
489
+ if (lane < WARPS_PER_BLOCK) {
490
+ sum = red_smem[lane];
491
+ }
492
+
493
+ // Parallel reduction inside the warp.
494
+ #pragma unroll
495
+ for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
496
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
497
+ }
498
+
499
+ // Broadcast to other threads.
500
+ return __shfl_sync(uint32_t(-1), sum, 0);
501
+ }
502
+
503
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
504
+
505
+ inline __device__ void convert_from_float(float& dst, float src)
506
+ {
507
+ dst = src;
508
+ }
509
+
510
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
511
+
512
+ inline __device__ void convert_from_float(uint16_t& dst, float src)
513
+ {
514
+ dst = float_to_half(src);
515
+ }
516
+
517
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
518
+
519
+ inline __device__ void convert_from_float(uint32_t& dst, float2 src)
520
+ {
521
+ dst = float2_to_half2(src);
522
+ }
523
+
524
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
525
+ #ifdef ENABLE_BF16
526
+ inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
527
+ {
528
+ dst = __float2bfloat16(src);
529
+ }
530
+
531
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
532
+
533
+ inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
534
+ {
535
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
536
+ dst = __float22bfloat162_rn(src);
537
+ #else
538
+ dst = __floats2bfloat162_rn(src.x, src.y);
539
+ #endif
540
+ }
541
+ #endif // ENABLE_BF16
542
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
543
+
544
+ inline __device__ void convert_from_float(uint2& dst, Float4_ src)
545
+ {
546
+ dst.x = float2_to_half2(src.x);
547
+ dst.y = float2_to_half2(src.y);
548
+ }
549
+
550
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
551
+
552
+ inline __device__ void convert_from_float(uint2& dst, float4 src)
553
+ {
554
+ convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
555
+ }
556
+
557
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
558
+
559
+ inline __device__ void convert_from_float(uint4& dst, Float8_ src)
560
+ {
561
+ dst.x = float2_to_half2(src.x);
562
+ dst.y = float2_to_half2(src.y);
563
+ dst.z = float2_to_half2(src.z);
564
+ dst.w = float2_to_half2(src.w);
565
+ }
566
+
567
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
568
+
569
+ #ifdef ENABLE_BF16
570
+ inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
571
+ {
572
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
573
+ dst.x = __float22bfloat162_rn(src.x);
574
+ dst.y = __float22bfloat162_rn(src.y);
575
+ #else
576
+ dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
577
+ dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
578
+ #endif
579
+ }
580
+
581
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
582
+
583
+ inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
584
+ {
585
+ convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
586
+ }
587
+
588
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
589
+
590
+ inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
591
+ {
592
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
593
+ dst.x = __float22bfloat162_rn(src.x);
594
+ dst.y = __float22bfloat162_rn(src.y);
595
+ dst.z = __float22bfloat162_rn(src.z);
596
+ dst.w = __float22bfloat162_rn(src.w);
597
+ #else
598
+ dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
599
+ dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
600
+ dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
601
+ dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
602
+ #endif
603
+ }
604
+ #endif // ENABLE_BF16
605
+
606
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
607
+
608
+ inline __device__ void convert_from_float(float2& dst, float2 src)
609
+ {
610
+ dst = src;
611
+ }
612
+
613
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
614
+
615
+ inline __device__ void convert_from_float(float4& dst, float4 src)
616
+ {
617
+ dst = src;
618
+ }
619
+
620
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
621
+
622
+ inline __device__ float convert_to_float(float4 u)
623
+ {
624
+ return u.x;
625
+ }
626
+
627
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
628
+
629
+ inline __device__ float convert_to_float(uint4 u)
630
+ {
631
+ float2 tmp = half2_to_float2(u.x);
632
+ return tmp.x;
633
+ }
634
+
635
+ #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
636
+
637
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
638
+
639
+ inline __device__ float cast_to_float(float u)
640
+ {
641
+ return u;
642
+ }
643
+
644
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
645
+
646
+ inline __device__ float2 cast_to_float(float2 u)
647
+ {
648
+ return u;
649
+ }
650
+
651
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
652
+
653
+ inline __device__ float4 cast_to_float(float4 u)
654
+ {
655
+ return u;
656
+ }
657
+
658
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
659
+
660
+ inline __device__ Float4_ cast_to_float(Float4_ u)
661
+ {
662
+ return u;
663
+ }
664
+
665
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
666
+
667
+ inline __device__ Float8_ cast_to_float(Float8_ u)
668
+ {
669
+ return u;
670
+ }
671
+
672
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
673
+
674
+ inline __device__ float2 cast_to_float(uint32_t u)
675
+ {
676
+ return half2_to_float2(u);
677
+ }
678
+
679
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
680
+
681
+ inline __device__ Float4_ cast_to_float(uint2 u)
682
+ {
683
+ Float4_ tmp;
684
+ tmp.x = half2_to_float2(u.x);
685
+ tmp.y = half2_to_float2(u.y);
686
+ return tmp;
687
+ }
688
+
689
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
690
+
691
+ inline __device__ Float8_ cast_to_float(uint4 u)
692
+ {
693
+ Float8_ tmp;
694
+ tmp.x = half2_to_float2(u.x);
695
+ tmp.y = half2_to_float2(u.y);
696
+ tmp.z = half2_to_float2(u.z);
697
+ tmp.w = half2_to_float2(u.w);
698
+ return tmp;
699
+ }
700
+
701
+ #endif
702
+
703
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
704
+
705
+ inline __device__ float float_from_int8(int8_t u)
706
+ {
707
+ return u;
708
+ }
709
+
710
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
711
+
712
+ inline __device__ float2 float_from_int8(int16_t u)
713
+ {
714
+ union {
715
+ int16_t int16;
716
+ int8_t int8[2];
717
+ };
718
+ int16 = u;
719
+ return make_float2(int8[0], int8[1]);
720
+ }
721
+
722
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
723
+
724
+ inline __device__ float4 float_from_int8(int32_t u)
725
+ {
726
+ union {
727
+ int32_t int32;
728
+ int8_t int8[4];
729
+ };
730
+ int32 = u;
731
+ return make_float4(int8[0], int8[1], int8[2], int8[3]);
732
+ }
733
+
734
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
735
+
736
+ // clang-format off
737
+ inline __device__ Float8_ float_from_int8(int64_t u)
738
+ {
739
+ union {
740
+ int64_t int64;
741
+ int16_t int16[4];
742
+ };
743
+ int64 = u;
744
+ return Float8_ {float_from_int8(int16[0]),
745
+ float_from_int8(int16[1]),
746
+ float_from_int8(int16[2]),
747
+ float_from_int8(int16[3])};
748
+ }
749
+ // clang-format on
750
+
751
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
752
+
753
+ inline __device__ int8_t cast_to_int8(float val)
754
+ {
755
+ union {
756
+ int8_t int8[2];
757
+ int16_t int16;
758
+ };
759
+ asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
760
+ return int8[0];
761
+ }
762
+
763
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
764
+
765
+ inline __device__ int32_t cast_to_int8(float4 val)
766
+ {
767
+ union {
768
+ int8_t int8[4];
769
+ int32_t int32;
770
+ };
771
+ int8[0] = cast_to_int8(val.x);
772
+ int8[1] = cast_to_int8(val.y);
773
+ int8[2] = cast_to_int8(val.z);
774
+ int8[3] = cast_to_int8(val.w);
775
+ return int32;
776
+ }
777
+
778
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
779
+
780
+ inline __device__ int64_t cast_to_int8(Float8_ val)
781
+ {
782
+ union {
783
+ int8_t int8[8];
784
+ int64_t int64;
785
+ };
786
+ int8[0] = cast_to_int8(val.x.x);
787
+ int8[1] = cast_to_int8(val.x.y);
788
+ int8[2] = cast_to_int8(val.y.x);
789
+ int8[3] = cast_to_int8(val.y.y);
790
+ int8[4] = cast_to_int8(val.z.x);
791
+ int8[5] = cast_to_int8(val.z.y);
792
+ int8[6] = cast_to_int8(val.w.x);
793
+ int8[7] = cast_to_int8(val.w.y);
794
+ return int64;
795
+ }
796
+
797
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
798
+
799
+ template<typename T>
800
+ inline __device__ __host__ T div_up(T m, T n)
801
+ {
802
+ return (m + n - 1) / n;
803
+ }
804
+
805
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
806
+
807
+ template<typename T, bool DO_CROSS_ATTENTION>
808
+ inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
809
+ int threads_per_value,
810
+ int threads_per_block)
811
+ {
812
+ // The amount of shared memory needed to store the Q*K^T values in float.
813
+ const int max_timesteps = min(params.timestep, params.memory_max_len);
814
+ size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
815
+
816
+ // The extra memory needed if we are not using floats for the final logits.
817
+ size_t logits_sz = 0;
818
+ #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
819
+ if (sizeof(T) != 4) {
820
+ // TDOD
821
+ logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) :
822
+ div_up(max_timesteps + 1, 4) * 4 * sizeof(T);
823
+ }
824
+ #endif
825
+
826
+ // The total size needed during softmax.
827
+ size_t softmax_sz = qk_sz + logits_sz;
828
+
829
+ // The number of partial rows to reduce in the final reduction.
830
+ int rows_per_red = threads_per_block / threads_per_value;
831
+ // The amount of storage needed to finalize the outputs.
832
+ size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
833
+
834
+ size_t transpose_rotary_size = 0;
835
+ if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
836
+ transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
837
+ }
838
+
839
+ // The max.
840
+ return max(max(softmax_sz, red_sz), transpose_rotary_size);
841
+ }
842
+
843
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
844
+
845
+ inline __device__ constexpr uint32_t shfl_mask(int threads)
846
+ {
847
+ return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
848
+ }
849
+
850
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
851
+
852
+ template<
853
+ // The type of the inputs. Supported types: float and half.
854
+ typename T,
855
+ // The hidden dimension per head.
856
+ int Dh,
857
+ int Dh_MAX,
858
+ // The number of threads per key.
859
+ int THREADS_PER_KEY,
860
+ // The number of threads per value.
861
+ int THREADS_PER_VALUE,
862
+ // The number of threads in a threadblock.
863
+ int THREADS_PER_BLOCK,
864
+ bool DO_CROSS_ATTENTION>
865
+ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params)
866
+ {
867
+
868
+ // Make sure the hidden dimension per head is a multiple of the number of threads per key.
869
+ static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
870
+ // Make sure the hidden dimension per head is a multiple of the number of threads per value.
871
+ static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
872
+
873
+ // The size of a warp.
874
+ constexpr int WARP_SIZE = 32;
875
+ // The number of warps in a threadblock.
876
+ constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
877
+
878
+ // Use smem_size_in_bytes (above) to determine the amount of shared memory.
879
+ extern __shared__ char smem_[];
880
+
881
+ // The shared memory for the Q*K^T values and partial logits in softmax.
882
+ float* qk_smem = reinterpret_cast<float*>(smem_);
883
+
884
+ // The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
885
+ char* logits_smem_ = smem_;
886
+ #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
887
+ if (sizeof(T) != 4) {
888
+ // TODO - change to tlength
889
+ const int max_timesteps = min(params.timestep, params.memory_max_len);
890
+ logits_smem_ +=
891
+ (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
892
+ }
893
+ T* logits_smem = reinterpret_cast<T*>(logits_smem_);
894
+ #else
895
+ float* logits_smem = reinterpret_cast<float*>(logits_smem_);
896
+ #endif
897
+
898
+ // The shared memory to do the final reduction for the output values. Reuse qk_smem.
899
+ T* out_smem = reinterpret_cast<T*>(smem_);
900
+
901
+ // The shared memory buffers for the block-wide reductions. One for max, one for sum.
902
+ __shared__ float red_smem[WARPS_PER_BLOCK * 2];
903
+
904
+ // A vector of Q or K elements for the current timestep.
905
+ using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
906
+
907
+ // Use alignment for safely casting the shared buffers as Qk_vec.
908
+ // Shared memory to store Q inputs.
909
+ __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
910
+
911
+ // This is one of the reasons we should have a separate kernel for cross attention
912
+ __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
913
+
914
+ // A vector of Q or K elements for the current timestep.
915
+ using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
916
+ // The number of elements per vector.
917
+ constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
918
+ // Make sure the hidden size per head is a multiple of the vector size.
919
+ static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
920
+ // We will use block wide reduction if needed
921
+ // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
922
+ // The number of vectors per warp.
923
+ constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
924
+
925
+ // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
926
+ // owns x elements, we have to decompose the linear index into chunks of x values and the posi-
927
+ // tion of the thread in that chunk.
928
+
929
+ // The number of elements in a chunk of 16B (that's the x in the above formula).
930
+ constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
931
+ // The number of K vectors in 16B.
932
+ constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
933
+
934
+ // The batch/beam idx
935
+ const int bi = blockIdx.y;
936
+ if (params.finished != nullptr && params.finished[bi] == true) {
937
+ return;
938
+ }
939
+ // The beam idx
940
+ const int beami = bi % params.beam_width;
941
+ // The "beam-aware" batch idx
942
+ const int bbi = bi / params.beam_width;
943
+ // The head.
944
+ // const int hi = blockIdx.x;
945
+ const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
946
+ const int hi_kv = hi / params.num_heads_q_kv_ratio;
947
+ // Combine the batch and the head indices.
948
+ const int bhi = bi * params.num_heads + hi;
949
+ const int bhi_kv = bi * params.num_heads_kv + hi_kv;
950
+ // Combine the "beam-aware" batch idx and the head indices.
951
+ const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv;
952
+ // The thread in the block.
953
+ const int tidx = threadIdx.x;
954
+
955
+ const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);
956
+
957
+ // While doing the product Q*K^T for the different keys we track the max.
958
+ float qk_max = -FLT_MAX;
959
+
960
+ float qk = 0.0F;
961
+
962
+ int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh;
963
+ int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh;
964
+ int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh;
965
+
966
+ const size_t bi_seq_len_offset = bi * params.memory_max_len;
967
+
968
+ // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
969
+ int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
970
+ (params.length_per_sample == nullptr) ?
971
+ params.timestep :
972
+ params.length_per_sample[bi] + params.max_prefix_prompt_length;
973
+ const int first_step = max(0, tlength + 1 - params.memory_max_len);
974
+ const int tlength_circ = tlength % params.memory_max_len;
975
+
976
+ // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
977
+ const bool is_masked = tidx >= QK_VECS_PER_WARP;
978
+
979
+ // The offset in the Q and K buffer also accounts for the batch.
980
+ int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
981
+ int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
982
+ // The offset in the bias buffer.
983
+ int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
984
+ int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
985
+
986
+ const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
987
+ const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
988
+
989
+ // Trigger the loads from the Q and K buffers.
990
+ Qk_vec q;
991
+ zero(q);
992
+ if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
993
+ if (params.int8_mode == 2) {
994
+ using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
995
+ using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
996
+ const auto q_scaling = params.qkv_scale_out[0];
997
+ const auto q_quant =
998
+ *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
999
+
1000
+ convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
1001
+ }
1002
+ else {
1003
+ q = *reinterpret_cast<const Qk_vec*>(&params.q[q_offset]);
1004
+ }
1005
+ }
1006
+
1007
+ Qk_vec k;
1008
+ zero(k);
1009
+ if (DO_CROSS_ATTENTION) {
1010
+ // The 16B chunk written by the thread.
1011
+ int co = tidx / QK_VECS_IN_16B;
1012
+ // The position of the thread in that 16B chunk.
1013
+ int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
1014
+
1015
+ // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
1016
+ int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
1017
+ // params.timestep*QK_ELTS_IN_16B +
1018
+ tlength * QK_ELTS_IN_16B + ci;
1019
+ k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
1020
+ *reinterpret_cast<const Qk_vec*>(&params.k_cache[offset]) :
1021
+ k;
1022
+ }
1023
+ else {
1024
+ if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
1025
+ if (params.int8_mode == 2) {
1026
+ using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
1027
+ using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
1028
+ const auto k_scaling = params.qkv_scale_out[1];
1029
+ const auto k_quant =
1030
+ *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
1031
+
1032
+ convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
1033
+ }
1034
+ else {
1035
+ k = *reinterpret_cast<const Qk_vec*>(&params.k[k_offset]);
1036
+ }
1037
+ }
1038
+ }
1039
+
1040
+ // Trigger the loads from the Q and K bias buffers.
1041
+ Qk_vec q_bias;
1042
+ zero(q_bias);
1043
+ q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
1044
+ *reinterpret_cast<const Qk_vec*>(&params.q_bias[q_bias_offset]) :
1045
+ q_bias;
1046
+
1047
+ Qk_vec k_bias;
1048
+ zero(k_bias);
1049
+ if (handle_kv) {
1050
+ k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
1051
+ *reinterpret_cast<const Qk_vec*>(&params.k_bias[k_bias_offset]) :
1052
+ k_bias;
1053
+ }
1054
+
1055
+ // Computes the Q/K values with bias.
1056
+ q = add(q, q_bias);
1057
+ if (handle_kv) {
1058
+ k = add(k, k_bias);
1059
+ }
1060
+ if (do_ia3 && !is_masked) {
1061
+ k = mul<Qk_vec, Qk_vec, Qk_vec>(
1062
+ k,
1063
+ *reinterpret_cast<const Qk_vec*>(
1064
+ &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]));
1065
+ }
1066
+
1067
+ // Padded len
1068
+ const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
1069
+ if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
1070
+ if (handle_kv) {
1071
+ if (params.rotary_cos == nullptr) {
1072
+ apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
1073
+ } else {
1074
+ apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len,
1075
+ params.rotary_cos + bi * params.rotary_embedding_dim / 2,
1076
+ params.rotary_sin + bi * params.rotary_embedding_dim / 2);
1077
+ }
1078
+ }
1079
+ else {
1080
+ if (params.rotary_cos == nullptr) {
1081
+ apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
1082
+ } else {
1083
+ apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len,
1084
+ params.rotary_cos + bi * params.rotary_embedding_dim / 2,
1085
+ params.rotary_sin + bi * params.rotary_embedding_dim / 2);
1086
+ }
1087
+ }
1088
+ }
1089
+ else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
1090
+ const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
1091
+
1092
+ T* q_smem = reinterpret_cast<T*>(smem_);
1093
+ T* k_smem = q_smem + params.rotary_embedding_dim;
1094
+
1095
+ const int half_rotary_dim = params.rotary_embedding_dim / 2;
1096
+ const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim;
1097
+ const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim;
1098
+ const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
1099
+
1100
+ assert(half_rotary_dim % QK_VEC_SIZE == 0);
1101
+
1102
+ if (do_rotary) {
1103
+ *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
1104
+
1105
+ if (handle_kv) {
1106
+ *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
1107
+ }
1108
+ }
1109
+
1110
+ __syncthreads();
1111
+
1112
+ const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
1113
+ constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
1114
+ if (do_rotary) {
1115
+ mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
1116
+
1117
+ if (handle_kv) {
1118
+ mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
1119
+
1120
+ if (params.rotary_cos == nullptr) {
1121
+ mmha::apply_rotary_embedding(
1122
+ q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
1123
+ } else {
1124
+ mmha::apply_rotary_embedding(
1125
+ q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len,
1126
+ params.rotary_cos + bi * params.rotary_embedding_dim / 2,
1127
+ params.rotary_sin + bi * params.rotary_embedding_dim / 2);
1128
+ }
1129
+
1130
+ mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
1131
+ }
1132
+ else {
1133
+ if (params.rotary_cos == nullptr) {
1134
+ mmha::apply_rotary_embedding(
1135
+ q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
1136
+ } else {
1137
+ mmha::apply_rotary_embedding(
1138
+ q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength,
1139
+ params.rotary_cos + bi * params.rotary_embedding_dim / 2,
1140
+ params.rotary_sin + bi * params.rotary_embedding_dim / 2);
1141
+ }
1142
+ }
1143
+ mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
1144
+ }
1145
+
1146
+ __syncthreads();
1147
+
1148
+ if (do_rotary) {
1149
+ q = *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx);
1150
+ if (handle_kv) {
1151
+ k = *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx);
1152
+ }
1153
+ }
1154
+
1155
+ __syncthreads();
1156
+ }
1157
+
1158
+ if (!is_masked) {
1159
+ // Store the Q values to shared memory.
1160
+ *reinterpret_cast<Qk_vec*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
1161
+
1162
+ // Store Dh values of k_bias into smem, since will need to add later
1163
+ // if params.timestep == 0
1164
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
1165
+ *reinterpret_cast<Qk_vec*>(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
1166
+ }
1167
+
1168
+ // Write the K values to the global memory cache.
1169
+ //
1170
+ // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
1171
+ // system. We designed it this way as it allows much better memory loads (and there are many
1172
+ // more loads) + the stores are really "write and forget" since we won't need the ack before
1173
+ // the end of the kernel. There's plenty of time for the transactions to complete.
1174
+
1175
+ // The 16B chunk written by the thread.
1176
+ int co = tidx / QK_VECS_IN_16B;
1177
+ // The position of the thread in that 16B chunk.
1178
+ int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
1179
+
1180
+ // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
1181
+ int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
1182
+ // params.timestep*QK_ELTS_IN_16B +
1183
+ tlength_circ * QK_ELTS_IN_16B + ci;
1184
+
1185
+ if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) {
1186
+ // Trigger the stores to global memory.
1187
+ if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
1188
+ *reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
1189
+ }
1190
+ }
1191
+
1192
+ // Compute \sum_i Q[i] * K^T[i] for the current timestep.
1193
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
1194
+ using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec>::Type;
1195
+ #else
1196
+ using Qk_vec_acum = Qk_vec;
1197
+ #endif
1198
+ qk = dot<Qk_vec_acum, Qk_vec>(q, k);
1199
+ if (QK_VECS_PER_WARP <= WARP_SIZE) {
1200
+ #pragma unroll
1201
+ for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
1202
+ qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
1203
+ }
1204
+ }
1205
+ }
1206
+
1207
+ if (QK_VECS_PER_WARP > WARP_SIZE) {
1208
+ constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
1209
+ qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
1210
+ }
1211
+
1212
+ // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
1213
+ if (tidx == 0) {
1214
+ // Normalize qk.
1215
+ qk *= params.inv_sqrt_dh;
1216
+ if (params.relative_attention_bias != nullptr) {
1217
+ qk = add(qk,
1218
+ params.relative_attention_bias[hi * params.relative_attention_bias_stride
1219
+ * params.relative_attention_bias_stride
1220
+ + (tlength - padd_len) * params.relative_attention_bias_stride
1221
+ + (tlength - padd_len)]);
1222
+ }
1223
+ // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
1224
+
1225
+ qk_max = qk;
1226
+ qk_smem[tlength - first_step] = qk;
1227
+ // qk_smem[params.timestep] = qk;
1228
+ }
1229
+
1230
+ // Make sure the data is in shared memory.
1231
+ __syncthreads();
1232
+
1233
+ // The type of queries and keys for the math in the Q*K^T product.
1234
+ using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
1235
+ // The number of elements per vector.
1236
+ constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
1237
+ // Make sure the hidden size per head is a multiple of the vector size.
1238
+ static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
1239
+ // The number of elements per thread.
1240
+ constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
1241
+ // The number of vectors per thread.
1242
+ constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
1243
+
1244
+ // The position the first key loaded by each thread from the cache buffer (for this B * H).
1245
+ int ko = tidx / THREADS_PER_KEY;
1246
+ // The position of the thread in the chunk of keys.
1247
+ int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
1248
+
1249
+ static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
1250
+
1251
+ // Load the Q values from shared memory. The values are reused during the loop on K.
1252
+ K_vec q_vec[K_VECS_PER_THREAD];
1253
+ #pragma unroll
1254
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
1255
+ q_vec[ii] = *reinterpret_cast<const K_vec*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
1256
+ }
1257
+
1258
+ K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
1259
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
1260
+ #pragma unroll
1261
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
1262
+ k_bias_vec[ii] = *reinterpret_cast<const K_vec*>(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
1263
+ }
1264
+ }
1265
+
1266
+ // The number of timesteps loaded per iteration.
1267
+ constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
1268
+ // The number of keys per warp.
1269
+ constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
1270
+
1271
+ // The base pointer for the key in the cache buffer.
1272
+ T* k_cache = &params.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
1273
+ // Base pointer for the beam's batch, before offsetting with indirection buffer
1274
+ T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
1275
+
1276
+ // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
1277
+ // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
1278
+ int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
1279
+
1280
+ // prefix prompt length if has
1281
+ const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
1282
+
1283
+ // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
1284
+ const bool has_beams = params.cache_indir != nullptr;
1285
+ const int* beam_indices = has_beams ? &params.cache_indir[bi_seq_len_offset] : nullptr;
1286
+
1287
+ for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
1288
+ const int ti_circ = ti % params.memory_max_len;
1289
+
1290
+ // The keys loaded from the key cache.
1291
+ K_vec k[K_VECS_PER_THREAD];
1292
+ K_vec k_vec_zero;
1293
+ zero(k_vec_zero);
1294
+ #pragma unroll
1295
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
1296
+ int jj = ii * params.memory_max_len + ti_circ;
1297
+ // if( ti < params.timestep ) {
1298
+ const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
1299
+ if (ti < tlength) {
1300
+ if (!within_bounds) {
1301
+ k[ii] = k_vec_zero;
1302
+ }
1303
+ else {
1304
+ if (has_beams) {
1305
+ const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
1306
+ k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]);
1307
+ }
1308
+ else {
1309
+ k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[jj * QK_ELTS_IN_16B]);
1310
+ }
1311
+ }
1312
+ // add bias and update k_cache
1313
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
1314
+ k[ii] = add(k[ii], k_bias_vec[ii]);
1315
+
1316
+ if (do_ia3) {
1317
+ k[ii] = mul<K_vec, K_vec, K_vec>(
1318
+ k[ii],
1319
+ *reinterpret_cast<const K_vec*>(
1320
+ &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki
1321
+ + ii * THREADS_PER_KEY * K_VEC_SIZE]));
1322
+ }
1323
+
1324
+ if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) {
1325
+ *reinterpret_cast<K_vec*>(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
1326
+ }
1327
+ }
1328
+ }
1329
+ }
1330
+
1331
+ // Perform the dot product and normalize qk.
1332
+ //
1333
+ // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
1334
+ float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
1335
+ bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
1336
+
1337
+ // Store the product to shared memory. There's one qk value per timestep. Update the max.
1338
+ // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
1339
+ if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
1340
+ if (params.relative_attention_bias != nullptr) {
1341
+ qk = add(qk,
1342
+ params.relative_attention_bias[hi * params.relative_attention_bias_stride
1343
+ * params.relative_attention_bias_stride
1344
+ + tlength * params.relative_attention_bias_stride + ti]);
1345
+ }
1346
+ if (params.linear_bias_slopes != nullptr) {
1347
+ // Apply the linear position bias: (ki - qi) * slope[hi].
1348
+ // The padding token locates between the input context and the generated tokens.
1349
+ // We need to remove the number of padding tokens in the distance computation.
1350
+ // ti : 0 1 2 3 4 5 6 7 8 9(tlength)
1351
+ // token: i i i i p p p o o o where i=input, p=pad, o=output.
1352
+ // e.g. ti = 2, dist = (9 - 3) - 2 = 4.
1353
+ int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
1354
+ float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
1355
+
1356
+ qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
1357
+ }
1358
+ qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
1359
+ qk_smem[ti - first_step] = qk;
1360
+ }
1361
+ }
1362
+
1363
+ // Perform the final reduction to compute the max inside each warp.
1364
+ //
1365
+ // NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
1366
+ // group so it's not needed to run the reduction inside the group (again).
1367
+ #pragma unroll
1368
+ for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
1369
+ qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
1370
+ }
1371
+
1372
+ // Decompose the thread index into warp and lane.
1373
+ const int warp = tidx / WARP_SIZE;
1374
+ const int lane = tidx % WARP_SIZE;
1375
+
1376
+ // The warp leader writes the max to shared memory.
1377
+ if (lane == 0) {
1378
+ red_smem[warp] = qk_max;
1379
+ }
1380
+
1381
+ // Make sure the products are in shared memory.
1382
+ __syncthreads();
1383
+
1384
+ // The warps finalize the reduction.
1385
+ qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
1386
+ #pragma unroll
1387
+ for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
1388
+ qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
1389
+ }
1390
+
1391
+ // Broadcast to all the threads in the warp.
1392
+ qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
1393
+
1394
+ // Compute the logits and start the sum.
1395
+ float sum = 0.f;
1396
+ // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
1397
+ for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
1398
+ bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
1399
+ float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
1400
+ sum += logit;
1401
+ qk_smem[ti - first_step] = logit;
1402
+ }
1403
+
1404
+ // Compute the sum.
1405
+ sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
1406
+
1407
+ // Normalize the logits.
1408
+ float inv_sum = __fdividef(1.f, sum + 1.e-6f);
1409
+ // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
1410
+ const size_t cross_attention_out_offset =
1411
+ params.is_return_cross_attentions ?
1412
+ bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len :
1413
+ 0;
1414
+ for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
1415
+ float logit = qk_smem[ti - first_step] * inv_sum;
1416
+ if (params.is_return_cross_attentions) {
1417
+ params.cross_attention_out[cross_attention_out_offset + ti] = logit;
1418
+ }
1419
+ convert_from_float(logits_smem[ti - first_step], logit);
1420
+ }
1421
+
1422
+ // Put Values part below so we leverage __syncthreads
1423
+ // from the previous step
1424
+
1425
+ // The number of elements per vector.
1426
+ constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
1427
+ // A vector of V elements for the current timestep.
1428
+ using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
1429
+
1430
+ // The value computed by this thread.
1431
+ int vo = tidx / THREADS_PER_VALUE;
1432
+ // The hidden dimensions computed by this particular thread.
1433
+ int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
1434
+
1435
+ // The base pointer for the value in the cache buffer.
1436
+ T* v_cache = &params.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
1437
+ // Base pointer for the beam's batch, before offsetting with indirection buffer
1438
+ T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
1439
+
1440
+ // The number of values processed per iteration of the loop.
1441
+ constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
1442
+
1443
+ // One group of threads computes the product(s) for the current timestep.
1444
+ V_vec v_bias;
1445
+ zero(v_bias);
1446
+ // if( vo == params.timestep % V_PER_ITER ) {
1447
+ if (Dh == Dh_MAX || vi < Dh) {
1448
+ if (handle_kv) {
1449
+ if (vo == tlength % V_PER_ITER) {
1450
+ // Trigger the loads from the V bias buffer.
1451
+ if (params.v_bias != nullptr) {
1452
+ v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi_kv * Dh + vi]);
1453
+ }
1454
+ if (DO_CROSS_ATTENTION) {
1455
+ *reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
1456
+ }
1457
+ }
1458
+ }
1459
+ }
1460
+
1461
+ // From previous, before values, step
1462
+ // Also make sure the logits are in shared memory.
1463
+ __syncthreads();
1464
+
1465
+ // Values continued
1466
+ #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
1467
+ using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
1468
+ #else
1469
+ using V_vec_acum = V_vec;
1470
+ #endif
1471
+ // The partial outputs computed by each thread.
1472
+ V_vec_acum out;
1473
+ zero(out);
1474
+
1475
+ // Loop over the timesteps to compute the partial outputs.
1476
+ // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
1477
+ if (Dh == Dh_MAX || vi < Dh) {
1478
+ for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
1479
+ const int ti_circ = ti % params.memory_max_len;
1480
+
1481
+ // Fetch offset based on cache_indir when beam sampling
1482
+ const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
1483
+ const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh;
1484
+ // Load the values from the cache.
1485
+ V_vec v = *reinterpret_cast<const V_vec*>(&v_cache_batch[beam_offset + ti_circ * Dh]);
1486
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
1487
+ v = add(v, *reinterpret_cast<V_vec*>(&bias_smem[vi]));
1488
+ if (do_ia3) {
1489
+ v = mul<V_vec, V_vec, V_vec>(
1490
+ v,
1491
+ *reinterpret_cast<const V_vec*>(
1492
+ &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
1493
+ }
1494
+ *reinterpret_cast<V_vec*>(&v_cache[ti * Dh]) = v;
1495
+ }
1496
+ // Load the logits from shared memory.
1497
+ #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
1498
+ float logit = logits_smem[ti - first_step];
1499
+ out = fma(logit, cast_to_float(v), out);
1500
+ #else
1501
+ T logit = logits_smem[ti - first_step];
1502
+
1503
+ // Update the partial sums.
1504
+ out = fma(logit, v, out);
1505
+ #endif
1506
+ }
1507
+ }
1508
+
1509
+ // One group of threads computes the product(s) for the current timestep.
1510
+ // if( vo == params.timestep % V_PER_ITER ) {
1511
+ if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
1512
+
1513
+ V_vec v;
1514
+ if (DO_CROSS_ATTENTION) {
1515
+ v = *reinterpret_cast<const V_vec*>(&v_cache[tlength * Dh]);
1516
+ }
1517
+ else {
1518
+ // Trigger the loads from the V buffer.
1519
+ const auto v_offset = v_base_offset + vi;
1520
+ if (params.int8_mode == 2) {
1521
+ using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
1522
+ using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
1523
+ const auto v_scaling = params.qkv_scale_out[2];
1524
+ const auto v_quant =
1525
+ *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
1526
+
1527
+ convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
1528
+ }
1529
+ else {
1530
+ v = *reinterpret_cast<const V_vec*>(&params.v[v_offset]);
1531
+ }
1532
+ // Trigger the loads from the V bias buffer.
1533
+ // V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
1534
+ }
1535
+
1536
+ // Compute the V values with bias.
1537
+ if (handle_kv) {
1538
+ v = add(v, v_bias);
1539
+
1540
+ if (do_ia3) {
1541
+ v = mul<V_vec, V_vec, V_vec>(
1542
+ v,
1543
+ *reinterpret_cast<const V_vec*>(
1544
+ &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
1545
+ }
1546
+
1547
+ // Store the values with bias back to global memory in the cache for V.
1548
+ if (hi % params.num_heads_q_kv_ratio == 0) {
1549
+ //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
1550
+ *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
1551
+ }
1552
+ }
1553
+
1554
+ // Initialize the output value with the current timestep.
1555
+ #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
1556
+ // out = fma(logits_smem[params.timestep], cast_to_float(v), out);
1557
+ out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
1558
+ #else
1559
+ // out = fma(logits_smem[params.timestep], v, out);
1560
+ out = fma(logits_smem[tlength - first_step], v, out);
1561
+ #endif
1562
+ }
1563
+
1564
+ // Make sure we can start writing to shared memory.
1565
+ __syncthreads();
1566
+
1567
+ // Run the final reduction amongst the different groups computing different partial outputs.
1568
+ if (Dh == Dh_MAX || vi < Dh) {
1569
+ #pragma unroll
1570
+ for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
1571
+
1572
+ // The midpoint in the number of active groups.
1573
+ int midpoint = active_groups / 2;
1574
+
1575
+ // The upper part of active threads store to shared memory.
1576
+ if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
1577
+ #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
1578
+ convert_from_float(*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
1579
+ #else
1580
+ *reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
1581
+ #endif
1582
+ }
1583
+ __syncthreads();
1584
+
1585
+ // The bottom warps update their values.
1586
+ if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
1587
+ out = add(*reinterpret_cast<const V_vec*>(&out_smem[vo * Dh + vi]), out);
1588
+ }
1589
+ __syncthreads();
1590
+ }
1591
+ }
1592
+
1593
+ // Output the final values.
1594
+ if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
1595
+ #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
1596
+ if (params.int8_mode == 2) {
1597
+ using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
1598
+ out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
1599
+ *reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
1600
+ cast_to_int8(out);
1601
+ }
1602
+ else {
1603
+ convert_from_float(*reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]), out);
1604
+ }
1605
+ #else
1606
+ // TODO: support int8_mode?
1607
+ *reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]) = out;
1608
+ #endif
1609
+ }
1610
+ }
1611
+
1612
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1613
+
1614
+ } // namespace mmha
1615
+
1616
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1617
+
1618
+ template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
1619
+ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);
decoder_masked_multihead_attention_utils.h ADDED
@@ -0,0 +1,2017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
3
+ /*
4
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ #include "cuda_bf16_wrapper.h"
22
+ #include "cuda_bf16_fallbacks.cuh"
23
+ #include <stdint.h>
24
+
25
+ using namespace fastertransformer;
26
+
27
+ namespace mmha {
28
+
29
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
30
+
31
+ struct Float8_ {
32
+ float2 x;
33
+ float2 y;
34
+ float2 z;
35
+ float2 w;
36
+ };
37
+
38
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
39
+
40
+ struct Float4_ {
41
+ float2 x;
42
+ float2 y;
43
+ };
44
+
45
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ #ifdef ENABLE_BF16
48
+ struct bf16_4_t {
49
+ __nv_bfloat162 x;
50
+ __nv_bfloat162 y;
51
+ };
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ struct bf16_8_t {
56
+ __nv_bfloat162 x;
57
+ __nv_bfloat162 y;
58
+ __nv_bfloat162 z;
59
+ __nv_bfloat162 w;
60
+ };
61
+ #endif
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
64
+
65
+ template<typename T>
66
+ struct num_elems;
67
+ template<>
68
+ struct num_elems<float> {
69
+ static constexpr int value = 1;
70
+ };
71
+ template<>
72
+ struct num_elems<float2> {
73
+ static constexpr int value = 2;
74
+ };
75
+ template<>
76
+ struct num_elems<float4> {
77
+ static constexpr int value = 4;
78
+ };
79
+ template<>
80
+ struct num_elems<Float4_> {
81
+ static constexpr int value = 4;
82
+ };
83
+ template<>
84
+ struct num_elems<Float8_> {
85
+ static constexpr int value = 8;
86
+ };
87
+
88
+ template<>
89
+ struct num_elems<uint32_t> {
90
+ static constexpr int value = 2;
91
+ };
92
+ template<>
93
+ struct num_elems<uint2> {
94
+ static constexpr int value = 4;
95
+ };
96
+ template<>
97
+ struct num_elems<uint4> {
98
+ static constexpr int value = 8;
99
+ };
100
+
101
+ #ifdef ENABLE_BF16
102
+ template<>
103
+ struct num_elems<__nv_bfloat162> {
104
+ static constexpr int value = 2;
105
+ };
106
+ template<>
107
+ struct num_elems<bf16_4_t> {
108
+ static constexpr int value = 4;
109
+ };
110
+ template<>
111
+ struct num_elems<bf16_8_t> {
112
+ static constexpr int value = 8;
113
+ };
114
+ #endif
115
+
116
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
117
+
118
+ template<typename T, int N>
119
+ struct packed_type;
120
+ template<typename T>
121
+ struct packed_type<T, 1> {
122
+ using type = T;
123
+ };
124
+ template<>
125
+ struct packed_type<int8_t, 2> {
126
+ using type = int16_t;
127
+ };
128
+ template<>
129
+ struct packed_type<int8_t, 4> {
130
+ using type = int32_t;
131
+ };
132
+ template<>
133
+ struct packed_type<int8_t, 8> {
134
+ using type = int64_t;
135
+ };
136
+
137
+ template<>
138
+ struct packed_type<float, 2> {
139
+ using type = float2;
140
+ };
141
+ template<>
142
+ struct packed_type<float, 4> {
143
+ using type = float4;
144
+ };
145
+ template<>
146
+ struct packed_type<float, 8> {
147
+ using type = Float8_;
148
+ };
149
+
150
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
151
+
152
+ inline __device__ float add(float a, float b)
153
+ {
154
+ return a + b;
155
+ }
156
+
157
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
158
+
159
+ inline __device__ float2 add(float2 a, float2 b)
160
+ {
161
+ float2 c;
162
+ c.x = add(a.x, b.x);
163
+ c.y = add(a.y, b.y);
164
+ return c;
165
+ }
166
+
167
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
168
+
169
+ inline __device__ float4 add(float4 a, float4 b)
170
+ {
171
+ float4 c;
172
+ c.x = add(a.x, b.x);
173
+ c.y = add(a.y, b.y);
174
+ c.z = add(a.z, b.z);
175
+ c.w = add(a.w, b.w);
176
+ return c;
177
+ }
178
+
179
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
180
+
181
+ #ifdef ENABLE_BF16
182
+ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
183
+ {
184
+ return a + b;
185
+ }
186
+
187
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
188
+
189
+ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
190
+ {
191
+ return bf16hadd2(a, b);
192
+ }
193
+
194
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
195
+
196
+ inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
197
+ {
198
+ bf16_4_t c;
199
+ c.x = add(a.x, b.x);
200
+ c.y = add(a.y, b.y);
201
+ return c;
202
+ }
203
+
204
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
205
+
206
+ inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
207
+ {
208
+ bf16_8_t c;
209
+ c.x = add(a.x, b.x);
210
+ c.y = add(a.y, b.y);
211
+ c.z = add(a.z, b.z);
212
+ c.w = add(a.w, b.w);
213
+ return c;
214
+ }
215
+ #endif // ENABLE_BF16
216
+
217
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
218
+
219
+ inline __device__ uint16_t add(uint16_t a, uint16_t b)
220
+ {
221
+ uint16_t c;
222
+ asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
223
+ return c;
224
+ }
225
+
226
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
227
+
228
+ inline __device__ uint32_t add(uint32_t a, uint32_t b)
229
+ {
230
+ uint32_t c;
231
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
232
+ return c;
233
+ }
234
+
235
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
236
+
237
+ inline __device__ uint2 add(uint2 a, uint2 b)
238
+ {
239
+ uint2 c;
240
+ c.x = add(a.x, b.x);
241
+ c.y = add(a.y, b.y);
242
+ return c;
243
+ }
244
+
245
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
246
+
247
+ inline __device__ uint4 add(uint4 a, uint4 b)
248
+ {
249
+ uint4 c;
250
+ c.x = add(a.x, b.x);
251
+ c.y = add(a.y, b.y);
252
+ c.z = add(a.z, b.z);
253
+ c.w = add(a.w, b.w);
254
+ return c;
255
+ }
256
+
257
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
258
+
259
+ inline __device__ uint16_t float_to_half(float f)
260
+ {
261
+ union {
262
+ uint32_t u32;
263
+ uint16_t u16[2];
264
+ } tmp;
265
+ #if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
266
+ float zero = 0.f;
267
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
268
+ #else
269
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
270
+ #endif
271
+ return tmp.u16[0];
272
+ }
273
+
274
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
275
+
276
+ inline __device__ uint32_t float2_to_half2(float2 f)
277
+ {
278
+ union {
279
+ uint32_t u32;
280
+ uint16_t u16[2];
281
+ } tmp;
282
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
283
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
284
+ #else
285
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
286
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
287
+ #endif
288
+ return tmp.u32;
289
+ }
290
+
291
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
292
+
293
+ inline __device__ float half_to_float(uint16_t h)
294
+ {
295
+ float f;
296
+ asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
297
+ return f;
298
+ }
299
+
300
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
301
+
302
+ inline __device__ float2 half2_to_float2(uint32_t v)
303
+ {
304
+ uint16_t lo, hi;
305
+ asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
306
+ return make_float2(half_to_float(lo), half_to_float(hi));
307
+ }
308
+
309
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
310
+
311
+ inline __device__ float add(float a, uint16_t b)
312
+ {
313
+ return a + half_to_float(b);
314
+ }
315
+
316
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
317
+
318
+ #ifdef ENABLE_BF16
319
+ inline __device__ float add(float a, __nv_bfloat16 b)
320
+ {
321
+ return a + __bfloat162float(b);
322
+ }
323
+ #endif
324
+
325
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
326
+
327
+ inline __device__ float2 add(uint32_t a, float2 fb)
328
+ {
329
+ float2 fa = half2_to_float2(a);
330
+ return add(fa, fb);
331
+ }
332
+
333
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
334
+
335
+ inline __device__ Float4_ add(uint2 a, Float4_ fb)
336
+ {
337
+ Float4_ fc;
338
+ fc.x = add(a.x, fb.x);
339
+ fc.y = add(a.y, fb.y);
340
+ return fc;
341
+ }
342
+
343
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
344
+
345
+ inline __device__ Float8_ add(uint4 a, Float8_ fb)
346
+ {
347
+ Float8_ fc;
348
+ fc.x = add(a.x, fb.x);
349
+ fc.y = add(a.y, fb.y);
350
+ fc.z = add(a.z, fb.z);
351
+ fc.w = add(a.w, fb.w);
352
+ return fc;
353
+ }
354
+
355
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
356
+
357
+ inline __device__ uint32_t h0_h0(uint16_t a)
358
+ {
359
+ uint32_t b;
360
+ asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
361
+ return b;
362
+ }
363
+
364
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
365
+
366
+ inline __device__ float fma(float a, float b, float c)
367
+ {
368
+ return a * b + c;
369
+ }
370
+
371
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
372
+
373
+ inline __device__ float2 fma(float2 a, float2 b, float2 c)
374
+ {
375
+ float2 d;
376
+ d.x = fma(a.x, b.x, c.x);
377
+ d.y = fma(a.y, b.y, c.y);
378
+ return d;
379
+ }
380
+
381
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
382
+
383
+ inline __device__ float2 fma(float a, float2 b, float2 c)
384
+ {
385
+ float2 d;
386
+ d.x = fma(a, b.x, c.x);
387
+ d.y = fma(a, b.y, c.y);
388
+ return d;
389
+ }
390
+
391
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
392
+
393
+ inline __device__ float4 fma(float4 a, float4 b, float4 c)
394
+ {
395
+ float4 d;
396
+ d.x = fma(a.x, b.x, c.x);
397
+ d.y = fma(a.y, b.y, c.y);
398
+ d.z = fma(a.z, b.z, c.z);
399
+ d.w = fma(a.w, b.w, c.w);
400
+ return d;
401
+ }
402
+
403
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
404
+
405
+ inline __device__ float4 fma(float a, float4 b, float4 c)
406
+ {
407
+ float4 d;
408
+ d.x = fma(a, b.x, c.x);
409
+ d.y = fma(a, b.y, c.y);
410
+ d.z = fma(a, b.z, c.z);
411
+ d.w = fma(a, b.w, c.w);
412
+ return d;
413
+ }
414
+
415
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
416
+
417
+ inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c)
418
+ {
419
+ Float4_ d;
420
+ d.x = fma(a, b.x, c.x);
421
+ d.y = fma(a, b.y, c.y);
422
+ return d;
423
+ }
424
+
425
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
426
+
427
+ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
428
+ {
429
+ Float8_ d;
430
+ d.x = fma(a, b.x, c.x);
431
+ d.y = fma(a, b.y, c.y);
432
+ d.z = fma(a, b.z, c.z);
433
+ d.w = fma(a, b.w, c.w);
434
+ return d;
435
+ }
436
+
437
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
438
+
439
+ #ifdef ENABLE_BF16
440
+ inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
441
+ {
442
+ float2 fa = bf1622float2(a);
443
+ return add(fa, fb);
444
+ }
445
+
446
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
447
+
448
+ inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
449
+ {
450
+ Float4_ fc;
451
+ fc.x = add(a.x, fb.x);
452
+ fc.y = add(a.y, fb.y);
453
+ return fc;
454
+ }
455
+
456
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
457
+
458
+ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
459
+ {
460
+ Float8_ fc;
461
+ fc.x = add(a.x, fb.x);
462
+ fc.y = add(a.y, fb.y);
463
+ fc.z = add(a.z, fb.z);
464
+ fc.w = add(a.w, fb.w);
465
+ return fc;
466
+ }
467
+ #endif // ENABLE_BF16
468
+
469
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
470
+
471
+ inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
472
+ {
473
+ uint32_t d;
474
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
475
+ return d;
476
+ }
477
+
478
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
479
+
480
+ inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c)
481
+ {
482
+ return fma(h0_h0(a), b, c);
483
+ }
484
+
485
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
486
+
487
+ inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c)
488
+ {
489
+ uint2 d;
490
+ d.x = fma(a.x, b.x, c.x);
491
+ d.y = fma(a.y, b.y, c.y);
492
+ return d;
493
+ }
494
+
495
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
496
+
497
+ inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c)
498
+ {
499
+ uint32_t s = h0_h0(a);
500
+ uint2 d;
501
+ d.x = fma(s, b.x, c.x);
502
+ d.y = fma(s, b.y, c.y);
503
+ return d;
504
+ }
505
+
506
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
507
+
508
+ inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c)
509
+ {
510
+ uint4 d;
511
+ d.x = fma(a.x, b.x, c.x);
512
+ d.y = fma(a.y, b.y, c.y);
513
+ d.z = fma(a.z, b.z, c.z);
514
+ d.w = fma(a.w, b.w, c.w);
515
+ return d;
516
+ }
517
+
518
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
519
+
520
+ inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c)
521
+ {
522
+ uint32_t s = h0_h0(a);
523
+ uint4 d;
524
+ d.x = fma(s, b.x, c.x);
525
+ d.y = fma(s, b.y, c.y);
526
+ d.z = fma(s, b.z, c.z);
527
+ d.w = fma(s, b.w, c.w);
528
+ return d;
529
+ }
530
+
531
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
532
+
533
+ inline __device__ float fma(uint16_t a, uint16_t b, float fc)
534
+ {
535
+ float fa = half_to_float(a);
536
+ float fb = half_to_float(b);
537
+ return fa * fb + fc;
538
+ }
539
+
540
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
541
+
542
+ inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc)
543
+ {
544
+ float2 fa = half2_to_float2(a);
545
+ float2 fb = half2_to_float2(b);
546
+ return fma(fa, fb, fc);
547
+ }
548
+
549
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
550
+
551
+ inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc)
552
+ {
553
+ return fma(h0_h0(a), b, fc);
554
+ }
555
+
556
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
557
+
558
+ inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc)
559
+ {
560
+ Float4_ fd;
561
+ fd.x = fma(a.x, b.x, fc.x);
562
+ fd.y = fma(a.y, b.y, fc.y);
563
+ return fd;
564
+ }
565
+
566
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
567
+
568
+ inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc)
569
+ {
570
+ uint32_t s = h0_h0(a);
571
+ Float4_ fd;
572
+ fd.x = fma(s, b.x, fc.x);
573
+ fd.y = fma(s, b.y, fc.y);
574
+ return fd;
575
+ }
576
+
577
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
578
+
579
+ inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc)
580
+ {
581
+ Float8_ fd;
582
+ fd.x = fma(a.x, b.x, fc.x);
583
+ fd.y = fma(a.y, b.y, fc.y);
584
+ fd.z = fma(a.z, b.z, fc.z);
585
+ fd.w = fma(a.w, b.w, fc.w);
586
+ return fd;
587
+ }
588
+
589
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
590
+
591
+ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
592
+ {
593
+ uint32_t s = h0_h0(a);
594
+ Float8_ fd;
595
+ fd.x = fma(s, b.x, fc.x);
596
+ fd.y = fma(s, b.y, fc.y);
597
+ fd.z = fma(s, b.z, fc.z);
598
+ fd.w = fma(s, b.w, fc.w);
599
+ return fd;
600
+ }
601
+
602
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
603
+ #ifdef ENABLE_BF16
604
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
605
+ {
606
+ return bf16hfma2(a, b, c);
607
+ }
608
+
609
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
610
+
611
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
612
+ {
613
+ return bf16hfma2(bf162bf162(a), b, c);
614
+ }
615
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
616
+
617
+ inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
618
+ {
619
+ bf16_4_t d;
620
+ d.x = fma(a.x, b.x, c.x);
621
+ d.y = fma(a.y, b.y, c.y);
622
+ return d;
623
+ }
624
+
625
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
626
+
627
+ inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
628
+ {
629
+ __nv_bfloat162 s = bf162bf162(a);
630
+ bf16_4_t d;
631
+ d.x = fma(s, b.x, c.x);
632
+ d.y = fma(s, b.y, c.y);
633
+ return d;
634
+ }
635
+
636
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
637
+
638
+ inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
639
+ {
640
+ bf16_8_t d;
641
+ d.x = fma(a.x, b.x, c.x);
642
+ d.y = fma(a.y, b.y, c.y);
643
+ d.z = fma(a.z, b.z, c.z);
644
+ d.w = fma(a.w, b.w, c.w);
645
+ return d;
646
+ }
647
+
648
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
649
+
650
+ inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
651
+ {
652
+ __nv_bfloat162 s = bf162bf162(a);
653
+ bf16_8_t d;
654
+ d.x = fma(s, b.x, c.x);
655
+ d.y = fma(s, b.y, c.y);
656
+ d.z = fma(s, b.z, c.z);
657
+ d.w = fma(s, b.w, c.w);
658
+ return d;
659
+ }
660
+
661
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
662
+
663
+ inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
664
+ {
665
+ return __bfloat162float(a) * __bfloat162float(b) + fc;
666
+ }
667
+
668
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
669
+
670
+ inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
671
+ {
672
+ float2 fa = bf1622float2(a);
673
+ float2 fb = bf1622float2(b);
674
+ return fma(fa, fb, fc);
675
+ }
676
+
677
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
678
+
679
+ inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
680
+ {
681
+ return fma(bf162bf162(a), b, fc);
682
+ }
683
+
684
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
685
+
686
+ inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
687
+ {
688
+ Float4_ fd;
689
+ fd.x = fma(a.x, b.x, fc.x);
690
+ fd.y = fma(a.y, b.y, fc.y);
691
+ return fd;
692
+ }
693
+
694
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
695
+
696
+ inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
697
+ {
698
+ __nv_bfloat162 s = bf162bf162(a);
699
+ Float4_ fd;
700
+ fd.x = fma(s, b.x, fc.x);
701
+ fd.y = fma(s, b.y, fc.y);
702
+ return fd;
703
+ }
704
+
705
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
706
+
707
+ inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
708
+ {
709
+ Float8_ fd;
710
+ fd.x = fma(a.x, b.x, fc.x);
711
+ fd.y = fma(a.y, b.y, fc.y);
712
+ fd.z = fma(a.z, b.z, fc.z);
713
+ fd.w = fma(a.w, b.w, fc.w);
714
+ return fd;
715
+ }
716
+
717
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
718
+
719
+ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
720
+ {
721
+ __nv_bfloat162 s = bf162bf162(a);
722
+ Float8_ fd;
723
+ fd.x = fma(s, b.x, fc.x);
724
+ fd.y = fma(s, b.y, fc.y);
725
+ fd.z = fma(s, b.z, fc.z);
726
+ fd.w = fma(s, b.w, fc.w);
727
+ return fd;
728
+ }
729
+ #endif // ENABLE_BF16
730
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
731
+
732
+ template<typename Acc, typename A, typename B>
733
+ inline __device__ Acc mul(A a, B b)
734
+ {
735
+ return a * b;
736
+ }
737
+
738
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
739
+
740
+ template<>
741
+ inline __device__ float mul<float, float>(float a, float b)
742
+ {
743
+ return a * b;
744
+ }
745
+
746
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
747
+
748
+ template<>
749
+ inline __device__ float2 mul(float2 a, float2 b)
750
+ {
751
+ float2 c;
752
+ c.x = a.x * b.x;
753
+ c.y = a.y * b.y;
754
+ return c;
755
+ }
756
+
757
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
758
+
759
+ template<>
760
+ inline __device__ float2 mul(float a, float2 b)
761
+ {
762
+ float2 c;
763
+ c.x = a * b.x;
764
+ c.y = a * b.y;
765
+ return c;
766
+ }
767
+
768
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
769
+
770
+ template<>
771
+ inline __device__ float4 mul(float4 a, float4 b)
772
+ {
773
+ float4 c;
774
+ c.x = a.x * b.x;
775
+ c.y = a.y * b.y;
776
+ c.z = a.z * b.z;
777
+ c.w = a.w * b.w;
778
+ return c;
779
+ }
780
+
781
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
782
+
783
+ template<>
784
+ inline __device__ float4 mul(float a, float4 b)
785
+ {
786
+ float4 c;
787
+ c.x = a * b.x;
788
+ c.y = a * b.y;
789
+ c.z = a * b.z;
790
+ c.w = a * b.w;
791
+ return c;
792
+ }
793
+
794
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
795
+
796
+ template<>
797
+ inline __device__ Float8_ mul(float a, Float8_ b)
798
+ {
799
+ Float8_ c;
800
+ c.x = make_float2(a * b.x.x, a * b.x.y);
801
+ c.y = make_float2(a * b.y.x, a * b.y.y);
802
+ c.z = make_float2(a * b.z.x, a * b.z.y);
803
+ c.w = make_float2(a * b.w.x, a * b.w.y);
804
+ return c;
805
+ }
806
+
807
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
808
+
809
+ template<>
810
+ inline __device__ uint16_t mul(uint16_t a, uint16_t b)
811
+ {
812
+ uint16_t c;
813
+ asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
814
+ return c;
815
+ }
816
+
817
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
818
+
819
+ template<>
820
+ inline __device__ uint32_t mul(uint32_t a, uint32_t b)
821
+ {
822
+ uint32_t c;
823
+ asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
824
+ return c;
825
+ }
826
+
827
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
828
+
829
+ template<>
830
+ inline __device__ uint32_t mul(uint16_t a, uint32_t b)
831
+ {
832
+ return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
833
+ }
834
+
835
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
836
+
837
+ template<>
838
+ inline __device__ uint2 mul(uint2 a, uint2 b)
839
+ {
840
+ uint2 c;
841
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
842
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
843
+ return c;
844
+ }
845
+
846
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
847
+
848
+ template<>
849
+ inline __device__ uint2 mul(uint16_t a, uint2 b)
850
+ {
851
+ uint32_t s = h0_h0(a);
852
+ uint2 c;
853
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
854
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
855
+ return c;
856
+ }
857
+
858
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
859
+
860
+ template<>
861
+ inline __device__ uint4 mul(uint4 a, uint4 b)
862
+ {
863
+ uint4 c;
864
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
865
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
866
+ c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
867
+ c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
868
+ return c;
869
+ }
870
+
871
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
872
+
873
+ template<>
874
+ inline __device__ uint4 mul(uint16_t a, uint4 b)
875
+ {
876
+ uint32_t s = h0_h0(a);
877
+ uint4 c;
878
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
879
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
880
+ c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
881
+ c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
882
+ return c;
883
+ }
884
+
885
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
886
+
887
+ template<>
888
+ inline __device__ float mul(uint16_t a, uint16_t b)
889
+ {
890
+ float fa = half_to_float(a);
891
+ float fb = half_to_float(b);
892
+ return fa * fb;
893
+ }
894
+
895
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
896
+
897
+ template<>
898
+ inline __device__ float mul(uint16_t a, float b)
899
+ {
900
+ return half_to_float(a) * b;
901
+ }
902
+
903
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
904
+
905
+ template<>
906
+ inline __device__ float2 mul(uint32_t a, uint32_t b)
907
+ {
908
+ float2 fa = half2_to_float2(a);
909
+ float2 fb = half2_to_float2(b);
910
+ return mul<float2, float2, float2>(fa, fb);
911
+ }
912
+
913
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
914
+
915
+ template<>
916
+ inline __device__ float2 mul(uint16_t a, uint32_t b)
917
+ {
918
+ return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
919
+ }
920
+
921
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
922
+
923
+ template<>
924
+ inline __device__ Float4_ mul(uint2 a, uint2 b)
925
+ {
926
+ Float4_ fc;
927
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
928
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
929
+ return fc;
930
+ }
931
+
932
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
933
+
934
+ template<>
935
+ inline __device__ Float4_ mul(uint16_t a, uint2 b)
936
+ {
937
+ uint32_t s = h0_h0(a);
938
+ Float4_ fc;
939
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
940
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
941
+ return fc;
942
+ }
943
+
944
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
945
+
946
+ template<>
947
+ inline __device__ Float8_ mul(uint4 a, uint4 b)
948
+ {
949
+ Float8_ fc;
950
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
951
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
952
+ fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
953
+ fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
954
+ return fc;
955
+ }
956
+
957
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
958
+
959
+ template<>
960
+ inline __device__ Float8_ mul(uint16_t a, uint4 b)
961
+ {
962
+ uint32_t s = h0_h0(a);
963
+ Float8_ fc;
964
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
965
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
966
+ fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
967
+ fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
968
+ return fc;
969
+ }
970
+
971
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
972
+
973
+ #ifdef ENABLE_BF16
974
+ template<>
975
+ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
976
+ {
977
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
978
+ return __hmul(a, b);
979
+ #else
980
+ return bf16hmul(a, b);
981
+ #endif
982
+ }
983
+
984
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
985
+
986
+ template<>
987
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
988
+ {
989
+ return bf16hmul2(a, b);
990
+ }
991
+
992
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
993
+
994
+ template<>
995
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
996
+ {
997
+ return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
998
+ }
999
+
1000
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1001
+
1002
+ template<>
1003
+ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
1004
+ {
1005
+ bf16_4_t c;
1006
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
1007
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
1008
+ return c;
1009
+ }
1010
+
1011
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1012
+
1013
+ template<>
1014
+ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
1015
+ {
1016
+ __nv_bfloat162 s = bf162bf162(a);
1017
+ bf16_4_t c;
1018
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
1019
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
1020
+ return c;
1021
+ }
1022
+
1023
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1024
+
1025
+ template<>
1026
+ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
1027
+ {
1028
+ bf16_8_t c;
1029
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
1030
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
1031
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
1032
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
1033
+ return c;
1034
+ }
1035
+
1036
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1037
+
1038
+ template<>
1039
+ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b)
1040
+ {
1041
+ __nv_bfloat162 s = bf162bf162(a);
1042
+ bf16_8_t c;
1043
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
1044
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
1045
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
1046
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
1047
+ return c;
1048
+ }
1049
+
1050
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1051
+
1052
+ template<>
1053
+ inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
1054
+ {
1055
+ float fa = (float)a;
1056
+ float fb = (float)b;
1057
+ return fa * fb;
1058
+ }
1059
+
1060
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1061
+
1062
+ template<>
1063
+ inline __device__ float mul(__nv_bfloat16 a, float b)
1064
+ {
1065
+ return __bfloat162float(a) * b;
1066
+ }
1067
+
1068
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1069
+
1070
+ template<>
1071
+ inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
1072
+ {
1073
+ float2 fa = bf1622float2(a);
1074
+ float2 fb = bf1622float2(b);
1075
+ return mul<float2, float2, float2>(fa, fb);
1076
+ }
1077
+
1078
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1079
+
1080
+ template<>
1081
+ inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
1082
+ {
1083
+ return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
1084
+ }
1085
+
1086
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1087
+
1088
+ template<>
1089
+ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
1090
+ {
1091
+ Float4_ fc;
1092
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
1093
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
1094
+ return fc;
1095
+ }
1096
+
1097
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1098
+
1099
+ template<>
1100
+ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
1101
+ {
1102
+ __nv_bfloat162 s = bf162bf162(a);
1103
+ Float4_ fc;
1104
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
1105
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
1106
+ return fc;
1107
+ }
1108
+
1109
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1110
+
1111
+ template<>
1112
+ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
1113
+ {
1114
+ Float8_ fc;
1115
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
1116
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
1117
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
1118
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
1119
+ return fc;
1120
+ }
1121
+
1122
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1123
+
1124
+ template<>
1125
+ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
1126
+ {
1127
+ __nv_bfloat162 s = bf162bf162(a);
1128
+ Float8_ fc;
1129
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
1130
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
1131
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
1132
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
1133
+ return fc;
1134
+ }
1135
+ #endif // ENABLE_BF16
1136
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1137
+
1138
+ inline __device__ float sum(float v)
1139
+ {
1140
+ return v;
1141
+ }
1142
+
1143
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1144
+
1145
+ inline __device__ float sum(float2 v)
1146
+ {
1147
+ return v.x + v.y;
1148
+ }
1149
+
1150
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1151
+
1152
+ inline __device__ float sum(float4 v)
1153
+ {
1154
+ return v.x + v.y + v.z + v.w;
1155
+ }
1156
+
1157
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1158
+
1159
+ #ifdef ENABLE_BF16
1160
+ inline __device__ float sum(__nv_bfloat162 v)
1161
+ {
1162
+ float2 vf = bf1622float2(v);
1163
+ return vf.x + vf.y;
1164
+ }
1165
+
1166
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1167
+
1168
+ inline __device__ float sum(bf16_4_t v)
1169
+ {
1170
+ return sum(v.x) + sum(v.y);
1171
+ }
1172
+
1173
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1174
+
1175
+ inline __device__ float sum(bf16_8_t v)
1176
+ {
1177
+ return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
1178
+ }
1179
+ #endif // ENABLE_BF16
1180
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1181
+
1182
+ inline __device__ float sum(uint16_t v)
1183
+ {
1184
+ return half_to_float(v);
1185
+ }
1186
+
1187
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1188
+
1189
+ inline __device__ float sum(uint32_t v)
1190
+ {
1191
+ float2 tmp = half2_to_float2(v);
1192
+ return tmp.x + tmp.y;
1193
+ }
1194
+
1195
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1196
+
1197
+ inline __device__ float sum(uint2 v)
1198
+ {
1199
+ uint32_t c = add(v.x, v.y);
1200
+ return sum(c);
1201
+ }
1202
+
1203
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1204
+
1205
+ inline __device__ float sum(uint4 v)
1206
+ {
1207
+ #if 1
1208
+ uint32_t c = add(v.x, v.y);
1209
+ c = add(c, v.z);
1210
+ c = add(c, v.w);
1211
+ #else
1212
+ uint32_t c = add(v.x, v.y);
1213
+ uint32_t d = add(v.z, v.w);
1214
+ c = add(c, d);
1215
+ #endif
1216
+ return sum(c);
1217
+ }
1218
+
1219
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1220
+
1221
+ inline __device__ float sum(Float4_ v)
1222
+ {
1223
+ return v.x.x + v.x.y + v.y.x + v.y.y;
1224
+ }
1225
+
1226
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1227
+
1228
+ inline __device__ float sum(Float8_ v)
1229
+ {
1230
+ return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
1231
+ }
1232
+
1233
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1234
+
1235
+ template<typename T>
1236
+ inline __device__ float dot(T a, T b)
1237
+ {
1238
+ return sum(mul<T, T, T>(a, b));
1239
+ }
1240
+
1241
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1242
+
1243
+ template<typename A, typename T>
1244
+ inline __device__ float dot(T a, T b)
1245
+ {
1246
+ return sum(mul<A, T, T>(a, b));
1247
+ }
1248
+
1249
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1250
+
1251
+ inline __device__ void zero(uint16_t& dst)
1252
+ {
1253
+ dst = uint16_t(0);
1254
+ }
1255
+
1256
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1257
+
1258
+ template<typename T>
1259
+ inline __device__ void zero(T& dst)
1260
+ {
1261
+ constexpr int WORDS = sizeof(T) / 4;
1262
+ union {
1263
+ T raw;
1264
+ uint32_t words[WORDS];
1265
+ } tmp;
1266
+ #pragma unroll
1267
+ for (int ii = 0; ii < WORDS; ++ii) {
1268
+ tmp.words[ii] = 0u;
1269
+ }
1270
+ dst = tmp.raw;
1271
+ }
1272
+
1273
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1274
+
1275
+ inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const int t_step, const float base)
1276
+ {
1277
+ const float pos_idx_inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
1278
+ return {cos(pos_idx_inv_freq), sin(pos_idx_inv_freq)};
1279
+ }
1280
+
1281
+ inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
1282
+ {
1283
+ float2 rot_v;
1284
+ rot_v.x = coef.x * v.x - coef.y * v.y;
1285
+ rot_v.y = coef.x * v.y + coef.y * v.x;
1286
+ return rot_v;
1287
+ }
1288
+
1289
+ inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef)
1290
+ {
1291
+ float2 fv = half2_to_float2(v);
1292
+ float2 rot_fv = rotary_embedding_transform(fv, coef);
1293
+ return float2_to_half2(rot_fv);
1294
+ }
1295
+
1296
+ #ifdef ENABLE_BF16
1297
+ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef)
1298
+ {
1299
+ float2 fv = bf1622float2(v);
1300
+ float2 rot_fv = rotary_embedding_transform(fv, coef);
1301
+ return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
1302
+ }
1303
+ #endif
1304
+
1305
+ inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
1306
+ {
1307
+ return;
1308
+ }
1309
+
1310
+ inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
1311
+ {
1312
+ return;
1313
+ }
1314
+
1315
+ inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1316
+ {
1317
+ if (2 * tid >= rot_embed_dim) {
1318
+ return;
1319
+ }
1320
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1321
+ q = rotary_embedding_transform(q, coef);
1322
+ }
1323
+
1324
+ inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1325
+ {
1326
+ if (2 * tid >= rot_embed_dim) {
1327
+ return;
1328
+ }
1329
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1330
+ q = rotary_embedding_transform(q, coef);
1331
+ k = rotary_embedding_transform(k, coef);
1332
+ }
1333
+
1334
+ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1335
+ {
1336
+ if (4 * tid >= rot_embed_dim) {
1337
+ return;
1338
+ }
1339
+
1340
+ Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
1341
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1342
+ q_.x = rotary_embedding_transform(q_.x, coef0);
1343
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1344
+ q_.y = rotary_embedding_transform(q_.y, coef1);
1345
+ }
1346
+
1347
+ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1348
+ {
1349
+ if (4 * tid >= rot_embed_dim) {
1350
+ return;
1351
+ }
1352
+
1353
+ Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
1354
+ Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
1355
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1356
+ q_.x = rotary_embedding_transform(q_.x, coef0);
1357
+ k_.x = rotary_embedding_transform(k_.x, coef0);
1358
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1359
+ q_.y = rotary_embedding_transform(q_.y, coef1);
1360
+ k_.y = rotary_embedding_transform(k_.y, coef1);
1361
+ }
1362
+
1363
+ inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1364
+ {
1365
+ if (2 * tid >= rot_embed_dim) {
1366
+ return;
1367
+ }
1368
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1369
+ q = rotary_embedding_transform(q, coef);
1370
+ }
1371
+
1372
+ inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1373
+ {
1374
+ if (2 * tid >= rot_embed_dim) {
1375
+ return;
1376
+ }
1377
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1378
+ q = rotary_embedding_transform(q, coef);
1379
+ k = rotary_embedding_transform(k, coef);
1380
+ }
1381
+
1382
+ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1383
+ {
1384
+ if (4 * tid >= rot_embed_dim) {
1385
+ return;
1386
+ }
1387
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1388
+ q.x = rotary_embedding_transform(q.x, coef0);
1389
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1390
+ q.y = rotary_embedding_transform(q.y, coef1);
1391
+ }
1392
+
1393
+ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1394
+ {
1395
+ if (4 * tid >= rot_embed_dim) {
1396
+ return;
1397
+ }
1398
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1399
+ q.x = rotary_embedding_transform(q.x, coef0);
1400
+ k.x = rotary_embedding_transform(k.x, coef0);
1401
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1402
+ q.y = rotary_embedding_transform(q.y, coef1);
1403
+ k.y = rotary_embedding_transform(k.y, coef1);
1404
+ }
1405
+
1406
+ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1407
+ {
1408
+ if (8 * tid >= rot_embed_dim) {
1409
+ return;
1410
+ }
1411
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
1412
+ q.x = rotary_embedding_transform(q.x, coef0);
1413
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
1414
+ q.y = rotary_embedding_transform(q.y, coef1);
1415
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
1416
+ q.z = rotary_embedding_transform(q.z, coef2);
1417
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
1418
+ q.w = rotary_embedding_transform(q.w, coef3);
1419
+ }
1420
+
1421
+ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1422
+ {
1423
+ if (8 * tid >= rot_embed_dim) {
1424
+ return;
1425
+ }
1426
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
1427
+ q.x = rotary_embedding_transform(q.x, coef0);
1428
+ k.x = rotary_embedding_transform(k.x, coef0);
1429
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
1430
+ q.y = rotary_embedding_transform(q.y, coef1);
1431
+ k.y = rotary_embedding_transform(k.y, coef1);
1432
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
1433
+ q.z = rotary_embedding_transform(q.z, coef2);
1434
+ k.z = rotary_embedding_transform(k.z, coef2);
1435
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
1436
+ q.w = rotary_embedding_transform(q.w, coef3);
1437
+ k.w = rotary_embedding_transform(k.w, coef3);
1438
+ }
1439
+
1440
+ #ifdef ENABLE_BF16
1441
+ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1442
+ {
1443
+ if (2 * tid >= rot_embed_dim) {
1444
+ return;
1445
+ }
1446
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1447
+ q = rotary_embedding_transform(q, coef);
1448
+ }
1449
+
1450
+ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1451
+ {
1452
+ if (2 * tid >= rot_embed_dim) {
1453
+ return;
1454
+ }
1455
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1456
+ q = rotary_embedding_transform(q, coef);
1457
+ k = rotary_embedding_transform(k, coef);
1458
+ }
1459
+
1460
+ inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1461
+ {
1462
+ if (4 * tid >= rot_embed_dim) {
1463
+ return;
1464
+ }
1465
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1466
+ q.x = rotary_embedding_transform(q.x, coef0);
1467
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1468
+ q.y = rotary_embedding_transform(q.y, coef1);
1469
+ }
1470
+
1471
+ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1472
+ {
1473
+ if (4 * tid >= rot_embed_dim) {
1474
+ return;
1475
+ }
1476
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1477
+ q.x = rotary_embedding_transform(q.x, coef0);
1478
+ k.x = rotary_embedding_transform(k.x, coef0);
1479
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1480
+ q.y = rotary_embedding_transform(q.y, coef1);
1481
+ k.y = rotary_embedding_transform(k.y, coef1);
1482
+ }
1483
+
1484
+ inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1485
+ {
1486
+ if (8 * tid >= rot_embed_dim) {
1487
+ return;
1488
+ }
1489
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
1490
+ q.x = rotary_embedding_transform(q.x, coef0);
1491
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
1492
+ q.y = rotary_embedding_transform(q.y, coef1);
1493
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
1494
+ q.z = rotary_embedding_transform(q.z, coef2);
1495
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
1496
+ q.w = rotary_embedding_transform(q.w, coef3);
1497
+ }
1498
+
1499
+ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1500
+ {
1501
+ if (8 * tid >= rot_embed_dim) {
1502
+ return;
1503
+ }
1504
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
1505
+ q.x = rotary_embedding_transform(q.x, coef0);
1506
+ k.x = rotary_embedding_transform(k.x, coef0);
1507
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
1508
+ q.y = rotary_embedding_transform(q.y, coef1);
1509
+ k.y = rotary_embedding_transform(k.y, coef1);
1510
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
1511
+ q.z = rotary_embedding_transform(q.z, coef2);
1512
+ k.z = rotary_embedding_transform(k.z, coef2);
1513
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
1514
+ q.w = rotary_embedding_transform(q.w, coef3);
1515
+ k.w = rotary_embedding_transform(k.w, coef3);
1516
+ }
1517
+ #endif // ENABLE_BF16
1518
+
1519
+ template <typename T>
1520
+ inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin)
1521
+ {
1522
+ // zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
1523
+ // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
1524
+ return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])};
1525
+ }
1526
+
1527
+ // fp16 is special because we use uint16_t for reading the data, for backward compatibility.
1528
+ template <>
1529
+ inline __device__ float2 rotary_embedding_coefficient<uint16_t>(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
1530
+ {
1531
+ // zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
1532
+ // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
1533
+ return {float(reinterpret_cast<const __half*>(rotary_cos)[zid / 2]),
1534
+ float(reinterpret_cast<const __half*>(rotary_sin)[zid / 2])};
1535
+ }
1536
+
1537
+ inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
1538
+ {
1539
+ return;
1540
+ }
1541
+
1542
+ inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
1543
+ {
1544
+ return;
1545
+ }
1546
+
1547
+ inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
1548
+ {
1549
+ if (2 * tid >= rot_embed_dim) {
1550
+ return;
1551
+ }
1552
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
1553
+ q = rotary_embedding_transform(q, coef);
1554
+ }
1555
+
1556
+ inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
1557
+ {
1558
+ if (2 * tid >= rot_embed_dim) {
1559
+ return;
1560
+ }
1561
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
1562
+ q = rotary_embedding_transform(q, coef);
1563
+ k = rotary_embedding_transform(k, coef);
1564
+ }
1565
+
1566
+ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
1567
+ {
1568
+ if (4 * tid >= rot_embed_dim) {
1569
+ return;
1570
+ }
1571
+
1572
+ Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
1573
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
1574
+ q_.x = rotary_embedding_transform(q_.x, coef0);
1575
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
1576
+ q_.y = rotary_embedding_transform(q_.y, coef1);
1577
+ }
1578
+
1579
+ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
1580
+ {
1581
+ if (4 * tid >= rot_embed_dim) {
1582
+ return;
1583
+ }
1584
+
1585
+ Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
1586
+ Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
1587
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
1588
+ q_.x = rotary_embedding_transform(q_.x, coef0);
1589
+ k_.x = rotary_embedding_transform(k_.x, coef0);
1590
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
1591
+ q_.y = rotary_embedding_transform(q_.y, coef1);
1592
+ k_.y = rotary_embedding_transform(k_.y, coef1);
1593
+ }
1594
+
1595
+ inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
1596
+ {
1597
+ if (2 * tid >= rot_embed_dim) {
1598
+ return;
1599
+ }
1600
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
1601
+ q = rotary_embedding_transform(q, coef);
1602
+ }
1603
+
1604
+ inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
1605
+ {
1606
+ if (2 * tid >= rot_embed_dim) {
1607
+ return;
1608
+ }
1609
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
1610
+ q = rotary_embedding_transform(q, coef);
1611
+ k = rotary_embedding_transform(k, coef);
1612
+ }
1613
+
1614
+ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
1615
+ {
1616
+ if (4 * tid >= rot_embed_dim) {
1617
+ return;
1618
+ }
1619
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
1620
+ q.x = rotary_embedding_transform(q.x, coef0);
1621
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
1622
+ q.y = rotary_embedding_transform(q.y, coef1);
1623
+ }
1624
+
1625
+ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
1626
+ {
1627
+ if (4 * tid >= rot_embed_dim) {
1628
+ return;
1629
+ }
1630
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
1631
+ q.x = rotary_embedding_transform(q.x, coef0);
1632
+ k.x = rotary_embedding_transform(k.x, coef0);
1633
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
1634
+ q.y = rotary_embedding_transform(q.y, coef1);
1635
+ k.y = rotary_embedding_transform(k.y, coef1);
1636
+ }
1637
+
1638
+ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
1639
+ {
1640
+ if (8 * tid >= rot_embed_dim) {
1641
+ return;
1642
+ }
1643
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
1644
+ q.x = rotary_embedding_transform(q.x, coef0);
1645
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
1646
+ q.y = rotary_embedding_transform(q.y, coef1);
1647
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
1648
+ q.z = rotary_embedding_transform(q.z, coef2);
1649
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
1650
+ q.w = rotary_embedding_transform(q.w, coef3);
1651
+ }
1652
+
1653
+ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
1654
+ {
1655
+ if (8 * tid >= rot_embed_dim) {
1656
+ return;
1657
+ }
1658
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
1659
+ q.x = rotary_embedding_transform(q.x, coef0);
1660
+ k.x = rotary_embedding_transform(k.x, coef0);
1661
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
1662
+ q.y = rotary_embedding_transform(q.y, coef1);
1663
+ k.y = rotary_embedding_transform(k.y, coef1);
1664
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
1665
+ q.z = rotary_embedding_transform(q.z, coef2);
1666
+ k.z = rotary_embedding_transform(k.z, coef2);
1667
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
1668
+ q.w = rotary_embedding_transform(q.w, coef3);
1669
+ k.w = rotary_embedding_transform(k.w, coef3);
1670
+ }
1671
+
1672
+ #ifdef ENABLE_BF16
1673
+ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
1674
+ {
1675
+ if (2 * tid >= rot_embed_dim) {
1676
+ return;
1677
+ }
1678
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
1679
+ q = rotary_embedding_transform(q, coef);
1680
+ }
1681
+
1682
+ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
1683
+ {
1684
+ if (2 * tid >= rot_embed_dim) {
1685
+ return;
1686
+ }
1687
+ const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
1688
+ q = rotary_embedding_transform(q, coef);
1689
+ k = rotary_embedding_transform(k, coef);
1690
+ }
1691
+
1692
+ inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
1693
+ {
1694
+ if (4 * tid >= rot_embed_dim) {
1695
+ return;
1696
+ }
1697
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
1698
+ q.x = rotary_embedding_transform(q.x, coef0);
1699
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
1700
+ q.y = rotary_embedding_transform(q.y, coef1);
1701
+ }
1702
+
1703
+ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
1704
+ {
1705
+ if (4 * tid >= rot_embed_dim) {
1706
+ return;
1707
+ }
1708
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
1709
+ q.x = rotary_embedding_transform(q.x, coef0);
1710
+ k.x = rotary_embedding_transform(k.x, coef0);
1711
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
1712
+ q.y = rotary_embedding_transform(q.y, coef1);
1713
+ k.y = rotary_embedding_transform(k.y, coef1);
1714
+ }
1715
+
1716
+ inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
1717
+ {
1718
+ if (8 * tid >= rot_embed_dim) {
1719
+ return;
1720
+ }
1721
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
1722
+ q.x = rotary_embedding_transform(q.x, coef0);
1723
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
1724
+ q.y = rotary_embedding_transform(q.y, coef1);
1725
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
1726
+ q.z = rotary_embedding_transform(q.z, coef2);
1727
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
1728
+ q.w = rotary_embedding_transform(q.w, coef3);
1729
+ }
1730
+
1731
+ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
1732
+ {
1733
+ if (8 * tid >= rot_embed_dim) {
1734
+ return;
1735
+ }
1736
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
1737
+ q.x = rotary_embedding_transform(q.x, coef0);
1738
+ k.x = rotary_embedding_transform(k.x, coef0);
1739
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
1740
+ q.y = rotary_embedding_transform(q.y, coef1);
1741
+ k.y = rotary_embedding_transform(k.y, coef1);
1742
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
1743
+ q.z = rotary_embedding_transform(q.z, coef2);
1744
+ k.z = rotary_embedding_transform(k.z, coef2);
1745
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
1746
+ q.w = rotary_embedding_transform(q.w, coef3);
1747
+ k.w = rotary_embedding_transform(k.w, coef3);
1748
+ }
1749
+ #endif // ENABLE_BF16
1750
+
1751
+ template<typename Vec_T, typename T>
1752
+ __device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
1753
+
1754
+ template<>
1755
+ __device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch)
1756
+ {
1757
+ return;
1758
+ }
1759
+
1760
+ template<>
1761
+ __device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1762
+ {
1763
+ union {
1764
+ uint32_t u32;
1765
+ uint16_t u16[2];
1766
+ } tmp;
1767
+ tmp.u16[0] = smem[transpose_idx];
1768
+ tmp.u16[1] = smem[smem_pitch + transpose_idx];
1769
+
1770
+ vec = tmp.u32;
1771
+ }
1772
+
1773
+ template<>
1774
+ __device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1775
+ {
1776
+ union {
1777
+ uint32_t u32;
1778
+ uint16_t u16[2];
1779
+ } tmp_1, tmp_2;
1780
+ tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
1781
+ tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
1782
+
1783
+ union {
1784
+ uint2 u32x2;
1785
+ uint16_t u16[4];
1786
+ } tmp_3;
1787
+ tmp_3.u16[0] = tmp_1.u16[0];
1788
+ tmp_3.u16[1] = tmp_2.u16[0];
1789
+ tmp_3.u16[2] = tmp_1.u16[1];
1790
+ tmp_3.u16[3] = tmp_2.u16[1];
1791
+
1792
+ vec = tmp_3.u32x2;
1793
+ }
1794
+
1795
+ template<>
1796
+ __device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1797
+ {
1798
+ union {
1799
+ uint64_t u64;
1800
+ uint16_t u16[4];
1801
+ } tmp_1, tmp_2;
1802
+ tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
1803
+ tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
1804
+
1805
+ union {
1806
+ uint4 u32x4;
1807
+ uint16_t u16[8];
1808
+ } tmp_3;
1809
+ tmp_3.u16[0] = tmp_1.u16[0];
1810
+ tmp_3.u16[1] = tmp_2.u16[0];
1811
+ tmp_3.u16[2] = tmp_1.u16[1];
1812
+ tmp_3.u16[3] = tmp_2.u16[1];
1813
+ tmp_3.u16[4] = tmp_1.u16[2];
1814
+ tmp_3.u16[5] = tmp_2.u16[2];
1815
+ tmp_3.u16[6] = tmp_1.u16[3];
1816
+ tmp_3.u16[7] = tmp_2.u16[3];
1817
+
1818
+ vec = tmp_3.u32x4;
1819
+ }
1820
+
1821
+ #ifdef ENABLE_BF16
1822
+ template<>
1823
+ __device__ __inline__ void
1824
+ vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1825
+ {
1826
+ union {
1827
+ uint32_t u32;
1828
+ __nv_bfloat16 bf16[2];
1829
+ } tmp_1, tmp_2;
1830
+ tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
1831
+ tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
1832
+
1833
+ vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
1834
+ vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
1835
+ }
1836
+
1837
+ template<>
1838
+ __device__ __inline__ void
1839
+ vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1840
+ {
1841
+ union {
1842
+ uint64_t u64;
1843
+ __nv_bfloat16 bf16[4];
1844
+ } tmp_1, tmp_2;
1845
+ tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
1846
+ tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
1847
+
1848
+ vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
1849
+ vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
1850
+ vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]};
1851
+ vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]};
1852
+ }
1853
+ #endif // ENABLE_BF16
1854
+
1855
+ template<>
1856
+ __device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch)
1857
+ {
1858
+ vec.x = smem[transpose_idx];
1859
+ vec.z = smem[transpose_idx + 1];
1860
+ vec.y = smem[smem_pitch + transpose_idx];
1861
+ vec.w = smem[smem_pitch + transpose_idx + 1];
1862
+ }
1863
+
1864
+ template<>
1865
+ __device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
1866
+ {
1867
+ union {
1868
+ uint32_t u32;
1869
+ half u16[2];
1870
+ } tmp;
1871
+ tmp.u16[0] = smem[transpose_idx];
1872
+ tmp.u16[1] = smem[smem_pitch + transpose_idx];
1873
+
1874
+ vec = tmp.u32;
1875
+ }
1876
+
1877
+ #ifdef ENABLE_BF16
1878
+ template<>
1879
+ __device__ __inline__ void
1880
+ vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1881
+ {
1882
+ vec.x = smem[transpose_idx];
1883
+ vec.y = smem[smem_pitch + transpose_idx];
1884
+ }
1885
+ #endif
1886
+
1887
+ template<>
1888
+ __device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch)
1889
+ {
1890
+ vec.x = smem[transpose_idx];
1891
+ vec.y = smem[smem_pitch + transpose_idx];
1892
+ }
1893
+
1894
+ template<typename Vec_T, typename T>
1895
+ __device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
1896
+
1897
+ template<>
1898
+ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch)
1899
+ {
1900
+ return;
1901
+ }
1902
+
1903
+ template<>
1904
+ __device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1905
+ {
1906
+ union {
1907
+ uint64_t u64;
1908
+ uint16_t u16[4];
1909
+ } tmp_1, tmp_2;
1910
+
1911
+ union {
1912
+ uint4 u32x4;
1913
+ uint16_t u16[8];
1914
+ } tmp_3;
1915
+ tmp_3.u32x4 = vec;
1916
+ tmp_1.u16[0] = tmp_3.u16[0];
1917
+ tmp_2.u16[0] = tmp_3.u16[1];
1918
+ tmp_1.u16[1] = tmp_3.u16[2];
1919
+ tmp_2.u16[1] = tmp_3.u16[3];
1920
+ tmp_1.u16[2] = tmp_3.u16[4];
1921
+ tmp_2.u16[2] = tmp_3.u16[5];
1922
+ tmp_1.u16[3] = tmp_3.u16[6];
1923
+ tmp_2.u16[3] = tmp_3.u16[7];
1924
+
1925
+ *reinterpret_cast<uint64_t*>(&smem[transpose_idx]) = tmp_1.u64;
1926
+ *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u64;
1927
+ }
1928
+
1929
+ template<>
1930
+ __device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1931
+ {
1932
+ union {
1933
+ uint32_t u32;
1934
+ uint16_t u16[2];
1935
+ } tmp_1, tmp_2;
1936
+
1937
+ union {
1938
+ uint2 u32x2;
1939
+ uint16_t u16[4];
1940
+ } tmp_3;
1941
+ tmp_3.u32x2 = vec;
1942
+ tmp_1.u16[0] = tmp_3.u16[0];
1943
+ tmp_2.u16[0] = tmp_3.u16[1];
1944
+ tmp_1.u16[1] = tmp_3.u16[2];
1945
+ tmp_2.u16[1] = tmp_3.u16[3];
1946
+
1947
+ *reinterpret_cast<uint32_t*>(&smem[transpose_idx]) = tmp_1.u32;
1948
+ *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u32;
1949
+ }
1950
+
1951
+ template<>
1952
+ __device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1953
+ {
1954
+ union {
1955
+ uint32_t u32;
1956
+ uint16_t u16[2];
1957
+ } tmp;
1958
+ tmp.u32 = vec;
1959
+
1960
+ smem[transpose_idx] = tmp.u16[0];
1961
+ smem[smem_pitch + transpose_idx] = tmp.u16[1];
1962
+ }
1963
+
1964
+ template<>
1965
+ __device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch)
1966
+ {
1967
+ smem[transpose_idx] = vec.x;
1968
+ smem[transpose_idx + 1] = vec.z;
1969
+ smem[smem_pitch + transpose_idx] = vec.y;
1970
+ smem[smem_pitch + transpose_idx + 1] = vec.w;
1971
+ }
1972
+
1973
+ template<>
1974
+ __device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
1975
+ {
1976
+ union {
1977
+ uint32_t u32;
1978
+ half u16[2];
1979
+ } tmp;
1980
+
1981
+ tmp.u32 = vec;
1982
+ smem[transpose_idx] = tmp.u16[0];
1983
+ smem[smem_pitch + transpose_idx] = tmp.u16[1];
1984
+ }
1985
+
1986
+ #ifdef ENABLE_BF16
1987
+ template<>
1988
+ __device__ __inline__ void
1989
+ write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1990
+ {
1991
+ smem[transpose_idx] = vec.x;
1992
+ smem[smem_pitch + transpose_idx] = vec.y;
1993
+ }
1994
+
1995
+ template<>
1996
+ __device__ __inline__ void
1997
+ write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1998
+ {
1999
+ write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
2000
+ }
2001
+
2002
+ template<>
2003
+ __device__ __inline__ void
2004
+ write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
2005
+ {
2006
+ write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
2007
+ }
2008
+ #endif
2009
+
2010
+ template<>
2011
+ __device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch)
2012
+ {
2013
+ smem[transpose_idx] = vec.x;
2014
+ smem[smem_pitch + transpose_idx] = vec.y;
2015
+ }
2016
+
2017
+ } // namespace mmha